diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index ddf3c4943..829786009 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -14,6 +14,43 @@ concurrency: cancel-in-progress: true jobs: + wasm: + name: WASM build test + timeout-minutes: 5 + runs-on: ubuntu-latest + if: ${{ false }} # disable for now + steps: + - uses: actions/checkout@v3 + + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: 1.18 + cache: true + + - name: Install Node + uses: actions/setup-node@v2 + with: + node-version: 14 + + - uses: actions/cache@v3 + with: + path: ~/.npm + key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} + restore-keys: | + ${{ runner.os }}-node- + + - name: Reconfigure Git to use HTTPS auth for repo packages + run: > + git config --global url."https://github.com/".insteadOf + ssh://git@github.com/ + + - name: Install test dependencies + working-directory: ./test/wasm + run: npm ci + + - name: Test + run: ./test-dendritejs.sh # Run golangci-lint lint: @@ -64,19 +101,12 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} + cache: true - name: Set up gotestfmt uses: gotesttools/gotestfmt-action@v2 with: # Optional: pass GITHUB_TOKEN to avoid rate limiting. token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-test- - run: go test -json -v ./... 2>&1 | gotestfmt env: POSTGRES_HOST: localhost @@ -101,17 +131,17 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - - name: Install dependencies x86 - if: ${{ matrix.goarch == '386' }} - run: sudo apt update && sudo apt-get install -y gcc-multilib - uses: actions/cache@v3 with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}- + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + - name: Install dependencies x86 + if: ${{ matrix.goarch == '386' }} + run: sudo apt update && sudo apt-get install -y gcc-multilib - env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} @@ -119,6 +149,39 @@ jobs: CGO_CFLAGS: -fno-stack-protector run: go build -trimpath -v -o "bin/" ./cmd/... + # build for Windows 64-bit + build_windows: + name: Build for Windows + timeout-minutes: 10 + runs-on: ubuntu-latest + strategy: + matrix: + go: ["1.18", "1.19"] + goos: ["windows"] + goarch: ["amd64"] + steps: + - uses: actions/checkout@v3 + - name: Setup Go ${{ matrix.go }} + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} + restore-keys: | + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + - name: Install dependencies + run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc + - env: + GOOS: ${{ matrix.goos }} + GOARCH: ${{ matrix.goarch }} + CGO_ENABLED: 1 + CC: "/usr/bin/x86_64-w64-mingw32-gcc" + run: go build -trimpath -v -o "bin/" ./cmd/... + # Dummy step to gate other tests on without repeating the whole list initial-tests-done: name: Initial tests passed @@ -151,6 +214,8 @@ jobs: image: matrixdotorg/sytest-dendrite:latest volumes: - ${{ github.workspace }}:/src + - /root/.cache/go-build:/github/home/.cache/go-build + - /root/.cache/go-mod:/gopath/pkg/mod env: POSTGRES: ${{ matrix.postgres && 1}} API: ${{ matrix.api && 1 }} @@ -158,6 +223,14 @@ jobs: CGO_ENABLED: ${{ matrix.cgo && 1 }} steps: - uses: actions/checkout@v3 + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + /gopath/pkg/mod + key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go-sytest- - name: Run Sytest run: /bootstrap.sh dendrite working-directory: /src @@ -192,10 +265,12 @@ jobs: include: - label: PostgreSQL postgres: Postgres + cgo: 0 - label: PostgreSQL, full HTTP APIs postgres: Postgres api: full-http + cgo: 0 steps: # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path @@ -203,14 +278,12 @@ jobs: run: | echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH echo "~/go/bin" >> $GITHUB_PATH - - name: "Install Complement Dependencies" # We don't need to install Go because it is included on the Ubuntu 20.04 image: # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 run: | sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest - - name: Run actions/checkout@v3 for dendrite uses: actions/checkout@v3 with: @@ -239,9 +312,8 @@ jobs: (wget -O - "https://github.com/globekeeper/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break done - # Build initial Dendrite image - - run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . + - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . working-directory: dendrite env: DOCKER_BUILDKIT: 1 @@ -253,9 +325,8 @@ jobs: shell: bash name: Run Complement Tests env: - COMPLEMENT_BASE_IMAGE: complement-dendrite:latest + COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} API: ${{ matrix.api && 1 }} - CGO_ENABLED: ${{ matrix.cgo && 1 }} working-directory: complement integration-tests-done: diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index c07917248..ff4d47187 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -45,6 +45,11 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} + - name: Set up gotestfmt + uses: gotesttools/gotestfmt-action@v2 + with: + # Optional: pass GITHUB_TOKEN to avoid rate limiting. + token: ${{ secrets.GITHUB_TOKEN }} - uses: actions/cache@v3 with: path: | @@ -53,12 +58,14 @@ jobs: key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go${{ matrix.go }}-test-race- - - run: go test -race ./... + - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt env: POSTGRES_HOST: localhost POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres POSTGRES_DB: dendrite + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 # Dummy step to gate other tests on without repeating the whole list initial-tests-done: diff --git a/CHANGES.md b/CHANGES.md index cdeb1dea3..f5a82cfe2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,28 @@ # Changelog +## Dendrite 0.10.8 (2022-11-29) + +### Features + +* The built-in NATS Server has been updated to version 2.9.8 +* A number of under-the-hood changes have been merged for future virtual hosting support in Dendrite (running multiple domain names on the same Dendrite deployment) + +### Fixes + +* Event auth handling of invites has been refactored, which should fix some edge cases being handled incorrectly +* Fix a bug when returning an empty protocol list, which could cause Element to display "The homeserver may be too old to support third party networks" when opening the public room directory +* The sync API will no longer filter out the user's own membership when using lazy-loading +* Dendrite will now correctly detect JetStream consumers being deleted, stopping the consumer goroutine as needed +* A panic in the federation API where the server list could go out of bounds has been fixed +* Blacklisted servers will now be excluded when querying joined servers, which improves CPU usage and performs less unnecessary outbound requests +* A database writer will now be used to assign state key NIDs when requesting NIDs that may not exist yet +* Dendrite will now correctly move local aliases for an upgraded room when the room is upgraded remotely +* Dendrite will now correctly move account data for an upgraded room when the room is upgraded remotely +* Missing state key NIDs will now be allocated on request rather than returning an error +* Guest access is now correctly denied on a number of endpoints +* Presence information will now be correctly sent for new private chats +* A number of unspecced fields have been removed from outbound `/send` transactions + ## Dendrite 0.10.7 (2022-11-04) ### Features diff --git a/Dockerfile b/Dockerfile index 3180e9626..a0d3e1bbf 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,18 +4,16 @@ RUN apk --update --no-cache add bash build-base WORKDIR /build -COPY . /build - -RUN mkdir -p bin -RUN go build -trimpath -o bin/ ./cmd/dendrite-monolith-server -RUN go build -trimpath -o bin/ ./cmd/create-account -RUN go build -trimpath -o bin/ ./cmd/generate-keys - -FROM alpine:latest +# +# The dendrite base image +# +FROM alpine:latest AS dendrite-base LABEL org.opencontainers.image.title="Dendrite (Monolith)" LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" LABEL org.opencontainers.image.licenses="Apache-2.0" +LABEL org.opencontainers.image.documentation="https://matrix-org.github.io/dendrite/" +LABEL org.opencontainers.image.vendor="The Matrix.org Foundation C.I.C." COPY --from=base /build/bin/* /usr/bin/ diff --git a/appservice/appservice.go b/appservice/appservice.go index 0c778b6ca..b3c28dbde 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -32,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) // AddInternalRoutes registers HTTP handlers for internal API calls @@ -74,7 +75,7 @@ func NewInternalAPI( // events to be sent out. for _, appservice := range base.Cfg.Derived.ApplicationServices { // Create bot account for this AS if it doesn't already exist - if err := generateAppServiceAccount(userAPI, appservice); err != nil { + if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil { logrus.WithFields(logrus.Fields{ "appservice": appservice.ID, }).WithError(err).Panicf("failed to generate bot account for appservice") @@ -101,11 +102,13 @@ func NewInternalAPI( func generateAppServiceAccount( userAPI userapi.AppserviceUserAPI, as config.ApplicationService, + serverName gomatrixserverlib.ServerName, ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeAppService, Localpart: as.SenderLocalpart, + ServerName: serverName, AppServiceID: as.ID, OnConflict: userapi.ConflictUpdate, }, &accRes) @@ -115,6 +118,7 @@ func generateAppServiceAccount( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ Localpart: as.SenderLocalpart, + ServerName: serverName, AccessToken: as.ASToken, DeviceID: &as.SenderLocalpart, DeviceDisplayName: &as.SenderLocalpart, diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index adb4e40a6..9100ebf0f 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -40,6 +40,7 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" @@ -58,6 +59,7 @@ import ( pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" + pineconeEvents "github.com/matrix-org/pinecone/router/events" pineconeSessions "github.com/matrix-org/pinecone/sessions" "github.com/matrix-org/pinecone/types" @@ -295,7 +297,12 @@ func (m *DendriteMonolith) Start() { m.logger.SetOutput(BindLogger{}) logrus.SetOutput(BindLogger{}) + pineconeEventChannel := make(chan pineconeEvents.Event) m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) + m.PineconeRouter.EnableHopLimiting() + m.PineconeRouter.EnableWakeupBroadcasts() + m.PineconeRouter.Subscribe(pineconeEventChannel) + m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil) @@ -423,6 +430,34 @@ func (m *DendriteMonolith) Start() { m.logger.Fatal(err) } }() + + go func(ch <-chan pineconeEvents.Event) { + eLog := logrus.WithField("pinecone", "events") + + for event := range ch { + switch e := event.(type) { + case pineconeEvents.PeerAdded: + case pineconeEvents.PeerRemoved: + case pineconeEvents.TreeParentUpdate: + case pineconeEvents.SnakeDescUpdate: + case pineconeEvents.TreeRootAnnUpdate: + case pineconeEvents.SnakeEntryAdded: + case pineconeEvents.SnakeEntryRemoved: + case pineconeEvents.BroadcastReceived: + eLog.Info("Broadcast received from: ", e.PeerID) + + req := &api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + } + res := &api.PerformWakeupServersResponse{} + if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + } + case pineconeEvents.BandwidthReport: + default: + } + } + }(pineconeEventChannel) } func (m *DendriteMonolith) Stop() { diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 14b28498b..79422e645 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -10,12 +10,13 @@ RUN mkdir /dendrite # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. +ARG CGO RUN --mount=target=. \ --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - go build -o /dendrite ./cmd/generate-config && \ - go build -o /dendrite ./cmd/generate-keys && \ - go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 785090b0b..3faf43cc7 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -28,12 +28,13 @@ RUN mkdir /dendrite # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. +ARG CGO RUN --mount=target=. \ --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - go build -o /dendrite ./cmd/generate-config && \ - go build -o /dendrite ./cmd/generate-keys && \ - go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index 4017c26d5..4bad1dcc2 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/ratelimit" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -67,7 +68,9 @@ func TestLoginFromJSONReader(t *testing.T) { var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, RtFailedLogin: ratelimit.RtFailedLoginConfig{ Enabled: false, @@ -148,7 +151,9 @@ func TestBadLoginFromJSONReader(t *testing.T) { var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, cfg, nil) diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index a13520899..6cad7a26b 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -110,13 +110,19 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: jsonerror.BadJSON("A password must be supplied."), } } - localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) + localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, JSON: jsonerror.InvalidUsername(err.Error()), } } + if !t.Config.Matrix.IsLocalServerName(domain) { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.InvalidUsername("The server name is not known."), + } + } // Squash username to all lowercase letters res := &api.QueryAccountByPasswordResponse{} localpart = strings.ToLower(localpart) @@ -129,23 +135,28 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } } - err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: localpart, PlaintextPassword: r.Password}, res) + err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ + Localpart: localpart, + ServerName: domain, + PlaintextPassword: r.Password, + }, res) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("unable to fetch account by password"), + JSON: jsonerror.Unknown("Unable to fetch account by password."), } } if !res.Exists { err = t.UserApi.QueryAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: localpart, + ServerName: domain, PlaintextPassword: r.Password, }, res) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("unable to fetch account by password"), + JSON: jsonerror.Unknown("Unable to fetch account by password."), } } // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 8267d2222..45e2b5af9 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -49,7 +49,9 @@ func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *a func setup() *UserInteractive { cfg := &config.ClientAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } return NewUserInteractive(&fakeAccountDatabase{}, cfg) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 9088f7716..be8073c33 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap if err != nil { return util.ErrorResponse(err) } + serverName := cfg.Matrix.ServerName localpart, ok := vars["localpart"] if !ok { return util.JSONResponse{ @@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.MissingArgument("Expecting user localpart."), } } + if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil { + localpart, serverName = l, s + } request := struct { Password string `json:"password"` }{} @@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } updateReq := &userapi.PerformPasswordUpdateRequest{ Localpart: localpart, + ServerName: serverName, Password: request.Password, LogoutDevices: true, } diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index eefe8e24b..a0d80903d 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -477,7 +477,7 @@ func createRoom( SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } - if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { + if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index ce14745aa..b3c5aae45 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -77,7 +77,7 @@ func DirectoryRoom( // If we don't know it locally, do a federation query. // But don't send the query to ourselves. if !cfg.Matrix.IsLocalServerName(domain) { - fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) + fedRes, fedErr := federation.LookupRoomAlias(req.Context(), cfg.Matrix.ServerName, domain, roomAlias) if fedErr != nil { // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index b1043e994..606744767 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -74,7 +74,7 @@ func GetPostPublicRooms( serverName := gomatrixserverlib.ServerName(request.Server) if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( - req.Context(), serverName, + req.Context(), cfg.Matrix.ServerName, serverName, int(request.Limit), request.Since, request.Filter.SearchTerms, false, "", diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index e33bb3f96..d380a94a6 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -113,6 +113,7 @@ func completeAuth( DeviceID: login.DeviceID, AccessToken: token, Localpart: localpart, + ServerName: serverName, IPAddr: ipAddr, UserAgent: userAgent, }, &performRes) diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 94ba17a02..482c1f5f7 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -110,6 +110,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic ctx, rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, + device.UserDomain(), serverName, serverName, nil, @@ -322,7 +323,12 @@ func buildMembershipEvent( return nil, err } - return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return nil, err + } + + return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil) } // loadProfile lookups the profile of a given user from the database and returns diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go index 8a424a141..f593e27db 100644 --- a/clientapi/routing/notification.go +++ b/clientapi/routing/notification.go @@ -40,16 +40,17 @@ func GetNotifications( } var queryRes userapi.QueryNotificationsResponse - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") return jsonerror.InternalServerError() } err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ - Localpart: localpart, - From: req.URL.Query().Get("from"), - Limit: int(limit), - Only: req.URL.Query().Get("only"), + Localpart: localpart, + ServerName: domain, + From: req.URL.Query().Get("from"), + Limit: int(limit), + Only: req.URL.Query().Get("only"), }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed") diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 44ca153f2..800e87512 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -61,6 +61,7 @@ func Password( sessionID = util.RandomString(sessionIDLength) } var localpart string + var domain gomatrixserverlib.ServerName switch r.Auth.Type { case authtypes.LoginTypePassword: // Check if the existing password is correct. @@ -73,7 +74,7 @@ func Password( } // Get the local part. var err error - localpart, _, err = gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err = gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() @@ -119,6 +120,7 @@ func Password( } } localpart = res.Localpart + domain = res.ServerName sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeEmail) default: flows := []authtypes.Flow{ @@ -151,8 +153,9 @@ func Password( // Ask the user API to perform the password change. passwordReq := &api.PerformPasswordUpdateRequest{ - Localpart: localpart, - Password: r.NewPassword, + Localpart: localpart, + ServerName: domain, + Password: r.NewPassword, } passwordRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { @@ -191,8 +194,9 @@ func Password( } pushersReq := &api.PerformPusherDeletionRequest{ - Localpart: localpart, - SessionID: sessionId, + Localpart: localpart, + ServerName: domain, + SessionID: sessionId, } if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 20ed5ede0..f94e45ab0 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -278,7 +278,7 @@ func updateProfile( } events, err := buildMembershipEvents( - ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, + ctx, device, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) switch e := err.(type) { case nil: @@ -292,7 +292,7 @@ func updateProfile( return jsonerror.InternalServerError(), e } - if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError(), err } @@ -315,7 +315,7 @@ func getProfile( } if !cfg.Matrix.IsLocalServerName(domain) { - profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") + profile, fedErr := federation.LookupProfile(ctx, cfg.Matrix.ServerName, domain, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { @@ -343,6 +343,7 @@ func getProfile( func buildMembershipEvents( ctx context.Context, + device *userapi.Device, roomIDs []string, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, evTime time.Time, rsAPI api.ClientRoomserverAPI, @@ -374,7 +375,12 @@ func buildMembershipEvents( return nil, err } - event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return nil, err + } + + event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go index 48d319ebd..548423c3c 100644 --- a/clientapi/routing/pusher.go +++ b/clientapi/routing/pusher.go @@ -31,13 +31,14 @@ func GetPushers( userAPI userapi.ClientUserAPI, ) util.JSONResponse { var queryRes userapi.QueryPushersResponse - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") return jsonerror.InternalServerError() } err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ - Localpart: localpart, + Localpart: localpart, + ServerName: domain, }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") @@ -59,7 +60,7 @@ func SetPusher( req *http.Request, device *userapi.Device, userAPI userapi.ClientUserAPI, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") return jsonerror.InternalServerError() @@ -93,6 +94,7 @@ func SetPusher( } body.Localpart = localpart + body.ServerName = domain body.SessionID = device.SessionID err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) if err != nil { diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 778a02fd4..7841b3b07 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -123,8 +123,13 @@ func SendRedaction( return jsonerror.InternalServerError() } + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return jsonerror.InternalServerError() + } + var queryRes roomserverAPI.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, @@ -132,7 +137,7 @@ func SendRedaction( } } domain := device.UserDomain() - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index a96942999..f077e48d5 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -212,9 +212,10 @@ var ( // previous parameters with the ones supplied. This mean you cannot "build up" request params. type registerRequest struct { // registration parameters - Password string `json:"password"` - Username string `json:"username"` - Admin bool `json:"admin"` + Password string `json:"password"` + Username string `json:"username"` + ServerName gomatrixserverlib.ServerName `json:"-"` + Admin bool `json:"admin"` // user-interactive auth params Auth authDict `json:"auth"` @@ -552,6 +553,12 @@ func Register( } var r registerRequest + host := gomatrixserverlib.ServerName(req.Host) + if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { + r.ServerName = v.ServerName + } else { + r.ServerName = cfg.Matrix.ServerName + } sessionID := gjson.GetBytes(reqBody, "auth.session").String() if sessionID == "" { // Generate a new, random session ID @@ -561,6 +568,7 @@ func Register( // Some of these might end up being overwritten if the // values are specified again in the request body. r.Username = data.Username + r.ServerName = data.ServerName r.Password = data.Password r.DeviceID = data.DeviceID r.InitialDisplayName = data.InitialDisplayName @@ -572,11 +580,13 @@ func Register( JSON: response, } } - } if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } + if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { + r.Username, r.ServerName = l, d + } if req.URL.Query().Get("kind") == "guest" { return handleGuestRegistration(req, r, cfg, userAPI) } @@ -590,12 +600,15 @@ func Register( } // Auto generate a numeric username if r.Username is empty if r.Username == "" { - res := &userapi.QueryNumericLocalpartResponse{} - if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil { + nreq := &userapi.QueryNumericLocalpartRequest{ + ServerName: r.ServerName, + } + nres := &userapi.QueryNumericLocalpartResponse{} + if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") return jsonerror.InternalServerError() } - r.Username = strconv.FormatInt(res.ID, 10) + r.Username = strconv.FormatInt(nres.ID, 10) } // Is this an appservice registration? It will be if the access @@ -608,7 +621,7 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { + if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { return *resErr } case accessTokenErr == nil: @@ -621,7 +634,7 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { + if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { return *resErr } } @@ -645,16 +658,25 @@ func handleGuestRegistration( cfg *config.ClientAPI, userAPI userapi.ClientUserAPI, ) util.JSONResponse { - if cfg.RegistrationDisabled || cfg.GuestsDisabled { + registrationEnabled := !cfg.RegistrationDisabled + guestsEnabled := !cfg.GuestsDisabled + if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil { + registrationEnabled, guestsEnabled = v.RegistrationAllowed() + } + + if !registrationEnabled || !guestsEnabled { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Guest registration is disabled"), + JSON: jsonerror.Forbidden( + fmt.Sprintf("Guest registration is disabled on %q", r.ServerName), + ), } } var res userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeGuest, + ServerName: r.ServerName, }, &res) if err != nil { return util.JSONResponse{ @@ -678,6 +700,7 @@ func handleGuestRegistration( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ Localpart: res.Account.Localpart, + ServerName: res.Account.ServerName, DeviceDisplayName: r.InitialDisplayName, AccessToken: token, IPAddr: req.RemoteAddr, @@ -730,10 +753,16 @@ func handleRegistrationFlow( ) } - if cfg.RegistrationDisabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { + registrationEnabled := !cfg.RegistrationDisabled + if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil { + registrationEnabled, _ = v.RegistrationAllowed() + } + if !registrationEnabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Registration is disabled"), + JSON: jsonerror.Forbidden( + fmt.Sprintf("Registration is disabled on %q", r.ServerName), + ), } } @@ -845,7 +874,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, + req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, nil, ) } @@ -865,7 +894,7 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, + req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, threePid, ) } @@ -888,7 +917,8 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username, password, appserviceID, ipAddr, userAgent, sessionID string, + username string, serverName gomatrixserverlib.ServerName, + password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, @@ -911,6 +941,7 @@ func completeRegistration( err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ AppServiceID: appserviceID, Localpart: username, + ServerName: serverName, Password: password, AccountType: accType, OnConflict: userapi.ConflictAbort, @@ -969,6 +1000,7 @@ func completeRegistration( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: username, + ServerName: serverName, AccessToken: token, DeviceDisplayName: displayName, DeviceID: deviceID, @@ -1062,13 +1094,31 @@ func RegisterAvailable( // Squash username to all lowercase letters username = strings.ToLower(username) + domain := cfg.Matrix.ServerName + host := gomatrixserverlib.ServerName(req.Host) + if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { + domain = v.ServerName + } + if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil { + username, domain = u, l + } + for _, v := range cfg.Matrix.VirtualHosts { + if v.ServerName == domain && !v.AllowRegistration { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden( + fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)), + ), + } + } + } - if err := validateUsername(username, cfg.Matrix.ServerName); err != nil { + if err := validateUsername(username, domain); err != nil { return *err } // Check if this username is reserved by an application service - userID := userutil.MakeUserID(username, cfg.Matrix.ServerName) + userID := userutil.MakeUserID(username, domain) for _, appservice := range cfg.Derived.ApplicationServices { if appservice.OwnsNamespaceCoveringUserId(userID) { return util.JSONResponse{ @@ -1080,7 +1130,8 @@ func RegisterAvailable( res := &userapi.QueryAccountAvailabilityResponse{} err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ - Localpart: username, + Localpart: username, + ServerName: domain, }, res) if err != nil { return util.JSONResponse{ @@ -1137,5 +1188,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil) + return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType, nil) } diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 12622f293..f74e1837e 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -159,7 +159,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}", + dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) }), @@ -254,7 +254,7 @@ func Setup( return JoinRoomByIDOrAlias( req, device, rsAPI, userAPI, vars["roomIDOrAlias"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) if mscCfg.Enabled("msc2753") { @@ -276,7 +276,7 @@ func Setup( v3mux.Handle("/joined_rooms", httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetJoinedRooms(req, device, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -290,7 +290,7 @@ func Setup( return JoinRoomByIDOrAlias( req, device, rsAPI, userAPI, vars["roomID"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -304,7 +304,7 @@ func Setup( return LeaveRoomByID( req, device, rsAPI, vars["roomID"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/unpeek", httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -363,7 +363,7 @@ func Setup( return util.ErrorResponse(err) } return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -374,7 +374,7 @@ func Setup( txnID := vars["txnID"] return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID, nil, cfg, rsAPI, transactionsCache) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -383,7 +383,7 @@ func Setup( return util.ErrorResponse(err) } return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -402,7 +402,7 @@ func Setup( eventType := strings.TrimSuffix(vars["type"], "/") eventFormat := req.URL.Query().Get("format") == "event" return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -411,7 +411,7 @@ func Setup( } eventFormat := req.URL.Query().Get("format") == "event" return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -422,7 +422,7 @@ func Setup( emptyString := "" eventType := strings.TrimSuffix(vars["eventType"], "/") return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", @@ -433,7 +433,7 @@ func Setup( } stateKey := vars["stateKey"] return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/multiroom/{dataType}", @@ -588,7 +588,7 @@ func Setup( } txnID := vars["txnID"] return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) // This is only here because sytest refers to /unstable for this endpoint @@ -602,7 +602,7 @@ func Setup( } txnID := vars["txnID"] return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/account/whoami", @@ -611,7 +611,7 @@ func Setup( return *r } return Whoami(req, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/password", @@ -843,7 +843,7 @@ func Setup( return util.ErrorResponse(err) } return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method @@ -884,7 +884,7 @@ func Setup( v3mux.Handle("/thirdparty/protocols", httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Protocols(req, asAPI, device, "") - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/protocol/{protocolID}", @@ -894,7 +894,7 @@ func Setup( return util.ErrorResponse(err) } return Protocols(req, asAPI, device, vars["protocolID"]) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user/{protocolID}", @@ -904,13 +904,13 @@ func Setup( return util.ErrorResponse(err) } return User(req, asAPI, device, vars["protocolID"], req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user", httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return User(req, asAPI, device, "", req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location/{protocolID}", @@ -920,13 +920,13 @@ func Setup( return util.ErrorResponse(err) } return Location(req, asAPI, device, vars["protocolID"], req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location", httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Location(req, asAPI, device, "", req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/initialSync", @@ -1067,7 +1067,7 @@ func Setup( v3mux.Handle("/devices", httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetDevicesByLocalpart(req, userAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", @@ -1077,7 +1077,7 @@ func Setup( return util.ErrorResponse(err) } return GetDeviceByID(req, userAPI, device, vars["deviceID"]) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", @@ -1087,7 +1087,7 @@ func Setup( return util.ErrorResponse(err) } return UpdateDeviceByID(req, userAPI, device, vars["deviceID"]) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", @@ -1129,21 +1129,21 @@ func Setup( // Stub implementations for sytest v3mux.Handle("/events", - httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { + httputil.MakeAuthAPI("events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, "start": "", "end": "", }} - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/initialSync", - httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { + httputil.MakeAuthAPI("initial_sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", }} - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", @@ -1182,7 +1182,7 @@ func Setup( return *r } return GetCapabilities(req, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) // Key Backup Versions (Metadata) @@ -1363,7 +1363,7 @@ func Setup( postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadCrossSigningDeviceSignatures(req, keyAPI, device) - }) + }, httputil.WithAllowGuests()) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) @@ -1375,22 +1375,22 @@ func Setup( v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req, keyAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return ClaimKeys(req, keyAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index bb66cf6fc..90af9ac4d 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -186,6 +186,7 @@ func SendEvent( []*gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, + device.UserDomain(), domain, domain, txnAndSessionID, @@ -275,8 +276,14 @@ func generateSendEvent( return nil, &resErr } + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index a6a78061d..fb93d8783 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -231,6 +231,7 @@ func SendServerNotice( []*gomatrixserverlib.HeaderedEvent{ e.Headered(roomVersion), }, + device.UserDomain(), cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName, txnAndSessionID, @@ -286,6 +287,7 @@ func getSenderDevice( err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeUser, Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, OnConflict: userapi.ConflictUpdate, }, &accRes) if err != nil { @@ -295,8 +297,9 @@ func getSenderDevice( // Set the avatarurl for the user avatarRes := &userapi.PerformSetAvatarURLResponse{} if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ - Localpart: cfg.Matrix.ServerNotices.LocalPart, - AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, + Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, + AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, }, avatarRes); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") return nil, err @@ -308,6 +311,7 @@ func getSenderDevice( displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, DisplayName: cfg.Matrix.ServerNotices.DisplayName, }, displayNameRes); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") @@ -353,6 +357,7 @@ func getSenderDevice( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, AccessToken: token, NoDeviceListUpdate: true, diff --git a/clientapi/routing/thirdparty.go b/clientapi/routing/thirdparty.go index e757cd411..7a62da449 100644 --- a/clientapi/routing/thirdparty.go +++ b/clientapi/routing/thirdparty.go @@ -36,9 +36,15 @@ func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, dev return jsonerror.InternalServerError() } if !resp.Exists { + if protocol != "" { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The protocol is unknown."), + } + } return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The protocol is unknown."), + Code: http.StatusOK, + JSON: struct{}{}, } } if protocol != "" { diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 4b7989ecb..971bfcad3 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -136,16 +136,17 @@ func CheckAndSave3PIDAssociation( } // Save the association in the database - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() } if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ - ThreePID: address, - Localpart: localpart, - Medium: medium, + ThreePID: address, + Localpart: localpart, + ServerName: domain, + Medium: medium, }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") return jsonerror.InternalServerError() @@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation( func GetAssociated3PIDs( req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() @@ -169,7 +170,8 @@ func GetAssociated3PIDs( res := &api.QueryThreePIDsForLocalpartResponse{} err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{ - Localpart: localpart, + Localpart: localpart, + ServerName: domain, }, res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index d3d1c22e4..62af9efa4 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -106,7 +106,7 @@ knownUsersLoop: continue } // TODO: We should probably cache/store this - fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "") + fedProfile, fedErr := federation.LookupProfile(ctx, localServerName, serverName, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 99fb8171d..1f294a032 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -359,8 +359,13 @@ func emit3PIDInviteEvent( return err } + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return err + } + queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes) if err != nil { return err } @@ -371,6 +376,7 @@ func emit3PIDInviteEvent( []*gomatrixserverlib.HeaderedEvent{ event.Headered(queryRes.RoomVersion), }, + device.UserDomain(), cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, diff --git a/clientapi/userutil/userutil_test.go b/clientapi/userutil/userutil_test.go index ccd6647b2..ee6bf8a01 100644 --- a/clientapi/userutil/userutil_test.go +++ b/clientapi/userutil/userutil_test.go @@ -30,7 +30,9 @@ var ( // TestGoodUserID checks that correct localpart is returned for a valid user ID. func TestGoodUserID(t *testing.T) { cfg := &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, } lp, _, err := ParseUsernameParam(goodUserID, cfg) @@ -47,7 +49,9 @@ func TestGoodUserID(t *testing.T) { // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. func TestWithLocalpartOnly(t *testing.T) { cfg := &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, } lp, _, err := ParseUsernameParam(localpart, cfg) @@ -64,7 +68,9 @@ func TestWithLocalpartOnly(t *testing.T) { // TestIncorrectDomain checks for error when there's server name mismatch. func TestIncorrectDomain(t *testing.T) { cfg := &config.Global{ - ServerName: invalidServerName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: invalidServerName, + }, } _, _, err := ParseUsernameParam(goodUserID, cfg) @@ -77,7 +83,9 @@ func TestIncorrectDomain(t *testing.T) { // TestBadUserID checks that ParseUsernameParam fails for invalid user ID func TestBadUserID(t *testing.T) { cfg := &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, } _, _, err := ParseUsernameParam(badUserID, cfg) diff --git a/cmd/dendrite-demo-pinecone/conn/client.go b/cmd/dendrite-demo-pinecone/conn/client.go index 27e18c2a3..a91434f62 100644 --- a/cmd/dendrite-demo-pinecone/conn/client.go +++ b/cmd/dendrite-demo-pinecone/conn/client.go @@ -101,9 +101,7 @@ func CreateFederationClient( base *base.BaseDendrite, s *pineconeSessions.Sessions, ) *gomatrixserverlib.FederationClient { return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.ServerName, - base.Cfg.Global.KeyID, - base.Cfg.Global.PrivateKey, + base.Cfg.Global.SigningIdentities(), gomatrixserverlib.WithTransport(createTransport(s)), ) } diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 6c719a1ee..421b17d56 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -37,6 +37,7 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" @@ -51,6 +52,7 @@ import ( pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" + pineconeEvents "github.com/matrix-org/pinecone/router/events" pineconeSessions "github.com/matrix-org/pinecone/sessions" "github.com/sirupsen/logrus" @@ -155,7 +157,12 @@ func main() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck + pineconeEventChannel := make(chan pineconeEvents.Event) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) + pRouter.EnableHopLimiting() + pRouter.EnableWakeupBroadcasts() + pRouter.Subscribe(pineconeEventChannel) + pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) pManager := pineconeConnections.NewConnectionManager(pRouter, nil) @@ -293,5 +300,33 @@ func main() { logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) }() + go func(ch <-chan pineconeEvents.Event) { + eLog := logrus.WithField("pinecone", "events") + + for event := range ch { + switch e := event.(type) { + case pineconeEvents.PeerAdded: + case pineconeEvents.PeerRemoved: + case pineconeEvents.TreeParentUpdate: + case pineconeEvents.SnakeDescUpdate: + case pineconeEvents.TreeRootAnnUpdate: + case pineconeEvents.SnakeEntryAdded: + case pineconeEvents.SnakeEntryRemoved: + case pineconeEvents.BroadcastReceived: + eLog.Info("Broadcast received from: ", e.PeerID) + + req := &api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + } + res := &api.PerformWakeupServersResponse{} + if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + } + case pineconeEvents.BandwidthReport: + default: + } + } + }(pineconeEventChannel) + base.WaitForShutdown() } diff --git a/cmd/dendrite-demo-pinecone/rooms/rooms.go b/cmd/dendrite-demo-pinecone/rooms/rooms.go index 0fafbedc3..0ac705cc1 100644 --- a/cmd/dendrite-demo-pinecone/rooms/rooms.go +++ b/cmd/dendrite-demo-pinecone/rooms/rooms.go @@ -58,13 +58,17 @@ func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { for _, k := range p.r.Peers() { list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} } - return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, list) + return bulkFetchPublicRoomsFromServers( + context.Background(), p.fedClient, + gomatrixserverlib.ServerName(p.r.PublicKey().String()), list, + ) } // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( ctx context.Context, fedClient *gomatrixserverlib.FederationClient, + origin gomatrixserverlib.ServerName, homeservers map[gomatrixserverlib.ServerName]struct{}, ) (publicRooms []gomatrixserverlib.PublicRoom) { limit := 200 @@ -82,7 +86,7 @@ func bulkFetchPublicRoomsFromServers( go func(homeserverDomain gomatrixserverlib.ServerName) { defer wg.Done() util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") - fres, err := fedClient.GetPublicRooms(reqctx, homeserverDomain, int(limit), "", false, "") + fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "") if err != nil { util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( "bulkFetchPublicRoomsFromServers: failed to query hs", diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go index 358d3725e..41a9ec123 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/client.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -55,8 +55,7 @@ func (n *Node) CreateFederationClient( }, ) return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, - base.Cfg.Global.PrivateKey, + base.Cfg.Global.SigningIdentities(), gomatrixserverlib.WithTransport(tr), ) } diff --git a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go index 402b86ed3..0de64755e 100644 --- a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go +++ b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go @@ -43,13 +43,18 @@ func NewYggdrasilRoomProvider( } func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { - return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, p.node.KnownNodes()) + return bulkFetchPublicRoomsFromServers( + context.Background(), p.fedClient, + gomatrixserverlib.ServerName(p.node.DerivedServerName()), + p.node.KnownNodes(), + ) } // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( ctx context.Context, fedClient *gomatrixserverlib.FederationClient, + origin gomatrixserverlib.ServerName, homeservers []gomatrixserverlib.ServerName, ) (publicRooms []gomatrixserverlib.PublicRoom) { limit := 200 @@ -66,7 +71,7 @@ func bulkFetchPublicRoomsFromServers( go func(homeserverDomain gomatrixserverlib.ServerName) { defer wg.Done() util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") - fres, err := fedClient.GetPublicRooms(ctx, homeserverDomain, int(limit), "", false, "") + fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "") if err != nil { util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn( "bulkFetchPublicRoomsFromServers: failed to query hs", diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index dce22472d..75446d18c 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -38,6 +38,7 @@ var ( flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github") flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.") flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done") + flagSqlite = flag.Bool("sqlite", false, "Test SQLite instead of PostgreSQL") alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+") ) @@ -49,7 +50,7 @@ const HEAD = "HEAD" // due to the error: // When using COPY with more than one source file, the destination must be a directory and end with a / // We need to run a postgres anyway, so use the dockerfile associated with Complement instead. -const Dockerfile = `FROM golang:1.18-stretch as build +const DockerfilePostgreSQL = `FROM golang:1.18-stretch as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build @@ -92,6 +93,42 @@ ENV SERVER_NAME=localhost EXPOSE 8008 8448 CMD /build/run_dendrite.sh ` +const DockerfileSQLite = `FROM golang:1.18-stretch as build +RUN apt-get update && apt-get install -y postgresql +WORKDIR /build + +# Copy the build context to the repo as this is the right dendrite code. This is different to the +# Complement Dockerfile which wgets a branch. +COPY . . + +RUN go build ./cmd/dendrite-monolith-server +RUN go build ./cmd/generate-keys +RUN go build ./cmd/generate-config +RUN ./generate-config --ci > dendrite.yaml +RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key + +# Make sure the SQLite databases are in a persistent location, we're already mapping +# the postgresql folder so let's just use that for simplicity +RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postgresql\/9.6\/main\/%g" dendrite.yaml + +# This entry script starts postgres, waits for it to be up then starts dendrite +RUN echo '\ +sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\ +PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\ +./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\ +' > run_dendrite.sh && chmod +x run_dendrite.sh + +ENV SERVER_NAME=localhost +EXPOSE 8008 8448 +CMD /build/run_dendrite.sh ` + +func dockerfile() []byte { + if *flagSqlite { + return []byte(DockerfileSQLite) + } + return []byte(DockerfilePostgreSQL) +} + const dendriteUpgradeTestLabel = "dendrite_upgrade_test" // downloadArchive downloads an arbitrary github archive of the form: @@ -150,7 +187,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, if branchOrTagName == HEAD && *flagHead != "" { log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead) // add top level Dockerfile - err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), []byte(Dockerfile), os.ModePerm) + err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), dockerfile(), os.ModePerm) if err != nil { return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err) } @@ -166,7 +203,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, // pull an archive, this contains a top-level directory which screws with the build context // which we need to fix up post download u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName) - tarball, err = downloadArchive(httpClient, tmpDir, u, []byte(Dockerfile)) + tarball, err = downloadArchive(httpClient, tmpDir, u, dockerfile()) if err != nil { return "", fmt.Errorf("failed to download archive %s: %w", u, err) } @@ -367,7 +404,8 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string) // hit /versions to check it is up var lastErr error for i := 0; i < 500; i++ { - res, err := http.Get(versionsURL) + var res *http.Response + res, err = http.Get(versionsURL) if err != nil { lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err) time.Sleep(50 * time.Millisecond) @@ -381,18 +419,22 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string) lastErr = nil break } - if lastErr != nil { - logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{ - ShowStdout: true, - ShowStderr: true, - }) - // ignore errors when cannot get logs, it's just for debugging anyways - if err == nil { - logbody, err := io.ReadAll(logs) - if err == nil { - log.Printf("Container logs:\n\n%s\n\n", string(logbody)) + logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{ + ShowStdout: true, + ShowStderr: true, + Follow: true, + }) + // ignore errors when cannot get logs, it's just for debugging anyways + if err == nil { + go func() { + for { + if body, err := io.ReadAll(logs); err == nil && len(body) > 0 { + log.Printf("%s: %s", version, string(body)) + } else { + return + } } - } + }() } return baseURL, containerID, lastErr } diff --git a/cmd/furl/main.go b/cmd/furl/main.go index f59f9c8ce..b208ba868 100644 --- a/cmd/furl/main.go +++ b/cmd/furl/main.go @@ -48,10 +48,15 @@ func main() { panic("unexpected key block") } + serverName := gomatrixserverlib.ServerName(*requestFrom) client := gomatrixserverlib.NewFederationClient( - gomatrixserverlib.ServerName(*requestFrom), - gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), - privateKey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: serverName, + KeyID: gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), + PrivateKey: privateKey, + }, + }, ) u, err := url.Parse(flag.Arg(0)) @@ -79,6 +84,7 @@ func main() { req := gomatrixserverlib.NewFederationRequest( method, + serverName, gomatrixserverlib.ServerName(u.Host), u.RequestURI(), ) diff --git a/docs/installation/2_domainname.md b/docs/installation/2_domainname.md index e7b3495f7..545a2daf6 100644 --- a/docs/installation/2_domainname.md +++ b/docs/installation/2_domainname.md @@ -90,7 +90,7 @@ For example, this can be done with the following Caddy config: handle /.well-known/matrix/server { header Content-Type application/json header Access-Control-Allow-Origin * - respond `"m.server": "matrix.example.com:8448"` + respond `{"m.server": "matrix.example.com:8448"}` } handle /.well-known/matrix/client { diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 362333fc9..50d0339e4 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -21,8 +21,8 @@ type FederationInternalAPI interface { QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) - MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. PerformBroadcastEDU( @@ -30,6 +30,12 @@ type FederationInternalAPI interface { request *PerformBroadcastEDURequest, response *PerformBroadcastEDUResponse, ) error + + PerformWakeupServers( + ctx context.Context, + request *PerformWakeupServersRequest, + response *PerformWakeupServersResponse, + ) error } type ClientFederationAPI interface { @@ -60,18 +66,18 @@ type RoomserverFederationAPI interface { // containing only the server names (without information for membership events). // The response will include this server if they are joined to the room. QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error - GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError type KeyserverFederationAPI interface { - GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) - ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) - QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) + GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) + ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) + QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) } // an interface for gmsl.FederationClient - contains functions called by federationapi only. @@ -80,28 +86,28 @@ type FederationClient interface { SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) // Perform operations - LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) - Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) - MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) - SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) - MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) - SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) - SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) + LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) + Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) + MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) + SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) + MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) + SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) + SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) - ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) - QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) - Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) - MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) + ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) + QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) + Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) + MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) - ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) - LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) - LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) + LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) + LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. @@ -198,8 +204,9 @@ type PerformInviteResponse struct { // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomRequest struct { - RoomID string `json:"room_id"` - ExcludeSelf bool `json:"exclude_self"` + RoomID string `json:"room_id"` + ExcludeSelf bool `json:"exclude_self"` + ExcludeBlacklisted bool `json:"exclude_blacklisted"` } // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames @@ -213,6 +220,13 @@ type PerformBroadcastEDURequest struct { type PerformBroadcastEDUResponse struct { } +type PerformWakeupServersRequest struct { + ServerNames []gomatrixserverlib.ServerName `json:"server_names"` +} + +type PerformWakeupServersResponse struct { +} + type InputPublicKeysRequest struct { Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"` } diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 7d1ae0f81..601257d4b 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { sentry.CaptureException(err) logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") @@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { return true } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { sentry.CaptureException(err) logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index 153fc40b5..29b16f373 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg } // send this presence to all servers who share rooms with this user. - joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { log.WithError(err).Error("failed to get joined hosts") return true diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index a42733628..d16af6626 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -18,6 +18,10 @@ import ( "context" "encoding/json" "fmt" + "strconv" + "time" + + syncAPITypes "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -34,14 +38,16 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { - ctx context.Context - cfg *config.FederationAPI - rsAPI api.FederationRoomserverAPI - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - topic string + ctx context.Context + cfg *config.FederationAPI + rsAPI api.FederationRoomserverAPI + jetstream nats.JetStreamContext + natsClient *nats.Conn + durable string + db storage.Database + queues *queue.OutgoingQueues + topic string + topicPresence string } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -49,19 +55,22 @@ func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.FederationAPI, js nats.JetStreamContext, + natsClient *nats.Conn, queues *queue.OutgoingQueues, store storage.Database, rsAPI api.FederationRoomserverAPI, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ - ctx: process.Context(), - cfg: cfg, - jetstream: js, - db: store, - queues: queues, - rsAPI: rsAPI, - durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), + ctx: process.Context(), + cfg: cfg, + jetstream: js, + natsClient: natsClient, + db: store, + queues: queues, + rsAPI: rsAPI, + durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), + topicPresence: cfg.Matrix.JetStream.Prefixed(jetstream.RequestPresence), } } @@ -146,6 +155,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error { + addsStateEvents, missingEventIDs := ore.NeededStateEventIDs() // Ask the roomserver and add in the rest of the results into the set. @@ -184,6 +194,14 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew return err } + // If we added new hosts, inform them about our known presence events for this room + if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + membership, _ := ore.Event.Membership() + if membership == gomatrixserverlib.Join { + s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) + } + } + if oldJoinedHosts == nil { // This means that there is nothing to update as this is a duplicate // message. @@ -213,6 +231,76 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew ) } +func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) { + joined := make([]gomatrixserverlib.ServerName, len(addedJoined)) + for _, added := range addedJoined { + joined = append(joined, added.ServerName) + } + + // get our locally joined users + var queryRes api.QueryMembershipsForRoomResponse + err := s.rsAPI.QueryMembershipsForRoom(s.ctx, &api.QueryMembershipsForRoomRequest{ + JoinedOnly: true, + LocalOnly: true, + RoomID: roomID, + }, &queryRes) + if err != nil { + log.WithError(err).Error("failed to calculate joined rooms for user") + return + } + + // send every presence we know about to the remote server + content := types.Presence{} + for _, ev := range queryRes.JoinEvents { + msg := nats.NewMsg(s.topicPresence) + msg.Header.Set(jetstream.UserID, ev.Sender) + + var presence *nats.Msg + presence, err = s.natsClient.RequestMsg(msg, time.Second*10) + if err != nil { + log.WithError(err).Errorf("unable to get presence") + continue + } + + statusMsg := presence.Header.Get("status_msg") + e := presence.Header.Get("error") + if e != "" { + continue + } + var lastActive int + lastActive, err = strconv.Atoi(presence.Header.Get("last_active_ts")) + if err != nil { + continue + } + + p := syncAPITypes.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)} + + content.Push = append(content.Push, types.PresenceContent{ + CurrentlyActive: p.CurrentlyActive(), + LastActiveAgo: p.LastActiveAgo(), + Presence: presence.Header.Get("presence"), + StatusMsg: &statusMsg, + UserID: ev.Sender, + }) + } + + if len(content.Push) == 0 { + return + } + + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MPresence, + Origin: string(s.cfg.Matrix.ServerName), + } + if edu.Content, err = json.Marshal(content); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return + } + if err := s.queues.SendEDU(edu, s.cfg.Matrix.ServerName, joined); err != nil { + log.WithError(err).Error("failed to send EDU") + } +} + // joinedHostsAtEvent works out a list of matrix servers that were joined to // the room at the event (including peeking ones) // It is important to use the state at the event for sending messages because: diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 202da6c51..854251220 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -118,21 +118,19 @@ func NewInternalAPI( stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + + signingInfo := base.Cfg.Global.SigningIdentities() queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, cfg.Matrix.DisableFederation, cfg.Matrix.ServerName, federation, rsAPI, &stats, - &queue.SigningInfo{ - KeyID: cfg.Matrix.KeyID, - PrivateKey: cfg.Matrix.PrivateKey, - ServerName: cfg.Matrix.ServerName, - }, + signingInfo, ) rsConsumer := consumers.NewOutputRoomEventConsumer( - base.ProcessContext, cfg, js, queues, + base.ProcessContext, cfg, js, nats, queues, federationDB, rsAPI, ) if err = rsConsumer.Start(); err != nil { diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 7ccc02f76..cc03cdece 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -104,7 +104,7 @@ func TestMain(m *testing.M) { // Create the federation client. s.fedclient = gomatrixserverlib.NewFederationClient( - s.config.Matrix.ServerName, serverKeyID, testPriv, + s.config.Matrix.SigningIdentities(), gomatrixserverlib.WithTransport(transport), ) @@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err } // Get the keys and JSON-ify them. - keys := routing.LocalKeys(s.config) + keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host)) body, err := json.MarshalIndent(keys.JSON, "", " ") if err != nil { return nil, err diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index c37bc87c2..68a06a033 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -103,7 +103,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv return keys, nil } -func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { +func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { for _, r := range f.allowJoins { if r.ID == roomID { res.RoomVersion = r.Version @@ -127,7 +127,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName } return } -func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { +func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { @@ -283,7 +283,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) fedCli := gomatrixserverlib.NewFederationClient( - serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, + cfg.Global.SigningIdentities(), gomatrixserverlib.WithSkipVerify(true), ) @@ -326,7 +326,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { t.Errorf("failed to create invite v2 request: %s", err) continue } - _, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) + _, err = fedCli.SendInviteV2(context.Background(), cfg.Global.ServerName, serverName, invReq) if err == nil { t.Errorf("expected an error, got none") continue diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index 2636b7fa0..db6348ec1 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -11,13 +11,13 @@ import ( // client. func (a *FederationInternalAPI) GetEventAuth( - ctx context.Context, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (res gomatrixserverlib.RespEventAuth, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetEventAuth(ctx, s, roomVersion, roomID, eventID) + return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) }) if err != nil { return gomatrixserverlib.RespEventAuth{}, err @@ -26,12 +26,12 @@ func (a *FederationInternalAPI) GetEventAuth( } func (a *FederationInternalAPI) GetUserDevices( - ctx context.Context, s gomatrixserverlib.ServerName, userID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetUserDevices(ctx, s, userID) + return a.federation.GetUserDevices(ctx, origin, s, userID) }) if err != nil { return gomatrixserverlib.RespUserDevices{}, err @@ -40,12 +40,12 @@ func (a *FederationInternalAPI) GetUserDevices( } func (a *FederationInternalAPI) ClaimKeys( - ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.ClaimKeys(ctx, s, oneTimeKeys) + return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) }) if err != nil { return gomatrixserverlib.RespClaimKeys{}, err @@ -54,10 +54,10 @@ func (a *FederationInternalAPI) ClaimKeys( } func (a *FederationInternalAPI) QueryKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { - return a.federation.QueryKeys(ctx, s, keys) + return a.federation.QueryKeys(ctx, origin, s, keys) }) if err != nil { return gomatrixserverlib.RespQueryKeys{}, err @@ -66,12 +66,12 @@ func (a *FederationInternalAPI) QueryKeys( } func (a *FederationInternalAPI) Backfill( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) + return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) }) if err != nil { return gomatrixserverlib.Transaction{}, err @@ -80,12 +80,12 @@ func (a *FederationInternalAPI) Backfill( } func (a *FederationInternalAPI) LookupState( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespState, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) + return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) }) if err != nil { return gomatrixserverlib.RespState{}, err @@ -94,12 +94,12 @@ func (a *FederationInternalAPI) LookupState( } func (a *FederationInternalAPI) LookupStateIDs( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, ) (res gomatrixserverlib.RespStateIDs, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupStateIDs(ctx, s, roomID, eventID) + return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) }) if err != nil { return gomatrixserverlib.RespStateIDs{}, err @@ -108,13 +108,13 @@ func (a *FederationInternalAPI) LookupStateIDs( } func (a *FederationInternalAPI) LookupMissingEvents( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) + return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) }) if err != nil { return gomatrixserverlib.RespMissingEvents{}, err @@ -123,12 +123,12 @@ func (a *FederationInternalAPI) LookupMissingEvents( } func (a *FederationInternalAPI) GetEvent( - ctx context.Context, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetEvent(ctx, s, eventID) + return a.federation.GetEvent(ctx, origin, s, eventID) }) if err != nil { return gomatrixserverlib.Transaction{}, err @@ -151,13 +151,13 @@ func (a *FederationInternalAPI) LookupServerKeys( } func (a *FederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) + return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion) }) if err != nil { return res, err @@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) + return a.federation.MSC2946Spaces(ctx, origin, s, roomID, suggestedOnly) }) if err != nil { return res, err diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 1b61ec711..d86d07e03 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -26,6 +26,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( ) (err error) { dir, err := r.federation.LookupRoomAlias( ctx, + r.cfg.Matrix.ServerName, request.ServerName, request.RoomAlias, ) @@ -143,10 +144,16 @@ func (r *FederationInternalAPI) performJoinUsingServer( supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) + if err != nil { + return err + } + // Try to perform a make_join using the information supplied in the // request. respMakeJoin, err := r.federation.MakeJoin( ctx, + origin, serverName, roomID, userID, @@ -192,7 +199,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // Build the join event. event, err := respMakeJoin.JoinEvent.Build( time.Now(), - r.cfg.Matrix.ServerName, + origin, r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, respMakeJoin.RoomVersion, @@ -204,6 +211,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // Try to perform a send_join using the newly built event. respSendJoin, err := r.federation.SendJoin( context.Background(), + origin, serverName, event, ) @@ -246,7 +254,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( respMakeJoin.RoomVersion, r.keyRing, event, - federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), + federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName), ) if err != nil { return fmt.Errorf("respSendJoin.Check: %w", err) @@ -281,6 +289,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( if err = roomserverAPI.SendEventWithState( context.Background(), r.rsAPI, + origin, roomserverAPI.KindNew, respState, event.Headered(respMakeJoin.RoomVersion), @@ -427,6 +436,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // request. respPeek, err := r.federation.Peek( ctx, + r.cfg.Matrix.ServerName, serverName, roomID, peekID, @@ -453,7 +463,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) + authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName)) if err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) } @@ -475,7 +485,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // logrus.Warnf("got respPeek %#v", respPeek) // Send the newly returned state to the roomserver to update our local view. if err = roomserverAPI.SendEventWithState( - ctx, r.rsAPI, + ctx, r.rsAPI, r.cfg.Matrix.ServerName, roomserverAPI.KindNew, &respState, respPeek.LatestEvent.Headered(respPeek.RoomVersion), @@ -495,6 +505,11 @@ func (r *FederationInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) (err error) { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID) + if err != nil { + return err + } + // Deduplicate the server names we were provided. util.SortAndUnique(request.ServerNames) @@ -505,6 +520,7 @@ func (r *FederationInternalAPI) PerformLeave( // request. respMakeLeave, err := r.federation.MakeLeave( ctx, + origin, serverName, request.RoomID, request.UserID, @@ -546,7 +562,7 @@ func (r *FederationInternalAPI) PerformLeave( // Build the leave event. event, err := respMakeLeave.LeaveEvent.Build( time.Now(), - r.cfg.Matrix.ServerName, + origin, r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, respMakeLeave.RoomVersion, @@ -559,6 +575,7 @@ func (r *FederationInternalAPI) PerformLeave( // Try to perform a send_leave using the newly built event. err = r.federation.SendLeave( ctx, + origin, serverName, event, ) @@ -585,6 +602,11 @@ func (r *FederationInternalAPI) PerformInvite( request *api.PerformInviteRequest, response *api.PerformInviteResponse, ) (err error) { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender()) + if err != nil { + return err + } + if request.Event.StateKey() == nil { return errors.New("invite must be a state event") } @@ -607,7 +629,7 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) if err != nil { return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } @@ -648,9 +670,23 @@ func (r *FederationInternalAPI) PerformBroadcastEDU( return nil } +// PerformWakeupServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) PerformWakeupServers( + ctx context.Context, + request *api.PerformWakeupServersRequest, + response *api.PerformWakeupServersResponse, +) (err error) { + r.MarkServersAlive(request.ServerNames) + return nil +} + func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { for _, srv := range destinations { - _ = r.db.RemoveServerFromBlacklist(srv) + // Check the statistics cache for the blacklist status to prevent hitting + // the database unnecessarily. + if r.queues.IsServerBlacklisted(srv) { + _ = r.db.RemoveServerFromBlacklist(srv) + } r.queues.RetryServer(srv) } } @@ -708,7 +744,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder // FederatedAuthProvider is an auth chain provider which fetches events from the server provided func federatedAuthProvider( ctx context.Context, federation api.FederationClient, - keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, + keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName, ) gomatrixserverlib.AuthChainProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. @@ -738,7 +774,7 @@ func federatedAuthProvider( // Try to retrieve the event from the server that sent us the send // join response. - tx, txerr := federation.GetEvent(ctx, server, eventID) + tx, txerr := federation.GetEvent(ctx, origin, server, eventID) if txerr != nil { return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr) } diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index b0a76eeb7..688afa8ea 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) (err error) { - joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) + joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted) if err != nil { return } diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 812d3c6da..6eefdc7cd 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -23,6 +23,7 @@ const ( FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest" FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" + FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -150,18 +151,32 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU( ) } +// Handle an instruction to remove the respective servers from being blacklisted. +func (h *httpFederationInternalAPI) PerformWakeupServers( + ctx context.Context, + request *api.PerformWakeupServersRequest, + response *api.PerformWakeupServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformWakeupServers", h.federationAPIURL+FederationAPIPerformWakeupServers, + h.httpClient, ctx, request, response, + ) +} + type getUserDevices struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName UserID string } func (h *httpFederationInternalAPI) GetUserDevices( - ctx context.Context, s gomatrixserverlib.ServerName, userID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, ctx, &getUserDevices{ S: s, + Origin: origin, UserID: userID, }, ) @@ -169,52 +184,58 @@ func (h *httpFederationInternalAPI) GetUserDevices( type claimKeys struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName OneTimeKeys map[string]map[string]string } func (h *httpFederationInternalAPI) ClaimKeys( - ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, ctx, &claimKeys{ S: s, + Origin: origin, OneTimeKeys: oneTimeKeys, }, ) } type queryKeys struct { - S gomatrixserverlib.ServerName - Keys map[string][]string + S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName + Keys map[string][]string } func (h *httpFederationInternalAPI) QueryKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, ctx, &queryKeys{ - S: s, - Keys: keys, + S: s, + Origin: origin, + Keys: keys, }, ) } type backfill struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string Limit int EventIDs []string } func (h *httpFederationInternalAPI) Backfill( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (gomatrixserverlib.Transaction, error) { return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, ctx, &backfill{ S: s, + Origin: origin, RoomID: roomID, Limit: limit, EventIDs: eventIDs, @@ -224,18 +245,20 @@ func (h *httpFederationInternalAPI) Backfill( type lookupState struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string EventID string RoomVersion gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) LookupState( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (gomatrixserverlib.RespState, error) { return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, ctx, &lookupState{ S: s, + Origin: origin, RoomID: roomID, EventID: eventID, RoomVersion: roomVersion, @@ -245,17 +268,19 @@ func (h *httpFederationInternalAPI) LookupState( type lookupStateIDs struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string EventID string } func (h *httpFederationInternalAPI) LookupStateIDs( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, ) (gomatrixserverlib.RespStateIDs, error) { return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, ctx, &lookupStateIDs{ S: s, + Origin: origin, RoomID: roomID, EventID: eventID, }, @@ -264,19 +289,21 @@ func (h *httpFederationInternalAPI) LookupStateIDs( type lookupMissingEvents struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string Missing gomatrixserverlib.MissingEvents RoomVersion gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) LookupMissingEvents( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, ctx, &lookupMissingEvents{ S: s, + Origin: origin, RoomID: roomID, Missing: missing, RoomVersion: roomVersion, @@ -286,16 +313,18 @@ func (h *httpFederationInternalAPI) LookupMissingEvents( type getEvent struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName EventID string } func (h *httpFederationInternalAPI) GetEvent( - ctx context.Context, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, ) (gomatrixserverlib.Transaction, error) { return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, ctx, &getEvent{ S: s, + Origin: origin, EventID: eventID, }, ) @@ -303,19 +332,21 @@ func (h *httpFederationInternalAPI) GetEvent( type getEventAuth struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomVersion gomatrixserverlib.RoomVersion RoomID string EventID string } func (h *httpFederationInternalAPI) GetEventAuth( - ctx context.Context, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (gomatrixserverlib.RespEventAuth, error) { return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, ctx, &getEventAuth{ S: s, + Origin: origin, RoomVersion: roomVersion, RoomID: roomID, EventID: eventID, @@ -351,18 +382,20 @@ func (h *httpFederationInternalAPI) LookupServerKeys( type eventRelationships struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName Req gomatrixserverlib.MSC2836EventRelationshipsRequest RoomVer gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, ctx, &eventRelationships{ S: s, + Origin: origin, Req: r, RoomVer: roomVersion, }, @@ -371,17 +404,19 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( type spacesReq struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName SuggestedOnly bool RoomID string } func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, ctx, &spacesReq{ S: dst, + Origin: origin, SuggestedOnly: suggestedOnly, RoomID: roomID, }, diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 7aa0e4801..21a070392 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -43,6 +43,11 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), ) + internalAPIMux.Handle( + FederationAPIPerformWakeupServers, + httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers), + ) + internalAPIMux.Handle( FederationAPIPerformJoinRequestPath, httputil.MakeInternalRPCAPI( @@ -59,7 +64,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetUserDevices", func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { - res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) + res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID) return &res, federationClientError(err) }, ), @@ -70,7 +75,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIClaimKeys", func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { - res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) + res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys) return &res, federationClientError(err) }, ), @@ -81,7 +86,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIQueryKeys", func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { - res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) + res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys) return &res, federationClientError(err) }, ), @@ -92,7 +97,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIBackfill", func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) + res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs) return &res, federationClientError(err) }, ), @@ -103,7 +108,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupState", func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { - res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) + res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion) return &res, federationClientError(err) }, ), @@ -114,7 +119,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupStateIDs", func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { - res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) + res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID) return &res, federationClientError(err) }, ), @@ -125,7 +130,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupMissingEvents", func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { - res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) + res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion) return &res, federationClientError(err) }, ), @@ -136,7 +141,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetEvent", func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.GetEvent(ctx, req.S, req.EventID) + res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID) return &res, federationClientError(err) }, ), @@ -147,7 +152,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetEventAuth", func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { - res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) + res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID) return &res, federationClientError(err) }, ), @@ -174,7 +179,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIMSC2836EventRelationships", func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { - res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) + res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer) return &res, federationClientError(err) }, ), @@ -185,7 +190,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIMSC2946SpacesSummary", func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { - res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) + res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly) return &res, federationClientError(err) }, ), diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a638a5742..a4a87fe99 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -50,7 +50,7 @@ type destinationQueue struct { queues *OutgoingQueues db storage.Database process *process.ProcessContext - signing *SigningInfo + signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity rsAPI api.FederationRoomserverAPI client fedapi.FederationClient // federation client origin gomatrixserverlib.ServerName // origin of requests @@ -141,23 +141,44 @@ func (oq *destinationQueue) handleBackoffNotifier() { } } +// wakeQueueIfEventsPending calls wakeQueueAndNotify only if there are +// pending events or if forceWakeup is true. This prevents starting the +// queue unnecessarily. +func (oq *destinationQueue) wakeQueueIfEventsPending(forceWakeup bool) { + eventsPending := func() bool { + oq.pendingMutex.Lock() + defer oq.pendingMutex.Unlock() + return len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 + } + + // NOTE : Only wakeup and notify the queue if there are pending events + // or if forceWakeup is true. Otherwise there is no reason to start the + // queue goroutine and waste resources. + if forceWakeup || eventsPending() { + logrus.Info("Starting queue due to pending events or forceWakeup") + oq.wakeQueueAndNotify() + } +} + // wakeQueueAndNotify ensures the destination queue is running and notifies it // that there is pending work. func (oq *destinationQueue) wakeQueueAndNotify() { - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() + // NOTE : Send notification before waking queue to prevent a race + // where the queue was running and stops due to a timeout in between + // checking it and sending the notification. // Notify the queue that there are events ready to send. select { case oq.notify <- struct{}{}: default: } + + // Wake up the queue if it's asleep. + oq.wakeQueueIfNeeded() } // wakeQueueIfNeeded will wake up the destination queue if it is -// not already running. If it is running but it is backing off -// then we will interrupt the backoff, causing any federation -// requests to retry. +// not already running. func (oq *destinationQueue) wakeQueueIfNeeded() { // Clear the backingOff flag and update the backoff metrics if it was set. if oq.backingOff.CompareAndSwap(true, false) { diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index b5d0552c6..75b1b36be 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -15,7 +15,6 @@ package queue import ( - "crypto/ed25519" "encoding/json" "fmt" "sync" @@ -46,7 +45,7 @@ type OutgoingQueues struct { origin gomatrixserverlib.ServerName client fedapi.FederationClient statistics *statistics.Statistics - signing *SigningInfo + signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity queuesMutex sync.Mutex // protects the below queues map[gomatrixserverlib.ServerName]*destinationQueue } @@ -91,7 +90,7 @@ func NewOutgoingQueues( client fedapi.FederationClient, rsAPI api.FederationRoomserverAPI, statistics *statistics.Statistics, - signing *SigningInfo, + signing []*gomatrixserverlib.SigningIdentity, ) *OutgoingQueues { queues := &OutgoingQueues{ disabled: disabled, @@ -101,9 +100,12 @@ func NewOutgoingQueues( origin: origin, client: client, statistics: statistics, - signing: signing, + signing: map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity{}, queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, } + for _, identity := range signing { + queues.signing[identity.ServerName] = identity + } // Look up which servers we have pending items for and then rehydrate those queues. if !disabled { serverNames := map[gomatrixserverlib.ServerName]struct{}{} @@ -135,14 +137,6 @@ func NewOutgoingQueues( return queues } -// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables -// around together -type SigningInfo struct { - ServerName gomatrixserverlib.ServerName - KeyID gomatrixserverlib.KeyID - PrivateKey ed25519.PrivateKey -} - type queuedPDU struct { receipt *shared.Receipt pdu *gomatrixserverlib.HeaderedEvent @@ -199,11 +193,10 @@ func (oqs *OutgoingQueues) SendEvent( log.Trace("Federation is disabled, not sending event") return nil } - if origin != oqs.origin { - // TODO: Support virtual hosting; gh issue #577. + if _, ok := oqs.signing[origin]; !ok { return fmt.Errorf( - "sendevent: unexpected server to send as: got %q expected %q", - origin, oqs.origin, + "sendevent: unexpected server to send as %q", + origin, ) } @@ -214,7 +207,9 @@ func (oqs *OutgoingQueues) SendEvent( destmap[d] = struct{}{} } delete(destmap, oqs.origin) - delete(destmap, oqs.signing.ServerName) + for local := range oqs.signing { + delete(destmap, local) + } // Check if any of the destinations are prohibited by server ACLs. for destination := range destmap { @@ -288,11 +283,10 @@ func (oqs *OutgoingQueues) SendEDU( log.Trace("Federation is disabled, not sending EDU") return nil } - if origin != oqs.origin { - // TODO: Support virtual hosting; gh issue #577. + if _, ok := oqs.signing[origin]; !ok { return fmt.Errorf( - "sendevent: unexpected server to send as: got %q expected %q", - origin, oqs.origin, + "sendevent: unexpected server to send as %q", + origin, ) } @@ -303,7 +297,9 @@ func (oqs *OutgoingQueues) SendEDU( destmap[d] = struct{}{} } delete(destmap, oqs.origin) - delete(destmap, oqs.signing.ServerName) + for local := range oqs.signing { + delete(destmap, local) + } // There is absolutely no guarantee that the EDU will have a room_id // field, as it is not required by the spec. However, if it *does* @@ -378,14 +374,24 @@ func (oqs *OutgoingQueues) SendEDU( return nil } +// IsServerBlacklisted returns whether or not the provided server is currently +// blacklisted. +func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool { + return oqs.statistics.ForServer(srv).Blacklisted() +} + // RetryServer attempts to resend events to the given server if we had given up. func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { if oqs.disabled { return } - oqs.statistics.ForServer(srv).RemoveBlacklist() + + serverStatistics := oqs.statistics.ForServer(srv) + forceWakeup := serverStatistics.Blacklisted() + serverStatistics.RemoveBlacklist() + serverStatistics.ClearBackoff() + if queue := oqs.getQueue(srv); queue != nil { - queue.statistics.ClearBackoff() - queue.wakeQueueIfNeeded() + queue.wakeQueueIfEventsPending(forceWakeup) } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 7ef4646f7..b2ec4b836 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T } rs := &stubFederationRoomServerAPI{} stats := statistics.NewStatistics(db, failuresUntilBlacklist) - signingInfo := &SigningInfo{ - KeyID: "ed21019:auto", - PrivateKey: test.PrivateKeyA, - ServerName: "localhost", + signingInfo := []*gomatrixserverlib.SigningIdentity{ + { + KeyID: "ed21019:auto", + PrivateKey: test.PrivateKeyA, + ServerName: "localhost", + }, } queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 7b9ca66f6..272f5e9d8 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -82,7 +82,8 @@ func Backfill( BackwardsExtremities: map[string][]string{ "": eIDs, }, - ServerName: request.Origin(), + ServerName: request.Origin(), + VirtualHost: request.Destination(), } if req.Limit, err = strconv.Atoi(limit); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") @@ -123,7 +124,7 @@ func Backfill( } txn := gomatrixserverlib.Transaction{ - Origin: cfg.Matrix.ServerName, + Origin: request.Destination(), PDUs: eventJSONs, OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 504204504..f424fcacd 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -140,6 +140,21 @@ func processInvite( } } + if event.StateKey() == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The invite event has no state key"), + } + } + + _, domain, err := cfg.Matrix.SplitLocalID('@', *event.StateKey()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)), + } + } + // Check that the event is signed by the server sending the request. redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) if err != nil { @@ -175,7 +190,7 @@ func processInvite( // Sign the event so that other servers will know that we have received the invite. signedEvent := event.Sign( - string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, + string(domain), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, ) // Add the invite event to the roomserver. diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 74d065e59..03809df75 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -131,10 +131,20 @@ func MakeJoin( return jsonerror.InternalServerError() } + identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound( + fmt.Sprintf("Server name %q does not exist", request.Destination()), + ), + } + } + queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: verRes.RoomVersion, } - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 8931830f3..dc262cfde 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -134,36 +134,53 @@ func ClaimOneTimeKeys( // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys -func LocalKeys(cfg *config.FederationAPI) util.JSONResponse { - keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) +func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse { + keys, err := localKeys(cfg, serverName) if err != nil { - return util.ErrorResponse(err) + return util.MessageResponse(http.StatusNotFound, err.Error()) } return util.JSONResponse{Code: http.StatusOK, JSON: keys} } -func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { +func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys - - keys.ServerName = cfg.Matrix.ServerName - keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) - - publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) - - keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ - cfg.Matrix.KeyID: { - Key: gomatrixserverlib.Base64Bytes(publicKey), - }, - } - - keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} - for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys { - keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{ - VerifyKey: gomatrixserverlib.VerifyKey{ - Key: oldVerifyKey.PublicKey, - }, - ExpiredTS: oldVerifyKey.ExpiredAt, + var identity *gomatrixserverlib.SigningIdentity + var err error + if virtualHost := cfg.Matrix.VirtualHostForHTTPHost(serverName); virtualHost == nil { + if identity, err = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName); err != nil { + return nil, err } + publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) + keys.ServerName = cfg.Matrix.ServerName + keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) + keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ + cfg.Matrix.KeyID: { + Key: gomatrixserverlib.Base64Bytes(publicKey), + }, + } + keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} + for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys { + keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{ + VerifyKey: gomatrixserverlib.VerifyKey{ + Key: oldVerifyKey.PublicKey, + }, + ExpiredTS: oldVerifyKey.ExpiredAt, + } + } + } else { + if identity, err = cfg.Matrix.SigningIdentityFor(virtualHost.ServerName); err != nil { + return nil, err + } + publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey) + keys.ServerName = virtualHost.ServerName + keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod)) + keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ + virtualHost.KeyID: { + Key: gomatrixserverlib.Base64Bytes(publicKey), + }, + } + // TODO: Virtual hosts probably want to be able to specify old signing + // keys too, just in case } toSign, err := json.Marshal(keys.ServerKeyFields) @@ -172,13 +189,9 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver } keys.Raw, err = gomatrixserverlib.SignJSON( - string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, + string(identity.ServerName), identity.KeyID, identity.PrivateKey, toSign, ) - if err != nil { - return nil, err - } - - return &keys, nil + return &keys, err } func NotaryKeys( @@ -186,6 +199,14 @@ func NotaryKeys( fsAPI federationAPI.FederationInternalAPI, req *gomatrixserverlib.PublicKeyNotaryLookupRequest, ) util.JSONResponse { + serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal + if !cfg.Matrix.IsLocalServerName(serverName) { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("Server name not known"), + } + } + if req == nil { req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { @@ -201,7 +222,7 @@ func NotaryKeys( for serverName, kidToCriteria := range req.ServerKeys { var keyList []gomatrixserverlib.ServerKeys if serverName == cfg.Matrix.ServerName { - if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { + if k, err := localKeys(cfg, serverName); err == nil { keyList = append(keyList, *k) } else { return util.ErrorResponse(err) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index a67e4e28b..f1e9f49ba 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -13,6 +13,7 @@ package routing import ( + "fmt" "net/http" "time" @@ -60,8 +61,18 @@ func MakeLeave( return jsonerror.InternalServerError() } + identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound( + fmt.Sprintf("Server name %q does not exist", request.Destination()), + ), + } + } + var queryRes api.QueryLatestEventsAndStateResponse - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index f672811af..e4d2230ad 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -42,16 +41,9 @@ func GetProfile( } } - _, domain, err := gomatrixserverlib.SplitID('@', userID) + _, domain, err := cfg.Matrix.SplitLocalID('@', userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument(fmt.Sprintf("Format of user ID %q is invalid", userID)), - } - } - - if domain != cfg.Matrix.ServerName { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)), diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 316c61a14..e6dc52601 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -83,7 +83,7 @@ func RoomAliasToID( } } } else { - resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, roomAlias) + resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, cfg.Matrix.ServerName, roomAlias) if err != nil { switch x := err.(type) { case gomatrix.HTTPError: diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 9f16e5093..0a3ab7a88 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -74,7 +74,7 @@ func Setup( } localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { - return LocalKeys(cfg) + return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host)) }) notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index b3bbaa394..a146d85bd 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -197,12 +197,12 @@ type txnReq struct { // A subset of FederationClient functionality that txn requires. Useful for testing. type txnFederationClient interface { - LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) - LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, + LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } @@ -287,6 +287,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res []*gomatrixserverlib.HeaderedEvent{ event.Headered(roomVersion), }, + t.Destination, t.Origin, api.DoNotSendToOtherServers, nil, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 1c796f542..b8bfe0221 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -147,7 +147,7 @@ type txnFedClient struct { getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) } -func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( +func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) { fmt.Println("testFederationClient.LookupState", eventID) @@ -159,7 +159,7 @@ func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.Serv res = r return } -func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { +func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { fmt.Println("testFederationClient.LookupStateIDs", eventID) r, ok := c.stateIDs[eventID] if !ok { @@ -169,7 +169,7 @@ func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.S res = r return } -func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { +func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { fmt.Println("testFederationClient.GetEvent", eventID) r, ok := c.getEvent[eventID] if !ok { @@ -179,7 +179,7 @@ func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerN res = r return } -func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, +func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { return c.getMissingEvents(missing) } diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index ccde9168e..d07faef39 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -90,7 +90,17 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil, false); err != nil { + if err := api.SendEvents( + req.Context(), + rsAPI, + api.KindNew, + evs, + cfg.Matrix.ServerName, // TODO: which virtual host? + "TODO", + cfg.Matrix.ServerName, + nil, + false, + ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -126,6 +136,14 @@ func ExchangeThirdPartyInvite( } } + _, senderDomain, err := cfg.Matrix.SplitLocalID('@', builder.Sender) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()), + } + } + // Check that the state key is correct. _, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey) if err != nil { @@ -171,7 +189,7 @@ func ExchangeThirdPartyInvite( util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") return jsonerror.InternalServerError() } - signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) + signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() @@ -189,6 +207,7 @@ func ExchangeThirdPartyInvite( []*gomatrixserverlib.HeaderedEvent{ inviteEvent.Headered(verRes.RoomVersion), }, + request.Destination(), request.Origin(), cfg.Matrix.ServerName, nil, @@ -341,7 +360,7 @@ func buildMembershipEvent( // them responded with an error. func sendToRemoteServer( ctx context.Context, inv invite, - federation federationAPI.FederationClient, _ *config.FederationAPI, + federation federationAPI.FederationClient, cfg *config.FederationAPI, builder gomatrixserverlib.EventBuilder, ) (err error) { remoteServers := make([]gomatrixserverlib.ServerName, 2) @@ -357,7 +376,7 @@ func sendToRemoteServer( } for _, server := range remoteServers { - err = federation.ExchangeThirdPartyInvite(ctx, server, builder) + err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, builder) if err == nil { return } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 09098cd1e..b15b8bfae 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -32,7 +32,7 @@ type Database interface { GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. - GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationapi/storage/postgres/joined_hosts_table.go b/federationapi/storage/postgres/joined_hosts_table.go index 5c95b72a8..9a3977560 100644 --- a/federationapi/storage/postgres/joined_hosts_table.go +++ b/federationapi/storage/postgres/joined_hosts_table.go @@ -66,14 +66,20 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" +const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" + + " SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" + + ");" + type joinedHostsStatements struct { - db *sql.DB - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - deleteJoinedHostsForRoomStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt - selectJoinedHostsForRoomsStmt *sql.Stmt + db *sql.DB + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + deleteJoinedHostsForRoomStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + selectJoinedHostsForRoomsStmt *sql.Stmt + selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt } func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { return } + if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil { + return + } return } @@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( } func (s *joinedHostsStatements) SelectJoinedHostsForRooms( - ctx context.Context, roomIDs []string, + ctx context.Context, roomIDs []string, excludingBlacklisted bool, ) ([]gomatrixserverlib.ServerName, error) { - rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + stmt := s.selectJoinedHostsForRoomsStmt + if excludingBlacklisted { + stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt + } + rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs)) if err != nil { return nil, err } diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index a33fa4a43..fe84e932e 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { return nil, err } + blacklist, err := NewPostgresBlacklistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { return nil, err @@ -58,10 +62,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - blacklist, err := NewPostgresBlacklistTable(d.db) - if err != nil { - return nil, err - } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 4fabff7d4..1e1ea9e17 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -117,15 +117,17 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { - servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) +func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { + servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err } if excludeSelf { for i, server := range servers { if d.IsLocalServerName(server) { - servers = append(servers[:i], servers[i+1:]...) + copy(servers[i:], servers[i+1:]) + servers = servers[:len(servers)-1] + break } } } diff --git a/federationapi/storage/sqlite3/joined_hosts_table.go b/federationapi/storage/sqlite3/joined_hosts_table.go index e0e0f2873..83112c150 100644 --- a/federationapi/storage/sqlite3/joined_hosts_table.go +++ b/federationapi/storage/sqlite3/joined_hosts_table.go @@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" +const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" + + " SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" + + ");" + type joinedHostsStatements struct { db *sql.DB insertJoinedHostsStmt *sql.Stmt @@ -74,6 +79,7 @@ type joinedHostsStatements struct { selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic + // selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic } func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( } func (s *joinedHostsStatements) SelectJoinedHostsForRooms( - ctx context.Context, roomIDs []string, + ctx context.Context, roomIDs []string, excludingBlacklisted bool, ) ([]gomatrixserverlib.ServerName, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i := range roomIDs { iRoomIDs[i] = roomIDs[i] } - - sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + query := selectJoinedHostsForRoomsSQL + if excludingBlacklisted { + query = selectJoinedHostsForRoomsExcludingBlacklistedSQL + } + sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) if err != nil { return nil, err diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index e86ac817b..d13b5defc 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -41,6 +41,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } + blacklist, err := NewSQLiteBlacklistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { return nil, err @@ -57,10 +61,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - blacklist, err := NewSQLiteBlacklistTable(d.db) - if err != nil { - return nil, err - } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 3c116a1d0..9f4e86a6e 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -58,7 +58,7 @@ type FederationJoinedHosts interface { SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) + SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error) } type FederationBlacklist interface { diff --git a/go.mod b/go.mod index e4d2cbd60..6d00e80dc 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/Masterminds/semver/v3 v3.1.1 github.com/blevesearch/bleve/v2 v2.3.4 github.com/codeclysm/extract v2.2.0+incompatible - github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d + github.com/dgraph-io/ristretto v0.1.1 github.com/docker/docker v20.10.19+incompatible github.com/docker/go-connections v0.4.0 github.com/getsentry/sentry-go v0.14.0 @@ -22,12 +22,12 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e - github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 + github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 - github.com/nats-io/nats-server/v2 v2.9.4 - github.com/nats-io/nats.go v1.19.0 + github.com/nats-io/nats-server/v2 v2.9.8 + github.com/nats-io/nats.go v1.20.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 diff --git a/go.sum b/go.sum index 1a61ff544..11e1a5de7 100644 --- a/go.sum +++ b/go.sum @@ -139,8 +139,8 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d h1:Wrc3UKTS+cffkOx0xRGFC+ZesNuTfn0ThvEC72N0krk= -github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d/go.mod h1:RAy2GVV4sTWVlNMavv3xhLsk18rxhfhDnombTe6EF5c= +github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= +github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= @@ -346,10 +346,10 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e h1:6I34fdyiHMRCxL6GOb/G8ZyI1WWlb6ZxCF2hIGSMSCc= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= -github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= +github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matryer/is v1.4.0 h1:sosSmIWwkYITGrxZ25ULNDeKiMNzFSr4V/eqBQP0PeE= @@ -385,10 +385,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.4 h1:GvRgv1936J/zYUwMg/cqtYaJ6L+bgeIOIvPslbesdow= -github.com/nats-io/nats-server/v2 v2.9.4/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= -github.com/nats-io/nats.go v1.19.0 h1:H6j8aBnTQFoVrTGB6Xjd903UMdE7jz6DS4YkmAqgZ9Q= -github.com/nats-io/nats.go v1.19.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= +github.com/nats-io/nats-server/v2 v2.9.8 h1:jgxZsv+A3Reb3MgwxaINcNq/za8xZInKhDg9Q0cGN1o= +github.com/nats-io/nats-server/v2 v2.9.8/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= +github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM= +github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= @@ -526,7 +526,6 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -697,6 +696,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index d96231963..c572d8830 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -38,7 +38,8 @@ var ErrRoomNoExists = errors.New("room does not exist") // Returns an error if something else went wrong func QueryAndBuildEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, + builder *gomatrixserverlib.EventBuilder, cfg *config.Global, + identity *gomatrixserverlib.SigningIdentity, evTime time.Time, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.HeaderedEvent, error) { if queryRes == nil { @@ -50,24 +51,24 @@ func QueryAndBuildEvent( // This can pass through a ErrRoomNoExists to the caller return nil, err } - return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes) + return BuildEvent(ctx, builder, cfg, identity, evTime, eventsNeeded, queryRes) } // BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse // provided. func BuildEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, + builder *gomatrixserverlib.EventBuilder, cfg *config.Global, + identity *gomatrixserverlib.SigningIdentity, evTime time.Time, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.HeaderedEvent, error) { - err := addPrevEventsToEvent(builder, eventsNeeded, queryRes) - if err != nil { + if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil { return nil, err } event, err := builder.Build( - evTime, cfg.ServerName, cfg.KeyID, - cfg.PrivateKey, queryRes.RoomVersion, + evTime, identity.ServerName, identity.KeyID, + identity.PrivateKey, queryRes.RoomVersion, ) if err != nil { return nil, err diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 039f5aef3..9c4ea59d4 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -42,10 +42,26 @@ type BasicAuth struct { Password string `yaml:"password"` } +type AuthAPIOpts struct { + GuestAccessAllowed bool +} + +// AuthAPIOption is an option to MakeAuthAPI to add additional checks (e.g. guest access) to verify +// the user is allowed to do specific things. +type AuthAPIOption func(opts *AuthAPIOpts) + +// WithAllowGuests checks that guest users have access to this endpoint +func WithAllowGuests() AuthAPIOption { + return func(opts *AuthAPIOpts) { + opts.GuestAccessAllowed = true + } +} + // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( metricsName string, userAPI userapi.QueryAcccessTokenAPI, f func(*http.Request, *userapi.Device) util.JSONResponse, + checks ...AuthAPIOption, ) http.Handler { h := func(req *http.Request) util.JSONResponse { logger := util.GetLogger(req.Context()) @@ -76,6 +92,19 @@ func MakeAuthAPI( } }() + // apply additional checks, if any + opts := AuthAPIOpts{} + for _, opt := range checks { + opt(&opts) + } + + if !opts.GuestAccessAllowed && device.AccountType == userapi.AccountTypeGuest { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.GuestAccessForbidden("Guest access not allowed"), + } + } + jsonRes := f(req, device) // do not log 4xx as errors as they are client fails, not server fails if hub != nil && jsonRes.Code >= 500 { diff --git a/internal/version.go b/internal/version.go index 85b19046e..685237b9e 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 7 + VersionPatch = 8 VersionTag = "" // example: "rc1" ) diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go index 575e41281..cd911f8c6 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/keyserver/consumers/devicelistupdate.go @@ -30,12 +30,12 @@ import ( // DeviceListUpdateConsumer consumes device list updates that came in over federation. type DeviceListUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - updater *internal.DeviceListUpdater - serverName gomatrixserverlib.ServerName + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + updater *internal.DeviceListUpdater + isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. @@ -46,12 +46,12 @@ func NewDeviceListUpdateConsumer( updater *internal.DeviceListUpdater, ) *DeviceListUpdateConsumer { return &DeviceListUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - updater: updater, - serverName: cfg.Matrix.ServerName, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + updater: updater, + isLocalServerName: cfg.Matrix.IsLocalServerName, } } @@ -75,7 +75,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { return true - } else if serverName == t.serverName { + } else if t.isLocalServerName(serverName) { return true } else if serverName != origin { return true diff --git a/keyserver/consumers/signingkeyupdate.go b/keyserver/consumers/signingkeyupdate.go index 366e259b4..bcceaad15 100644 --- a/keyserver/consumers/signingkeyupdate.go +++ b/keyserver/consumers/signingkeyupdate.go @@ -31,12 +31,13 @@ import ( // SigningKeyUpdateConsumer consumes signing key updates that came in over federation. type SigningKeyUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - keyAPI *internal.KeyInternalAPI - cfg *config.KeyServer + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + keyAPI *internal.KeyInternalAPI + cfg *config.KeyServer + isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. @@ -47,12 +48,13 @@ func NewSigningKeyUpdateConsumer( keyAPI *internal.KeyInternalAPI, ) *SigningKeyUpdateConsumer { return &SigningKeyUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - keyAPI: keyAPI, - cfg: cfg, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + keyAPI: keyAPI, + cfg: cfg, + isLocalServerName: cfg.Matrix.IsLocalServerName, } } @@ -77,7 +79,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { logrus.WithError(err).Error("failed to split user id") return true - } else if serverName == t.cfg.Matrix.ServerName { + } else if t.isLocalServerName(serverName) { logrus.Warn("dropping device key update from ourself") return true } else if serverName != origin { diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 8b02f3d6c..8ff9dfc31 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -47,7 +47,6 @@ var ( ) ) -const defaultWaitTime = time.Minute const requestTimeout = time.Second * 30 func init() { @@ -97,6 +96,7 @@ type DeviceListUpdater struct { producer KeyChangeProducer fedClient fedsenderapi.KeyserverFederationAPI workerChans []chan gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // block on or timeout via a select. @@ -140,6 +140,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, + thisServer gomatrixserverlib.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -149,6 +150,7 @@ func NewDeviceListUpdater( api: api, producer: producer, fedClient: fedClient, + thisServer: thisServer, workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, @@ -436,8 +438,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go "server_name": serverName, "user_id": userID, }) - - res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) + res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return time.Minute * 10, err @@ -454,7 +455,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go } else if e.Code >= 300 { // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. - return time.Hour, err + return hourWaitTime, err } case net.Error: // Use the default waitTime, if it's a timeout. @@ -468,7 +469,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go // This is to avoid spamming remote servers, which may not be Matrix servers anymore. if e.Code >= 300 { logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") - return time.Hour, err + return hourWaitTime, err } default: // Something else failed diff --git a/keyserver/internal/device_list_update_default.go b/keyserver/internal/device_list_update_default.go new file mode 100644 index 000000000..7d357c951 --- /dev/null +++ b/keyserver/internal/device_list_update_default.go @@ -0,0 +1,22 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vw + +package internal + +import "time" + +const defaultWaitTime = time.Minute +const hourWaitTime = time.Hour diff --git a/keyserver/internal/device_list_update_sytest.go b/keyserver/internal/device_list_update_sytest.go new file mode 100644 index 000000000..1c60d2eb9 --- /dev/null +++ b/keyserver/internal/device_list_update_sytest.go @@ -0,0 +1,25 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vw + +package internal + +import "time" + +// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite +// results in a one-hour wait time from a previous device so the test times out. This is fine for +// production, but makes an otherwise passing test fail. +const defaultWaitTime = time.Second +const hourWaitTime = time.Second diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 28a13a0a0..a374c9516 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -129,7 +129,13 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { _, pkey, _ := ed25519.GenerateKey(nil) fedClient := gomatrixserverlib.NewFederationClient( - gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: gomatrixserverlib.ServerName("example.test"), + KeyID: gomatrixserverlib.KeyID("ed25519:test"), + PrivateKey: pkey, + }, + }, ) fedClient.Client = *gomatrixserverlib.NewClient( gomatrixserverlib.WithTransport(&roundTripper{tripper}), @@ -147,7 +153,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost") event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -219,7 +225,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -288,7 +294,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 92ee80d81..9a08a0bb7 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -33,16 +33,17 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/storage" + "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" ) type KeyInternalAPI struct { - DB storage.Database - ThisServer gomatrixserverlib.ServerName - FedClient fedsenderapi.KeyserverFederationAPI - UserAPI userapi.KeyserverUserAPI - Producer *producers.KeyChange - Updater *DeviceListUpdater + DB storage.Database + Cfg *config.KeyServer + FedClient fedsenderapi.KeyserverFederationAPI + UserAPI userapi.KeyserverUserAPI + Producer *producers.KeyChange + Updater *DeviceListUpdater } func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { @@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC nested[userID] = val domainToDeviceKeys[string(serverName)] = nested } - // claim local keys - if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok { + for domain, local := range domainToDeviceKeys { + if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + // claim local keys keys, err := a.DB.ClaimKeys(ctx, local) if err != nil { res.Error = &api.KeyError{ @@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON } } - delete(domainToDeviceKeys, string(a.ThisServer)) + delete(domainToDeviceKeys, domain) } if len(domainToDeviceKeys) > 0 { a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) @@ -142,7 +146,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( defer cancel() defer wg.Done() - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) mu.Lock() defer mu.Unlock() @@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } domain := string(serverName) // query local devices - if serverName == a.ThisServer { + if a.Cfg.Matrix.IsLocalServerName(serverName) { deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ @@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys( domains := map[string]struct{}{} for domain := range domainToDeviceKeys { - if domain == string(a.ThisServer) { + if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} } for domain := range domainToCrossSigningKeys { - if domain == string(a.ThisServer) { + if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} @@ -555,7 +559,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( if len(devKeys) == 0 { return } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp return @@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per if err != nil { continue // ignore invalid users } - if serverName != a.ThisServer { + if !a.Cfg.Matrix.IsLocalServerName(serverName) { continue // ignore remote users } if len(key.KeyJSON) == 0 { diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 9ae4f9ca3..a86c2da4e 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -53,12 +53,12 @@ func NewInternalAPI( DB: db, } ap := &internal.KeyInternalAPI{ - DB: db, - ThisServer: cfg.Matrix.ServerName, - FedClient: fedClient, - Producer: keyChangeProducer, + DB: db, + Cfg: cfg, + FedClient: fedClient, + Producer: keyChangeProducer, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable ap.Updater = updater go func() { if err := updater.Start(); err != nil { diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go index 63281adfb..d0fe50d00 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" type staleDeviceListsStatements struct { upsertStaleDeviceListStmt *sql.Stmt @@ -77,7 +77,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) return err } diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index fc2cc37c4..1e08b266c 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" type staleDeviceListsStatements struct { db *sql.DB @@ -80,7 +80,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) return err } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index a1373a62b..01e87ec8a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -177,6 +177,7 @@ type FederationRoomserverAPI interface { QueryBulkStateContentAPI // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error + QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 45a9ef497..88d523270 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -94,8 +94,9 @@ type TransactionID struct { // InputRoomEventsRequest is a request to InputRoomEvents type InputRoomEventsRequest struct { - InputRoomEvents []InputRoomEvent `json:"input_room_events"` - Asynchronous bool `json:"async"` + InputRoomEvents []InputRoomEvent `json:"input_room_events"` + Asynchronous bool `json:"async"` + VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` } // InputRoomEventsResponse is a response to InputRoomEvents diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 2536a4bb5..e70e5ea9c 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -148,6 +148,8 @@ type PerformBackfillRequest struct { Limit int `json:"limit"` // The server interested in the events. ServerName gomatrixserverlib.ServerName `json:"server_name"` + // Which virtual host are we doing this for? + VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` } // PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 8b031982c..252be557f 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -26,7 +26,7 @@ import ( func SendEvents( ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, events []*gomatrixserverlib.HeaderedEvent, - origin gomatrixserverlib.ServerName, + virtualHost, origin gomatrixserverlib.ServerName, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, async bool, ) error { @@ -40,14 +40,15 @@ func SendEvents( TransactionID: txnID, } } - return SendInputRoomEvents(ctx, rsAPI, ires, async) + return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async) } // SendEventWithState writes an event with the specified kind to the roomserver // with the state at the event as KindOutlier before it. Will not send any event that is // marked as `true` in haveEventIDs. func SendEventWithState( - ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, + ctx context.Context, rsAPI InputRoomEventsAPI, + virtualHost gomatrixserverlib.ServerName, kind Kind, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { @@ -85,17 +86,19 @@ func SendEventWithState( StateEventIDs: stateEventIDs, }) - return SendInputRoomEvents(ctx, rsAPI, ires, async) + return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async) } // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI InputRoomEventsAPI, + virtualHost gomatrixserverlib.ServerName, ires []InputRoomEvent, async bool, ) error { request := InputRoomEventsRequest{ InputRoomEvents: ires, Asynchronous: async, + VirtualHost: virtualHost, } var response InputRoomEventsResponse if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 175bb9310..329e6af7f 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -137,6 +137,11 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { + _, virtualHost, err := r.Cfg.Matrix.SplitLocalID('@', request.UserID) + if err != nil { + return err + } + roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) @@ -190,6 +195,16 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( sender = ev.Sender() } + _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', sender) + if err != nil { + return err + } + + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + return err + } + builder := &gomatrixserverlib.EventBuilder{ Sender: sender, RoomID: ev.RoomID(), @@ -211,12 +226,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, stateRes) + newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, stateRes) if err != nil { return err } - err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, r.ServerName, r.ServerName, nil, false) + err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false) if err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a11586a5..1a3626609 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -87,10 +87,10 @@ func NewRoomserverAPI( Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"), ServerACLs: serverACLs, Queryer: &query.Queryer{ - DB: roomserverDB, - Cache: base.Caches, - ServerName: base.Cfg.Global.ServerName, - ServerACLs: serverACLs, + DB: roomserverDB, + Cache: base.Caches, + IsLocalServerName: base.Cfg.Global.IsLocalServerName, + ServerACLs: serverACLs, }, // perform-er structs get initialised when we have a federation sender to use } @@ -127,13 +127,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Inputer: r.Inputer, } r.Joiner = &perform.Joiner{ - ServerName: r.Cfg.Matrix.ServerName, - Cfg: r.Cfg, - DB: r.DB, - FSAPI: r.fsAPI, - RSAPI: r, - Inputer: r.Inputer, - Queryer: r.Queryer, + Cfg: r.Cfg, + DB: r.DB, + FSAPI: r.fsAPI, + RSAPI: r, + Inputer: r.Inputer, + Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ ServerName: r.Cfg.Matrix.ServerName, @@ -163,10 +162,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio DB: r.DB, } r.Backfiller = &perform.Backfiller{ - ServerName: r.ServerName, - DB: r.DB, - FSAPI: r.fsAPI, - KeyRing: r.KeyRing, + IsLocalServerName: r.Cfg.Matrix.IsLocalServerName, + DB: r.DB, + FSAPI: r.fsAPI, + KeyRing: r.KeyRing, // Perspective servers are trusted to not lie about server keys, so we will also // prefer these servers when backfilling (assuming they are in the room) rather // than trying random servers diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go new file mode 100644 index 000000000..aa5c30e44 --- /dev/null +++ b/roomserver/internal/helpers/helpers_test.go @@ -0,0 +1,56 @@ +package helpers + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/setup/base" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + + "github.com/matrix-org/dendrite/roomserver/storage" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches) + if err != nil { + t.Fatalf("failed to create Database: %v", err) + } + return base, db, close +} + +func TestIsInvitePendingWithoutNID(t *testing.T) { + + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) + _ = bob + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + _, db, close := mustCreateDatabase(t, dbType) + defer close() + + // store all events + var authNIDs []types.EventNID + for _, x := range room.Events() { + + evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false) + assert.NoError(t, err) + authNIDs = append(authNIDs, evNID) + } + + // Alice should have no pending invites and should have a NID + pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID) + assert.NoError(t, err, "failed to get pending invites") + assert.False(t, pendingInvite, "unexpected pending invite") + + // Bob should have no pending invites and receive a new NID + pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID) + assert.NoError(t, err, "failed to get pending invites") + assert.False(t, pendingInvite, "unexpected pending invite") + }) +} diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index f5099ca11..e965691c9 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -278,7 +278,11 @@ func (w *worker) _next() { // a string, because we might want to return that to the caller if // it was a synchronous request. var errString string - if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { + if err = w.r.processRoomEvent( + w.r.ProcessContext.Context(), + gomatrixserverlib.ServerName(msg.Header.Get("virtual_host")), + &inputRoomEvent, + ); err != nil { switch err.(type) { case types.RejectedError: // Don't send events that were rejected to Sentry @@ -358,6 +362,7 @@ func (r *Inputer) queueInputRoomEvents( if replyTo != "" { msg.Header.Set("sync", replyTo) } + msg.Header.Set("virtual_host", string(request.VirtualHost)) msg.Data, err = json.Marshal(e) if err != nil { return nil, fmt.Errorf("json.Marshal: %w", err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 60160e8e5..10b8ee27f 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -23,6 +23,8 @@ import ( "fmt" "time" + "github.com/tidwall/gjson" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" @@ -67,6 +69,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, + virtualHost gomatrixserverlib.ServerName, input *api.InputRoomEvent, ) error { select { @@ -162,8 +165,9 @@ func (r *Inputer) processRoomEvent( if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: event.RoomID(), - ExcludeSelf: true, + RoomID: event.RoomID(), + ExcludeSelf: true, + ExcludeBlacklisted: true, } if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) @@ -198,7 +202,7 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { return fmt.Errorf("r.fetchAuthEvents: %w", err) } @@ -263,16 +267,17 @@ func (r *Inputer) processRoomEvent( // processRoomEvent. if len(serverRes.ServerNames) > 0 { missingState := missingStateReq{ - origin: input.Origin, - inputer: r, - db: r.DB, - roomInfo: roomInfo, - federation: r.FSAPI, - keys: r.KeyRing, - roomsMu: internal.NewMutexByRoom(), - servers: serverRes.ServerNames, - hadEvents: map[string]bool{}, - haveEvents: map[string]*gomatrixserverlib.Event{}, + origin: input.Origin, + virtualHost: virtualHost, + inputer: r, + db: r.DB, + roomInfo: roomInfo, + federation: r.FSAPI, + keys: r.KeyRing, + roomsMu: internal.NewMutexByRoom(), + servers: serverRes.ServerNames, + hadEvents: map[string]bool{}, + haveEvents: map[string]*gomatrixserverlib.Event{}, } var stateSnapshot *parsedRespState if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { @@ -409,6 +414,13 @@ func (r *Inputer) processRoomEvent( } } + // Handle remote room upgrades, e.g. remove published room + if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) { + if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil { + return fmt.Errorf("failed to handle remote room upgrade: %w", err) + } + } + // processing this event resulted in an event (which may not be the one we're processing) // being redacted. We are guaranteed to have both sides (the redaction/redacted event), // so notify downstream components to redact this event - they should have it if they've @@ -434,6 +446,13 @@ func (r *Inputer) processRoomEvent( return nil } +// handleRemoteRoomUpgrade updates published rooms and room aliases +func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixserverlib.Event) error { + oldRoomID := event.RoomID() + newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) +} + // processStateBefore works out what the state is before the event and // then checks the event auths against the state at the time. It also // tries to determine what the history visibility was of the event at @@ -539,6 +558,7 @@ func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, roomInfo *types.RoomInfo, + virtualHost gomatrixserverlib.ServerName, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, known map[string]*types.Event, @@ -589,7 +609,7 @@ func (r *Inputer) fetchAuthEvents( // Request the entire auth chain for the event in question. This should // contain all of the auth events — including ones that we already know — // so we'll need to filter through those in the next section. - res, err = r.FSAPI.GetEventAuth(ctx, serverName, event.RoomVersion, event.RoomID(), event.EventID()) + res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.RoomVersion, event.RoomID(), event.EventID()) if err != nil { logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) continue diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index f788c5b3a..03ac2b38d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -41,6 +41,7 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event { type missingStateReq struct { log *logrus.Entry + virtualHost gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName db storage.Database roomInfo *types.RoomInfo @@ -101,7 +102,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -157,7 +158,7 @@ func (t *missingStateReq) processEventWithMissingState( }) } for _, ire := range outlierRoomEvents { - if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { + if err = t.inputer.processRoomEvent(ctx, t.virtualHost, &ire); err != nil { if _, ok := err.(types.RejectedError); !ok { return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) } @@ -180,7 +181,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -199,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -519,7 +520,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve var missingResp *gomatrixserverlib.RespMissingEvents for _, server := range t.servers { var m gomatrixserverlib.RespMissingEvents - if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ + if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. EarliestEvents: latestEvents, @@ -635,7 +636,7 @@ func (t *missingStateReq) lookupMissingStateViaState( span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") defer span.Finish() - state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) + state, err := t.federation.LookupState(ctx, t.virtualHost, t.origin, roomID, eventID, roomVersion) if err != nil { return nil, err } @@ -675,7 +676,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5) for _, serverName := range t.servers { reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20) - stateIDs, err = t.federation.LookupStateIDs(reqctx, serverName, roomID, eventID) + stateIDs, err = t.federation.LookupStateIDs(reqctx, t.virtualHost, serverName, roomID, eventID) reqcancel() if err == nil { break @@ -855,7 +856,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs for _, serverName := range t.servers { reqctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) + txn, err := t.federation.GetEvent(reqctx, t.virtualHost, serverName, missingEventID) if err != nil || len(txn.PDUs) == 0 { t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID") if errors.Is(err, context.DeadlineExceeded) { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 0406fc8f4..d42f4e45d 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -139,7 +139,12 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil } - event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + continue + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -242,6 +247,15 @@ func (r *Admin) PerformAdminDownloadState( req *api.PerformAdminDownloadStateRequest, res *api.PerformAdminDownloadStateResponse, ) error { + _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', req.UserID) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.Cfg.Matrix.SplitLocalID: %s", err), + } + return nil + } + roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) if err != nil { res.Error = &api.PerformError{ @@ -273,7 +287,7 @@ func (r *Admin) PerformAdminDownloadState( for _, fwdExtremity := range fwdExtremities { var state gomatrixserverlib.RespState - state, err = r.Inputer.FSAPI.LookupState(ctx, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) + state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -331,7 +345,12 @@ func (r *Admin) PerformAdminDownloadState( Depth: depth, } - ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, queryRes) + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + return err + } + + ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 57e121ea2..069f017a9 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -37,10 +37,10 @@ import ( const maxBackfillServers = 5 type Backfiller struct { - ServerName gomatrixserverlib.ServerName - DB storage.Database - FSAPI federationAPI.RoomserverFederationAPI - KeyRing gomatrixserverlib.JSONVerifier + IsLocalServerName func(gomatrixserverlib.ServerName) bool + DB storage.Database + FSAPI federationAPI.RoomserverFederationAPI + KeyRing gomatrixserverlib.JSONVerifier // The servers which should be preferred above other servers when backfilling PreferServers []gomatrixserverlib.ServerName @@ -55,7 +55,7 @@ func (r *Backfiller) PerformBackfill( // if we are requesting the backfill then we need to do a federation hit // TODO: we could be more sensible and fetch as many events we already have then request the rest // which is what the syncapi does already. - if request.ServerName == r.ServerName { + if r.IsLocalServerName(request.ServerName) { return r.backfillViaFederation(ctx, request, response) } // someone else is requesting the backfill, try to service their request. @@ -112,16 +112,18 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform if info == nil || info.IsStub() { return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) } - requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities, r.PreferServers) + requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) // Request 100 items regardless of what the query asks for. // We don't want to go much higher than this. // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass // (so we don't need to hit /state_ids which the test has no listener for) // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( - ctx, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100) + ctx, req.VirtualHost, requester, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, + ) if err != nil { + logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed") return err } logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) @@ -144,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform var entries []types.StateEntry if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil { // attempt to fetch the missing events - r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) + r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs, req.VirtualHost) // try again entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true) if err != nil { @@ -173,7 +175,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just // best effort. func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - backfillRequester *backfillRequester, stateIDs []string) { + backfillRequester *backfillRequester, stateIDs []string, virtualHost gomatrixserverlib.ServerName) { servers := backfillRequester.servers @@ -198,7 +200,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue // already found } logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) - res, err := r.FSAPI.GetEvent(ctx, srv, id) + res, err := r.FSAPI.GetEvent(ctx, virtualHost, srv, id) if err != nil { logger.WithError(err).Warn("failed to get event from server") continue @@ -241,11 +243,12 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { - db storage.Database - fsAPI federationAPI.RoomserverFederationAPI - thisServer gomatrixserverlib.ServerName - preferServer map[gomatrixserverlib.ServerName]bool - bwExtrems map[string][]string + db storage.Database + fsAPI federationAPI.RoomserverFederationAPI + virtualHost gomatrixserverlib.ServerName + isLocalServerName func(gomatrixserverlib.ServerName) bool + preferServer map[gomatrixserverlib.ServerName]bool + bwExtrems map[string][]string // per-request state servers []gomatrixserverlib.ServerName @@ -255,7 +258,9 @@ type backfillRequester struct { } func newBackfillRequester( - db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, thisServer gomatrixserverlib.ServerName, + db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, + virtualHost gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, ) *backfillRequester { preferServer := make(map[gomatrixserverlib.ServerName]bool) @@ -265,7 +270,8 @@ func newBackfillRequester( return &backfillRequester{ db: db, fsAPI: fsAPI, - thisServer: thisServer, + virtualHost: virtualHost, + isLocalServerName: isLocalServerName, eventIDToBeforeStateIDs: make(map[string][]string), eventIDMap: make(map[string]*gomatrixserverlib.Event), bwExtrems: bwExtrems, @@ -450,7 +456,7 @@ FindSuccessor: } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -477,7 +483,7 @@ FindSuccessor: } var servers []gomatrixserverlib.ServerName for server := range serverSet { - if server == b.thisServer { + if b.isLocalServerName(server) { continue } if b.preferServer[server] { // insert at the front @@ -496,10 +502,10 @@ FindSuccessor: // Backfill performs a backfill request to the given server. // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid -func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, +func (b *backfillRequester) Backfill(ctx context.Context, origin, server gomatrixserverlib.ServerName, roomID string, limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { - tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs) + tx, err := b.fsAPI.Backfill(ctx, origin, server, roomID, limit, fromEventIDs) return tx, err } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 9d596ab30..4de008c66 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -39,11 +39,10 @@ import ( ) type Joiner struct { - ServerName gomatrixserverlib.ServerName - Cfg *config.RoomServer - FSAPI fsAPI.RoomserverFederationAPI - RSAPI rsAPI.RoomserverInternalAPI - DB storage.Database + Cfg *config.RoomServer + FSAPI fsAPI.RoomserverFederationAPI + RSAPI rsAPI.RoomserverInternalAPI + DB storage.Database Inputer *input.Inputer Queryer *query.Queryer @@ -197,7 +196,7 @@ func (r *Joiner) performJoinRoomByID( // Prepare the template for the join event. userID := req.UserID - _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) if err != nil { return "", "", &rsAPI.PerformError{ Code: rsAPI.PerformErrorBadRequest, @@ -283,7 +282,7 @@ func (r *Joiner) performJoinRoomByID( // locally on the homeserver. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, userDomain, &eb) switch err { case nil: @@ -410,7 +409,9 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( } func buildEvent( - ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, + ctx context.Context, db storage.Database, cfg *config.Global, + senderDomain gomatrixserverlib.ServerName, + builder *gomatrixserverlib.EventBuilder, ) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) if err != nil { @@ -438,7 +439,12 @@ func buildEvent( } } - ev, err := eventutil.BuildEvent(ctx, builder, cfg, time.Now(), &eventsNeeded, &queryRes) + identity, err := cfg.SigningIdentityFor(senderDomain) + if err != nil { + return nil, nil, err + } + + ev, err := eventutil.BuildEvent(ctx, builder, cfg, identity, time.Now(), &eventsNeeded, &queryRes) if err != nil { return nil, nil, err } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 49e4b479a..fa998e3e1 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -162,21 +162,21 @@ func (r *Leaver) performLeaveRoomByID( return nil, fmt.Errorf("eb.SetUnsigned: %w", err) } + // Get the sender domain. + _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', eb.Sender) + if serr != nil { + return nil, fmt.Errorf("sender %q is invalid", eb.Sender) + } + // We know that the user is in the room at this point so let's build // a leave event. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, senderDomain, &eb) if err != nil { return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) } - // Get the sender domain. - _, senderDomain, serr := gomatrixserverlib.SplitID('@', event.Sender()) - if serr != nil { - return nil, fmt.Errorf("sender %q is invalid", event.Sender()) - } - // Give our leave event to the roomserver input stream. The // roomserver will process the membership change and notify // downstream automatically. diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 38abe323c..02a19911c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -60,7 +60,7 @@ func (r *Upgrader) performRoomUpgrade( ) (string, *api.PerformError) { roomID := req.RoomID userID := req.UserID - _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) if err != nil { return "", &api.PerformError{ Code: api.PerformErrorNotAllowed, @@ -558,7 +558,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user SendAsServer: api.DoNotSendToOtherServers, }) } - if err = api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { + if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err), } @@ -595,8 +595,21 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err), } } + // Get the sender domain. + _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender) + if serr != nil { + return nil, &api.PerformError{ + Msg: fmt.Sprintf("Failed to split user ID %q: %s", builder.Sender, err), + } + } + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + return nil, &api.PerformError{ + Msg: fmt.Sprintf("Failed to get signing identity for %q: %s", senderDomain, err), + } + } var queryRes api.QueryLatestEventsAndStateResponse - headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, evTime, r.URSAPI, &queryRes) + headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &api.PerformError{ Code: api.PerformErrorNoRoom, @@ -686,7 +699,7 @@ func (r *Upgrader) sendHeaderedEvent( Origin: serverName, SendAsServer: sendAsServer, }) - if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { + if err := api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false); err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err), } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 89f9fcbda..b1a1f9102 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -38,10 +38,10 @@ import ( ) type Queryer struct { - DB storage.Database - Cache caching.RoomServerCaches - ServerName gomatrixserverlib.ServerName - ServerACLs *acls.ServerACLs + DB storage.Database + Cache caching.RoomServerCaches + IsLocalServerName func(gomatrixserverlib.ServerName) bool + ServerACLs *acls.ServerACLs } // QueryLatestEventsAndState implements api.RoomserverInternalAPI @@ -398,7 +398,7 @@ func (r *Queryer) QueryServerJoinedToRoom( } response.RoomExists = true - if request.ServerName == r.ServerName || request.ServerName == "" { + if r.IsLocalServerName(request.ServerName) || request.ServerName == "" { response.IsInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) if err != nil { return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 4e98af853..24b5515e5 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -44,7 +44,7 @@ func Test_SharedUsers(t *testing.T) { // SetFederationAPI starts the room event input consumer rsAPI.SetFederationAPI(nil, nil) // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 0a2a13254..a13f4f04d 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -172,4 +172,5 @@ type Database interface { ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index b6692333f..64e553a9c 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -103,6 +103,7 @@ func (d *Database) eventStateKeyNIDs( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) + eventStateKeys = util.UniqueStrings(eventStateKeys) nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) if err != nil { return nil, err @@ -110,6 +111,27 @@ func (d *Database) eventStateKeyNIDs( for eventStateKey, nid := range nids { result[eventStateKey] = nid } + // We received some nids, but are still missing some, work out which and create them + if len(eventStateKeys) > len(result) { + var nid types.EventStateKeyNID + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + for _, eventStateKey := range eventStateKeys { + if _, ok := result[eventStateKey]; ok { + continue + } + + nid, err = d.assignStateKeyNID(ctx, txn, eventStateKey) + if err != nil { + return err + } + result[eventStateKey] = nid + } + return nil + }) + if err != nil { + return nil, err + } + } return result, nil } @@ -1243,7 +1265,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } - eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) + eventStateKeyNIDMap, err := d.eventStateKeyNIDs(ctx, nil, eventStateKeys) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) } @@ -1309,7 +1331,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ if err != nil { return nil, err } - userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs) + userNIDsMap, err := d.eventStateKeyNIDs(ctx, nil, userIDs) if err != nil { return nil, err } @@ -1386,6 +1408,36 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { + + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // un-publish old room + if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil { + return fmt.Errorf("failed to unpublish room: %w", err) + } + // publish new room + if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil { + return fmt.Errorf("failed to publish room: %w", err) + } + + // Migrate any existing room aliases + aliases, err := d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, txn, oldRoomID) + if err != nil { + return fmt.Errorf("failed to get room aliases: %w", err) + } + + for _, alias := range aliases { + if err = d.RoomAliasesTable.DeleteRoomAlias(ctx, txn, alias); err != nil { + return fmt.Errorf("failed to remove room alias: %w", err) + } + if err = d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, newRoomID, eventSender); err != nil { + return fmt.Errorf("failed to set room alias: %w", err) + } + } + return nil + }) +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/setup/base/base.go b/setup/base/base.go index 6bcabef4b..e5a6a3c87 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -363,10 +363,10 @@ func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client { // CreateFederationClient creates a new federation client. Should only be called // once per component. func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { + identities := b.Cfg.Global.SigningIdentities() if b.Cfg.Global.DisableFederation { return gomatrixserverlib.NewFederationClient( - b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey, - gomatrixserverlib.WithTransport(noOpHTTPTransport), + identities, gomatrixserverlib.WithTransport(noOpHTTPTransport), ) } opts := []gomatrixserverlib.ClientOption{ @@ -378,8 +378,7 @@ func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationCli opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache)) } client := gomatrixserverlib.NewFederationClient( - b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, - b.Cfg.Global.PrivateKey, opts..., + identities, opts..., ) client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) return client diff --git a/setup/config/config.go b/setup/config/config.go index 47fb5e812..2b438f988 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -232,6 +232,21 @@ func loadConfig( return nil, err } + for _, v := range c.Global.VirtualHosts { + if v.KeyValidityPeriod == 0 { + v.KeyValidityPeriod = c.Global.KeyValidityPeriod + } + if v.PrivateKeyPath == "" || v.PrivateKey == nil || v.KeyID == "" { + v.KeyID = c.Global.KeyID + v.PrivateKey = c.Global.PrivateKey + continue + } + privateKeyPath := absPath(basePath, v.PrivateKeyPath) + if v.KeyID, v.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { + return nil, err + } + } + for _, key := range c.Global.OldVerifyKeys { switch { case key.PrivateKeyPath != "": diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 825772827..511951fe6 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "math/rand" "strconv" "strings" @@ -11,22 +12,16 @@ import ( ) type Global struct { - // The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. - ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + // Signing identity contains the server name, private key and key ID of + // the deployment. + gomatrixserverlib.SigningIdentity `yaml:",inline"` // The secondary server names, used for virtual hosting. - SecondaryServerNames []gomatrixserverlib.ServerName `yaml:"-"` + VirtualHosts []*VirtualHost `yaml:"-"` // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` - // The private key which will be used to sign requests and events. - PrivateKey ed25519.PrivateKey `yaml:"-"` - - // An arbitrary string used to uniquely identify the PrivateKey. Must start with the - // prefix "ed25519:". - KeyID gomatrixserverlib.KeyID `yaml:"-"` - // Information about old private keys that used to be used to sign requests and // events on this domain. They will not be used but will be advertised to other // servers that ask for them to help verify old events. @@ -114,6 +109,10 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { checkNotEmpty(configErrs, "global.server_name", string(c.ServerName)) checkNotEmpty(configErrs, "global.private_key", string(c.PrivateKeyPath)) + for _, v := range c.VirtualHosts { + v.Verify(configErrs) + } + c.JetStream.Verify(configErrs, isMonolith) c.Metrics.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith) @@ -127,14 +126,108 @@ func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool if c.ServerName == serverName { return true } - for _, secondaryName := range c.SecondaryServerNames { - if secondaryName == serverName { + for _, v := range c.VirtualHosts { + if v.ServerName == serverName { return true } } return false } +func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib.ServerName, error) { + u, s, err := gomatrixserverlib.SplitID(sigil, id) + if err != nil { + return u, s, err + } + if !c.IsLocalServerName(s) { + return u, s, fmt.Errorf("server name %q not known", s) + } + return u, s, nil +} + +func (c *Global) VirtualHost(serverName gomatrixserverlib.ServerName) *VirtualHost { + for _, v := range c.VirtualHosts { + if v.ServerName == serverName { + return v + } + } + return nil +} + +func (c *Global) VirtualHostForHTTPHost(serverName gomatrixserverlib.ServerName) *VirtualHost { + for _, v := range c.VirtualHosts { + if v.ServerName == serverName { + return v + } + for _, h := range v.MatchHTTPHosts { + if h == serverName { + return v + } + } + } + return nil +} + +func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.SigningIdentity, error) { + for _, id := range c.SigningIdentities() { + if id.ServerName == serverName { + return id, nil + } + } + return nil, fmt.Errorf("no signing identity %q", serverName) +} + +func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { + identities := make([]*gomatrixserverlib.SigningIdentity, 0, len(c.VirtualHosts)+1) + identities = append(identities, &c.SigningIdentity) + for _, v := range c.VirtualHosts { + identities = append(identities, &v.SigningIdentity) + } + return identities +} + +type VirtualHost struct { + // Signing identity contains the server name, private key and key ID of + // the virtual host. + gomatrixserverlib.SigningIdentity `yaml:",inline"` + + // Path to the private key. If not specified, the default global private key + // will be used instead. + PrivateKeyPath Path `yaml:"private_key"` + + // How long a remote server can cache our server key for before requesting it again. + // Increasing this number will reduce the number of requests made by remote servers + // for our key, but increases the period a compromised key will be considered valid + // by remote servers. + // Defaults to 24 hours. + KeyValidityPeriod time.Duration `yaml:"key_validity_period"` + + // Match these HTTP Host headers on the `/key/v2/server` endpoint, this needs + // to match all delegated names, likely including the port number too if + // the well-known delegation includes that also. + MatchHTTPHosts []gomatrixserverlib.ServerName `yaml:"match_http_hosts"` + + // Is registration enabled on this virtual host? + AllowRegistration bool `yaml:"allow_registration"` + + // Is guest registration enabled on this virtual host? + AllowGuests bool `yaml:"allow_guests"` +} + +func (v *VirtualHost) Verify(configErrs *ConfigErrors) { + checkNotEmpty(configErrs, "virtual_host.*.server_name", string(v.ServerName)) +} + +// RegistrationAllowed returns two bools, the first states whether registration +// is allowed for this virtual host and the second states whether guests are +// allowed for this virtual host. +func (v *VirtualHost) RegistrationAllowed() (bool, bool) { + if v == nil { + return false, false + } + return v.AllowRegistration, v.AllowGuests +} + type OldVerifyKeys struct { // Path to the private key. PrivateKeyPath Path `yaml:"private_key"` diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 1ec860b04..c1ce9583f 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -2,6 +2,7 @@ package jetstream import ( "context" + "errors" "fmt" "github.com/getsentry/sentry-go" @@ -72,6 +73,9 @@ func JetStreamConsumer( // just timed out and we should try again. continue } + } else if errors.Is(err, nats.ErrConsumerDeleted) { + // The consumer was deleted so stop. + return } else { // Something else went wrong, so we'll panic. sentry.CaptureException(err) diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 98502f5cb..bc369c166 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -397,7 +397,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen serversToQuery := rc.getServersForEventID(parentID) var result *MSC2836EventRelationshipsResponse for _, srv := range serversToQuery { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: parentID, Direction: "down", Limit: 100, @@ -484,7 +484,7 @@ func walkThread( // MSC2836EventRelationships performs an /event_relationships request to a remote server func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: eventID, DepthFirst: rc.req.DepthFirst, Direction: rc.req.Direction, @@ -665,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo }) } // we've got the data by this point so use a background context - err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) + err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, rc.serverName, ires, false) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index bc9df0f96..56c063598 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -57,7 +57,7 @@ func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache, ) error { - clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName)) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName), httputil.WithAllowGuests()) base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) @@ -433,7 +433,7 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) *gomatrixserver if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) + res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 96ebba7ef..92f081500 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -28,22 +28,20 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - notifier *notifier.Notifier - stream streams.StreamProvider - serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.SyncRoomserverAPI + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + notifier *notifier.Notifier + stream streams.StreamProvider + rsAPI roomserverAPI.SyncRoomserverAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. @@ -59,15 +57,14 @@ func NewOutputKeyChangeEventConsumer( stream streams.StreamProvider, ) *OutputKeyChangeEventConsumer { s := &OutputKeyChangeEventConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), - topic: topic, - db: store, - serverName: cfg.Matrix.ServerName, - rsAPI: rsAPI, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), + topic: topic, + db: store, + rsAPI: rsAPI, + notifier: notifier, + stream: stream, } return s diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index 8aaa65730..e39d43f94 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -34,14 +34,13 @@ import ( // OutputReceiptEventConsumer consumes events that originated in the EDU server. type OutputReceiptEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream streams.StreamProvider - notifier *notifier.Notifier - serverName gomatrixserverlib.ServerName + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + stream streams.StreamProvider + notifier *notifier.Notifier } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -55,14 +54,13 @@ func NewOutputReceiptEventConsumer( stream streams.StreamProvider, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), - durable: cfg.Matrix.JetStream.Durable("SyncAPIReceiptConsumer"), - db: store, - notifier: notifier, - stream: stream, - serverName: cfg.Matrix.ServerName, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), + durable: cfg.Matrix.JetStream.Durable("SyncAPIReceiptConsumer"), + db: store, + notifier: notifier, + stream: stream, } } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index f767615c8..1b67f5684 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -364,11 +364,7 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom // TODO: check that it's a join and not a profile change (means unmarshalling prev_content) if membership == gomatrixserverlib.Join { // check it's a local join - _, domain, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) - if err != nil { - return sp, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if domain != s.cfg.Matrix.ServerName { + if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil { return sp, nil } @@ -390,9 +386,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( if msg.Event.StateKey() == nil { return } - if _, serverName, err := gomatrixserverlib.SplitID('@', *msg.Event.StateKey()); err != nil { - return - } else if serverName != s.cfg.Matrix.ServerName { + if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil { return } pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 49d84cca3..356e83263 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -37,15 +37,15 @@ import ( // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. type OutputSendToDeviceEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - keyAPI keyapi.SyncKeyAPI - serverName gomatrixserverlib.ServerName // our server name - stream streams.StreamProvider - notifier *notifier.Notifier + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + keyAPI keyapi.SyncKeyAPI + isLocalServerName func(gomatrixserverlib.ServerName) bool + stream streams.StreamProvider + notifier *notifier.Notifier } // NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer. @@ -60,15 +60,15 @@ func NewOutputSendToDeviceEventConsumer( stream streams.StreamProvider, ) *OutputSendToDeviceEventConsumer { return &OutputSendToDeviceEventConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), - durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), - db: store, - keyAPI: keyAPI, - serverName: cfg.Matrix.ServerName, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), + db: store, + keyAPI: keyAPI, + isLocalServerName: cfg.Matrix.IsLocalServerName, + notifier: notifier, + stream: stream, } } @@ -89,7 +89,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] log.WithError(err).Errorf("send-to-device: failed to split user id, dropping message") return true } - if domain != s.serverName { + if !s.isLocalServerName(domain) { log.Tracef("ignoring send-to-device event with destination %s", domain) return true } @@ -114,7 +114,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] if output.Type == "m.room_key_request" { requestingDeviceID := gjson.GetBytes(output.SendToDeviceEvent.Content, "requesting_device_id").Str _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) - if requestingDeviceID != "" && senderDomain != s.serverName { + if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) { // Mark the requesting device as stale, if we don't know about it. if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 86cf8e736..0d740ebfc 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -528,6 +528,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] BackwardsExtremities: backwardsExtremities, Limit: limit, ServerName: r.cfg.Matrix.ServerName, + VirtualHost: r.device.UserDomain(), }, &res) if err != nil { return nil, fmt.Errorf("PerformBackfill failed: %w", err) diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index bc3ad2384..4cc1a6a85 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -51,7 +51,7 @@ func Setup( // TODO: Add AS support for all handlers below. v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -59,7 +59,7 @@ func Setup( return util.ErrorResponse(err) } return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, rsAPI, cfg, srp, lazyLoadCache) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/event/{eventID}", httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -68,7 +68,7 @@ func Setup( return util.ErrorResponse(err) } return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, syncDB, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/filter", @@ -93,7 +93,7 @@ func Setup( v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingKeyChangeRequest(req, device) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/context/{eventId}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -108,7 +108,7 @@ func Setup( vars["roomId"], vars["eventId"], lazyLoadCache, ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}", @@ -122,7 +122,7 @@ func Setup( req, device, syncDB, rsAPI, vars["roomId"], vars["eventId"], "", "", ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}", @@ -136,7 +136,7 @@ func Setup( req, device, syncDB, rsAPI, vars["roomId"], vars["eventId"], vars["relType"], "", ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}", @@ -150,7 +150,7 @@ func Setup( req, device, syncDB, rsAPI, vars["roomId"], vars["eventId"], vars["relType"], vars["eventType"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/search", @@ -191,7 +191,7 @@ func Setup( at := req.URL.Query().Get("at") return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, false, membership, notMembership, at) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/joined_members", diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 0d15ba726..3d6e7a770 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -590,7 +590,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list stateKey := *event.StateKey() - if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental { + if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID { newStateEvents = append(newStateEvents, event) if !stateFilter.IncludeRedundantMembers { p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index 977815078..16a81e833 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -87,8 +87,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( } ev := gomatrixserverlib.ClientEvent{ - Type: gomatrixserverlib.MReceipt, - RoomID: roomID, + Type: gomatrixserverlib.MReceipt, } content := make(map[string]ReceiptMRead) for _, receipt := range receipts { diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index a4985dbf4..483274481 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -433,7 +433,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { beforeJoinBody := fmt.Sprintf("Before invite in a %s room", tc.historyVisibility) beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": beforeJoinBody}) eventsToSend := append(room.Events(), beforeJoinEv) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } syncUntil(t, base, aliceDev.AccessToken, false, @@ -472,7 +472,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } syncUntil(t, base, aliceDev.AccessToken, false, diff --git a/syncapi/types/types.go b/syncapi/types/types.go index bba9be4ba..c4c7b39fb 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -493,6 +493,13 @@ func (jr JoinResponse) MarshalJSON() ([]byte, error) { if jr.Ephemeral != nil && len(jr.Ephemeral.Events) == 0 { a.Ephemeral = nil } + if jr.Ephemeral != nil { + // Remove the room_id from EDUs, as this seems to cause Element Web + // to trigger notifications - https://github.com/vector-im/element-web/issues/17263 + for i := range jr.Ephemeral.Events { + jr.Ephemeral.Events[i].RoomID = "" + } + } if jr.AccountData != nil && len(jr.AccountData.Events) == 0 { a.AccountData = nil } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 08614ebb7..de19c33de 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -2,6 +2,7 @@ package types import ( "encoding/json" + "reflect" "testing" "github.com/matrix-org/gomatrixserverlib" @@ -63,3 +64,102 @@ func TestNewInviteResponse(t *testing.T) { t.Fatalf("Invite response didn't contain correct info") } } + +func TestJoinResponse_MarshalJSON(t *testing.T) { + type fields struct { + Summary *Summary + State *ClientEvents + Timeline *Timeline + Ephemeral *ClientEvents + AccountData *ClientEvents + UnreadNotifications *UnreadNotifications + } + tests := []struct { + name string + fields fields + want []byte + wantErr bool + }{ + { + name: "empty state is removed", + fields: fields{ + State: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty accountdata is removed", + fields: fields{ + AccountData: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty ephemeral is removed", + fields: fields{ + Ephemeral: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty timeline is removed", + fields: fields{ + Timeline: &Timeline{}, + }, + want: []byte("{}"), + }, + { + name: "empty summary is removed", + fields: fields{ + Summary: &Summary{}, + }, + want: []byte("{}"), + }, + { + name: "unread notifications are removed, if everything else is empty", + fields: fields{ + UnreadNotifications: &UnreadNotifications{}, + }, + want: []byte("{}"), + }, + { + name: "unread notifications are NOT removed, if state is set", + fields: fields{ + State: &ClientEvents{Events: []gomatrixserverlib.ClientEvent{{Content: []byte("{}")}}}, + UnreadNotifications: &UnreadNotifications{NotificationCount: 1}, + }, + want: []byte(`{"state":{"events":[{"content":{},"type":""}]},"unread_notifications":{"highlight_count":0,"notification_count":1}}`), + }, + { + name: "roomID is removed from EDUs", + fields: fields{ + Ephemeral: &ClientEvents{ + Events: []gomatrixserverlib.ClientEvent{ + {RoomID: "!someRandomRoomID:test", Content: []byte("{}")}, + }, + }, + }, + want: []byte(`{"ephemeral":{"events":[{"content":{},"type":""}]}}`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jr := JoinResponse{ + Summary: tt.fields.Summary, + State: tt.fields.State, + Timeline: tt.fields.Timeline, + Ephemeral: tt.fields.Ephemeral, + AccountData: tt.fields.AccountData, + UnreadNotifications: tt.fields.UnreadNotifications, + } + got, err := jr.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", string(got), string(tt.want)) + } + }) + } +} diff --git a/sytest-whitelist b/sytest-whitelist index a543ab523..75a2da635 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -758,4 +758,9 @@ Can filter rooms/{roomId}/members Current state appears in timeline in private history with many messages after AS can publish rooms in their own list AS and main public room lists are separate -AS can deactivate a user \ No newline at end of file +AS can deactivate a user +/upgrade preserves direct room state +local user has tags copied to the new room +remote user has tags copied to the new room +/upgrade moves remote aliases to the new room +Local and remote users' homeservers remove a room from their public directory on upgrade diff --git a/userapi/api/api.go b/userapi/api/api.go index c5ba2bddf..d3f5aefc8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -78,7 +78,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI - QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error + QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error @@ -335,9 +335,10 @@ type PerformAccountCreationResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformPasswordUpdateRequest struct { - Localpart string // Required: The localpart for this account. - Password string // Required: The new password to set. - LogoutDevices bool // Optional: Whether to log out all user devices. + Localpart string // Required: The localpart for this account. + ServerName gomatrixserverlib.ServerName // Required: The domain for this account. + Password string // Required: The new password to set. + LogoutDevices bool // Optional: Whether to log out all user devices. } // PerformAccountCreationResponse is the response for PerformAccountCreation @@ -518,7 +519,8 @@ const ( ) type QueryPushersRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryPushersResponse struct { @@ -526,14 +528,16 @@ type QueryPushersResponse struct { } type PerformPusherSetRequest struct { - Pusher // Anonymous field because that's how clientapi unmarshals it. - Localpart string - Append bool `json:"append"` + Pusher // Anonymous field because that's how clientapi unmarshals it. + Localpart string + ServerName gomatrixserverlib.ServerName + Append bool `json:"append"` } type PerformPusherDeletionRequest struct { - Localpart string - SessionID int64 // Pusher corresponding to this SessionID will not be deleted + Localpart string + ServerName gomatrixserverlib.ServerName + SessionID int64 } // Pusher represents a push notification subscriber @@ -571,10 +575,11 @@ type QueryPushRulesResponse struct { } type QueryNotificationsRequest struct { - Localpart string `json:"localpart"` // Required. - From string `json:"from,omitempty"` - Limit int `json:"limit,omitempty"` - Only string `json:"only,omitempty"` + Localpart string `json:"localpart"` // Required. + ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required. + From string `json:"from,omitempty"` + Limit int `json:"limit,omitempty"` + Only string `json:"only,omitempty"` } type QueryNotificationsResponse struct { @@ -601,12 +606,17 @@ type PerformSetAvatarURLResponse struct { Changed bool `json:"changed"` } +type QueryNumericLocalpartRequest struct { + ServerName gomatrixserverlib.ServerName +} + type QueryNumericLocalpartResponse struct { ID int64 } type QueryAccountAvailabilityRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryAccountAvailabilityResponse struct { @@ -614,7 +624,9 @@ type QueryAccountAvailabilityResponse struct { } type QueryAccountByPasswordRequest struct { - Localpart, PlaintextPassword string + Localpart string + ServerName gomatrixserverlib.ServerName + PlaintextPassword string } type QueryAccountByPasswordResponse struct { @@ -638,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct { } type QueryLocalpartForThreePIDResponse struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryThreePIDsForLocalpartRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryThreePIDsForLocalpartResponse struct { @@ -652,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct { type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest type PerformSaveThreePIDAssociationRequest struct { - ThreePID, Localpart, Medium string + ThreePID string + Localpart string + ServerName gomatrixserverlib.ServerName + Medium string } diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 90834f7e3..ce661770f 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet return err } -func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error { - err := t.Impl.QueryNumericLocalpart(ctx, res) +func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error { + err := t.Impl.QueryNumericLocalpart(ctx, req, res) util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res)) return err } diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 79f1bf06f..42ae72e77 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return false } - updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) + updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) if err != nil { log.WithError(err).Error("userapi EDU consumer") return false @@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats if !updated { return true } - if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil { log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") return false } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 97c17e188..5d8924dda 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -2,12 +2,16 @@ package consumers import ( "context" + "database/sql" "encoding/json" + "errors" "fmt" "strings" "sync" "time" + "github.com/tidwall/gjson" + "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -185,13 +189,115 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy } } +func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { + for _, membership := range localMembers { + // Copy any existing push rules from old -> new room + if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { + return err + } + + // preserve m.direct room state + if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil { + return err + } + + // copy existing m.tag entries, if any + if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { + return err + } + } + return nil +} + +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error { + pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName) + if err != nil { + return fmt.Errorf("failed to query pushrules for user: %w", err) + } + if pushRules == nil { + return nil + } + + for _, roomRule := range pushRules.Global.Room { + if roomRule.RuleID != oldRoomID { + continue + } + cpRool := *roomRule + cpRool.RuleID = newRoomID + pushRules.Global.Room = append(pushRules.Global.Room, &cpRool) + rules, err := json.Marshal(pushRules) + if err != nil { + return err + } + if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil { + return fmt.Errorf("failed to update pushrules: %w", err) + } + } + return nil +} + +// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error { + // this is most likely not a DM, so skip updating m.direct state + if roomSize > 2 { + return nil + } + // Get direct message state + directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct") + if err != nil { + return fmt.Errorf("failed to get m.direct from database: %w", err) + } + directChats := gjson.ParseBytes(directChatsRaw) + newDirectChats := make(map[string][]string) + // iterate over all userID -> roomIDs + directChats.ForEach(func(userID, roomIDs gjson.Result) bool { + var found bool + for _, roomID := range roomIDs.Array() { + newDirectChats[userID.Str] = append(newDirectChats[userID.Str], roomID.Str) + // add the new roomID to m.direct + if roomID.Str == oldRoomID { + found = true + newDirectChats[userID.Str] = append(newDirectChats[userID.Str], newRoomID) + } + } + // Only hit the database if we found the old room as a DM for this user + if found { + var data []byte + data, err = json.Marshal(newDirectChats) + if err != nil { + return true + } + if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil { + return true + } + } + return true + }) + if err != nil { + return fmt.Errorf("failed to update m.direct state") + } + return nil +} + +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error { + tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag") + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + if tag == nil { + return nil + } + return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag) +} + func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) if err != nil { return fmt.Errorf("s.localRoomMembers: %w", err) } - if event.Type() == gomatrixserverlib.MRoomMember { + switch { + case event.Type() == gomatrixserverlib.MRoomMember: cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) var member *localMembership member, err = newLocalMembership(&cevent) @@ -203,6 +309,15 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gom // should also be pushed to the target user. members = append(members, member) } + case event.Type() == "m.room.tombstone" && event.StateKeyEquals(""): + // Handle room upgrades + oldRoomID := event.RoomID() + newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str + if err = s.handleRoomUpgrade(ctx, oldRoomID, newRoomID, members, roomSize); err != nil { + // while inconvenient, this shouldn't stop us from sending push notifications + log.WithError(err).Errorf("UserAPI: failed to handle room upgrade for users") + } + } // TODO: run in parallel with localRoomMembers. @@ -377,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) if err != nil { - return err + return fmt.Errorf("s.evaluatePushRules: %w", err) } a, tweaks, err := pushrules.ActionsToTweaks(actions) if err != nil { - return err + return fmt.Errorf("pushrules.ActionsToTweaks: %w", err) } // TODO: support coalescing. if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { @@ -393,9 +508,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr return nil } - devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) + devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks) if err != nil { - return err + return fmt.Errorf("s.localPushDevices: %w", err) } n := &api.Notification{ @@ -412,18 +527,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr RoomID: event.RoomID(), TS: gomatrixserverlib.AsTimestamp(time.Now()), } - if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { - return err + if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil { + return fmt.Errorf("s.db.InsertNotification: %w", err) } if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { - return err + return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err) } // We do this after InsertNotification. Thus, this should always return >=1. - userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) + userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications) if err != nil { - return err + return fmt.Errorf("s.db.GetNotificationCount: %w", err) } log.WithFields(log.Fields{ @@ -474,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr } if len(rejected) > 0 { - s.deleteRejectedPushers(ctx, rejected, mem.Localpart) + s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain) } }() @@ -491,7 +606,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * } // Get accountdata to check if the event.Sender() is ignored by mem.LocalPart - data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list") + data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list") if err != nil { return nil, err } @@ -506,7 +621,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * return nil, fmt.Errorf("user %s is ignored", sender) } } - ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart) + ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain) if err != nil { return nil, err } @@ -578,10 +693,10 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err // localPushDevices pushes to the configured devices of a local // user. The map keys are [url][format]. -func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { - pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) +func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { + pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db) if err != nil { - return nil, "", err + return nil, "", fmt.Errorf("util.GetPushDevices: %w", err) } var profileTag string @@ -676,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri } // deleteRejectedPushers deletes the pushers associated with the given devices. -func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { +func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) { log.WithFields(log.Fields{ "localpart": localpart, "app_id0": devices[0].AppID, @@ -684,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev }).Warnf("Deleting pushers rejected by the HTTP push gateway") for _, d := range devices { - if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { + if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil { log.WithFields(log.Fields{ "localpart": localpart, }).WithError(err).Errorf("Unable to delete rejected pusher") diff --git a/userapi/internal/api.go b/userapi/internal/api.go index a814a89ab..7f30449e5 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if req.DataType == "" { return fmt.Errorf("data type must not be empty") } - if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil { + if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil { util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed") return fmt.Errorf("failed to save account data: %w", err) } @@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } - deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) if err != nil { logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") return err @@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } - if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil { + if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil { logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") return err } @@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P if serverName == "" { serverName = a.Config.Matrix.ServerName } - // XXXX: Use the server name here - acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %s is not local", serverName) + } + acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { @@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { - return err + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil { + return fmt.Errorf("a.DB.SetDisplayName: %w", err) } postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) @@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P } func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { - if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + if !a.Config.Matrix.IsLocalServerName(req.ServerName) { + return fmt.Errorf("server name %s is not local", req.ServerName) + } + if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil { return err } if req.LogoutDevices { - if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil { + if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil { return err } } @@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe if serverName == "" { serverName = a.Config.Matrix.ServerName } - _ = serverName - // XXXX: Use the server name here + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %s is not local", serverName) + } util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") - dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) + dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err } @@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) + devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { - err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs) + err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs) } if err != nil { return err @@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate( req *api.PerformLastSeenUpdateRequest, res *api.PerformLastSeenUpdateResponse, ) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil } func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") return err } - dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } + dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID) if err == sql.ErrNoRows { res.DeviceExists = false return nil @@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf return nil } - err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err @@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) } - prof, err := a.DB.GetProfileByLocalpart(ctx, local) + prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain) if err != nil { if err == sql.ErrNoRows { return nil @@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) } - devs, err := a.DB.GetDevicesByLocalpart(ctx, local) + devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain) if err != nil { return err } @@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } if req.DataType != "" { var data json.RawMessage - data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType) if err != nil { return err } @@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } return nil } - global, rooms, err := a.DB.GetAccountData(ctx, local) + global, rooms, err := a.DB.GetAccountData(ctx, local, domain) if err != nil { return err } @@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc if !a.Config.Matrix.IsLocalServerName(domain) { return nil } - acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) + acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain) if err != nil { return err } @@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe AccountType: api.AccountTypeAppService, } - localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) + localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) if err != nil { return nil, err } if localpart != "" { // AS is masquerading as another user // Verify that the user is registered - account, err := a.DB.GetAccountByLocalpart(ctx, localpart) + account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain) // Verify that the account exists and either appServiceID matches or // it belongs to the appservice user namespaces if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { @@ -631,7 +643,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a return err } - err = a.DB.DeactivateAccount(ctx, req.Localpart) + err = a.DB.DeactivateAccount(ctx, req.Localpart, serverName) res.AccountDeactivated = err == nil return err } @@ -794,7 +806,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query if req.Only == "highlight" { filter = tables.HighlightNotifications } - notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter) + notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter) if err != nil { return err } @@ -822,23 +834,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform } } if req.Pusher.Kind == "" { - return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) + return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName) } if req.Pusher.PushKeyTS == 0 { req.Pusher.PushKeyTS = int64(time.Now().Unix()) } - return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) + return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName) } func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { - pushers, err := a.DB.GetPushers(ctx, req.Localpart) + pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName) if err != nil { return err } for i := range pushers { logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) if pushers[i].SessionID != req.SessionID { - err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart) + err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName) if err != nil { return err } @@ -849,7 +861,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { var err error - res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart) + res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName) return err } @@ -875,11 +887,11 @@ func (a *UserInternalAPI) PerformPushRulesPut( } func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) } - pushRules, err := a.DB.QueryPushRules(ctx, localpart) + pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain) if err != nil { return fmt.Errorf("failed to query push rules: %w", err) } @@ -888,14 +900,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL) res.Profile = profile res.Changed = changed return err } -func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { - id, err := a.DB.GetNewNumericLocalpart(ctx) +func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error { + id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName) if err != nil { return err } @@ -905,12 +917,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { var err error - res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart) + res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName) return err } func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { - acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword) + acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword) switch err { case sql.ErrNoRows: // user does not exist return nil @@ -926,23 +938,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { - profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName) res.Profile = profile res.Changed = changed return err } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { - localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) + localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) if err != nil { return err } res.Localpart = localpart + res.ServerName = domain return nil } func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { - r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart) + r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName) if err != nil { return err } @@ -955,7 +968,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { - return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium) + return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) } const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index 87f25e5e2..3b211db5b 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) } - if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { + if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil { res.Data = nil if err == sql.ErrNoRows { return nil diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index aa5d46d9f..87ae058c2 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL( func (h *httpUserInternalAPI) QueryNumericLocalpart( ctx context.Context, + request *api.QueryNumericLocalpartRequest, response *api.QueryNumericLocalpartResponse, ) error { return httputil.CallInternalRPCAPI( "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, - h.httpClient, ctx, &struct{}{}, response, + h.httpClient, ctx, request, response, ) } diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 99148b760..661fecfae 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -15,12 +15,9 @@ package inthttp import ( - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) // nolint: gocyclo @@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), ) - // TODO: Look at the shape of this - internalAPIMux.Handle(QueryNumericLocalpartPath, - httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse { - response := api.QueryNumericLocalpartResponse{} - if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + QueryNumericLocalpartPath, + httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart), ) internalAPIMux.Handle( diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index f556ea352..51eaa9856 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err // GetAndSendNotificationData reads the database and sends data about unread // notifications to the Sync API server. func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error { - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return err } - ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID) + ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID) if err != nil { return err } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 28ef26559..c22b7658f 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -29,40 +29,40 @@ import ( ) type Profile interface { - GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error) + SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) } type Account interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) - GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) - GetNewNumericLocalpart(ctx context.Context) (int64, error) - CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) - DeactivateAccount(ctx context.Context, localpart string) (err error) - SetPassword(ctx context.Context, localpart string, plaintextPassword string) error + CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error) + GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) + CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) + GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) + DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error } type AccountData interface { - SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error - GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) + SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error + GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given // localpart, room ID and type. // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval - GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) - QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error) + GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) + QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error) } type Device interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked @@ -70,12 +70,12 @@ type Device interface { // an error will be returned. // If no device ID is given one is generated. // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error - RemoveDevices(ctx context.Context, localpart string, devices []string) error + CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error + RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error) } type KeyBackup interface { @@ -107,26 +107,26 @@ type OpenID interface { } type Pusher interface { - UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error - GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) - RemovePusher(ctx context.Context, appid, pushkey, localpart string) error + UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error + GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error RemovePushers(ctx context.Context, appid, pushkey string) error } type ThreePID interface { - SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) + SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) - GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) - GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) } type Notification interface { - InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error) - GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) - GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) - GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) DeleteOldNotifications(ctx context.Context) error } diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 0b6a3af6d..2a4777d74 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -29,27 +30,28 @@ const accountDataSchema = ` CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) room_id TEXT, -- The account data type type TEXT NOT NULL, -- The account data content - content TEXT NOT NULL, - - PRIMARY KEY(localpart, room_id, type) + content TEXT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type); ` const insertAccountDataSQL = ` - INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content + INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5) + ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4" type accountDataStatements struct { insertAccountDataStmt *sql.Stmt @@ -71,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { } func (s *accountDataStatements) InsertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) return } func (s *accountDataStatements) SelectAccountData( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, error, ) { - rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName) if err != nil { return nil, nil, err } @@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData( } func (s *accountDataStatements) SelectAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index d3d0bbcd5..9c46249a7 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/matrix-org/gomatrixserverlib" @@ -34,7 +35,8 @@ const accountsSchema = ` -- Stores data about accounts. CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this account was first created, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The password hash for this account. Can be NULL if this is a passwordless account. @@ -48,25 +50,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name); ` const insertAccountSQL = "" + - "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + - "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3" const deactivateAccountSQL = "" + - "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" + "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'" + "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1" type accountsStatements struct { insertAccountStmt *sql.Stmt @@ -122,59 +126,62 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error if accountType != api.AccountTypeAppService { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { - return nil, err + return nil, fmt.Errorf("insertAccountStmt: %w", err) } return &api.Account{ Localpart: localpart, - UserID: userutil.MakeUserID(localpart, s.serverName), - ServerName: s.serverName, + UserID: userutil.MakeUserID(localpart, serverName), + ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, }, nil } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart, passwordHash string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) return } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -185,19 +192,17 @@ func (s *accountsStatements) SelectAccountByLocalpart( acc.AppServiceID = appserviceIDPtr.String } - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName - + acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName) return &acc, nil } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - err = stmt.QueryRowContext(ctx).Scan(&id) + err = stmt.QueryRowContext(ctx, serverName).Scan(&id) return id + 1, err } diff --git a/userapi/storage/postgres/deltas/2022110411000000_server_names.go b/userapi/storage/postgres/deltas/2022110411000000_server_names.go new file mode 100644 index 000000000..375f775be --- /dev/null +++ b/userapi/storage/postgres/deltas/2022110411000000_server_names.go @@ -0,0 +1,82 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +var serverNamesTables = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_devices", + "userapi_notifications", + "userapi_openid_tokens", + "userapi_profiles", + "userapi_pushers", + "userapi_threepids", +} + +// These tables have a PRIMARY KEY constraint which we need to drop so +// that we can recreate a new unique index that contains the server name. +// If the new key doesn't exist (i.e. the database was created before the +// table rename migration) we'll try to drop the old one instead. +var serverNamesDropPK = map[string]string{ + "userapi_accounts": "account_accounts", + "userapi_account_datas": "account_data", + "userapi_profiles": "account_profiles", +} + +// These indices are out of date so let's drop them. They will get recreated +// automatically. +var serverNamesDropIndex = []string{ + "userapi_pusher_localpart_idx", + "userapi_pusher_app_id_pushkey_localpart_idx", + "userapi_pusher_app_id_pushkey_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';", + pq.QuoteIdentifier(table), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("add server name to %q error: %w", table, err) + } + } + for newTable, oldTable := range serverNamesDropPK { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;", + pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(newTable+"_pkey"), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop new PK from %q error: %w", newTable, err) + } + q = fmt.Sprintf( + "ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;", + pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(oldTable+"_pkey"), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop old PK from %q error: %w", newTable, err) + } + } + for _, index := range serverNamesDropIndex { + q := fmt.Sprintf( + "DROP INDEX IF EXISTS %s;", + pq.QuoteIdentifier(index), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop index %q error: %w", index, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/deltas/2022110411000001_server_names.go b/userapi/storage/postgres/deltas/2022110411000001_server_names.go new file mode 100644 index 000000000..04a47fa7b --- /dev/null +++ b/userapi/storage/postgres/deltas/2022110411000001_server_names.go @@ -0,0 +1,28 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "UPDATE %s SET server_name = %s WHERE server_name = '';", + pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("write server names to %q error: %w", table, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 8b7fbd6cf..2dd216189 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/lib/pq" @@ -50,6 +51,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( -- as it is smaller, makes it clearer that we only manage devices for our own users, and may make -- migration to different domain names easier. localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this devices was first recognised on the network, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The display name, human friendlier than device_id and updatable @@ -65,39 +67,39 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( ); -- Device IDs must be unique for a given user. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id); ` const insertDeviceSQL = "" + - "INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + + "INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + " RETURNING session_id" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" const deleteDeviceSQL = "" + - "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3" const deleteDevicesSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" type devicesStatements struct { insertDeviceStmt *sql.Stmt @@ -148,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) InsertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { - return nil, err + if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { + return nil, fmt.Errorf("insertDeviceStmt: %w", err) } return &api.Device{ ID: id, - UserID: userutil.MakeUserID(localpart, s.serverName), + UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, @@ -170,38 +173,45 @@ func (s *devicesStatements) InsertDevice( // deleteDevice removes a single device by id and user localpart. func (s *devicesStatements) DeleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + _, err := stmt.ExecContext(ctx, id, localpart, serverName) return err } // deleteDevices removes a single or multiple devices by ids and user localpart. // Returns an error if the execution failed. func (s *devicesStatements) DeleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) - _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) + _, err := stmt.ExecContext(ctx, localpart, serverName, pq.Array(devices)) return err } // deleteDevicesByLocalpart removes all devices for the // given user localpart. func (s *devicesStatements) DeleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID) return err } func (s *devicesStatements) UpdateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID) return err } @@ -210,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) dev.AccessToken = accessToken } return &dev, err @@ -222,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { var dev api.Device var displayName, ip sql.NullString var lastseenTS sql.NullInt64 stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) + err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) if displayName.Valid { dev.DisplayName = displayName.String } @@ -254,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName var lastseents sql.NullInt64 var displayName sql.NullString for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { @@ -266,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s if lastseents.Valid { dev.LastSeenTS = lastseents.Int64 } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } func (s *devicesStatements) SelectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID) if err != nil { return devices, err @@ -307,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart( dev.UserAgent = useragent.String } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) return err } diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index 420b1babc..a1cff2c46 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -43,6 +43,7 @@ const notificationSchema = ` CREATE TABLE IF NOT EXISTS userapi_notifications ( id BIGSERIAL PRIMARY KEY, localpart TEXT NOT NULL, + server_name TEXT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, stream_pos BIGINT NOT NULL, @@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications ( read BOOLEAN NOT NULL DEFAULT FALSE ); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id); ` const insertNotificationSQL = "" + - "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const deleteNotificationsUpToSQL = "" + - "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4" const updateNotificationReadSQL = "" + - "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1" const selectNotificationSQL = "" + - "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + - "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + - ") AND NOT read ORDER BY localpart, id LIMIT $4" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" + + "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $5" const selectNotificationCountSQL = "" + - "SELECT COUNT(DISTINCT(room_id)) FROM userapi_notifications WHERE localpart = $1 AND (" + - "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + "SELECT COUNT(DISTINCT(room_id)) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + ") AND NOT read" const selectRoomNotificationCountsSQL = "" + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + - "WHERE localpart = $1 AND room_id = $2 AND NOT read" + "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read" const cleanNotificationsSQL = "" + "DELETE FROM userapi_notifications WHERE" + @@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local if err != nil { return err } - _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs)) return err } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { return nil, 0, err @@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { - err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { - err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 06ae30d08..68d87f007 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When the token expires, as a unix timestamp (ms resolution). token_expires_at_ms BIGINT NOT NULL ); ` const insertOpenIDTokenSQL = "" + - "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" + "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { insertTokenStmt *sql.Stmt @@ -54,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, + token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS) return } @@ -69,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes + var localpart string + var serverName gomatrixserverlib.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, + &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, ) + openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve token from the db") diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 2753b23d9..df4e0db63 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -23,42 +23,46 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account avatar_url TEXT ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name); ` const insertProfileSQL = "" + - "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2" const setAvatarURLSQL = "" + "UPDATE userapi_profiles AS new" + " SET avatar_url = $1" + " FROM userapi_profiles AS old" + - " WHERE new.localpart = $2" + + " WHERE new.localpart = $2 AND new.server_name = $3" + " RETURNING new.display_name, old.avatar_url <> new.avatar_url" const setDisplayNameSQL = "" + "UPDATE userapi_profiles AS new" + " SET display_name = $1" + " FROM userapi_profiles AS old" + - " WHERE new.localpart = $2" + + " WHERE new.localpart = $2 AND new.server_name = $3" + " RETURNING new.avatar_url, old.display_name <> new.display_name" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { serverNoticesLocalpart string @@ -87,18 +91,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( + &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { return nil, err @@ -107,28 +113,34 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ - Localpart: localpart, - AvatarURL: avatarURL, + Localpart: localpart, + ServerName: string(serverName), + AvatarURL: avatarURL, } var changed bool stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed) + err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed) return profile, changed, err } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, + ServerName: string(serverName), DisplayName: displayName, } var changed bool stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed) + err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed) return profile, changed, err } @@ -146,7 +158,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 42618cf34..b2ef4966b 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( id BIGSERIAL PRIMARY KEY, -- The Matrix user ID localpart for this pusher localpart TEXT NOT NULL, + server_name TEXT NOT NULL, session_id BIGINT DEFAULT NULL, profile_tag TEXT, kind TEXT NOT NULL, @@ -46,22 +48,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( ); -- For faster retrieving by localpart. -CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name); -- Pushkey must be unique for a given app. CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); ` const insertPusherSQL = "" + - "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + - "ON CONFLICT (app_id, pushkey) DO UPDATE SET localpart = $1, session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + + "ON CONFLICT (app_id, pushkey, server_name) DO UPDATE SET localpart = $1, session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12" const selectPushersSQL = "" + - "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" const deletePusherSQL = "" + - "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4" const deletePushersByAppIdAndPushKeySQL = "" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" @@ -92,18 +94,19 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) - logrus.Debugf("Created pusher %d", session_id) + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err } func (s *pushersStatements) SelectPushers( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} - rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName) if err != nil { return pushers, err @@ -140,9 +143,10 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( - ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, + ctx context.Context, txn *sql.Tx, appid, pushkey, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err } diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index de97e60b1..01f9e12e8 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -15,6 +15,8 @@ package postgres import ( + "context" + "database/sql" "fmt" "time" @@ -47,18 +49,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpUniquePusher, Down: deltas.DownUniquePusher, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + }) if err = m.Up(base.Context()); err != nil { return nil, err } - accountDataTable, err := NewPostgresAccountDataTable(db) - if err != nil { - return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) - } accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) } + accountDataTable, err := NewPostgresAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) + } devicesTable, err := NewPostgresDevicesTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) @@ -99,6 +107,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) } + + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names populate", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNamesPopulate(ctx, txn, serverName) + }, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 11af76161..f41c43122 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -33,21 +34,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids ( medium TEXT NOT NULL DEFAULT 'email', -- The localpart of the Matrix user ID associated to this 3PID localpart TEXT NOT NULL, + server_name TEXT NOT NULL, PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart); +CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" + "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2" const insertThreePIDSQL = "" + - "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)" const deleteThreePIDSQL = "" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" @@ -75,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { - return "", nil + return "", "", nil } return } func (s *threepidStatements) SelectThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { return } @@ -109,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( } func (s *threepidStatements) InsertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, txn *sql.Tx, threepid, medium, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) return } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index e4242913d..4bd8a04e7 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -68,9 +68,10 @@ const ( // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword string, ) (*api.Account, error) { - hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) + hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName) if err != nil { return nil, err } @@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword( if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { return nil, err } - return d.Accounts.SelectAccountByLocalpart(ctx, localpart) + return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) } // GetProfileByLocalpart returns the profile associated with the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart. func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { - return d.Profiles.SelectProfileByLocalpart(ctx, localpart) + return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName) } // SetAvatarURL updates the avatar URL of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL) return err }) return @@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL( // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName) return err }) return @@ -117,14 +123,15 @@ func (d *Database) SetDisplayName( // SetPassword sets the account password to the given hash. func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword string, ) error { hash, err := d.hashPassword(plaintextPassword) if err != nil { return err } return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.Accounts.UpdatePassword(ctx, localpart, hash) + return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash) }) } @@ -132,21 +139,22 @@ func (d *Database) SetPassword( // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { // For guest accounts, we create a new numeric local part if accountType == api.AccountTypeGuest { var numLocalpart int64 - numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) + numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName) if err != nil { - return err + return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err) } localpart = strconv.FormatInt(numLocalpart, 10) plaintextPassword = "" appserviceID = "" } - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) + acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType) return err }) return @@ -155,7 +163,9 @@ func (d *Database) CreateAccount( // WARNING! This function assumes that the relevant mutexes have already // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var err error var account *api.Account @@ -167,28 +177,28 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { + if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil { return nil, sqlutil.ErrUserExists } - if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { - return nil, err + if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil { + return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err) } - pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) prbs, err := json.Marshal(pushRuleSets) if err != nil { - return nil, err + return nil, fmt.Errorf("json.Marshal: %w", err) } - if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { - return nil, err + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil { + return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err) } return account, nil } func (d *Database) QueryPushRules( ctx context.Context, - localpart string, + localpart string, serverName gomatrixserverlib.ServerName, ) (*pushrules.AccountRuleSets, error) { - data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules") + data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules") if err != nil { return nil, err } @@ -196,13 +206,13 @@ func (d *Database) QueryPushRules( // If we didn't find any default push rules then we should just generate some // fresh ones. if len(data) == 0 { - pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) prbs, err := json.Marshal(pushRuleSets) if err != nil { return nil, fmt.Errorf("failed to marshal default push rules: %w", err) } err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil { + if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil { return fmt.Errorf("failed to save default push rules: %w", dbErr) } return nil @@ -225,22 +235,23 @@ func (d *Database) QueryPushRules( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content) + return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content) }) } // GetAccountData returns account data related to a given localpart // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( +func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ( global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error, ) { - return d.AccountDatas.SelectAccountData(ctx, localpart) + return d.AccountDatas.SelectAccountData(ctx, localpart, serverName) } // GetAccountDataByType returns account data matching a given @@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { return d.AccountDatas.SelectAccountDataByType( - ctx, localpart, roomID, dataType, + ctx, localpart, serverName, roomID, dataType, ) } // GetNewNumericLocalpart generates and returns a new unused numeric localpart func (d *Database) GetNewNumericLocalpart( - ctx context.Context, + ctx context.Context, serverName gomatrixserverlib.ServerName, ) (int64, error) { - return d.Accounts.SelectNewNumericLocalpart(ctx, nil) + return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName) } func (d *Database) hashPassword(plaintext string) (hash string, err error) { @@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use") // If the third-party identifier is already part of an association, returns Err3PIDInUse. // Returns an error if there was a problem talking to the database. func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, + ctx context.Context, threepid string, + localpart string, serverName gomatrixserverlib.ServerName, + medium string, ) (err error) { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - user, err := d.ThreePIDs.SelectLocalpartForThreePID( + user, _, err := d.ThreePIDs.SelectLocalpartForThreePID( ctx, txn, threepid, medium, ) if err != nil { @@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation( return Err3PIDInUse } - return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) + return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName) }) } @@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation( // Returns an error if there was a problem talking to the database. func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) } @@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID( // If no association is known for this user, returns an empty slice. // Returns an error if there was an issue talking to the database. func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) + return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName) } // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) { + _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) if err == sql.ErrNoRows { return true, nil } @@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // GetAccountByLocalpart returns the account associated with the given localpart. // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { // try to get the account with lowercase localpart (majority) - acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart)) + acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName) if err == sql.ErrNoRows { - acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request + acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request } return acc, err } @@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi } // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { +func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.Accounts.DeactivateAccount(ctx, localpart) + return d.Accounts.DeactivateAccount(ctx, localpart, serverName) }) } // CreateOpenIDToken persists a new token that was issued for OpenID Connect func (d *Database) CreateOpenIDToken( ctx context.Context, - token, localpart string, + token, userID string, ) (int64, error) { + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return 0, nil + } expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS) }) return expiresAtMS, err } @@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken( // GetDeviceByID returns the device matching the given ID. // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { - return d.Devices.SelectDeviceByID(ctx, localpart, deviceID) + return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID) } // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Device, error) { - return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "") + return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap // If no device ID is given one is generated. // Returns the device on success. func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil && *deviceID != "" { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device - if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil { + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { return err } - dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) return err }) } else { @@ -588,7 +610,7 @@ func (d *Database) CreateDevice( returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error - dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) return err }) if returnErr == nil { @@ -614,10 +636,12 @@ func generateDeviceID() (string, error) { // UpdateDevice updates the given device with the display name. // Returns SQL error if there are problems and nil on success. func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName) + return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName) }) } @@ -626,10 +650,12 @@ func (d *Database) UpdateDevice( // If the devices don't exist, it will not return an error // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows { return err } return nil @@ -640,14 +666,16 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) (devices []api.Device, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID) if err != nil { return err } - if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows { return err } return nil @@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a last seen timestamp and the ip address. -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent) + return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent) }) } @@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } -func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { +func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) }) } -func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) { +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) + affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos) return err }) return } -func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) { +func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) + affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b) return err }) return } -func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter) +func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter) } -func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { - return d.Notifications.SelectCount(ctx, nil, localpart, filter) +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) { + return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter) } -func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { - return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) { + return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID) } func (d *Database) DeleteOldNotifications(ctx context.Context) error { @@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error { } func (d *Database) UpsertPusher( - ctx context.Context, p api.Pusher, localpart string, + ctx context.Context, p api.Pusher, + localpart string, serverName gomatrixserverlib.ServerName, ) error { data, err := json.Marshal(p.Data) if err != nil { @@ -766,25 +795,26 @@ func (d *Database) UpsertPusher( p.ProfileTag, p.Language, string(data), - localpart) + localpart, + serverName) }) } // GetPushers returns the pushers matching the given localpart. func (d *Database) GetPushers( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { - return d.Pushers.SelectPushers(ctx, nil, localpart) + return d.Pushers.SelectPushers(ctx, nil, localpart, serverName) } // RemovePusher deletes one pusher // Invoked when `append` is true and `kind` is null in // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set func (d *Database) RemovePusher( - ctx context.Context, appid, pushkey, localpart string, + ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) + err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName) if err == sql.ErrNoRows { return nil } diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index af12decb3..2fbdc5732 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -28,27 +29,28 @@ const accountDataSchema = ` CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) room_id TEXT, -- The account data type type TEXT NOT NULL, -- The account data content - content TEXT NOT NULL, - - PRIMARY KEY(localpart, room_id, type) + content TEXT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type); ` const insertAccountDataSQL = ` - INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 + INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5) + ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5 ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4" type accountDataStatements struct { db *sql.DB @@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { } func (s *accountDataStatements) InsertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) error { - _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content) return err } func (s *accountDataStatements) SelectAccountData( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, error, ) { - rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName) if err != nil { return nil, nil, err } @@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData( } func (s *accountDataStatements) SelectAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 671c1aa04..f4ebe2158 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -34,7 +34,8 @@ const accountsSchema = ` -- Stores data about accounts. CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this account was first created, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The password hash for this account. Can be NULL if this is a passwordless account. @@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name); ` const insertAccountSQL = "" + - "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + - "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3" const deactivateAccountSQL = "" + - "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" + "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1" type accountsStatements struct { db *sql.DB @@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, + hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error if accountType != api.AccountTypeAppService { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount( return &api.Account{ Localpart: localpart, - UserID: userutil.MakeUserID(localpart, s.serverName), - ServerName: s.serverName, + UserID: userutil.MakeUserID(localpart, serverName), + ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, }, nil } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart, passwordHash string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) return } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart( acc.AppServiceID = appserviceIDPtr.String } - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName - + acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName) return &acc, nil } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - err = stmt.QueryRowContext(ctx).Scan(&id) + err = stmt.QueryRowContext(ctx, serverName).Scan(&id) if err == sql.ErrNoRows { return 1, nil } diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go index 9158cb365..2de85005f 100644 --- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error { ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index a9224db6b..636ce4efc 100644 --- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { session_id INTEGER, device_id TEXT , localpart TEXT , + server_name TEXT NOT NULL, created_ts BIGINT, display_name TEXT, last_seen_ts BIGINT, diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go index 230bc1433..471e496cd 100644 --- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go new file mode 100644 index 000000000..c11ea6844 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go @@ -0,0 +1,108 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +var serverNamesTables = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_devices", + "userapi_notifications", + "userapi_openid_tokens", + "userapi_profiles", + "userapi_pushers", + "userapi_threepids", +} + +// These tables have a PRIMARY KEY constraint which we need to drop so +// that we can recreate a new unique index that contains the server name. +var serverNamesDropPK = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_profiles", +} + +// These indices are out of date so let's drop them. They will get recreated +// automatically. +var serverNamesDropIndex = []string{ + "userapi_pusher_localpart_idx", + "userapi_pusher_app_id_pushkey_localpart_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 { + continue + } + q = fmt.Sprintf( + "SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'", + pq.QuoteIdentifier(table), + ) + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 { + logrus.Infof("Table %s already has column, skipping", table) + continue + } + if c == 0 { + q = fmt.Sprintf( + "ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';", + pq.QuoteIdentifier(table), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("add server name to %q error: %w", table, err) + } + } + } + for _, table := range serverNamesDropPK { + q := fmt.Sprintf( + "SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + var sql string + if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 { + continue + } + q = fmt.Sprintf(` + %s; -- create temporary table + INSERT INTO %s SELECT * FROM %s; -- copy data + DROP TABLE %s; -- drop original table + ALTER TABLE %s RENAME TO %s; -- rename new table + `, + strings.Replace(sql, table, table+"_tmp", 1), // create temporary table + table+"_tmp", table, // copy data + table, // drop original table + table+"_tmp", table, // rename new table + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop PK from %q error: %w", table, err) + } + } + for _, index := range serverNamesDropIndex { + q := fmt.Sprintf( + "DROP INDEX IF EXISTS %s;", + pq.QuoteIdentifier(index), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop index %q error: %w", index, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go new file mode 100644 index 000000000..04a47fa7b --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go @@ -0,0 +1,28 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "UPDATE %s SET server_name = %s WHERE server_name = '';", + pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("write server names to %q error: %w", table, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index e53a08062..c5db34bd7 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( session_id INTEGER, device_id TEXT , localpart TEXT , + server_name TEXT NOT NULL, created_ts BIGINT, display_name TEXT, last_seen_ts BIGINT, ip TEXT, user_agent TEXT, - UNIQUE (localpart, device_id) + UNIQUE (localpart, server_name, device_id) ); ` const insertDeviceSQL = "" + - "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + "INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" const selectDevicesCountSQL = "" + "SELECT COUNT(access_token) FROM userapi_devices" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" const deleteDeviceSQL = "" + - "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3" const deleteDevicesSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" type devicesStatements struct { db *sql.DB @@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) InsertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 @@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice( return nil, err } sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { + if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } return &api.Device{ ID: id, - UserID: userutil.MakeUserID(localpart, s.serverName), + UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, @@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice( } func (s *devicesStatements) DeleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + _, err := stmt.ExecContext(ctx, id, localpart, serverName) return err } func (s *devicesStatements) DeleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) + orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1) prep, err := s.db.Prepare(orig) if err != nil { return err } stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) + params := make([]interface{}, len(devices)+2) params[0] = localpart + params[1] = serverName for i, v := range devices { - params[i+1] = v + params[i+2] = v } _, err = stmt.ExecContext(ctx, params...) return err } func (s *devicesStatements) DeleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID) return err } func (s *devicesStatements) UpdateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID) return err } @@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) dev.AccessToken = accessToken } return &dev, err @@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { var dev api.Device var displayName, ip sql.NullString stmt := s.selectDeviceByIDStmt var lastseenTS sql.NullInt64 - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) + err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) if displayName.Valid { dev.DisplayName = displayName.String } @@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID( } func (s *devicesStatements) SelectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID) if err != nil { return devices, err @@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( dev.UserAgent = useragent.String } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } @@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { @@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s if lastseents.Valid { dev.LastSeenTS = lastseents.Int64 } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) return err } diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index d124bfbc5..049fcbd06 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -43,6 +43,7 @@ const notificationSchema = ` CREATE TABLE IF NOT EXISTS userapi_notifications ( id INTEGER PRIMARY KEY AUTOINCREMENT, localpart TEXT NOT NULL, + server_name TEXT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, stream_pos BIGINT NOT NULL, @@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications ( read BOOLEAN NOT NULL DEFAULT FALSE ); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id); ` const insertNotificationSQL = "" + - "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const deleteNotificationsUpToSQL = "" + - "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4" const updateNotificationReadSQL = "" + - "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1" const selectNotificationSQL = "" + - "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + - "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + - ") AND NOT read ORDER BY localpart, id LIMIT $4" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" + + "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $5" const selectNotificationCountSQL = "" + - "SELECT COUNT(DISTINCT(room_id)) FROM userapi_notifications WHERE localpart = $1 AND (" + - "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + "SELECT COUNT(DISTINCT(room_id)) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + ") AND NOT read" const selectRoomNotificationCountsSQL = "" + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + - "WHERE localpart = $1 AND room_id = $2 AND NOT read" + "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read" const cleanNotificationsSQL = "" + "DELETE FROM userapi_notifications WHERE" + @@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local if err != nil { return err } - _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs)) return err } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { return nil, 0, err @@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { - err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { - err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index 875f1a9a5..f06429741 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -3,6 +3,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When the token expires, as a unix timestamp (ms resolution). token_expires_at_ms BIGINT NOT NULL ); ` const insertOpenIDTokenSQL = "" + - "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" + "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { db *sql.DB @@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) ( func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, + token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS) return } @@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes + var localpart string + var serverName gomatrixserverlib.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, + &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, ) + openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve token from the db") diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index b6130a1e3..867026d7a 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -23,36 +23,40 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account avatar_url TEXT ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name); ` const insertProfileSQL = "" + - "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { db *sql.DB @@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return err } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( + &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { return nil, err @@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ - Localpart: localpart, - AvatarURL: avatarURL, + Localpart: localpart, + ServerName: string(serverName), + AvatarURL: avatarURL, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName) return profile, true, err } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, + ServerName: string(serverName), DisplayName: displayName, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL) return profile, true, err } @@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index 4de0a9f06..c9d451dc5 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( id INTEGER PRIMARY KEY AUTOINCREMENT, -- The Matrix user ID localpart for this pusher localpart TEXT NOT NULL, + server_name TEXT NOT NULL, session_id BIGINT DEFAULT NULL, profile_tag TEXT, kind TEXT NOT NULL, @@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); -- For faster retrieving by localpart. -CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name); -- Pushkey must be unique for a given user and app. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name); ` const insertPusherSQL = "" + - "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + - "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + + "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12" const selectPushersSQL = "" + - "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" const deletePusherSQL = "" + - "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4" const deletePushersByAppIdAndPushKeySQL = "" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" @@ -95,18 +97,19 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) - logrus.Debugf("Created pusher %d", session_id) + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err } func (s *pushersStatements) SelectPushers( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} - rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) + rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName) if err != nil { return pushers, err @@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( - ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, + ctx context.Context, txn *sql.Tx, appid, pushkey, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err } diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index dd33dc0cf..85a1f7063 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -15,6 +15,8 @@ package sqlite3 import ( + "context" + "database/sql" "fmt" "time" @@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + }) if err = m.Up(base.Context()); err != nil { return nil, err } - accountDataTable, err := NewSQLiteAccountDataTable(db) - if err != nil { - return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) - } accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) } + accountDataTable, err := NewSQLiteAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) + } devicesTable, err := NewSQLiteDevicesTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) @@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) } + + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names populate", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNamesPopulate(ctx, txn, serverName) + }, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 73af139db..2db7d5887 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids ( medium TEXT NOT NULL DEFAULT 'email', -- The localpart of the Matrix user ID associated to this 3PID localpart TEXT NOT NULL, + server_name TEXT NOT NULL, PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart); +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" + "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2" const insertThreePIDSQL = "" + - "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)" const deleteThreePIDSQL = "" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" @@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { - return "", nil + return "", "", nil } return } func (s *threepidStatements) SelectThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { return } @@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( } func (s *threepidStatements) InsertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, txn *sql.Tx, threepid, medium, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) return err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 53c51416b..29a806e4a 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() alice := test.NewUser(t) - localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) events := room.Events() contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID())) - err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom) + err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom) assert.NoError(t, err, "unable to save account data") contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID)) - err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal) + err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal) assert.NoError(t, err, "unable to save account data") - accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read") + accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read") assert.NoError(t, err, "unable to get account data by type") assert.Equal(t, contentRoom, accountData) - globalData, roomData, err := db.GetAccountData(ctx, localpart) + globalData, roomData, err := db.GetAccountData(ctx, localpart, domain) assert.NoError(t, err) assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"]) assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"]) @@ -81,83 +81,83 @@ func Test_Accounts(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) - accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") // verify the newly create account is the same as returned by CreateAccount var accGet *api.Account - accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing") + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing") assert.NoError(t, err, "failed to get account by password") assert.Equal(t, accAlice, accGet) - accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart) + accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to get account by localpart") assert.Equal(t, accAlice, accGet) // check account availability - available, err := db.CheckAccountAvailability(ctx, aliceLocalpart) + available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, false, available) - available, err = db.CheckAccountAvailability(ctx, "unusedname") + available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, true, available) // get guest account numeric aliceLocalpart - first, err := db.GetNewNumericLocalpart(ctx) + first, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err, "failed to get new numeric localpart") // Create a new account to verify the numeric localpart is updated - _, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest) assert.NoError(t, err, "failed to create account") - second, err := db.GetNewNumericLocalpart(ctx) + second, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err) assert.Greater(t, second, first) // update password for alice - err = db.SetPassword(ctx, aliceLocalpart, "newPassword") + err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to update password") - accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to get account by new password") assert.Equal(t, accAlice, accGet) // deactivate account - err = db.DeactivateAccount(ctx, aliceLocalpart) + err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to deactivate account") // This should fail now, as the account is deactivated - _, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + _, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.Error(t, err, "expected an error, got none") // This should return an empty slice, as the account is deactivated and the 3pid is unbound - threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to get 3pid for account") assert.Equal(t, len(threepids), 0) - _, err = db.GetAccountByLocalpart(ctx, "unusename") + _, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain) assert.Error(t, err, "expected an error for non existent localpart") // create an empty localpart; this should never happen, but is required to test getting a numeric localpart // if there's already a user without a localpart in the database - _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser) + _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser) assert.NoError(t, err) // test getting a numeric localpart, with an existing user without a localpart - _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest) assert.NoError(t, err) // Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type - _, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser) + _, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser) assert.NoError(t, err) // Now try to create a new guest user - _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest) assert.NoError(t, err) }) } func Test_Devices(t *testing.T) { alice := test.NewUser(t) - localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) deviceID := util.RandomString(8) accessToken := util.RandomString(16) @@ -166,10 +166,10 @@ func Test_Devices(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() - deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "") + deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") - gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID) + gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields @@ -179,14 +179,14 @@ func Test_Devices(t *testing.T) { // create a device without existing device ID accessToken = util.RandomString(16) - deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "") + deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") - gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID) + gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields // Get devices - devices, err := db.GetDevicesByLocalpart(ctx, localpart) + devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get devices by localpart") assert.Equal(t, 2, len(devices)) deviceIDs := make([]string, 0, len(devices)) @@ -200,15 +200,15 @@ func Test_Devices(t *testing.T) { // Update device newName := "new display name" - err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName) + err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName) assert.NoError(t, err, "unable to update device displayname") updatedAfterTimestamp := time.Now().Unix() - err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web") + err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web") assert.NoError(t, err, "unable to update device last seen") deviceWithID.DisplayName = newName deviceWithID.LastSeenIP = "127.0.0.1" - gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID) + gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 2, len(devices)) assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName) @@ -218,20 +218,20 @@ func Test_Devices(t *testing.T) { // create one more device and remove the devices step by step newDeviceID := util.RandomString(16) accessToken = util.RandomString(16) - _, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "") + _, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create new device") - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 3, len(devices)) - err = db.RemoveDevices(ctx, localpart, deviceIDs) + err = db.RemoveDevices(ctx, localpart, domain, deviceIDs) assert.NoError(t, err, "unable to remove devices") - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 1, len(devices)) - deleted, err := db.RemoveAllDevices(ctx, localpart, "") + deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "") assert.NoError(t, err, "unable to remove all devices") assert.Equal(t, 1, len(deleted)) assert.Equal(t, newDeviceID, deleted[0].ID) @@ -369,7 +369,7 @@ func Test_OpenID(t *testing.T) { func Test_Profile(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -377,30 +377,33 @@ func Test_Profile(t *testing.T) { defer close() // create account, which also creates a profile - _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + _, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") - gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) + gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get profile by localpart") - wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} + wantProfile := &authtypes.Profile{ + Localpart: aliceLocalpart, + ServerName: string(aliceDomain), + } assert.Equal(t, wantProfile, gotProfile) // set avatar & displayname wantProfile.DisplayName = "Alice" - gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice") + gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice") assert.Equal(t, wantProfile, gotProfile) assert.NoError(t, err, "unable to set displayname") assert.True(t, changed) wantProfile.AvatarURL = "mxc://aliceAvatar" - gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.True(t, changed) // Setting the same avatar again doesn't change anything wantProfile.AvatarURL = "mxc://aliceAvatar" - gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.False(t, changed) @@ -415,7 +418,7 @@ func Test_Profile(t *testing.T) { func Test_Pusher(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -437,11 +440,11 @@ func Test_Pusher(t *testing.T) { ProfileTag: util.RandomString(8), Language: util.RandomString(2), } - err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart) + err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to upsert pusher") // check it was actually persisted - gotPushers, err = db.GetPushers(ctx, aliceLocalpart) + gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, i+1, len(gotPushers)) assert.Equal(t, wantPusher, gotPushers[i]) @@ -449,16 +452,16 @@ func Test_Pusher(t *testing.T) { } // remove single pusher - err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart) + err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to remove pusher") - gotPushers, err := db.GetPushers(ctx, aliceLocalpart) + gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 1, len(gotPushers)) // remove last pusher err = db.RemovePushers(ctx, appID, pushKeys[1]) assert.NoError(t, err, "unable to remove pusher") - gotPushers, err = db.GetPushers(ctx, aliceLocalpart) + gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 0, len(gotPushers)) }) @@ -466,7 +469,7 @@ func Test_Pusher(t *testing.T) { func Test_ThreePID(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -474,15 +477,16 @@ func Test_ThreePID(t *testing.T) { defer close() threePID := util.RandomString(8) medium := util.RandomString(8) - err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium) + err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium) assert.NoError(t, err, "unable to save threepid association") // get the stored threepid - gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium) + gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium) assert.NoError(t, err, "unable to get localpart for threepid") assert.Equal(t, aliceLocalpart, gotLocalpart) + assert.Equal(t, aliceDomain, gotDomain) - threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 1, len(threepids)) assert.Equal(t, authtypes.ThreePID{ @@ -495,7 +499,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err, "unexpected error") // verify it was deleted - threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 0, len(threepids)) }) @@ -503,7 +507,7 @@ func Test_ThreePID(t *testing.T) { func Test_Notification(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) @@ -531,34 +535,34 @@ func Test_Notification(t *testing.T) { RoomID: roomID, TS: gomatrixserverlib.AsTimestamp(ts), } - err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification) + err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") } // get notifications - count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications) + count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications) assert.NoError(t, err, "unable to get notification count") assert.Equal(t, int64(2), count) - notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications) + notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications) assert.NoError(t, err, "unable to get notifications") assert.Equal(t, int64(10), count) assert.Equal(t, 10, len(notifs)) // ... for a specific room - total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(4), total) // mark notification as read - affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true) + affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) // this should delete 2 notifications - affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8) + affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) - total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(2), total) @@ -567,7 +571,7 @@ func Test_Notification(t *testing.T) { assert.NoError(t, err) // this should now return 0 notifications - total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(0), total) }) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 5e1dd0971..e14776cf3 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -28,31 +28,31 @@ import ( ) type AccountDataTable interface { - InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error - SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) - SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) + InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error + SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) + SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) } type AccountsTable interface { - InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) - UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) - DeactivateAccount(ctx context.Context, localpart string) (err error) - SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) - SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) - SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) + InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error) } type DevicesTable interface { - InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) - DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error - DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error - DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error - UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) - SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) + SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error } type KeyBackupTable interface { @@ -79,40 +79,40 @@ type LoginTokenTable interface { } type OpenIDTable interface { - InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error) SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) } type ProfileTable interface { - InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error - SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error) + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error + SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } type ThreePIDTable interface { - SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) - SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) - InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } type PusherTable interface { - InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error - SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) - DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error } type NotificationTable interface { Clean(ctx context.Context, txn *sql.Tx) error - Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error - DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) - UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) - Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) - SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) - SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) + Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) } type StatsTable interface { diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index a547423bc..b088d15cd 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -79,6 +79,7 @@ func mustMakeAccountAndDevice( accDB tables.AccountsTable, devDB tables.DevicesTable, localpart string, + serverName gomatrixserverlib.ServerName, // nolint:unparam accType api.AccountType, userAgent string, ) { @@ -89,11 +90,11 @@ func mustMakeAccountAndDevice( appServiceID = util.RandomString(16) } - _, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType) + _, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType) if err != nil { t.Fatalf("unable to create account: %v", err) } - _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent) + _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent) if err != nil { t.Fatalf("unable to create device: %v", err) } @@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) { }) t.Run("Want Users", func(t *testing.T) { - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko") gotStats, _, err := statsDB.UserStatistics(ctx, nil) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 2a43c0bd4..25fa75ee2 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -61,7 +61,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap cfg := &config.UserAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } @@ -80,14 +82,14 @@ func TestQueryProfile(t *testing.T) { // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) defer close() - _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) + _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } - if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { t.Fatalf("failed to set avatar url: %s", err) } - if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { t.Fatalf("failed to set display name: %s", err) } @@ -164,7 +166,7 @@ func TestPasswordlessLoginFails(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) defer close() - _, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService) + _, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService) if err != nil { t.Fatalf("failed to make account: %s", err) } @@ -190,7 +192,7 @@ func TestLoginToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) defer close() - _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) + _, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } diff --git a/userapi/util/devices.go b/userapi/util/devices.go index cbf3bd28f..c55fc7999 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -2,10 +2,12 @@ package util import ( "context" + "fmt" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -17,10 +19,10 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { - pushers, err := db.GetPushers(ctx, localpart) +func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { + pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { - return nil, err + return nil, fmt.Errorf("db.GetPushers: %w", err) } devices := make([]*PusherDevice, 0, len(pushers)) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index ff206bd3c..fc0ab39bf 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -8,6 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -16,8 +17,8 @@ import ( // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error { - pusherDevices, err := GetPushDevices(ctx, localpart, nil, db) +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error { + pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err } @@ -26,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc return nil } - userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications) + userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications) if err != nil { return err }