From b2712cd2b1ca8537d97d7cd6a416120826638e5c Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 4 Nov 2022 20:58:24 +0100 Subject: [PATCH 001/124] Fix GHA release script --- .github/workflows/docker.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f0500ccc5..846844173 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -32,7 +32,7 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV + echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) [ ${BRANCH} == "main" ] && BRANCH="" echo "BRANCH=${BRANCH}" >> $GITHUB_ENV @@ -112,7 +112,7 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV + echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) [ ${BRANCH} == "main" ] && BRANCH="" echo "BRANCH=${BRANCH}" >> $GITHUB_ENV @@ -191,7 +191,7 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV + echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) [ ${BRANCH} == "main" ] && BRANCH="" echo "BRANCH=${BRANCH}" >> $GITHUB_ENV @@ -258,7 +258,7 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV + echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) [ ${BRANCH} == "main" ] && BRANCH="" echo "BRANCH=${BRANCH}" >> $GITHUB_ENV From a7b74176e3dbcfd36525388665c64674c707e5c3 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Nov 2022 21:49:18 +0000 Subject: [PATCH 002/124] Revert Docker user change --- Dockerfile | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 499992343..a9bbce925 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,7 +24,7 @@ RUN --mount=target=. \ go build -v -ldflags="${FLAGS}" -trimpath -o /out/ ./cmd/... # -# The dendrite base image; mainly creates a user and switches to it +# The dendrite base image # FROM alpine:latest AS dendrite-base LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" @@ -32,8 +32,6 @@ LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" LABEL org.opencontainers.image.licenses="Apache-2.0" LABEL org.opencontainers.image.documentation="https://matrix-org.github.io/dendrite/" LABEL org.opencontainers.image.vendor="The Matrix.org Foundation C.I.C." -RUN addgroup dendrite && adduser dendrite -G dendrite -u 1337 -D -USER dendrite # # Builds the polylith image and only contains the polylith binary From c125203eb6e5dfeaa4290aeea8bbed9a73caf16c Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 7 Nov 2022 09:47:18 +0100 Subject: [PATCH 003/124] Handle `m.room.tombstone` events in the UserAPI (#2864) Fixes #2863 and makes ``` /upgrade preserves direct room state local user has tags copied to the new room remote user has tags copied to the new room ``` pass. --- sytest-whitelist | 5 +- userapi/consumers/roomserver.go | 117 +++++++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 2 deletions(-) diff --git a/sytest-whitelist b/sytest-whitelist index f4311d339..bb4f0a279 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -758,4 +758,7 @@ Can get rooms/{roomId}/members at a given point Can filter rooms/{roomId}/members Current state appears in timeline in private history with many messages after AS can publish rooms in their own list -AS and main public room lists are separate \ No newline at end of file +AS and main public room lists are separate +/upgrade preserves direct room state +local user has tags copied to the new room +remote user has tags copied to the new room \ No newline at end of file diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 97c17e188..b6b30a095 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -2,12 +2,16 @@ package consumers import ( "context" + "database/sql" "encoding/json" + "errors" "fmt" "strings" "sync" "time" + "github.com/tidwall/gjson" + "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -185,13 +189,115 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy } } +func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { + for _, membership := range localMembers { + // Copy any existing push rules from old -> new room + if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + return err + } + + // preserve m.direct room state + if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil { + return err + } + + // copy existing m.tag entries, if any + if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + return err + } + } + return nil +} + +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error { + pushRules, err := s.db.QueryPushRules(ctx, localpart) + if err != nil { + return fmt.Errorf("failed to query pushrules for user: %w", err) + } + if pushRules == nil { + return nil + } + + for _, roomRule := range pushRules.Global.Room { + if roomRule.RuleID != oldRoomID { + continue + } + cpRool := *roomRule + cpRool.RuleID = newRoomID + pushRules.Global.Room = append(pushRules.Global.Room, &cpRool) + rules, err := json.Marshal(pushRules) + if err != nil { + return err + } + if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil { + return fmt.Errorf("failed to update pushrules: %w", err) + } + } + return nil +} + +// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error { + // this is most likely not a DM, so skip updating m.direct state + if roomSize > 2 { + return nil + } + // Get direct message state + directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct") + if err != nil { + return fmt.Errorf("failed to get m.direct from database: %w", err) + } + directChats := gjson.ParseBytes(directChatsRaw) + newDirectChats := make(map[string][]string) + // iterate over all userID -> roomIDs + directChats.ForEach(func(userID, roomIDs gjson.Result) bool { + var found bool + for _, roomID := range roomIDs.Array() { + newDirectChats[userID.Str] = append(newDirectChats[userID.Str], roomID.Str) + // add the new roomID to m.direct + if roomID.Str == oldRoomID { + found = true + newDirectChats[userID.Str] = append(newDirectChats[userID.Str], newRoomID) + } + } + // Only hit the database if we found the old room as a DM for this user + if found { + var data []byte + data, err = json.Marshal(newDirectChats) + if err != nil { + return true + } + if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil { + return true + } + } + return true + }) + if err != nil { + return fmt.Errorf("failed to update m.direct state") + } + return nil +} + +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error { + tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag") + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + if tag == nil { + return nil + } + return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag) +} + func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) if err != nil { return fmt.Errorf("s.localRoomMembers: %w", err) } - if event.Type() == gomatrixserverlib.MRoomMember { + switch { + case event.Type() == gomatrixserverlib.MRoomMember: cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) var member *localMembership member, err = newLocalMembership(&cevent) @@ -203,6 +309,15 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gom // should also be pushed to the target user. members = append(members, member) } + case event.Type() == "m.room.tombstone" && event.StateKeyEquals(""): + // Handle room upgrades + oldRoomID := event.RoomID() + newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str + if err = s.handleRoomUpgrade(ctx, oldRoomID, newRoomID, members, roomSize); err != nil { + // while inconvenient, this shouldn't stop us from sending push notifications + log.WithError(err).Errorf("UserAPI: failed to handle room upgrade for users") + } + } // TODO: run in parallel with localRoomMembers. From 205a15621a67ca7b851dc8ece8cfbcce94fb8b4d Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 7 Nov 2022 15:07:47 +0100 Subject: [PATCH 004/124] Add custom build flag to satisfy Sytest --- keyserver/internal/device_list_update.go | 5 ++-- .../internal/device_list_update_default.go | 22 ++++++++++++++++ .../internal/device_list_update_sytest.go | 25 +++++++++++++++++++ .../storage/postgres/stale_device_lists.go | 6 ++--- .../storage/sqlite3/stale_device_lists.go | 6 ++--- 5 files changed, 55 insertions(+), 9 deletions(-) create mode 100644 keyserver/internal/device_list_update_default.go create mode 100644 keyserver/internal/device_list_update_sytest.go diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 8b02f3d6c..3f7c0d8b4 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -47,7 +47,6 @@ var ( ) ) -const defaultWaitTime = time.Minute const requestTimeout = time.Second * 30 func init() { @@ -454,7 +453,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go } else if e.Code >= 300 { // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. - return time.Hour, err + return hourWaitTime, err } case net.Error: // Use the default waitTime, if it's a timeout. @@ -468,7 +467,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go // This is to avoid spamming remote servers, which may not be Matrix servers anymore. if e.Code >= 300 { logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") - return time.Hour, err + return hourWaitTime, err } default: // Something else failed diff --git a/keyserver/internal/device_list_update_default.go b/keyserver/internal/device_list_update_default.go new file mode 100644 index 000000000..7d357c951 --- /dev/null +++ b/keyserver/internal/device_list_update_default.go @@ -0,0 +1,22 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vw + +package internal + +import "time" + +const defaultWaitTime = time.Minute +const hourWaitTime = time.Hour diff --git a/keyserver/internal/device_list_update_sytest.go b/keyserver/internal/device_list_update_sytest.go new file mode 100644 index 000000000..1c60d2eb9 --- /dev/null +++ b/keyserver/internal/device_list_update_sytest.go @@ -0,0 +1,25 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vw + +package internal + +import "time" + +// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite +// results in a one-hour wait time from a previous device so the test times out. This is fine for +// production, but makes an otherwise passing test fail. +const defaultWaitTime = time.Second +const hourWaitTime = time.Second diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go index 63281adfb..d0fe50d00 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" type staleDeviceListsStatements struct { upsertStaleDeviceListStmt *sql.Stmt @@ -77,7 +77,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) return err } diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index fc2cc37c4..1e08b266c 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" type staleDeviceListsStatements struct { db *sql.DB @@ -80,7 +80,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) return err } From a5cabdbac5678171167062fca06e60a93a21ddbf Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 9 Nov 2022 09:24:29 +0000 Subject: [PATCH 005/124] Remove unspecced fields from `Transaction` (update to matrix-org/gomatrixserverlib@715dc88) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 7dd9e0b2c..af315b47f 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e + github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 5d172e86b..ed2d97eed 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e h1:6I34fdyiHMRCxL6GOb/G8ZyI1WWlb6ZxCF2hIGSMSCc= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 h1:Bet5n+//Yh+A2SuPHD67N8jrOhC/EIKvEgisfsKhTss= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From bdaae060ccead4594cfa88f45e9656c3a094535c Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 9 Nov 2022 14:07:29 +0000 Subject: [PATCH 006/124] Update Ristretto --- go.mod | 2 +- go.sum | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index af315b47f..47c43fbc7 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/Masterminds/semver/v3 v3.1.1 github.com/blevesearch/bleve/v2 v2.3.4 github.com/codeclysm/extract v2.2.0+incompatible - github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d + github.com/dgraph-io/ristretto v0.1.1 github.com/docker/docker v20.10.19+incompatible github.com/docker/go-connections v0.4.0 github.com/getsentry/sentry-go v0.14.0 diff --git a/go.sum b/go.sum index ed2d97eed..a805a5bf9 100644 --- a/go.sum +++ b/go.sum @@ -141,8 +141,8 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d h1:Wrc3UKTS+cffkOx0xRGFC+ZesNuTfn0ThvEC72N0krk= -github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d/go.mod h1:RAy2GVV4sTWVlNMavv3xhLsk18rxhfhDnombTe6EF5c= +github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= +github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= @@ -526,7 +526,6 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -697,6 +696,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= From 503d9c7586d7c5ca91cdb78631553256628861cd Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 10 Nov 2022 10:07:19 +0000 Subject: [PATCH 007/124] Improve logging in upgrade tests --- cmd/dendrite-upgrade-tests/main.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index dce22472d..6630a0620 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -381,18 +381,22 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string) lastErr = nil break } - if lastErr != nil { - logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{ - ShowStdout: true, - ShowStderr: true, - }) - // ignore errors when cannot get logs, it's just for debugging anyways - if err == nil { - logbody, err := io.ReadAll(logs) - if err == nil { - log.Printf("Container logs:\n\n%s\n\n", string(logbody)) + logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{ + ShowStdout: true, + ShowStderr: true, + Follow: true, + }) + // ignore errors when cannot get logs, it's just for debugging anyways + if err == nil { + go func() { + for { + if body, err := io.ReadAll(logs); err == nil && len(body) > 0 { + log.Printf("%s: %s", version, string(body)) + } else { + return + } } - } + }() } return baseURL, containerID, lastErr } From efa50253f68f8ec1c6a9f4a8d82b329512027b00 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 10 Nov 2022 10:16:56 +0000 Subject: [PATCH 008/124] Fix lint error --- cmd/dendrite-upgrade-tests/main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 6630a0620..131ce4af9 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -367,7 +367,8 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string) // hit /versions to check it is up var lastErr error for i := 0; i < 500; i++ { - res, err := http.Get(versionsURL) + var res *http.Response + res, err = http.Get(versionsURL) if err != nil { lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err) time.Sleep(50 * time.Millisecond) From 0193549201299f5dcce919b2aeb3b1c40bdfcefa Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 11 Nov 2022 10:35:17 +0100 Subject: [PATCH 009/124] Send presence to newly added servers (#2869) This should make `New federated private chats get full presence information (SYN-115)` happy. --- federationapi/consumers/roomserver.go | 120 ++++++++++++++++++++++---- federationapi/federationapi.go | 4 +- roomserver/api/api.go | 1 + 3 files changed, 107 insertions(+), 18 deletions(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index a42733628..d16af6626 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -18,6 +18,10 @@ import ( "context" "encoding/json" "fmt" + "strconv" + "time" + + syncAPITypes "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -34,14 +38,16 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { - ctx context.Context - cfg *config.FederationAPI - rsAPI api.FederationRoomserverAPI - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - topic string + ctx context.Context + cfg *config.FederationAPI + rsAPI api.FederationRoomserverAPI + jetstream nats.JetStreamContext + natsClient *nats.Conn + durable string + db storage.Database + queues *queue.OutgoingQueues + topic string + topicPresence string } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -49,19 +55,22 @@ func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.FederationAPI, js nats.JetStreamContext, + natsClient *nats.Conn, queues *queue.OutgoingQueues, store storage.Database, rsAPI api.FederationRoomserverAPI, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ - ctx: process.Context(), - cfg: cfg, - jetstream: js, - db: store, - queues: queues, - rsAPI: rsAPI, - durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), + ctx: process.Context(), + cfg: cfg, + jetstream: js, + natsClient: natsClient, + db: store, + queues: queues, + rsAPI: rsAPI, + durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), + topicPresence: cfg.Matrix.JetStream.Prefixed(jetstream.RequestPresence), } } @@ -146,6 +155,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error { + addsStateEvents, missingEventIDs := ore.NeededStateEventIDs() // Ask the roomserver and add in the rest of the results into the set. @@ -184,6 +194,14 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew return err } + // If we added new hosts, inform them about our known presence events for this room + if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + membership, _ := ore.Event.Membership() + if membership == gomatrixserverlib.Join { + s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) + } + } + if oldJoinedHosts == nil { // This means that there is nothing to update as this is a duplicate // message. @@ -213,6 +231,76 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew ) } +func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) { + joined := make([]gomatrixserverlib.ServerName, len(addedJoined)) + for _, added := range addedJoined { + joined = append(joined, added.ServerName) + } + + // get our locally joined users + var queryRes api.QueryMembershipsForRoomResponse + err := s.rsAPI.QueryMembershipsForRoom(s.ctx, &api.QueryMembershipsForRoomRequest{ + JoinedOnly: true, + LocalOnly: true, + RoomID: roomID, + }, &queryRes) + if err != nil { + log.WithError(err).Error("failed to calculate joined rooms for user") + return + } + + // send every presence we know about to the remote server + content := types.Presence{} + for _, ev := range queryRes.JoinEvents { + msg := nats.NewMsg(s.topicPresence) + msg.Header.Set(jetstream.UserID, ev.Sender) + + var presence *nats.Msg + presence, err = s.natsClient.RequestMsg(msg, time.Second*10) + if err != nil { + log.WithError(err).Errorf("unable to get presence") + continue + } + + statusMsg := presence.Header.Get("status_msg") + e := presence.Header.Get("error") + if e != "" { + continue + } + var lastActive int + lastActive, err = strconv.Atoi(presence.Header.Get("last_active_ts")) + if err != nil { + continue + } + + p := syncAPITypes.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)} + + content.Push = append(content.Push, types.PresenceContent{ + CurrentlyActive: p.CurrentlyActive(), + LastActiveAgo: p.LastActiveAgo(), + Presence: presence.Header.Get("presence"), + StatusMsg: &statusMsg, + UserID: ev.Sender, + }) + } + + if len(content.Push) == 0 { + return + } + + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MPresence, + Origin: string(s.cfg.Matrix.ServerName), + } + if edu.Content, err = json.Marshal(content); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return + } + if err := s.queues.SendEDU(edu, s.cfg.Matrix.ServerName, joined); err != nil { + log.WithError(err).Error("failed to send EDU") + } +} + // joinedHostsAtEvent works out a list of matrix servers that were joined to // the room at the event (including peeking ones) // It is important to use the state at the event for sending messages because: diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 202da6c51..e35b9c7f0 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -118,7 +118,7 @@ func NewInternalAPI( stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, @@ -132,7 +132,7 @@ func NewInternalAPI( ) rsConsumer := consumers.NewOutputRoomEventConsumer( - base.ProcessContext, cfg, js, queues, + base.ProcessContext, cfg, js, nats, queues, federationDB, rsAPI, ) if err = rsConsumer.Start(); err != nil { diff --git a/roomserver/api/api.go b/roomserver/api/api.go index a1373a62b..01e87ec8a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -177,6 +177,7 @@ type FederationRoomserverAPI interface { QueryBulkStateContentAPI // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error + QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error From d35a5642e89a2a1b64f1c2ed1cb13e6080987b1c Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 11 Nov 2022 10:52:08 +0100 Subject: [PATCH 010/124] Deny guest access on several endpoints (#2873) Second part for guest access, this adds a `WithAllowGuests()` option to `MakeAuthAPI`, allowing guests to access the specified endpoints. Endpoints taken from the [spec](https://spec.matrix.org/v1.4/client-server-api/#client-behaviour-14) and by checking Synapse endpoints for `allow_guest=true`. --- clientapi/routing/routing.go | 68 +++++++++++++++++------------------ internal/httputil/httpapi.go | 29 +++++++++++++++ setup/mscs/msc2946/msc2946.go | 2 +- syncapi/routing/routing.go | 18 +++++----- 4 files changed, 73 insertions(+), 44 deletions(-) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f35aa7e12..1b3ef120a 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -252,7 +252,7 @@ func Setup( return JoinRoomByIDOrAlias( req, device, rsAPI, userAPI, vars["roomIDOrAlias"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) if mscCfg.Enabled("msc2753") { @@ -274,7 +274,7 @@ func Setup( v3mux.Handle("/joined_rooms", httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetJoinedRooms(req, device, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -288,7 +288,7 @@ func Setup( return JoinRoomByIDOrAlias( req, device, rsAPI, userAPI, vars["roomID"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -302,7 +302,7 @@ func Setup( return LeaveRoomByID( req, device, rsAPI, vars["roomID"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/unpeek", httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -361,7 +361,7 @@ func Setup( return util.ErrorResponse(err) } return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -372,7 +372,7 @@ func Setup( txnID := vars["txnID"] return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID, nil, cfg, rsAPI, transactionsCache) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -381,7 +381,7 @@ func Setup( return util.ErrorResponse(err) } return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -400,7 +400,7 @@ func Setup( eventType := strings.TrimSuffix(vars["type"], "/") eventFormat := req.URL.Query().Get("format") == "event" return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -409,7 +409,7 @@ func Setup( } eventFormat := req.URL.Query().Get("format") == "event" return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -420,7 +420,7 @@ func Setup( emptyString := "" eventType := strings.TrimSuffix(vars["eventType"], "/") return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", @@ -431,7 +431,7 @@ func Setup( } stateKey := vars["stateKey"] return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { @@ -575,7 +575,7 @@ func Setup( } txnID := vars["txnID"] return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) // This is only here because sytest refers to /unstable for this endpoint @@ -589,7 +589,7 @@ func Setup( } txnID := vars["txnID"] return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/account/whoami", @@ -598,7 +598,7 @@ func Setup( return *r } return Whoami(req, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/password", @@ -830,7 +830,7 @@ func Setup( return util.ErrorResponse(err) } return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method @@ -871,7 +871,7 @@ func Setup( v3mux.Handle("/thirdparty/protocols", httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Protocols(req, asAPI, device, "") - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/protocol/{protocolID}", @@ -881,7 +881,7 @@ func Setup( return util.ErrorResponse(err) } return Protocols(req, asAPI, device, vars["protocolID"]) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user/{protocolID}", @@ -891,13 +891,13 @@ func Setup( return util.ErrorResponse(err) } return User(req, asAPI, device, vars["protocolID"], req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user", httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return User(req, asAPI, device, "", req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location/{protocolID}", @@ -907,13 +907,13 @@ func Setup( return util.ErrorResponse(err) } return Location(req, asAPI, device, vars["protocolID"], req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location", httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Location(req, asAPI, device, "", req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/initialSync", @@ -1054,7 +1054,7 @@ func Setup( v3mux.Handle("/devices", httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetDevicesByLocalpart(req, userAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", @@ -1064,7 +1064,7 @@ func Setup( return util.ErrorResponse(err) } return GetDeviceByID(req, userAPI, device, vars["deviceID"]) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", @@ -1074,7 +1074,7 @@ func Setup( return util.ErrorResponse(err) } return UpdateDeviceByID(req, userAPI, device, vars["deviceID"]) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", @@ -1116,21 +1116,21 @@ func Setup( // Stub implementations for sytest v3mux.Handle("/events", - httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { + httputil.MakeAuthAPI("events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, "start": "", "end": "", }} - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/initialSync", - httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { + httputil.MakeAuthAPI("initial_sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", }} - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", @@ -1169,7 +1169,7 @@ func Setup( return *r } return GetCapabilities(req, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) // Key Backup Versions (Metadata) @@ -1350,7 +1350,7 @@ func Setup( postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadCrossSigningDeviceSignatures(req, keyAPI, device) - }) + }, httputil.WithAllowGuests()) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) @@ -1362,22 +1362,22 @@ func Setup( v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req, keyAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return ClaimKeys(req, keyAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 36dcaf453..4f33a3f79 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -42,10 +42,26 @@ type BasicAuth struct { Password string `yaml:"password"` } +type AuthAPIOpts struct { + GuestAccessAllowed bool +} + +// AuthAPIOption is an option to MakeAuthAPI to add additional checks (e.g. guest access) to verify +// the user is allowed to do specific things. +type AuthAPIOption func(opts *AuthAPIOpts) + +// WithAllowGuests checks that guest users have access to this endpoint +func WithAllowGuests() AuthAPIOption { + return func(opts *AuthAPIOpts) { + opts.GuestAccessAllowed = true + } +} + // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( metricsName string, userAPI userapi.QueryAcccessTokenAPI, f func(*http.Request, *userapi.Device) util.JSONResponse, + checks ...AuthAPIOption, ) http.Handler { h := func(req *http.Request) util.JSONResponse { logger := util.GetLogger(req.Context()) @@ -76,6 +92,19 @@ func MakeAuthAPI( } }() + // apply additional checks, if any + opts := AuthAPIOpts{} + for _, opt := range checks { + opt(&opts) + } + + if !opts.GuestAccessAllowed && device.AccountType == userapi.AccountTypeGuest { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.GuestAccessForbidden("Guest access not allowed"), + } + } + jsonRes := f(req, device) // do not log 4xx as errors as they are client fails, not server fails if hub != nil && jsonRes.Code >= 500 { diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index bc9df0f96..d7c022544 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -57,7 +57,7 @@ func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache, ) error { - clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName)) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName), httputil.WithAllowGuests()) base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index bc3ad2384..4cc1a6a85 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -51,7 +51,7 @@ func Setup( // TODO: Add AS support for all handlers below. v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -59,7 +59,7 @@ func Setup( return util.ErrorResponse(err) } return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, rsAPI, cfg, srp, lazyLoadCache) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/event/{eventID}", httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -68,7 +68,7 @@ func Setup( return util.ErrorResponse(err) } return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, syncDB, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/filter", @@ -93,7 +93,7 @@ func Setup( v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingKeyChangeRequest(req, device) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/context/{eventId}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -108,7 +108,7 @@ func Setup( vars["roomId"], vars["eventId"], lazyLoadCache, ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}", @@ -122,7 +122,7 @@ func Setup( req, device, syncDB, rsAPI, vars["roomId"], vars["eventId"], "", "", ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}", @@ -136,7 +136,7 @@ func Setup( req, device, syncDB, rsAPI, vars["roomId"], vars["eventId"], vars["relType"], "", ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}", @@ -150,7 +150,7 @@ func Setup( req, device, syncDB, rsAPI, vars["roomId"], vars["eventId"], vars["relType"], vars["eventType"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/search", @@ -191,7 +191,7 @@ func Setup( at := req.URL.Query().Get("at") return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, false, membership, notMembership, at) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/joined_members", From c648c671a326f2d626cf34db52cbcc9999b95bba Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 11 Nov 2022 10:52:43 +0100 Subject: [PATCH 011/124] Fix issue with missing user NIDs (#2874) This should fix #2696 and possibly other related issues regarding missing user NIDs. (https://github.com/matrix-org/dendrite/issues/2094?) --- roomserver/storage/shared/storage.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 4455ec3bf..958787070 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -110,6 +110,19 @@ func (d *Database) eventStateKeyNIDs( for eventStateKey, nid := range nids { result[eventStateKey] = nid } + // We received some nids, but are still missing some, work out which and create them + if len(eventStateKeys) < len(result) { + for _, eventStateKey := range eventStateKeys { + if _, ok := result[eventStateKey]; ok { + continue + } + nid, err := d.assignStateKeyNID(ctx, txn, eventStateKey) + if err != nil { + return result, err + } + result[eventStateKey] = nid + } + } return result, nil } @@ -1243,7 +1256,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } - eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) + eventStateKeyNIDMap, err := d.eventStateKeyNIDs(ctx, nil, eventStateKeys) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) } @@ -1309,7 +1322,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ if err != nil { return nil, err } - userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs) + userNIDsMap, err := d.eventStateKeyNIDs(ctx, nil, userIDs) if err != nil { return nil, err } From 72ce6acf7176ac9e597e4fd6587605199e9e0d7f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 11 Nov 2022 11:21:16 +0000 Subject: [PATCH 012/124] Run upgrade tests for SQLite too (#2875) This should hopefully catch problems with database migrations in SQLite as well as PostgreSQL. --- .github/workflows/dendrite.yml | 10 ++++--- cmd/dendrite-upgrade-tests/main.go | 43 +++++++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index a37e45e46..f96cbadd4 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -231,8 +231,10 @@ jobs: ${{ runner.os }}-go-upgrade - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests - - name: Test upgrade + - name: Test upgrade (PostgreSQL) run: ./dendrite-upgrade-tests --head . + - name: Test upgrade (SQLite) + run: ./dendrite-upgrade-tests --sqlite --head . # run database upgrade tests, skipping over one version upgrade_test_direct: @@ -256,7 +258,9 @@ jobs: ${{ runner.os }}-go-upgrade - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests - - name: Test upgrade + - name: Test upgrade (PostgreSQL) + run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head . + - name: Test upgrade (SQLite) run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head . # run Sytest in different variations @@ -434,7 +438,7 @@ jobs: permissions: packages: write contents: read - security-events: write # To upload Trivy sarif files + security-events: write # To upload Trivy sarif files if: github.repository == 'matrix-org/dendrite' && github.ref_name == 'main' needs: [integration-tests-done] uses: matrix-org/dendrite/.github/workflows/docker.yml@main diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 131ce4af9..75446d18c 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -38,6 +38,7 @@ var ( flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github") flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.") flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done") + flagSqlite = flag.Bool("sqlite", false, "Test SQLite instead of PostgreSQL") alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+") ) @@ -49,7 +50,7 @@ const HEAD = "HEAD" // due to the error: // When using COPY with more than one source file, the destination must be a directory and end with a / // We need to run a postgres anyway, so use the dockerfile associated with Complement instead. -const Dockerfile = `FROM golang:1.18-stretch as build +const DockerfilePostgreSQL = `FROM golang:1.18-stretch as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build @@ -92,6 +93,42 @@ ENV SERVER_NAME=localhost EXPOSE 8008 8448 CMD /build/run_dendrite.sh ` +const DockerfileSQLite = `FROM golang:1.18-stretch as build +RUN apt-get update && apt-get install -y postgresql +WORKDIR /build + +# Copy the build context to the repo as this is the right dendrite code. This is different to the +# Complement Dockerfile which wgets a branch. +COPY . . + +RUN go build ./cmd/dendrite-monolith-server +RUN go build ./cmd/generate-keys +RUN go build ./cmd/generate-config +RUN ./generate-config --ci > dendrite.yaml +RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key + +# Make sure the SQLite databases are in a persistent location, we're already mapping +# the postgresql folder so let's just use that for simplicity +RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postgresql\/9.6\/main\/%g" dendrite.yaml + +# This entry script starts postgres, waits for it to be up then starts dendrite +RUN echo '\ +sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\ +PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\ +./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\ +' > run_dendrite.sh && chmod +x run_dendrite.sh + +ENV SERVER_NAME=localhost +EXPOSE 8008 8448 +CMD /build/run_dendrite.sh ` + +func dockerfile() []byte { + if *flagSqlite { + return []byte(DockerfileSQLite) + } + return []byte(DockerfilePostgreSQL) +} + const dendriteUpgradeTestLabel = "dendrite_upgrade_test" // downloadArchive downloads an arbitrary github archive of the form: @@ -150,7 +187,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, if branchOrTagName == HEAD && *flagHead != "" { log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead) // add top level Dockerfile - err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), []byte(Dockerfile), os.ModePerm) + err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), dockerfile(), os.ModePerm) if err != nil { return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err) } @@ -166,7 +203,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, // pull an archive, this contains a top-level directory which screws with the build context // which we need to fix up post download u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName) - tarball, err = downloadArchive(httpClient, tmpDir, u, []byte(Dockerfile)) + tarball, err = downloadArchive(httpClient, tmpDir, u, dockerfile()) if err != nil { return "", fmt.Errorf("failed to download archive %s: %w", u, err) } From e177e0ae73d7cc34ffb9869681a6bf177f805205 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 11 Nov 2022 16:44:59 +0100 Subject: [PATCH 013/124] Fix oops, add simple UT --- roomserver/internal/helpers/helpers_test.go | 56 +++++++++++++++++++++ roomserver/storage/shared/storage.go | 2 +- 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 roomserver/internal/helpers/helpers_test.go diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go new file mode 100644 index 000000000..aa5c30e44 --- /dev/null +++ b/roomserver/internal/helpers/helpers_test.go @@ -0,0 +1,56 @@ +package helpers + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/setup/base" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + + "github.com/matrix-org/dendrite/roomserver/storage" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches) + if err != nil { + t.Fatalf("failed to create Database: %v", err) + } + return base, db, close +} + +func TestIsInvitePendingWithoutNID(t *testing.T) { + + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) + _ = bob + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + _, db, close := mustCreateDatabase(t, dbType) + defer close() + + // store all events + var authNIDs []types.EventNID + for _, x := range room.Events() { + + evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false) + assert.NoError(t, err) + authNIDs = append(authNIDs, evNID) + } + + // Alice should have no pending invites and should have a NID + pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID) + assert.NoError(t, err, "failed to get pending invites") + assert.False(t, pendingInvite, "unexpected pending invite") + + // Bob should have no pending invites and receive a new NID + pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID) + assert.NoError(t, err, "failed to get pending invites") + assert.False(t, pendingInvite, "unexpected pending invite") + }) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 958787070..08e912a00 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -111,7 +111,7 @@ func (d *Database) eventStateKeyNIDs( result[eventStateKey] = nid } // We received some nids, but are still missing some, work out which and create them - if len(eventStateKeys) < len(result) { + if len(eventStateKeys) > len(result) { for _, eventStateKey := range eventStateKeys { if _, ok := result[eventStateKey]; ok { continue From 529df30b5649e67a2f98114e6640d259cba53566 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 11 Nov 2022 16:41:37 +0000 Subject: [PATCH 014/124] Virtual hosting schema and logic changes (#2876) Note that virtual users cannot federate correctly yet. --- appservice/appservice.go | 6 +- clientapi/auth/password.go | 21 +- clientapi/routing/admin.go | 5 + clientapi/routing/login.go | 1 + clientapi/routing/notification.go | 11 +- clientapi/routing/password.go | 12 +- clientapi/routing/pusher.go | 8 +- clientapi/routing/register.go | 10 +- clientapi/routing/routing.go | 2 +- clientapi/routing/server_notices.go | 8 +- clientapi/routing/threepid.go | 14 +- federationapi/federationapi.go | 18 +- federationapi/federationapi_keys_test.go | 2 +- federationapi/queue/destinationqueue.go | 2 +- federationapi/queue/queue.go | 26 +-- federationapi/queue/queue_test.go | 10 +- federationapi/routing/keys.go | 24 ++- federationapi/routing/routing.go | 2 +- keyserver/internal/internal.go | 30 +-- keyserver/keyserver.go | 8 +- userapi/api/api.go | 55 +++-- userapi/api/api_trace.go | 4 +- userapi/consumers/clientapi.go | 4 +- userapi/consumers/roomserver.go | 58 ++--- userapi/internal/api.go | 101 +++++---- userapi/internal/api_logintoken.go | 2 +- userapi/inthttp/client.go | 3 +- userapi/inthttp/server.go | 15 +- userapi/producers/syncapi.go | 4 +- userapi/storage/interface.go | 66 +++--- .../storage/postgres/account_data_table.go | 33 +-- userapi/storage/postgres/accounts_table.go | 57 ++--- .../deltas/2022110411000000_server_names.go | 81 +++++++ .../deltas/2022110411000001_server_names.go | 28 +++ userapi/storage/postgres/devices_table.go | 88 ++++---- .../storage/postgres/notifications_table.go | 49 ++--- userapi/storage/postgres/openid_table.go | 15 +- userapi/storage/postgres/profile_table.go | 48 +++-- userapi/storage/postgres/pusher_table.go | 32 +-- userapi/storage/postgres/storage.go | 28 ++- userapi/storage/postgres/threepid_table.go | 26 ++- userapi/storage/shared/storage.go | 204 ++++++++++-------- userapi/storage/sqlite3/account_data_table.go | 33 +-- userapi/storage/sqlite3/accounts_table.go | 53 ++--- .../deltas/20200929203058_is_active.go | 1 + .../deltas/20201001204705_last_seen_ts_ip.go | 1 + .../2022021012490600_add_account_type.go | 1 + .../deltas/2022110411000000_server_names.go | 108 ++++++++++ .../deltas/2022110411000001_server_names.go | 28 +++ userapi/storage/sqlite3/devices_table.go | 92 ++++---- .../storage/sqlite3/notifications_table.go | 49 ++--- userapi/storage/sqlite3/openid_table.go | 15 +- userapi/storage/sqlite3/profile_table.go | 52 +++-- userapi/storage/sqlite3/pusher_table.go | 32 +-- userapi/storage/sqlite3/storage.go | 28 ++- userapi/storage/sqlite3/threepid_table.go | 26 ++- userapi/storage/storage_test.go | 132 ++++++------ userapi/storage/tables/interface.go | 68 +++--- userapi/storage/tables/stats_table_test.go | 17 +- userapi/userapi_test.go | 10 +- userapi/util/devices.go | 8 +- userapi/util/notify.go | 7 +- 62 files changed, 1250 insertions(+), 732 deletions(-) create mode 100644 userapi/storage/postgres/deltas/2022110411000000_server_names.go create mode 100644 userapi/storage/postgres/deltas/2022110411000001_server_names.go create mode 100644 userapi/storage/sqlite3/deltas/2022110411000000_server_names.go create mode 100644 userapi/storage/sqlite3/deltas/2022110411000001_server_names.go diff --git a/appservice/appservice.go b/appservice/appservice.go index 0c778b6ca..b3c28dbde 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -32,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) // AddInternalRoutes registers HTTP handlers for internal API calls @@ -74,7 +75,7 @@ func NewInternalAPI( // events to be sent out. for _, appservice := range base.Cfg.Derived.ApplicationServices { // Create bot account for this AS if it doesn't already exist - if err := generateAppServiceAccount(userAPI, appservice); err != nil { + if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil { logrus.WithFields(logrus.Fields{ "appservice": appservice.ID, }).WithError(err).Panicf("failed to generate bot account for appservice") @@ -101,11 +102,13 @@ func NewInternalAPI( func generateAppServiceAccount( userAPI userapi.AppserviceUserAPI, as config.ApplicationService, + serverName gomatrixserverlib.ServerName, ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeAppService, Localpart: as.SenderLocalpart, + ServerName: serverName, AppServiceID: as.ID, OnConflict: userapi.ConflictUpdate, }, &accRes) @@ -115,6 +118,7 @@ func generateAppServiceAccount( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ Localpart: as.SenderLocalpart, + ServerName: serverName, AccessToken: as.ASToken, DeviceID: &as.SenderLocalpart, DeviceDisplayName: &as.SenderLocalpart, diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 700a72f5d..4de2b443c 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -61,7 +61,7 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { r := req.(*PasswordRequest) - username := strings.ToLower(r.Username()) + username := r.Username() if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, @@ -74,32 +74,43 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: jsonerror.BadJSON("A password must be supplied."), } } - localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) + localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, JSON: jsonerror.InvalidUsername(err.Error()), } } + if !t.Config.Matrix.IsLocalServerName(domain) { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.InvalidUsername("The server name is not known."), + } + } // Squash username to all lowercase letters res := &api.QueryAccountByPasswordResponse{} - err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) + err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ + Localpart: strings.ToLower(localpart), + ServerName: domain, + PlaintextPassword: r.Password, + }, res) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("unable to fetch account by password"), + JSON: jsonerror.Unknown("Unable to fetch account by password."), } } if !res.Exists { err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: localpart, + ServerName: domain, PlaintextPassword: r.Password, }, res) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("unable to fetch account by password"), + JSON: jsonerror.Unknown("Unable to fetch account by password."), } } // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 9088f7716..9ed1f0ca2 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap if err != nil { return util.ErrorResponse(err) } + serverName := cfg.Matrix.ServerName localpart, ok := vars["localpart"] if !ok { return util.JSONResponse{ @@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.MissingArgument("Expecting user localpart."), } } + if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil { + localpart, serverName = l, s + } request := struct { Password string `json:"password"` }{} @@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } updateReq := &userapi.PerformPasswordUpdateRequest{ Localpart: localpart, + ServerName: serverName, Password: request.Password, LogoutDevices: true, } diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 7f5a8c4f8..0de324da1 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -100,6 +100,7 @@ func completeAuth( DeviceID: login.DeviceID, AccessToken: token, Localpart: localpart, + ServerName: serverName, IPAddr: ipAddr, UserAgent: userAgent, }, &performRes) diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go index 8a424a141..f593e27db 100644 --- a/clientapi/routing/notification.go +++ b/clientapi/routing/notification.go @@ -40,16 +40,17 @@ func GetNotifications( } var queryRes userapi.QueryNotificationsResponse - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") return jsonerror.InternalServerError() } err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ - Localpart: localpart, - From: req.URL.Query().Get("from"), - Limit: int(limit), - Only: req.URL.Query().Get("only"), + Localpart: localpart, + ServerName: domain, + From: req.URL.Query().Get("from"), + Limit: int(limit), + Only: req.URL.Query().Get("only"), }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed") diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 6dc9af508..9772f669a 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -86,7 +86,7 @@ func Password( } // Get the local part. - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() @@ -94,8 +94,9 @@ func Password( // Ask the user API to perform the password change. passwordReq := &api.PerformPasswordUpdateRequest{ - Localpart: localpart, - Password: r.NewPassword, + Localpart: localpart, + ServerName: domain, + Password: r.NewPassword, } passwordRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { @@ -122,8 +123,9 @@ func Password( } pushersReq := &api.PerformPusherDeletionRequest{ - Localpart: localpart, - SessionID: device.SessionID, + Localpart: localpart, + ServerName: domain, + SessionID: device.SessionID, } if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go index d6a6eb936..89ec824bf 100644 --- a/clientapi/routing/pusher.go +++ b/clientapi/routing/pusher.go @@ -31,13 +31,14 @@ func GetPushers( userAPI userapi.ClientUserAPI, ) util.JSONResponse { var queryRes userapi.QueryPushersResponse - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") return jsonerror.InternalServerError() } err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ - Localpart: localpart, + Localpart: localpart, + ServerName: domain, }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") @@ -59,7 +60,7 @@ func SetPusher( req *http.Request, device *userapi.Device, userAPI userapi.ClientUserAPI, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") return jsonerror.InternalServerError() @@ -93,6 +94,7 @@ func SetPusher( } body.Localpart = localpart + body.ServerName = domain body.SessionID = device.SessionID err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) if err != nil { diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index b9ebb0518..9dc63af84 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -588,12 +588,15 @@ func Register( } // Auto generate a numeric username if r.Username is empty if r.Username == "" { - res := &userapi.QueryNumericLocalpartResponse{} - if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil { + nreq := &userapi.QueryNumericLocalpartRequest{ + ServerName: cfg.Matrix.ServerName, // TODO: might not be right + } + nres := &userapi.QueryNumericLocalpartResponse{} + if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") return jsonerror.InternalServerError() } - r.Username = strconv.FormatInt(res.ID, 10) + r.Username = strconv.FormatInt(nres.ID, 10) } // Is this an appservice registration? It will be if the access @@ -676,6 +679,7 @@ func handleGuestRegistration( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ Localpart: res.Account.Localpart, + ServerName: res.Account.ServerName, DeviceDisplayName: r.InitialDisplayName, AccessToken: token, IPAddr: req.RemoteAddr, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 1b3ef120a..a510761eb 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -157,7 +157,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}", + dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) }), diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index a6a78061d..a7acee326 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -286,6 +286,7 @@ func getSenderDevice( err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeUser, Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, OnConflict: userapi.ConflictUpdate, }, &accRes) if err != nil { @@ -295,8 +296,9 @@ func getSenderDevice( // Set the avatarurl for the user avatarRes := &userapi.PerformSetAvatarURLResponse{} if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ - Localpart: cfg.Matrix.ServerNotices.LocalPart, - AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, + Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, + AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, }, avatarRes); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") return nil, err @@ -308,6 +310,7 @@ func getSenderDevice( displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, DisplayName: cfg.Matrix.ServerNotices.DisplayName, }, displayNameRes); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") @@ -353,6 +356,7 @@ func getSenderDevice( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, AccessToken: token, NoDeviceListUpdate: true, diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 4b7989ecb..971bfcad3 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -136,16 +136,17 @@ func CheckAndSave3PIDAssociation( } // Save the association in the database - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() } if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ - ThreePID: address, - Localpart: localpart, - Medium: medium, + ThreePID: address, + Localpart: localpart, + ServerName: domain, + Medium: medium, }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") return jsonerror.InternalServerError() @@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation( func GetAssociated3PIDs( req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() @@ -169,7 +170,8 @@ func GetAssociated3PIDs( res := &api.QueryThreePIDsForLocalpartResponse{} err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{ - Localpart: localpart, + Localpart: localpart, + ServerName: domain, }, res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index e35b9c7f0..4578e33aa 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -120,15 +120,23 @@ func NewInternalAPI( js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{} + for _, serverName := range append( + []gomatrixserverlib.ServerName{base.Cfg.Global.ServerName}, + base.Cfg.Global.SecondaryServerNames..., + ) { + signingInfo[serverName] = &queue.SigningInfo{ + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + ServerName: serverName, + } + } + queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, cfg.Matrix.DisableFederation, cfg.Matrix.ServerName, federation, rsAPI, &stats, - &queue.SigningInfo{ - KeyID: cfg.Matrix.KeyID, - PrivateKey: cfg.Matrix.PrivateKey, - ServerName: cfg.Matrix.ServerName, - }, + signingInfo, ) rsConsumer := consumers.NewOutputRoomEventConsumer( diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 7ccc02f76..3acaa70dc 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err } // Get the keys and JSON-ify them. - keys := routing.LocalKeys(s.config) + keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host)) body, err := json.MarshalIndent(keys.JSON, "", " ") if err != nil { return nil, err diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a638a5742..bf04ee99a 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -50,7 +50,7 @@ type destinationQueue struct { queues *OutgoingQueues db storage.Database process *process.ProcessContext - signing *SigningInfo + signing map[gomatrixserverlib.ServerName]*SigningInfo rsAPI api.FederationRoomserverAPI client fedapi.FederationClient // federation client origin gomatrixserverlib.ServerName // origin of requests diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index b5d0552c6..68f354993 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -46,7 +46,7 @@ type OutgoingQueues struct { origin gomatrixserverlib.ServerName client fedapi.FederationClient statistics *statistics.Statistics - signing *SigningInfo + signing map[gomatrixserverlib.ServerName]*SigningInfo queuesMutex sync.Mutex // protects the below queues map[gomatrixserverlib.ServerName]*destinationQueue } @@ -91,7 +91,7 @@ func NewOutgoingQueues( client fedapi.FederationClient, rsAPI api.FederationRoomserverAPI, statistics *statistics.Statistics, - signing *SigningInfo, + signing map[gomatrixserverlib.ServerName]*SigningInfo, ) *OutgoingQueues { queues := &OutgoingQueues{ disabled: disabled, @@ -199,11 +199,10 @@ func (oqs *OutgoingQueues) SendEvent( log.Trace("Federation is disabled, not sending event") return nil } - if origin != oqs.origin { - // TODO: Support virtual hosting; gh issue #577. + if _, ok := oqs.signing[origin]; !ok { return fmt.Errorf( - "sendevent: unexpected server to send as: got %q expected %q", - origin, oqs.origin, + "sendevent: unexpected server to send as %q", + origin, ) } @@ -214,7 +213,9 @@ func (oqs *OutgoingQueues) SendEvent( destmap[d] = struct{}{} } delete(destmap, oqs.origin) - delete(destmap, oqs.signing.ServerName) + for local := range oqs.signing { + delete(destmap, local) + } // Check if any of the destinations are prohibited by server ACLs. for destination := range destmap { @@ -288,11 +289,10 @@ func (oqs *OutgoingQueues) SendEDU( log.Trace("Federation is disabled, not sending EDU") return nil } - if origin != oqs.origin { - // TODO: Support virtual hosting; gh issue #577. + if _, ok := oqs.signing[origin]; !ok { return fmt.Errorf( - "sendevent: unexpected server to send as: got %q expected %q", - origin, oqs.origin, + "sendevent: unexpected server to send as %q", + origin, ) } @@ -303,7 +303,9 @@ func (oqs *OutgoingQueues) SendEDU( destmap[d] = struct{}{} } delete(destmap, oqs.origin) - delete(destmap, oqs.signing.ServerName) + for local := range oqs.signing { + delete(destmap, local) + } // There is absolutely no guarantee that the EDU will have a room_id // field, as it is not required by the spec. However, if it *does* diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 7ef4646f7..58745c607 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T } rs := &stubFederationRoomServerAPI{} stats := statistics.NewStatistics(db, failuresUntilBlacklist) - signingInfo := &SigningInfo{ - KeyID: "ed21019:auto", - PrivateKey: test.PrivateKeyA, - ServerName: "localhost", + signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{ + "localhost": { + KeyID: "ed21019:auto", + PrivateKey: test.PrivateKeyA, + ServerName: "localhost", + }, } queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 8931830f3..5650e3d53 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -16,6 +16,7 @@ package routing import ( "encoding/json" + "fmt" "net/http" "time" @@ -134,18 +135,21 @@ func ClaimOneTimeKeys( // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys -func LocalKeys(cfg *config.FederationAPI) util.JSONResponse { - keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) +func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse { + keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) if err != nil { return util.ErrorResponse(err) } return util.JSONResponse{Code: http.StatusOK, JSON: keys} } -func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { +func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys + if !cfg.Matrix.IsLocalServerName(serverName) { + return nil, fmt.Errorf("server name not known") + } - keys.ServerName = cfg.Matrix.ServerName + keys.ServerName = serverName keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) @@ -172,7 +176,7 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver } keys.Raw, err = gomatrixserverlib.SignJSON( - string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, + string(serverName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, ) if err != nil { return nil, err @@ -186,6 +190,14 @@ func NotaryKeys( fsAPI federationAPI.FederationInternalAPI, req *gomatrixserverlib.PublicKeyNotaryLookupRequest, ) util.JSONResponse { + serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal + if !cfg.Matrix.IsLocalServerName(serverName) { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("Server name not known"), + } + } + if req == nil { req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { @@ -201,7 +213,7 @@ func NotaryKeys( for serverName, kidToCriteria := range req.ServerKeys { var keyList []gomatrixserverlib.ServerKeys if serverName == cfg.Matrix.ServerName { - if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { + if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { keyList = append(keyList, *k) } else { return util.ErrorResponse(err) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 9f16e5093..0a3ab7a88 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -74,7 +74,7 @@ func Setup( } localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { - return LocalKeys(cfg) + return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host)) }) notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 92ee80d81..37c55c8f8 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -33,16 +33,17 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/storage" + "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" ) type KeyInternalAPI struct { - DB storage.Database - ThisServer gomatrixserverlib.ServerName - FedClient fedsenderapi.KeyserverFederationAPI - UserAPI userapi.KeyserverUserAPI - Producer *producers.KeyChange - Updater *DeviceListUpdater + DB storage.Database + Cfg *config.KeyServer + FedClient fedsenderapi.KeyserverFederationAPI + UserAPI userapi.KeyserverUserAPI + Producer *producers.KeyChange + Updater *DeviceListUpdater } func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { @@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC nested[userID] = val domainToDeviceKeys[string(serverName)] = nested } - // claim local keys - if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok { + for domain, local := range domainToDeviceKeys { + if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + // claim local keys keys, err := a.DB.ClaimKeys(ctx, local) if err != nil { res.Error = &api.KeyError{ @@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON } } - delete(domainToDeviceKeys, string(a.ThisServer)) + delete(domainToDeviceKeys, domain) } if len(domainToDeviceKeys) > 0 { a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) @@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } domain := string(serverName) // query local devices - if serverName == a.ThisServer { + if a.Cfg.Matrix.IsLocalServerName(serverName) { deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ @@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys( domains := map[string]struct{}{} for domain := range domainToDeviceKeys { - if domain == string(a.ThisServer) { + if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} } for domain := range domainToCrossSigningKeys { - if domain == string(a.ThisServer) { + if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} @@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per if err != nil { continue // ignore invalid users } - if serverName != a.ThisServer { + if !a.Cfg.Matrix.IsLocalServerName(serverName) { continue // ignore remote users } if len(key.KeyJSON) == 0 { diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 9ae4f9ca3..0a4b8fde6 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -53,10 +53,10 @@ func NewInternalAPI( DB: db, } ap := &internal.KeyInternalAPI{ - DB: db, - ThisServer: cfg.Matrix.ServerName, - FedClient: fedClient, - Producer: keyChangeProducer, + DB: db, + Cfg: cfg, + FedClient: fedClient, + Producer: keyChangeProducer, } updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable ap.Updater = updater diff --git a/userapi/api/api.go b/userapi/api/api.go index 8d7f783de..d3f5aefc8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -78,7 +78,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI - QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error + QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error @@ -335,9 +335,10 @@ type PerformAccountCreationResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformPasswordUpdateRequest struct { - Localpart string // Required: The localpart for this account. - Password string // Required: The new password to set. - LogoutDevices bool // Optional: Whether to log out all user devices. + Localpart string // Required: The localpart for this account. + ServerName gomatrixserverlib.ServerName // Required: The domain for this account. + Password string // Required: The new password to set. + LogoutDevices bool // Optional: Whether to log out all user devices. } // PerformAccountCreationResponse is the response for PerformAccountCreation @@ -518,7 +519,8 @@ const ( ) type QueryPushersRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryPushersResponse struct { @@ -526,14 +528,16 @@ type QueryPushersResponse struct { } type PerformPusherSetRequest struct { - Pusher // Anonymous field because that's how clientapi unmarshals it. - Localpart string - Append bool `json:"append"` + Pusher // Anonymous field because that's how clientapi unmarshals it. + Localpart string + ServerName gomatrixserverlib.ServerName + Append bool `json:"append"` } type PerformPusherDeletionRequest struct { - Localpart string - SessionID int64 + Localpart string + ServerName gomatrixserverlib.ServerName + SessionID int64 } // Pusher represents a push notification subscriber @@ -571,10 +575,11 @@ type QueryPushRulesResponse struct { } type QueryNotificationsRequest struct { - Localpart string `json:"localpart"` // Required. - From string `json:"from,omitempty"` - Limit int `json:"limit,omitempty"` - Only string `json:"only,omitempty"` + Localpart string `json:"localpart"` // Required. + ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required. + From string `json:"from,omitempty"` + Limit int `json:"limit,omitempty"` + Only string `json:"only,omitempty"` } type QueryNotificationsResponse struct { @@ -601,12 +606,17 @@ type PerformSetAvatarURLResponse struct { Changed bool `json:"changed"` } +type QueryNumericLocalpartRequest struct { + ServerName gomatrixserverlib.ServerName +} + type QueryNumericLocalpartResponse struct { ID int64 } type QueryAccountAvailabilityRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryAccountAvailabilityResponse struct { @@ -614,7 +624,9 @@ type QueryAccountAvailabilityResponse struct { } type QueryAccountByPasswordRequest struct { - Localpart, PlaintextPassword string + Localpart string + ServerName gomatrixserverlib.ServerName + PlaintextPassword string } type QueryAccountByPasswordResponse struct { @@ -638,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct { } type QueryLocalpartForThreePIDResponse struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryThreePIDsForLocalpartRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryThreePIDsForLocalpartResponse struct { @@ -652,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct { type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest type PerformSaveThreePIDAssociationRequest struct { - ThreePID, Localpart, Medium string + ThreePID string + Localpart string + ServerName gomatrixserverlib.ServerName + Medium string } diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 90834f7e3..ce661770f 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet return err } -func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error { - err := t.Impl.QueryNumericLocalpart(ctx, res) +func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error { + err := t.Impl.QueryNumericLocalpart(ctx, req, res) util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res)) return err } diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 79f1bf06f..42ae72e77 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return false } - updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) + updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) if err != nil { log.WithError(err).Error("userapi EDU consumer") return false @@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats if !updated { return true } - if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil { log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") return false } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index b6b30a095..5d8924dda 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { for _, membership := range localMembers { // Copy any existing push rules from old -> new room - if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { return err } // preserve m.direct room state - if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil { + if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil { return err } // copy existing m.tag entries, if any - if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { return err } } return nil } -func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error { - pushRules, err := s.db.QueryPushRules(ctx, localpart) +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error { + pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName) if err != nil { return fmt.Errorf("failed to query pushrules for user: %w", err) } @@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, if err != nil { return err } - if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil { + if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil { return fmt.Errorf("failed to update pushrules: %w", err) } } @@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, } // updateMDirect copies the "is_direct" flag from oldRoomID to newROomID -func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error { +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error { // this is most likely not a DM, so skip updating m.direct state if roomSize > 2 { return nil } // Get direct message state - directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct") + directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct") if err != nil { return fmt.Errorf("failed to get m.direct from database: %w", err) } @@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, if err != nil { return true } - if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil { + if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil { return true } } @@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, return nil } -func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error { - tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag") +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error { + tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag") if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } if tag == nil { return nil } - return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag) + return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag) } func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { @@ -492,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) if err != nil { - return err + return fmt.Errorf("s.evaluatePushRules: %w", err) } a, tweaks, err := pushrules.ActionsToTweaks(actions) if err != nil { - return err + return fmt.Errorf("pushrules.ActionsToTweaks: %w", err) } // TODO: support coalescing. if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { @@ -508,9 +508,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr return nil } - devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) + devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks) if err != nil { - return err + return fmt.Errorf("s.localPushDevices: %w", err) } n := &api.Notification{ @@ -527,18 +527,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr RoomID: event.RoomID(), TS: gomatrixserverlib.AsTimestamp(time.Now()), } - if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { - return err + if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil { + return fmt.Errorf("s.db.InsertNotification: %w", err) } if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { - return err + return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err) } // We do this after InsertNotification. Thus, this should always return >=1. - userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) + userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications) if err != nil { - return err + return fmt.Errorf("s.db.GetNotificationCount: %w", err) } log.WithFields(log.Fields{ @@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr } if len(rejected) > 0 { - s.deleteRejectedPushers(ctx, rejected, mem.Localpart) + s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain) } }() @@ -606,7 +606,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * } // Get accountdata to check if the event.Sender() is ignored by mem.LocalPart - data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list") + data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list") if err != nil { return nil, err } @@ -621,7 +621,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * return nil, fmt.Errorf("user %s is ignored", sender) } } - ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart) + ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain) if err != nil { return nil, err } @@ -693,10 +693,10 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err // localPushDevices pushes to the configured devices of a local // user. The map keys are [url][format]. -func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { - pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) +func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { + pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db) if err != nil { - return nil, "", err + return nil, "", fmt.Errorf("util.GetPushDevices: %w", err) } var profileTag string @@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri } // deleteRejectedPushers deletes the pushers associated with the given devices. -func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { +func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) { log.WithFields(log.Fields{ "localpart": localpart, "app_id0": devices[0].AppID, @@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev }).Warnf("Deleting pushers rejected by the HTTP push gateway") for _, d := range devices { - if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { + if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil { log.WithFields(log.Fields{ "localpart": localpart, }).WithError(err).Errorf("Unable to delete rejected pusher") diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 9ca76965d..3f256457e 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if req.DataType == "" { return fmt.Errorf("data type must not be empty") } - if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil { + if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil { util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed") return fmt.Errorf("failed to save account data: %w", err) } @@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } - deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) if err != nil { logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") return err @@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } - if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil { + if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil { logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") return err } @@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P if serverName == "" { serverName = a.Config.Matrix.ServerName } - // XXXX: Use the server name here - acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %s is not local", serverName) + } + acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { @@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { - return err + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil { + return fmt.Errorf("a.DB.SetDisplayName: %w", err) } postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) @@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P } func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { - if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + if !a.Config.Matrix.IsLocalServerName(req.ServerName) { + return fmt.Errorf("server name %s is not local", req.ServerName) + } + if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil { return err } if req.LogoutDevices { - if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil { + if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil { return err } } @@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe if serverName == "" { serverName = a.Config.Matrix.ServerName } - _ = serverName - // XXXX: Use the server name here + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %s is not local", serverName) + } util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") - dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) + dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err } @@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) + devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { - err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs) + err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs) } if err != nil { return err @@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate( req *api.PerformLastSeenUpdateRequest, res *api.PerformLastSeenUpdateResponse, ) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil } func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") return err } - dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } + dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID) if err == sql.ErrNoRows { res.DeviceExists = false return nil @@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf return nil } - err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err @@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) } - prof, err := a.DB.GetProfileByLocalpart(ctx, local) + prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain) if err != nil { if err == sql.ErrNoRows { return nil @@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) } - devs, err := a.DB.GetDevicesByLocalpart(ctx, local) + devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain) if err != nil { return err } @@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } if req.DataType != "" { var data json.RawMessage - data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType) if err != nil { return err } @@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } return nil } - global, rooms, err := a.DB.GetAccountData(ctx, local) + global, rooms, err := a.DB.GetAccountData(ctx, local, domain) if err != nil { return err } @@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc if !a.Config.Matrix.IsLocalServerName(domain) { return nil } - acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) + acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain) if err != nil { return err } @@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe AccountType: api.AccountTypeAppService, } - localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) + localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) if err != nil { return nil, err } if localpart != "" { // AS is masquerading as another user // Verify that the user is registered - account, err := a.DB.GetAccountByLocalpart(ctx, localpart) + account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain) // Verify that the account exists and either appServiceID matches or // it belongs to the appservice user namespaces if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { @@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a return err } - err := a.DB.DeactivateAccount(ctx, req.Localpart) + err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName) res.AccountDeactivated = err == nil return err } @@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query if req.Only == "highlight" { filter = tables.HighlightNotifications } - notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter) + notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter) if err != nil { return err } @@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform } } if req.Pusher.Kind == "" { - return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) + return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName) } if req.Pusher.PushKeyTS == 0 { req.Pusher.PushKeyTS = int64(time.Now().Unix()) } - return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) + return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName) } func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { - pushers, err := a.DB.GetPushers(ctx, req.Localpart) + pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName) if err != nil { return err } for i := range pushers { logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) if pushers[i].SessionID != req.SessionID { - err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart) + err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName) if err != nil { return err } @@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { var err error - res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart) + res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName) return err } @@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut( } func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) } - pushRules, err := a.DB.QueryPushRules(ctx, localpart) + pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain) if err != nil { return fmt.Errorf("failed to query push rules: %w", err) } @@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL) res.Profile = profile res.Changed = changed return err } -func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { - id, err := a.DB.GetNewNumericLocalpart(ctx) +func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error { + id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName) if err != nil { return err } @@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { var err error - res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart) + res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName) return err } func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { - acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword) + acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword) switch err { case sql.ErrNoRows: // user does not exist return nil @@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { - profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName) res.Profile = profile res.Changed = changed return err } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { - localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) + localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) if err != nil { return err } res.Localpart = localpart + res.ServerName = domain return nil } func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { - r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart) + r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName) if err != nil { return err } @@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { - return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium) + return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) } const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index 87f25e5e2..3b211db5b 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) } - if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { + if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil { res.Data = nil if err == sql.ErrNoRows { return nil diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index aa5d46d9f..87ae058c2 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL( func (h *httpUserInternalAPI) QueryNumericLocalpart( ctx context.Context, + request *api.QueryNumericLocalpartRequest, response *api.QueryNumericLocalpartResponse, ) error { return httputil.CallInternalRPCAPI( "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, - h.httpClient, ctx, &struct{}{}, response, + h.httpClient, ctx, request, response, ) } diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 99148b760..661fecfae 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -15,12 +15,9 @@ package inthttp import ( - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) // nolint: gocyclo @@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), ) - // TODO: Look at the shape of this - internalAPIMux.Handle(QueryNumericLocalpartPath, - httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse { - response := api.QueryNumericLocalpartResponse{} - if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + QueryNumericLocalpartPath, + httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart), ) internalAPIMux.Handle( diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index f556ea352..51eaa9856 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err // GetAndSendNotificationData reads the database and sends data about unread // notifications to the Sync API server. func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error { - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return err } - ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID) + ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID) if err != nil { return err } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 28ef26559..c22b7658f 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -29,40 +29,40 @@ import ( ) type Profile interface { - GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error) + SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) } type Account interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) - GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) - GetNewNumericLocalpart(ctx context.Context) (int64, error) - CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) - DeactivateAccount(ctx context.Context, localpart string) (err error) - SetPassword(ctx context.Context, localpart string, plaintextPassword string) error + CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error) + GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) + CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) + GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) + DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error } type AccountData interface { - SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error - GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) + SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error + GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given // localpart, room ID and type. // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval - GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) - QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error) + GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) + QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error) } type Device interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked @@ -70,12 +70,12 @@ type Device interface { // an error will be returned. // If no device ID is given one is generated. // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error - RemoveDevices(ctx context.Context, localpart string, devices []string) error + CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error + RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error) } type KeyBackup interface { @@ -107,26 +107,26 @@ type OpenID interface { } type Pusher interface { - UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error - GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) - RemovePusher(ctx context.Context, appid, pushkey, localpart string) error + UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error + GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error RemovePushers(ctx context.Context, appid, pushkey string) error } type ThreePID interface { - SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) + SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) - GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) - GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) } type Notification interface { - InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error) - GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) - GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) - GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) DeleteOldNotifications(ctx context.Context) error } diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 0b6a3af6d..2a4777d74 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -29,27 +30,28 @@ const accountDataSchema = ` CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) room_id TEXT, -- The account data type type TEXT NOT NULL, -- The account data content - content TEXT NOT NULL, - - PRIMARY KEY(localpart, room_id, type) + content TEXT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type); ` const insertAccountDataSQL = ` - INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content + INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5) + ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4" type accountDataStatements struct { insertAccountDataStmt *sql.Stmt @@ -71,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { } func (s *accountDataStatements) InsertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) return } func (s *accountDataStatements) SelectAccountData( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, error, ) { - rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName) if err != nil { return nil, nil, err } @@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData( } func (s *accountDataStatements) SelectAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 7c309eb4f..31a996527 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/matrix-org/gomatrixserverlib" @@ -34,7 +35,8 @@ const accountsSchema = ` -- Stores data about accounts. CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this account was first created, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The password hash for this account. Can be NULL if this is a passwordless account. @@ -48,25 +50,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name); ` const insertAccountSQL = "" + - "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + - "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3" const deactivateAccountSQL = "" + - "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" + "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'" + "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1" type accountsStatements struct { insertAccountStmt *sql.Stmt @@ -117,59 +121,62 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error if accountType != api.AccountTypeAppService { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { - return nil, err + return nil, fmt.Errorf("insertAccountStmt: %w", err) } return &api.Account{ Localpart: localpart, - UserID: userutil.MakeUserID(localpart, s.serverName), - ServerName: s.serverName, + UserID: userutil.MakeUserID(localpart, serverName), + ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, }, nil } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart, passwordHash string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) return } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -180,19 +187,17 @@ func (s *accountsStatements) SelectAccountByLocalpart( acc.AppServiceID = appserviceIDPtr.String } - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName - + acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName) return &acc, nil } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - err = stmt.QueryRowContext(ctx).Scan(&id) + err = stmt.QueryRowContext(ctx, serverName).Scan(&id) return id + 1, err } diff --git a/userapi/storage/postgres/deltas/2022110411000000_server_names.go b/userapi/storage/postgres/deltas/2022110411000000_server_names.go new file mode 100644 index 000000000..279e1e5f1 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022110411000000_server_names.go @@ -0,0 +1,81 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +var serverNamesTables = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_devices", + "userapi_notifications", + "userapi_openid_tokens", + "userapi_profiles", + "userapi_pushers", + "userapi_threepids", +} + +// These tables have a PRIMARY KEY constraint which we need to drop so +// that we can recreate a new unique index that contains the server name. +// If the new key doesn't exist (i.e. the database was created before the +// table rename migration) we'll try to drop the old one instead. +var serverNamesDropPK = map[string]string{ + "userapi_accounts": "account_accounts", + "userapi_account_datas": "account_data", + "userapi_profiles": "account_profiles", +} + +// These indices are out of date so let's drop them. They will get recreated +// automatically. +var serverNamesDropIndex = []string{ + "userapi_pusher_localpart_idx", + "userapi_pusher_app_id_pushkey_localpart_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';", + pq.QuoteIdentifier(table), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("add server name to %q error: %w", table, err) + } + } + for newTable, oldTable := range serverNamesDropPK { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;", + pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(newTable+"_pkey"), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop new PK from %q error: %w", newTable, err) + } + q = fmt.Sprintf( + "ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;", + pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(oldTable+"_pkey"), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop old PK from %q error: %w", newTable, err) + } + } + for _, index := range serverNamesDropIndex { + q := fmt.Sprintf( + "DROP INDEX IF EXISTS %s;", + pq.QuoteIdentifier(index), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop index %q error: %w", index, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/deltas/2022110411000001_server_names.go b/userapi/storage/postgres/deltas/2022110411000001_server_names.go new file mode 100644 index 000000000..04a47fa7b --- /dev/null +++ b/userapi/storage/postgres/deltas/2022110411000001_server_names.go @@ -0,0 +1,28 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "UPDATE %s SET server_name = %s WHERE server_name = '';", + pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("write server names to %q error: %w", table, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 8b7fbd6cf..2dd216189 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/lib/pq" @@ -50,6 +51,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( -- as it is smaller, makes it clearer that we only manage devices for our own users, and may make -- migration to different domain names easier. localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this devices was first recognised on the network, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The display name, human friendlier than device_id and updatable @@ -65,39 +67,39 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( ); -- Device IDs must be unique for a given user. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id); ` const insertDeviceSQL = "" + - "INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + + "INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + " RETURNING session_id" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" const deleteDeviceSQL = "" + - "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3" const deleteDevicesSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" type devicesStatements struct { insertDeviceStmt *sql.Stmt @@ -148,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) InsertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { - return nil, err + if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { + return nil, fmt.Errorf("insertDeviceStmt: %w", err) } return &api.Device{ ID: id, - UserID: userutil.MakeUserID(localpart, s.serverName), + UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, @@ -170,38 +173,45 @@ func (s *devicesStatements) InsertDevice( // deleteDevice removes a single device by id and user localpart. func (s *devicesStatements) DeleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + _, err := stmt.ExecContext(ctx, id, localpart, serverName) return err } // deleteDevices removes a single or multiple devices by ids and user localpart. // Returns an error if the execution failed. func (s *devicesStatements) DeleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) - _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) + _, err := stmt.ExecContext(ctx, localpart, serverName, pq.Array(devices)) return err } // deleteDevicesByLocalpart removes all devices for the // given user localpart. func (s *devicesStatements) DeleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID) return err } func (s *devicesStatements) UpdateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID) return err } @@ -210,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) dev.AccessToken = accessToken } return &dev, err @@ -222,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { var dev api.Device var displayName, ip sql.NullString var lastseenTS sql.NullInt64 stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) + err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) if displayName.Valid { dev.DisplayName = displayName.String } @@ -254,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName var lastseents sql.NullInt64 var displayName sql.NullString for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { @@ -266,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s if lastseents.Valid { dev.LastSeenTS = lastseents.Int64 } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } func (s *devicesStatements) SelectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID) if err != nil { return devices, err @@ -307,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart( dev.UserAgent = useragent.String } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) return err } diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index 24a30b2f5..dc64b1e79 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -43,6 +43,7 @@ const notificationSchema = ` CREATE TABLE IF NOT EXISTS userapi_notifications ( id BIGSERIAL PRIMARY KEY, localpart TEXT NOT NULL, + server_name TEXT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, stream_pos BIGINT NOT NULL, @@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications ( read BOOLEAN NOT NULL DEFAULT FALSE ); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id); ` const insertNotificationSQL = "" + - "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const deleteNotificationsUpToSQL = "" + - "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4" const updateNotificationReadSQL = "" + - "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1" const selectNotificationSQL = "" + - "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + - "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + - ") AND NOT read ORDER BY localpart, id LIMIT $4" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" + + "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $5" const selectNotificationCountSQL = "" + - "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + - "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + ") AND NOT read" const selectRoomNotificationCountsSQL = "" + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + - "WHERE localpart = $1 AND room_id = $2 AND NOT read" + "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read" const cleanNotificationsSQL = "" + "DELETE FROM userapi_notifications WHERE" + @@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local if err != nil { return err } - _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs)) return err } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { return nil, 0, err @@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { - err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { - err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 06ae30d08..68d87f007 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When the token expires, as a unix timestamp (ms resolution). token_expires_at_ms BIGINT NOT NULL ); ` const insertOpenIDTokenSQL = "" + - "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" + "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { insertTokenStmt *sql.Stmt @@ -54,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, + token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS) return } @@ -69,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes + var localpart string + var serverName gomatrixserverlib.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, + &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, ) + openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve token from the db") diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 2753b23d9..df4e0db63 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -23,42 +23,46 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account avatar_url TEXT ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name); ` const insertProfileSQL = "" + - "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2" const setAvatarURLSQL = "" + "UPDATE userapi_profiles AS new" + " SET avatar_url = $1" + " FROM userapi_profiles AS old" + - " WHERE new.localpart = $2" + + " WHERE new.localpart = $2 AND new.server_name = $3" + " RETURNING new.display_name, old.avatar_url <> new.avatar_url" const setDisplayNameSQL = "" + "UPDATE userapi_profiles AS new" + " SET display_name = $1" + " FROM userapi_profiles AS old" + - " WHERE new.localpart = $2" + + " WHERE new.localpart = $2 AND new.server_name = $3" + " RETURNING new.avatar_url, old.display_name <> new.display_name" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { serverNoticesLocalpart string @@ -87,18 +91,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( + &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { return nil, err @@ -107,28 +113,34 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ - Localpart: localpart, - AvatarURL: avatarURL, + Localpart: localpart, + ServerName: string(serverName), + AvatarURL: avatarURL, } var changed bool stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed) + err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed) return profile, changed, err } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, + ServerName: string(serverName), DisplayName: displayName, } var changed bool stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed) + err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed) return profile, changed, err } @@ -146,7 +158,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 6fb714fba..707b3bd2b 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( id BIGSERIAL PRIMARY KEY, -- The Matrix user ID localpart for this pusher localpart TEXT NOT NULL, + server_name TEXT NOT NULL, session_id BIGINT DEFAULT NULL, profile_tag TEXT, kind TEXT NOT NULL, @@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); -- For faster retrieving by localpart. -CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name); -- Pushkey must be unique for a given user and app. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name); ` const insertPusherSQL = "" + - "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + - "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + + "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12" const selectPushersSQL = "" + - "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" const deletePusherSQL = "" + - "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4" const deletePushersByAppIdAndPushKeySQL = "" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" @@ -95,18 +97,19 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) - logrus.Debugf("Created pusher %d", session_id) + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err } func (s *pushersStatements) SelectPushers( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} - rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName) if err != nil { return pushers, err @@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( - ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, + ctx context.Context, txn *sql.Tx, appid, pushkey, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err } diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index c059e3e60..92dc48081 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -15,6 +15,8 @@ package postgres import ( + "context" + "database/sql" "fmt" "time" @@ -43,18 +45,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + }) if err = m.Up(base.Context()); err != nil { return nil, err } - accountDataTable, err := NewPostgresAccountDataTable(db) - if err != nil { - return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) - } accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) } + accountDataTable, err := NewPostgresAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) + } devicesTable, err := NewPostgresDevicesTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) @@ -95,6 +103,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) } + + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names populate", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNamesPopulate(ctx, txn, serverName) + }, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 11af76161..f41c43122 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -33,21 +34,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids ( medium TEXT NOT NULL DEFAULT 'email', -- The localpart of the Matrix user ID associated to this 3PID localpart TEXT NOT NULL, + server_name TEXT NOT NULL, PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart); +CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" + "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2" const insertThreePIDSQL = "" + - "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)" const deleteThreePIDSQL = "" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" @@ -75,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { - return "", nil + return "", "", nil } return } func (s *threepidStatements) SelectThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { return } @@ -109,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( } func (s *threepidStatements) InsertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, txn *sql.Tx, threepid, medium, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) return } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index f8b8d02c9..f549dcef9 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -68,9 +68,10 @@ const ( // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword string, ) (*api.Account, error) { - hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) + hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName) if err != nil { return nil, err } @@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword( if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { return nil, err } - return d.Accounts.SelectAccountByLocalpart(ctx, localpart) + return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) } // GetProfileByLocalpart returns the profile associated with the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart. func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { - return d.Profiles.SelectProfileByLocalpart(ctx, localpart) + return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName) } // SetAvatarURL updates the avatar URL of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL) return err }) return @@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL( // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName) return err }) return @@ -117,14 +123,15 @@ func (d *Database) SetDisplayName( // SetPassword sets the account password to the given hash. func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword string, ) error { hash, err := d.hashPassword(plaintextPassword) if err != nil { return err } return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.Accounts.UpdatePassword(ctx, localpart, hash) + return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash) }) } @@ -132,21 +139,22 @@ func (d *Database) SetPassword( // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { // For guest accounts, we create a new numeric local part if accountType == api.AccountTypeGuest { var numLocalpart int64 - numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) + numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName) if err != nil { - return err + return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err) } localpart = strconv.FormatInt(numLocalpart, 10) plaintextPassword = "" appserviceID = "" } - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) + acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType) return err }) return @@ -155,7 +163,9 @@ func (d *Database) CreateAccount( // WARNING! This function assumes that the relevant mutexes have already // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var err error var account *api.Account @@ -167,28 +177,28 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { + if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil { return nil, sqlutil.ErrUserExists } - if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { - return nil, err + if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil { + return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err) } - pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) prbs, err := json.Marshal(pushRuleSets) if err != nil { - return nil, err + return nil, fmt.Errorf("json.Marshal: %w", err) } - if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { - return nil, err + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil { + return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err) } return account, nil } func (d *Database) QueryPushRules( ctx context.Context, - localpart string, + localpart string, serverName gomatrixserverlib.ServerName, ) (*pushrules.AccountRuleSets, error) { - data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules") + data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules") if err != nil { return nil, err } @@ -196,13 +206,13 @@ func (d *Database) QueryPushRules( // If we didn't find any default push rules then we should just generate some // fresh ones. if len(data) == 0 { - pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) prbs, err := json.Marshal(pushRuleSets) if err != nil { return nil, fmt.Errorf("failed to marshal default push rules: %w", err) } err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil { + if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil { return fmt.Errorf("failed to save default push rules: %w", dbErr) } return nil @@ -225,22 +235,23 @@ func (d *Database) QueryPushRules( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content) + return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content) }) } // GetAccountData returns account data related to a given localpart // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( +func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ( global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error, ) { - return d.AccountDatas.SelectAccountData(ctx, localpart) + return d.AccountDatas.SelectAccountData(ctx, localpart, serverName) } // GetAccountDataByType returns account data matching a given @@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { return d.AccountDatas.SelectAccountDataByType( - ctx, localpart, roomID, dataType, + ctx, localpart, serverName, roomID, dataType, ) } // GetNewNumericLocalpart generates and returns a new unused numeric localpart func (d *Database) GetNewNumericLocalpart( - ctx context.Context, + ctx context.Context, serverName gomatrixserverlib.ServerName, ) (int64, error) { - return d.Accounts.SelectNewNumericLocalpart(ctx, nil) + return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName) } func (d *Database) hashPassword(plaintext string) (hash string, err error) { @@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use") // If the third-party identifier is already part of an association, returns Err3PIDInUse. // Returns an error if there was a problem talking to the database. func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, + ctx context.Context, threepid string, + localpart string, serverName gomatrixserverlib.ServerName, + medium string, ) (err error) { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - user, err := d.ThreePIDs.SelectLocalpartForThreePID( + user, _, err := d.ThreePIDs.SelectLocalpartForThreePID( ctx, txn, threepid, medium, ) if err != nil { @@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation( return Err3PIDInUse } - return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) + return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName) }) } @@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation( // Returns an error if there was a problem talking to the database. func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) } @@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID( // If no association is known for this user, returns an empty slice. // Returns an error if there was an issue talking to the database. func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) + return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName) } // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) { + _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) if err == sql.ErrNoRows { return true, nil } @@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // GetAccountByLocalpart returns the account associated with the given localpart. // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { // try to get the account with lowercase localpart (majority) - acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart)) + acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName) if err == sql.ErrNoRows { - acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request + acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request } return acc, err } @@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi } // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { +func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.Accounts.DeactivateAccount(ctx, localpart) + return d.Accounts.DeactivateAccount(ctx, localpart, serverName) }) } // CreateOpenIDToken persists a new token that was issued for OpenID Connect func (d *Database) CreateOpenIDToken( ctx context.Context, - token, localpart string, + token, userID string, ) (int64, error) { + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return 0, nil + } expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS) }) return expiresAtMS, err } @@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken( // GetDeviceByID returns the device matching the given ID. // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { - return d.Devices.SelectDeviceByID(ctx, localpart, deviceID) + return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID) } // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Device, error) { - return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "") + return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap // If no device ID is given one is generated. // Returns the device on success. func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device - if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil { + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { return err } - dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) return err }) } else { @@ -588,7 +610,7 @@ func (d *Database) CreateDevice( returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error - dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) return err }) if returnErr == nil { @@ -614,10 +636,12 @@ func generateDeviceID() (string, error) { // UpdateDevice updates the given device with the display name. // Returns SQL error if there are problems and nil on success. func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName) + return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName) }) } @@ -626,10 +650,12 @@ func (d *Database) UpdateDevice( // If the devices don't exist, it will not return an error // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows { return err } return nil @@ -640,14 +666,16 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) (devices []api.Device, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID) if err != nil { return err } - if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows { return err } return nil @@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a last seen timestamp and the ip address. -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent) + return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent) }) } @@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } -func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { +func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) }) } -func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) { +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) + affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos) return err }) return } -func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) { +func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) + affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b) return err }) return } -func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter) +func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter) } -func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { - return d.Notifications.SelectCount(ctx, nil, localpart, filter) +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) { + return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter) } -func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { - return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) { + return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID) } func (d *Database) DeleteOldNotifications(ctx context.Context) error { @@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error { } func (d *Database) UpsertPusher( - ctx context.Context, p api.Pusher, localpart string, + ctx context.Context, p api.Pusher, + localpart string, serverName gomatrixserverlib.ServerName, ) error { data, err := json.Marshal(p.Data) if err != nil { @@ -766,25 +795,26 @@ func (d *Database) UpsertPusher( p.ProfileTag, p.Language, string(data), - localpart) + localpart, + serverName) }) } // GetPushers returns the pushers matching the given localpart. func (d *Database) GetPushers( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { - return d.Pushers.SelectPushers(ctx, nil, localpart) + return d.Pushers.SelectPushers(ctx, nil, localpart, serverName) } // RemovePusher deletes one pusher // Invoked when `append` is true and `kind` is null in // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set func (d *Database) RemovePusher( - ctx context.Context, appid, pushkey, localpart string, + ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) + err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName) if err == sql.ErrNoRows { return nil } diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index af12decb3..2fbdc5732 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -28,27 +29,28 @@ const accountDataSchema = ` CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) room_id TEXT, -- The account data type type TEXT NOT NULL, -- The account data content - content TEXT NOT NULL, - - PRIMARY KEY(localpart, room_id, type) + content TEXT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type); ` const insertAccountDataSQL = ` - INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 + INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5) + ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5 ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4" type accountDataStatements struct { db *sql.DB @@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { } func (s *accountDataStatements) InsertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) error { - _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content) return err } func (s *accountDataStatements) SelectAccountData( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, error, ) { - rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName) if err != nil { return nil, nil, err } @@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData( } func (s *accountDataStatements) SelectAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 671c1aa04..f4ebe2158 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -34,7 +34,8 @@ const accountsSchema = ` -- Stores data about accounts. CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this account was first created, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The password hash for this account. Can be NULL if this is a passwordless account. @@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name); ` const insertAccountSQL = "" + - "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + - "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3" const deactivateAccountSQL = "" + - "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" + "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1" type accountsStatements struct { db *sql.DB @@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, + hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error if accountType != api.AccountTypeAppService { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount( return &api.Account{ Localpart: localpart, - UserID: userutil.MakeUserID(localpart, s.serverName), - ServerName: s.serverName, + UserID: userutil.MakeUserID(localpart, serverName), + ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, }, nil } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart, passwordHash string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) return } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart( acc.AppServiceID = appserviceIDPtr.String } - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName - + acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName) return &acc, nil } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - err = stmt.QueryRowContext(ctx).Scan(&id) + err = stmt.QueryRowContext(ctx, serverName).Scan(&id) if err == sql.ErrNoRows { return 1, nil } diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go index 9158cb365..2de85005f 100644 --- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error { ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index a9224db6b..636ce4efc 100644 --- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { session_id INTEGER, device_id TEXT , localpart TEXT , + server_name TEXT NOT NULL, created_ts BIGINT, display_name TEXT, last_seen_ts BIGINT, diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go index 230bc1433..471e496cd 100644 --- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go new file mode 100644 index 000000000..c11ea6844 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go @@ -0,0 +1,108 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +var serverNamesTables = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_devices", + "userapi_notifications", + "userapi_openid_tokens", + "userapi_profiles", + "userapi_pushers", + "userapi_threepids", +} + +// These tables have a PRIMARY KEY constraint which we need to drop so +// that we can recreate a new unique index that contains the server name. +var serverNamesDropPK = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_profiles", +} + +// These indices are out of date so let's drop them. They will get recreated +// automatically. +var serverNamesDropIndex = []string{ + "userapi_pusher_localpart_idx", + "userapi_pusher_app_id_pushkey_localpart_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 { + continue + } + q = fmt.Sprintf( + "SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'", + pq.QuoteIdentifier(table), + ) + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 { + logrus.Infof("Table %s already has column, skipping", table) + continue + } + if c == 0 { + q = fmt.Sprintf( + "ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';", + pq.QuoteIdentifier(table), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("add server name to %q error: %w", table, err) + } + } + } + for _, table := range serverNamesDropPK { + q := fmt.Sprintf( + "SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + var sql string + if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 { + continue + } + q = fmt.Sprintf(` + %s; -- create temporary table + INSERT INTO %s SELECT * FROM %s; -- copy data + DROP TABLE %s; -- drop original table + ALTER TABLE %s RENAME TO %s; -- rename new table + `, + strings.Replace(sql, table, table+"_tmp", 1), // create temporary table + table+"_tmp", table, // copy data + table, // drop original table + table+"_tmp", table, // rename new table + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop PK from %q error: %w", table, err) + } + } + for _, index := range serverNamesDropIndex { + q := fmt.Sprintf( + "DROP INDEX IF EXISTS %s;", + pq.QuoteIdentifier(index), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop index %q error: %w", index, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go new file mode 100644 index 000000000..04a47fa7b --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go @@ -0,0 +1,28 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "UPDATE %s SET server_name = %s WHERE server_name = '';", + pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("write server names to %q error: %w", table, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index e53a08062..c5db34bd7 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( session_id INTEGER, device_id TEXT , localpart TEXT , + server_name TEXT NOT NULL, created_ts BIGINT, display_name TEXT, last_seen_ts BIGINT, ip TEXT, user_agent TEXT, - UNIQUE (localpart, device_id) + UNIQUE (localpart, server_name, device_id) ); ` const insertDeviceSQL = "" + - "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + "INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" const selectDevicesCountSQL = "" + "SELECT COUNT(access_token) FROM userapi_devices" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" const deleteDeviceSQL = "" + - "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3" const deleteDevicesSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" type devicesStatements struct { db *sql.DB @@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) InsertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 @@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice( return nil, err } sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { + if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } return &api.Device{ ID: id, - UserID: userutil.MakeUserID(localpart, s.serverName), + UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, @@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice( } func (s *devicesStatements) DeleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + _, err := stmt.ExecContext(ctx, id, localpart, serverName) return err } func (s *devicesStatements) DeleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) + orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1) prep, err := s.db.Prepare(orig) if err != nil { return err } stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) + params := make([]interface{}, len(devices)+2) params[0] = localpart + params[1] = serverName for i, v := range devices { - params[i+1] = v + params[i+2] = v } _, err = stmt.ExecContext(ctx, params...) return err } func (s *devicesStatements) DeleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID) return err } func (s *devicesStatements) UpdateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID) return err } @@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) dev.AccessToken = accessToken } return &dev, err @@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { var dev api.Device var displayName, ip sql.NullString stmt := s.selectDeviceByIDStmt var lastseenTS sql.NullInt64 - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) + err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) if displayName.Valid { dev.DisplayName = displayName.String } @@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID( } func (s *devicesStatements) SelectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID) if err != nil { return devices, err @@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( dev.UserAgent = useragent.String } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } @@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { @@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s if lastseents.Valid { dev.LastSeenTS = lastseents.Int64 } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) return err } diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index a35ec7be5..ef39d027c 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -43,6 +43,7 @@ const notificationSchema = ` CREATE TABLE IF NOT EXISTS userapi_notifications ( id INTEGER PRIMARY KEY AUTOINCREMENT, localpart TEXT NOT NULL, + server_name TEXT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, stream_pos BIGINT NOT NULL, @@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications ( read BOOLEAN NOT NULL DEFAULT FALSE ); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id); ` const insertNotificationSQL = "" + - "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const deleteNotificationsUpToSQL = "" + - "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4" const updateNotificationReadSQL = "" + - "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1" const selectNotificationSQL = "" + - "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + - "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + - ") AND NOT read ORDER BY localpart, id LIMIT $4" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" + + "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $5" const selectNotificationCountSQL = "" + - "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + - "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + ") AND NOT read" const selectRoomNotificationCountsSQL = "" + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + - "WHERE localpart = $1 AND room_id = $2 AND NOT read" + "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read" const cleanNotificationsSQL = "" + "DELETE FROM userapi_notifications WHERE" + @@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local if err != nil { return err } - _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs)) return err } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { return nil, 0, err @@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { - err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { - err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index 875f1a9a5..f06429741 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -3,6 +3,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When the token expires, as a unix timestamp (ms resolution). token_expires_at_ms BIGINT NOT NULL ); ` const insertOpenIDTokenSQL = "" + - "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" + "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { db *sql.DB @@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) ( func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, + token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS) return } @@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes + var localpart string + var serverName gomatrixserverlib.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, + &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, ) + openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve token from the db") diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index b6130a1e3..867026d7a 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -23,36 +23,40 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account avatar_url TEXT ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name); ` const insertProfileSQL = "" + - "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { db *sql.DB @@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return err } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( + &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { return nil, err @@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ - Localpart: localpart, - AvatarURL: avatarURL, + Localpart: localpart, + ServerName: string(serverName), + AvatarURL: avatarURL, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName) return profile, true, err } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, + ServerName: string(serverName), DisplayName: displayName, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL) return profile, true, err } @@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index 4de0a9f06..c9d451dc5 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( id INTEGER PRIMARY KEY AUTOINCREMENT, -- The Matrix user ID localpart for this pusher localpart TEXT NOT NULL, + server_name TEXT NOT NULL, session_id BIGINT DEFAULT NULL, profile_tag TEXT, kind TEXT NOT NULL, @@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); -- For faster retrieving by localpart. -CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name); -- Pushkey must be unique for a given user and app. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name); ` const insertPusherSQL = "" + - "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + - "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + + "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12" const selectPushersSQL = "" + - "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" const deletePusherSQL = "" + - "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4" const deletePushersByAppIdAndPushKeySQL = "" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" @@ -95,18 +97,19 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) - logrus.Debugf("Created pusher %d", session_id) + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err } func (s *pushersStatements) SelectPushers( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} - rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) + rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName) if err != nil { return pushers, err @@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( - ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, + ctx context.Context, txn *sql.Tx, appid, pushkey, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err } diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index dd33dc0cf..85a1f7063 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -15,6 +15,8 @@ package sqlite3 import ( + "context" + "database/sql" "fmt" "time" @@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + }) if err = m.Up(base.Context()); err != nil { return nil, err } - accountDataTable, err := NewSQLiteAccountDataTable(db) - if err != nil { - return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) - } accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) } + accountDataTable, err := NewSQLiteAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) + } devicesTable, err := NewSQLiteDevicesTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) @@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) } + + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names populate", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNamesPopulate(ctx, txn, serverName) + }, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 73af139db..2db7d5887 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids ( medium TEXT NOT NULL DEFAULT 'email', -- The localpart of the Matrix user ID associated to this 3PID localpart TEXT NOT NULL, + server_name TEXT NOT NULL, PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart); +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" + "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2" const insertThreePIDSQL = "" + - "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)" const deleteThreePIDSQL = "" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" @@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { - return "", nil + return "", "", nil } return } func (s *threepidStatements) SelectThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { return } @@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( } func (s *threepidStatements) InsertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, txn *sql.Tx, threepid, medium, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) return err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 354f085fc..23aafff03 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() alice := test.NewUser(t) - localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) events := room.Events() contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID())) - err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom) + err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom) assert.NoError(t, err, "unable to save account data") contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID)) - err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal) + err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal) assert.NoError(t, err, "unable to save account data") - accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read") + accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read") assert.NoError(t, err, "unable to get account data by type") assert.Equal(t, contentRoom, accountData) - globalData, roomData, err := db.GetAccountData(ctx, localpart) + globalData, roomData, err := db.GetAccountData(ctx, localpart, domain) assert.NoError(t, err) assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"]) assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"]) @@ -81,78 +81,78 @@ func Test_Accounts(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) - accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") // verify the newly create account is the same as returned by CreateAccount var accGet *api.Account - accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing") + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing") assert.NoError(t, err, "failed to get account by password") assert.Equal(t, accAlice, accGet) - accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart) + accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to get account by localpart") assert.Equal(t, accAlice, accGet) // check account availability - available, err := db.CheckAccountAvailability(ctx, aliceLocalpart) + available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, false, available) - available, err = db.CheckAccountAvailability(ctx, "unusedname") + available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, true, available) // get guest account numeric aliceLocalpart - first, err := db.GetNewNumericLocalpart(ctx) + first, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err, "failed to get new numeric localpart") // Create a new account to verify the numeric localpart is updated - _, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest) assert.NoError(t, err, "failed to create account") - second, err := db.GetNewNumericLocalpart(ctx) + second, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err) assert.Greater(t, second, first) // update password for alice - err = db.SetPassword(ctx, aliceLocalpart, "newPassword") + err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to update password") - accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to get account by new password") assert.Equal(t, accAlice, accGet) // deactivate account - err = db.DeactivateAccount(ctx, aliceLocalpart) + err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to deactivate account") // This should fail now, as the account is deactivated - _, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + _, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.Error(t, err, "expected an error, got none") - _, err = db.GetAccountByLocalpart(ctx, "unusename") + _, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain) assert.Error(t, err, "expected an error for non existent localpart") // create an empty localpart; this should never happen, but is required to test getting a numeric localpart // if there's already a user without a localpart in the database - _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser) + _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser) assert.NoError(t, err) // test getting a numeric localpart, with an existing user without a localpart - _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest) assert.NoError(t, err) // Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type - _, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser) + _, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser) assert.NoError(t, err) // Now try to create a new guest user - _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest) assert.NoError(t, err) }) } func Test_Devices(t *testing.T) { alice := test.NewUser(t) - localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) deviceID := util.RandomString(8) accessToken := util.RandomString(16) @@ -161,10 +161,10 @@ func Test_Devices(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() - deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "") + deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") - gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID) + gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields @@ -174,14 +174,14 @@ func Test_Devices(t *testing.T) { // create a device without existing device ID accessToken = util.RandomString(16) - deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "") + deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") - gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID) + gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields // Get devices - devices, err := db.GetDevicesByLocalpart(ctx, localpart) + devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get devices by localpart") assert.Equal(t, 2, len(devices)) deviceIDs := make([]string, 0, len(devices)) @@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) { // Update device newName := "new display name" - err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName) + err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName) assert.NoError(t, err, "unable to update device displayname") updatedAfterTimestamp := time.Now().Unix() - err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web") + err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web") assert.NoError(t, err, "unable to update device last seen") deviceWithID.DisplayName = newName deviceWithID.LastSeenIP = "127.0.0.1" - gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID) + gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 2, len(devices)) assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName) @@ -213,20 +213,20 @@ func Test_Devices(t *testing.T) { // create one more device and remove the devices step by step newDeviceID := util.RandomString(16) accessToken = util.RandomString(16) - _, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "") + _, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create new device") - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 3, len(devices)) - err = db.RemoveDevices(ctx, localpart, deviceIDs) + err = db.RemoveDevices(ctx, localpart, domain, deviceIDs) assert.NoError(t, err, "unable to remove devices") - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 1, len(devices)) - deleted, err := db.RemoveAllDevices(ctx, localpart, "") + deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "") assert.NoError(t, err, "unable to remove all devices") assert.Equal(t, 1, len(deleted)) assert.Equal(t, newDeviceID, deleted[0].ID) @@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) { func Test_Profile(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -372,30 +372,33 @@ func Test_Profile(t *testing.T) { defer close() // create account, which also creates a profile - _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + _, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") - gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) + gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get profile by localpart") - wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} + wantProfile := &authtypes.Profile{ + Localpart: aliceLocalpart, + ServerName: string(aliceDomain), + } assert.Equal(t, wantProfile, gotProfile) // set avatar & displayname wantProfile.DisplayName = "Alice" - gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice") + gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice") assert.Equal(t, wantProfile, gotProfile) assert.NoError(t, err, "unable to set displayname") assert.True(t, changed) wantProfile.AvatarURL = "mxc://aliceAvatar" - gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.True(t, changed) // Setting the same avatar again doesn't change anything wantProfile.AvatarURL = "mxc://aliceAvatar" - gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.False(t, changed) @@ -410,7 +413,7 @@ func Test_Profile(t *testing.T) { func Test_Pusher(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -432,11 +435,11 @@ func Test_Pusher(t *testing.T) { ProfileTag: util.RandomString(8), Language: util.RandomString(2), } - err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart) + err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to upsert pusher") // check it was actually persisted - gotPushers, err = db.GetPushers(ctx, aliceLocalpart) + gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, i+1, len(gotPushers)) assert.Equal(t, wantPusher, gotPushers[i]) @@ -444,16 +447,16 @@ func Test_Pusher(t *testing.T) { } // remove single pusher - err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart) + err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to remove pusher") - gotPushers, err := db.GetPushers(ctx, aliceLocalpart) + gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 1, len(gotPushers)) // remove last pusher err = db.RemovePushers(ctx, appID, pushKeys[1]) assert.NoError(t, err, "unable to remove pusher") - gotPushers, err = db.GetPushers(ctx, aliceLocalpart) + gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 0, len(gotPushers)) }) @@ -461,7 +464,7 @@ func Test_Pusher(t *testing.T) { func Test_ThreePID(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -469,15 +472,16 @@ func Test_ThreePID(t *testing.T) { defer close() threePID := util.RandomString(8) medium := util.RandomString(8) - err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium) + err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium) assert.NoError(t, err, "unable to save threepid association") // get the stored threepid - gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium) + gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium) assert.NoError(t, err, "unable to get localpart for threepid") assert.Equal(t, aliceLocalpart, gotLocalpart) + assert.Equal(t, aliceDomain, gotDomain) - threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 1, len(threepids)) assert.Equal(t, authtypes.ThreePID{ @@ -490,7 +494,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err, "unexpected error") // verify it was deleted - threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 0, len(threepids)) }) @@ -498,7 +502,7 @@ func Test_ThreePID(t *testing.T) { func Test_Notification(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) @@ -526,34 +530,34 @@ func Test_Notification(t *testing.T) { RoomID: roomID, TS: gomatrixserverlib.AsTimestamp(ts), } - err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification) + err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") } // get notifications - count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications) + count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications) assert.NoError(t, err, "unable to get notification count") assert.Equal(t, int64(10), count) - notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications) + notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications) assert.NoError(t, err, "unable to get notifications") assert.Equal(t, int64(10), count) assert.Equal(t, 10, len(notifs)) // ... for a specific room - total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(4), total) // mark notification as read - affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true) + affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) // this should delete 2 notifications - affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8) + affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) - total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(2), total) @@ -562,7 +566,7 @@ func Test_Notification(t *testing.T) { assert.NoError(t, err) // this should now return 0 notifications - total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(0), total) }) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 5e1dd0971..e14776cf3 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -28,31 +28,31 @@ import ( ) type AccountDataTable interface { - InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error - SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) - SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) + InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error + SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) + SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) } type AccountsTable interface { - InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) - UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) - DeactivateAccount(ctx context.Context, localpart string) (err error) - SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) - SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) - SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) + InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error) } type DevicesTable interface { - InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) - DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error - DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error - DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error - UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) - SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) + SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error } type KeyBackupTable interface { @@ -79,40 +79,40 @@ type LoginTokenTable interface { } type OpenIDTable interface { - InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error) SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) } type ProfileTable interface { - InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error - SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error) + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error + SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } type ThreePIDTable interface { - SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) - SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) - InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } type PusherTable interface { - InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error - SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) - DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error } type NotificationTable interface { Clean(ctx context.Context, txn *sql.Tx) error - Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error - DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) - UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) - Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) - SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) - SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) + Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) } type StatsTable interface { diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index a547423bc..b088d15cd 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -79,6 +79,7 @@ func mustMakeAccountAndDevice( accDB tables.AccountsTable, devDB tables.DevicesTable, localpart string, + serverName gomatrixserverlib.ServerName, // nolint:unparam accType api.AccountType, userAgent string, ) { @@ -89,11 +90,11 @@ func mustMakeAccountAndDevice( appServiceID = util.RandomString(16) } - _, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType) + _, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType) if err != nil { t.Fatalf("unable to create account: %v", err) } - _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent) + _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent) if err != nil { t.Fatalf("unable to create device: %v", err) } @@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) { }) t.Run("Want Users", func(t *testing.T) { - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko") gotStats, _, err := statsDB.UserStatistics(ctx, nil) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 2a43c0bd4..23acfa6f3 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -80,14 +80,14 @@ func TestQueryProfile(t *testing.T) { // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) defer close() - _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) + _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } - if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { t.Fatalf("failed to set avatar url: %s", err) } - if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { t.Fatalf("failed to set display name: %s", err) } @@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) defer close() - _, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService) + _, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService) if err != nil { t.Fatalf("failed to make account: %s", err) } @@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) defer close() - _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) + _, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } diff --git a/userapi/util/devices.go b/userapi/util/devices.go index cbf3bd28f..c55fc7999 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -2,10 +2,12 @@ package util import ( "context" + "fmt" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -17,10 +19,10 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { - pushers, err := db.GetPushers(ctx, localpart) +func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { + pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { - return nil, err + return nil, fmt.Errorf("db.GetPushers: %w", err) } devices := make([]*PusherDevice, 0, len(pushers)) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index ff206bd3c..fc0ab39bf 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -8,6 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -16,8 +17,8 @@ import ( // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error { - pusherDevices, err := GetPushDevices(ctx, localpart, nil, db) +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error { + pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err } @@ -26,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc return nil } - userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications) + userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications) if err != nil { return err } From 1e79b0557e177986262f365f6601fcdf56480a07 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 14 Nov 2022 13:06:27 +0100 Subject: [PATCH 015/124] Use a writer to assign state key NIDs (#2877) --- roomserver/storage/shared/storage.go | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 08e912a00..9bf438a23 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -103,6 +103,7 @@ func (d *Database) eventStateKeyNIDs( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) + eventStateKeys = util.UniqueStrings(eventStateKeys) nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) if err != nil { return nil, err @@ -112,15 +113,23 @@ func (d *Database) eventStateKeyNIDs( } // We received some nids, but are still missing some, work out which and create them if len(eventStateKeys) > len(result) { - for _, eventStateKey := range eventStateKeys { - if _, ok := result[eventStateKey]; ok { - continue + var nid types.EventStateKeyNID + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + for _, eventStateKey := range eventStateKeys { + if _, ok := result[eventStateKey]; ok { + continue + } + + nid, err = d.assignStateKeyNID(ctx, txn, eventStateKey) + if err != nil { + return err + } + result[eventStateKey] = nid } - nid, err := d.assignStateKeyNID(ctx, txn, eventStateKey) - if err != nil { - return result, err - } - result[eventStateKey] = nid + return nil + }) + if err != nil { + return nil, err } } return result, nil From 858a4af2244986356576b1cf97572275a8bc001f Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 14 Nov 2022 13:06:41 +0100 Subject: [PATCH 016/124] Try to optimize CI (#2867) Try to optimize CI by using caches --- .github/workflows/dendrite.yml | 80 ++++++++------------- build/scripts/Complement.Dockerfile | 7 +- build/scripts/ComplementPostgres.Dockerfile | 7 +- 3 files changed, 38 insertions(+), 56 deletions(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index f96cbadd4..fa4282384 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -26,22 +26,14 @@ jobs: uses: actions/setup-go@v3 with: go-version: 1.18 - - - uses: actions/cache@v2 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go-wasm + cache: true - name: Install Node uses: actions/setup-node@v2 with: node-version: 14 - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: ~/.npm key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} @@ -109,19 +101,12 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} + cache: true - name: Set up gotestfmt uses: gotesttools/gotestfmt-action@v2 with: # Optional: pass GITHUB_TOKEN to avoid rate limiting. token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-test- - run: go test -json -v ./... 2>&1 | gotestfmt env: POSTGRES_HOST: localhost @@ -146,17 +131,17 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - - name: Install dependencies x86 - if: ${{ matrix.goarch == '386' }} - run: sudo apt update && sudo apt-get install -y gcc-multilib - uses: actions/cache@v3 with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}- + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + - name: Install dependencies x86 + if: ${{ matrix.goarch == '386' }} + run: sudo apt update && sudo apt-get install -y gcc-multilib - env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} @@ -180,16 +165,16 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - - name: Install dependencies - run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc - uses: actions/cache@v3 with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }} + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + - name: Install dependencies + run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc - env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} @@ -221,14 +206,7 @@ jobs: uses: actions/setup-go@v3 with: go-version: "1.18" - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go-upgrade + cache: true - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests - name: Test upgrade (PostgreSQL) @@ -248,14 +226,7 @@ jobs: uses: actions/setup-go@v3 with: go-version: "1.18" - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go-upgrade + cache: true - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests - name: Test upgrade (PostgreSQL) @@ -295,6 +266,8 @@ jobs: image: matrixdotorg/sytest-dendrite:latest volumes: - ${{ github.workspace }}:/src + - /root/.cache/go-build:/github/home/.cache/go-build + - /root/.cache/go-mod:/gopath/pkg/mod env: POSTGRES: ${{ matrix.postgres && 1}} API: ${{ matrix.api && 1 }} @@ -302,6 +275,14 @@ jobs: CGO_ENABLED: ${{ matrix.cgo && 1 }} steps: - uses: actions/checkout@v3 + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + /gopath/pkg/mod + key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go-sytest- - name: Run Sytest run: /bootstrap.sh dendrite working-directory: /src @@ -336,12 +317,14 @@ jobs: matrix: include: - label: SQLite native + cgo: 0 - label: SQLite Cgo cgo: 1 - label: SQLite native, full HTTP APIs api: full-http + cgo: 0 - label: SQLite Cgo, full HTTP APIs api: full-http @@ -349,10 +332,12 @@ jobs: - label: PostgreSQL postgres: Postgres + cgo: 0 - label: PostgreSQL, full HTTP APIs postgres: Postgres api: full-http + cgo: 0 steps: # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path @@ -360,14 +345,12 @@ jobs: run: | echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH echo "~/go/bin" >> $GITHUB_PATH - - name: "Install Complement Dependencies" # We don't need to install Go because it is included on the Ubuntu 20.04 image: # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 run: | sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest - - name: Run actions/checkout@v3 for dendrite uses: actions/checkout@v3 with: @@ -393,12 +376,10 @@ jobs: if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then continue fi - (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break done - # Build initial Dendrite image - - run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . + - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . working-directory: dendrite env: DOCKER_BUILDKIT: 1 @@ -410,9 +391,8 @@ jobs: shell: bash name: Run Complement Tests env: - COMPLEMENT_BASE_IMAGE: complement-dendrite:latest + COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} API: ${{ matrix.api && 1 }} - CGO_ENABLED: ${{ matrix.cgo && 1 }} working-directory: complement integration-tests-done: diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 14b28498b..79422e645 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -10,12 +10,13 @@ RUN mkdir /dendrite # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. +ARG CGO RUN --mount=target=. \ --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - go build -o /dendrite ./cmd/generate-config && \ - go build -o /dendrite ./cmd/generate-keys && \ - go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 785090b0b..3faf43cc7 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -28,12 +28,13 @@ RUN mkdir /dendrite # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. +ARG CGO RUN --mount=target=. \ --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - go build -o /dendrite ./cmd/generate-config && \ - go build -o /dendrite ./cmd/generate-keys && \ - go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem From 2a77a910eb16333e56b85a7fdb3d6254acc9b6d7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 14 Nov 2022 13:07:13 +0100 Subject: [PATCH 017/124] Handle remote room upgrades (#2866) Makes the following tests pass ``` /upgrade moves remote aliases to the new room Local and remote users' homeservers remove a room from their public directory on upgrade ``` --- roomserver/internal/input/input_events.go | 16 ++++++++++++ roomserver/storage/interface.go | 1 + roomserver/storage/shared/storage.go | 30 +++++++++++++++++++++++ sytest-whitelist | 4 ++- 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 60160e8e5..682aa2b12 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -23,6 +23,8 @@ import ( "fmt" "time" + "github.com/tidwall/gjson" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" @@ -409,6 +411,13 @@ func (r *Inputer) processRoomEvent( } } + // Handle remote room upgrades, e.g. remove published room + if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) { + if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil { + return fmt.Errorf("failed to handle remote room upgrade: %w", err) + } + } + // processing this event resulted in an event (which may not be the one we're processing) // being redacted. We are guaranteed to have both sides (the redaction/redacted event), // so notify downstream components to redact this event - they should have it if they've @@ -434,6 +443,13 @@ func (r *Inputer) processRoomEvent( return nil } +// handleRemoteRoomUpgrade updates published rooms and room aliases +func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixserverlib.Event) error { + oldRoomID := event.RoomID() + newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) +} + // processStateBefore works out what the state is before the event and // then checks the event auths against the state at the time. It also // tries to determine what the history visibility was of the event at diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index c39a8cbba..06db4b2d8 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -172,4 +172,5 @@ type Database interface { ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 9bf438a23..16898bcb1 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1408,6 +1408,36 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { + + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // un-publish old room + if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil { + return fmt.Errorf("failed to unpublish room: %w", err) + } + // publish new room + if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil { + return fmt.Errorf("failed to publish room: %w", err) + } + + // Migrate any existing room aliases + aliases, err := d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, txn, oldRoomID) + if err != nil { + return fmt.Errorf("failed to get room aliases: %w", err) + } + + for _, alias := range aliases { + if err = d.RoomAliasesTable.DeleteRoomAlias(ctx, txn, alias); err != nil { + return fmt.Errorf("failed to remove room alias: %w", err) + } + if err = d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, newRoomID, eventSender); err != nil { + return fmt.Errorf("failed to set room alias: %w", err) + } + } + return nil + }) +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/sytest-whitelist b/sytest-whitelist index bb4f0a279..49ffb8fe8 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -761,4 +761,6 @@ AS can publish rooms in their own list AS and main public room lists are separate /upgrade preserves direct room state local user has tags copied to the new room -remote user has tags copied to the new room \ No newline at end of file +remote user has tags copied to the new room +/upgrade moves remote aliases to the new room +Local and remote users' homeservers remove a room from their public directory on upgrade \ No newline at end of file From f4ee3977340c84d321767d347795b1dcd05ac459 Mon Sep 17 00:00:00 2001 From: Omar Kotb Date: Mon, 14 Nov 2022 19:15:39 +0200 Subject: [PATCH 018/124] Fix Caddy config well-known delegation example (#2879) Signed-off-by: Omar Kotb Signed-off-by: Omar Kotb --- docs/installation/2_domainname.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation/2_domainname.md b/docs/installation/2_domainname.md index e7b3495f7..545a2daf6 100644 --- a/docs/installation/2_domainname.md +++ b/docs/installation/2_domainname.md @@ -90,7 +90,7 @@ For example, this can be done with the following Caddy config: handle /.well-known/matrix/server { header Content-Type application/json header Access-Control-Allow-Origin * - respond `"m.server": "matrix.example.com:8448"` + respond `{"m.server": "matrix.example.com:8448"}` } handle /.well-known/matrix/client { From 6650712a1c0dec282b47b7ba14bc8c2e06a385d8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Nov 2022 15:05:23 +0000 Subject: [PATCH 019/124] Federation fixes for virtual hosting --- clientapi/routing/admin.go | 2 +- clientapi/routing/createroom.go | 2 +- clientapi/routing/directory.go | 2 +- clientapi/routing/directory_public.go | 2 +- clientapi/routing/membership.go | 8 +- clientapi/routing/profile.go | 14 +++- clientapi/routing/redaction.go | 9 +- clientapi/routing/register.go | 39 ++++++--- clientapi/routing/sendevent.go | 9 +- clientapi/routing/server_notices.go | 1 + clientapi/routing/userdirectory.go | 2 +- clientapi/threepid/invites.go | 8 +- cmd/dendrite-demo-pinecone/conn/client.go | 4 +- cmd/dendrite-demo-pinecone/rooms/rooms.go | 8 +- cmd/dendrite-demo-yggdrasil/yggconn/client.go | 3 +- .../yggrooms/yggrooms.go | 9 +- cmd/furl/main.go | 12 ++- federationapi/api/api.go | 54 ++++++------ federationapi/federationapi.go | 12 +-- federationapi/federationapi_keys_test.go | 2 +- federationapi/federationapi_test.go | 8 +- federationapi/internal/federationclient.go | 44 +++++----- federationapi/internal/perform.go | 38 +++++++-- federationapi/inthttp/client.go | 52 ++++++++---- federationapi/inthttp/server.go | 22 ++--- federationapi/queue/destinationqueue.go | 2 +- federationapi/queue/queue.go | 18 ++-- federationapi/queue/queue_test.go | 4 +- federationapi/routing/backfill.go | 5 +- federationapi/routing/invite.go | 17 +++- federationapi/routing/join.go | 12 ++- federationapi/routing/keys.go | 67 +++++++++------ federationapi/routing/leave.go | 13 ++- federationapi/routing/profile.go | 10 +-- federationapi/routing/query.go | 2 +- federationapi/routing/send.go | 9 +- federationapi/routing/send_test.go | 8 +- federationapi/routing/threepid.go | 27 +++++- go.mod | 2 +- go.sum | 4 +- internal/eventutil/events.go | 15 ++-- keyserver/consumers/devicelistupdate.go | 26 +++--- keyserver/consumers/signingkeyupdate.go | 28 ++++--- keyserver/internal/device_list_update.go | 6 +- keyserver/internal/device_list_update_test.go | 14 +++- keyserver/internal/internal.go | 4 +- keyserver/keyserver.go | 2 +- roomserver/api/input.go | 5 +- roomserver/api/perform.go | 2 + roomserver/api/wrapper.go | 11 ++- roomserver/internal/alias.go | 19 ++++- roomserver/internal/api.go | 29 ++++--- roomserver/internal/input/input.go | 7 +- roomserver/internal/input/input_events.go | 27 +++--- roomserver/internal/input/input_missing.go | 17 ++-- roomserver/internal/perform/perform_admin.go | 25 +++++- .../internal/perform/perform_backfill.go | 50 ++++++----- roomserver/internal/perform/perform_join.go | 24 ++++-- roomserver/internal/perform/perform_leave.go | 14 ++-- .../internal/perform/perform_upgrade.go | 21 ++++- roomserver/internal/query/query.go | 10 +-- roomserver/roomserver_test.go | 2 +- setup/base/base.go | 7 +- setup/config/config.go | 15 ++++ setup/config/config_global.go | 82 ++++++++++++++++++- setup/mscs/msc2836/msc2836.go | 6 +- setup/mscs/msc2946/msc2946.go | 2 +- syncapi/consumers/keychange.go | 35 ++++---- syncapi/consumers/receipts.go | 30 ++++--- syncapi/consumers/roomserver.go | 10 +-- syncapi/consumers/sendtodevice.go | 40 ++++----- syncapi/routing/messages.go | 1 + syncapi/syncapi_test.go | 4 +- 73 files changed, 736 insertions(+), 420 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 9ed1f0ca2..be8073c33 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -110,7 +110,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.MissingArgument("Expecting user localpart."), } } - if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil { + if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil { localpart, serverName = l, s } request := struct { diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index eefe8e24b..a0d80903d 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -477,7 +477,7 @@ func createRoom( SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } - if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { + if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index ce14745aa..b3c5aae45 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -77,7 +77,7 @@ func DirectoryRoom( // If we don't know it locally, do a federation query. // But don't send the query to ourselves. if !cfg.Matrix.IsLocalServerName(domain) { - fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) + fedRes, fedErr := federation.LookupRoomAlias(req.Context(), cfg.Matrix.ServerName, domain, roomAlias) if fedErr != nil { // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index b1043e994..606744767 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -74,7 +74,7 @@ func GetPostPublicRooms( serverName := gomatrixserverlib.ServerName(request.Server) if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( - req.Context(), serverName, + req.Context(), cfg.Matrix.ServerName, serverName, int(request.Limit), request.Since, request.Filter.SearchTerms, false, "", diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 94ba17a02..482c1f5f7 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -110,6 +110,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic ctx, rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, + device.UserDomain(), serverName, serverName, nil, @@ -322,7 +323,12 @@ func buildMembershipEvent( return nil, err } - return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return nil, err + } + + return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil) } // loadProfile lookups the profile of a given user from the database and returns diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 4d9e1f8a5..92a75fc78 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -284,7 +284,7 @@ func updateProfile( } events, err := buildMembershipEvents( - ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, + ctx, device, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) switch e := err.(type) { case nil: @@ -298,7 +298,7 @@ func updateProfile( return jsonerror.InternalServerError(), e } - if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError(), err } @@ -321,7 +321,7 @@ func getProfile( } if !cfg.Matrix.IsLocalServerName(domain) { - profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") + profile, fedErr := federation.LookupProfile(ctx, cfg.Matrix.ServerName, domain, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { @@ -349,6 +349,7 @@ func getProfile( func buildMembershipEvents( ctx context.Context, + device *userapi.Device, roomIDs []string, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, evTime time.Time, rsAPI api.ClientRoomserverAPI, @@ -380,7 +381,12 @@ func buildMembershipEvents( return nil, err } - event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return nil, err + } + + event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 778a02fd4..7841b3b07 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -123,8 +123,13 @@ func SendRedaction( return jsonerror.InternalServerError() } + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return jsonerror.InternalServerError() + } + var queryRes roomserverAPI.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, @@ -132,7 +137,7 @@ func SendRedaction( } } domain := device.UserDomain() - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 9dc63af84..a92513b8b 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -211,9 +211,10 @@ var ( // previous parameters with the ones supplied. This mean you cannot "build up" request params. type registerRequest struct { // registration parameters - Password string `json:"password"` - Username string `json:"username"` - Admin bool `json:"admin"` + Password string `json:"password"` + Username string `json:"username"` + ServerName gomatrixserverlib.ServerName `json:"-"` + Admin bool `json:"admin"` // user-interactive auth params Auth authDict `json:"auth"` @@ -570,11 +571,14 @@ func Register( JSON: response, } } - } if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } + r.ServerName = cfg.Matrix.ServerName + if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { + r.Username, r.ServerName = l, d + } if req.URL.Query().Get("kind") == "guest" { return handleGuestRegistration(req, r, cfg, userAPI) } @@ -589,7 +593,7 @@ func Register( // Auto generate a numeric username if r.Username is empty if r.Username == "" { nreq := &userapi.QueryNumericLocalpartRequest{ - ServerName: cfg.Matrix.ServerName, // TODO: might not be right + ServerName: r.ServerName, } nres := &userapi.QueryNumericLocalpartResponse{} if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { @@ -609,7 +613,7 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { + if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { return *resErr } case accessTokenErr == nil: @@ -622,7 +626,7 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { + if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { return *resErr } } @@ -1023,13 +1027,27 @@ func RegisterAvailable( // Squash username to all lowercase letters username = strings.ToLower(username) + domain := cfg.Matrix.ServerName + if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil { + username, domain = u, l + } + for _, v := range cfg.Matrix.VirtualHosts { + if v.ServerName == domain && !v.AllowRegistration { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden( + fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)), + ), + } + } + } - if err := validateUsername(username, cfg.Matrix.ServerName); err != nil { + if err := validateUsername(username, domain); err != nil { return *err } // Check if this username is reserved by an application service - userID := userutil.MakeUserID(username, cfg.Matrix.ServerName) + userID := userutil.MakeUserID(username, domain) for _, appservice := range cfg.Derived.ApplicationServices { if appservice.OwnsNamespaceCoveringUserId(userID) { return util.JSONResponse{ @@ -1041,7 +1059,8 @@ func RegisterAvailable( res := &userapi.QueryAccountAvailabilityResponse{} err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ - Localpart: username, + Localpart: username, + ServerName: domain, }, res) if err != nil { return util.JSONResponse{ diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index bb66cf6fc..90af9ac4d 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -186,6 +186,7 @@ func SendEvent( []*gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, + device.UserDomain(), domain, domain, txnAndSessionID, @@ -275,8 +276,14 @@ func generateSendEvent( return nil, &resErr } + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index a7acee326..fb93d8783 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -231,6 +231,7 @@ func SendServerNotice( []*gomatrixserverlib.HeaderedEvent{ e.Headered(roomVersion), }, + device.UserDomain(), cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName, txnAndSessionID, diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index d3d1c22e4..62af9efa4 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -106,7 +106,7 @@ knownUsersLoop: continue } // TODO: We should probably cache/store this - fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "") + fedProfile, fedErr := federation.LookupProfile(ctx, localServerName, serverName, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 99fb8171d..1f294a032 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -359,8 +359,13 @@ func emit3PIDInviteEvent( return err } + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return err + } + queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes) if err != nil { return err } @@ -371,6 +376,7 @@ func emit3PIDInviteEvent( []*gomatrixserverlib.HeaderedEvent{ event.Headered(queryRes.RoomVersion), }, + device.UserDomain(), cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, diff --git a/cmd/dendrite-demo-pinecone/conn/client.go b/cmd/dendrite-demo-pinecone/conn/client.go index 27e18c2a3..a91434f62 100644 --- a/cmd/dendrite-demo-pinecone/conn/client.go +++ b/cmd/dendrite-demo-pinecone/conn/client.go @@ -101,9 +101,7 @@ func CreateFederationClient( base *base.BaseDendrite, s *pineconeSessions.Sessions, ) *gomatrixserverlib.FederationClient { return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.ServerName, - base.Cfg.Global.KeyID, - base.Cfg.Global.PrivateKey, + base.Cfg.Global.SigningIdentities(), gomatrixserverlib.WithTransport(createTransport(s)), ) } diff --git a/cmd/dendrite-demo-pinecone/rooms/rooms.go b/cmd/dendrite-demo-pinecone/rooms/rooms.go index 0fafbedc3..0ac705cc1 100644 --- a/cmd/dendrite-demo-pinecone/rooms/rooms.go +++ b/cmd/dendrite-demo-pinecone/rooms/rooms.go @@ -58,13 +58,17 @@ func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { for _, k := range p.r.Peers() { list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} } - return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, list) + return bulkFetchPublicRoomsFromServers( + context.Background(), p.fedClient, + gomatrixserverlib.ServerName(p.r.PublicKey().String()), list, + ) } // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( ctx context.Context, fedClient *gomatrixserverlib.FederationClient, + origin gomatrixserverlib.ServerName, homeservers map[gomatrixserverlib.ServerName]struct{}, ) (publicRooms []gomatrixserverlib.PublicRoom) { limit := 200 @@ -82,7 +86,7 @@ func bulkFetchPublicRoomsFromServers( go func(homeserverDomain gomatrixserverlib.ServerName) { defer wg.Done() util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") - fres, err := fedClient.GetPublicRooms(reqctx, homeserverDomain, int(limit), "", false, "") + fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "") if err != nil { util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( "bulkFetchPublicRoomsFromServers: failed to query hs", diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go index 358d3725e..41a9ec123 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/client.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -55,8 +55,7 @@ func (n *Node) CreateFederationClient( }, ) return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, - base.Cfg.Global.PrivateKey, + base.Cfg.Global.SigningIdentities(), gomatrixserverlib.WithTransport(tr), ) } diff --git a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go index 402b86ed3..0de64755e 100644 --- a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go +++ b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go @@ -43,13 +43,18 @@ func NewYggdrasilRoomProvider( } func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { - return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, p.node.KnownNodes()) + return bulkFetchPublicRoomsFromServers( + context.Background(), p.fedClient, + gomatrixserverlib.ServerName(p.node.DerivedServerName()), + p.node.KnownNodes(), + ) } // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( ctx context.Context, fedClient *gomatrixserverlib.FederationClient, + origin gomatrixserverlib.ServerName, homeservers []gomatrixserverlib.ServerName, ) (publicRooms []gomatrixserverlib.PublicRoom) { limit := 200 @@ -66,7 +71,7 @@ func bulkFetchPublicRoomsFromServers( go func(homeserverDomain gomatrixserverlib.ServerName) { defer wg.Done() util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") - fres, err := fedClient.GetPublicRooms(ctx, homeserverDomain, int(limit), "", false, "") + fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "") if err != nil { util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn( "bulkFetchPublicRoomsFromServers: failed to query hs", diff --git a/cmd/furl/main.go b/cmd/furl/main.go index f59f9c8ce..b208ba868 100644 --- a/cmd/furl/main.go +++ b/cmd/furl/main.go @@ -48,10 +48,15 @@ func main() { panic("unexpected key block") } + serverName := gomatrixserverlib.ServerName(*requestFrom) client := gomatrixserverlib.NewFederationClient( - gomatrixserverlib.ServerName(*requestFrom), - gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), - privateKey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: serverName, + KeyID: gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), + PrivateKey: privateKey, + }, + }, ) u, err := url.Parse(flag.Arg(0)) @@ -79,6 +84,7 @@ func main() { req := gomatrixserverlib.NewFederationRequest( method, + serverName, gomatrixserverlib.ServerName(u.Host), u.RequestURI(), ) diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 362333fc9..e34c9e8b6 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -21,8 +21,8 @@ type FederationInternalAPI interface { QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) - MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. PerformBroadcastEDU( @@ -60,18 +60,18 @@ type RoomserverFederationAPI interface { // containing only the server names (without information for membership events). // The response will include this server if they are joined to the room. QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error - GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError type KeyserverFederationAPI interface { - GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) - ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) - QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) + GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) + ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) + QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) } // an interface for gmsl.FederationClient - contains functions called by federationapi only. @@ -80,28 +80,28 @@ type FederationClient interface { SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) // Perform operations - LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) - Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) - MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) - SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) - MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) - SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) - SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) + LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) + Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) + MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) + SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) + MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) + SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) + SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) - ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) - QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) - Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) - MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) + ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) + QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) + Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) + MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) - ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) - LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) - LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) + LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) + LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 4578e33aa..854251220 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -120,17 +120,7 @@ func NewInternalAPI( js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{} - for _, serverName := range append( - []gomatrixserverlib.ServerName{base.Cfg.Global.ServerName}, - base.Cfg.Global.SecondaryServerNames..., - ) { - signingInfo[serverName] = &queue.SigningInfo{ - KeyID: cfg.Matrix.KeyID, - PrivateKey: cfg.Matrix.PrivateKey, - ServerName: serverName, - } - } + signingInfo := base.Cfg.Global.SigningIdentities() queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 3acaa70dc..cc03cdece 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -104,7 +104,7 @@ func TestMain(m *testing.M) { // Create the federation client. s.fedclient = gomatrixserverlib.NewFederationClient( - s.config.Matrix.ServerName, serverKeyID, testPriv, + s.config.Matrix.SigningIdentities(), gomatrixserverlib.WithTransport(transport), ) diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index c37bc87c2..68a06a033 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -103,7 +103,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv return keys, nil } -func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { +func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { for _, r := range f.allowJoins { if r.ID == roomID { res.RoomVersion = r.Version @@ -127,7 +127,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName } return } -func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { +func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { @@ -283,7 +283,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) fedCli := gomatrixserverlib.NewFederationClient( - serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, + cfg.Global.SigningIdentities(), gomatrixserverlib.WithSkipVerify(true), ) @@ -326,7 +326,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { t.Errorf("failed to create invite v2 request: %s", err) continue } - _, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) + _, err = fedCli.SendInviteV2(context.Background(), cfg.Global.ServerName, serverName, invReq) if err == nil { t.Errorf("expected an error, got none") continue diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index 2636b7fa0..db6348ec1 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -11,13 +11,13 @@ import ( // client. func (a *FederationInternalAPI) GetEventAuth( - ctx context.Context, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (res gomatrixserverlib.RespEventAuth, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetEventAuth(ctx, s, roomVersion, roomID, eventID) + return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) }) if err != nil { return gomatrixserverlib.RespEventAuth{}, err @@ -26,12 +26,12 @@ func (a *FederationInternalAPI) GetEventAuth( } func (a *FederationInternalAPI) GetUserDevices( - ctx context.Context, s gomatrixserverlib.ServerName, userID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetUserDevices(ctx, s, userID) + return a.federation.GetUserDevices(ctx, origin, s, userID) }) if err != nil { return gomatrixserverlib.RespUserDevices{}, err @@ -40,12 +40,12 @@ func (a *FederationInternalAPI) GetUserDevices( } func (a *FederationInternalAPI) ClaimKeys( - ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.ClaimKeys(ctx, s, oneTimeKeys) + return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) }) if err != nil { return gomatrixserverlib.RespClaimKeys{}, err @@ -54,10 +54,10 @@ func (a *FederationInternalAPI) ClaimKeys( } func (a *FederationInternalAPI) QueryKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { - return a.federation.QueryKeys(ctx, s, keys) + return a.federation.QueryKeys(ctx, origin, s, keys) }) if err != nil { return gomatrixserverlib.RespQueryKeys{}, err @@ -66,12 +66,12 @@ func (a *FederationInternalAPI) QueryKeys( } func (a *FederationInternalAPI) Backfill( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) + return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) }) if err != nil { return gomatrixserverlib.Transaction{}, err @@ -80,12 +80,12 @@ func (a *FederationInternalAPI) Backfill( } func (a *FederationInternalAPI) LookupState( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespState, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) + return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) }) if err != nil { return gomatrixserverlib.RespState{}, err @@ -94,12 +94,12 @@ func (a *FederationInternalAPI) LookupState( } func (a *FederationInternalAPI) LookupStateIDs( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, ) (res gomatrixserverlib.RespStateIDs, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupStateIDs(ctx, s, roomID, eventID) + return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) }) if err != nil { return gomatrixserverlib.RespStateIDs{}, err @@ -108,13 +108,13 @@ func (a *FederationInternalAPI) LookupStateIDs( } func (a *FederationInternalAPI) LookupMissingEvents( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) + return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) }) if err != nil { return gomatrixserverlib.RespMissingEvents{}, err @@ -123,12 +123,12 @@ func (a *FederationInternalAPI) LookupMissingEvents( } func (a *FederationInternalAPI) GetEvent( - ctx context.Context, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetEvent(ctx, s, eventID) + return a.federation.GetEvent(ctx, origin, s, eventID) }) if err != nil { return gomatrixserverlib.Transaction{}, err @@ -151,13 +151,13 @@ func (a *FederationInternalAPI) LookupServerKeys( } func (a *FederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) + return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion) }) if err != nil { return res, err @@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) + return a.federation.MSC2946Spaces(ctx, origin, s, roomID, suggestedOnly) }) if err != nil { return res, err diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 1b61ec711..603e2a9c9 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -26,6 +26,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( ) (err error) { dir, err := r.federation.LookupRoomAlias( ctx, + r.cfg.Matrix.ServerName, request.ServerName, request.RoomAlias, ) @@ -143,10 +144,16 @@ func (r *FederationInternalAPI) performJoinUsingServer( supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) + if err != nil { + return err + } + // Try to perform a make_join using the information supplied in the // request. respMakeJoin, err := r.federation.MakeJoin( ctx, + origin, serverName, roomID, userID, @@ -192,7 +199,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // Build the join event. event, err := respMakeJoin.JoinEvent.Build( time.Now(), - r.cfg.Matrix.ServerName, + origin, r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, respMakeJoin.RoomVersion, @@ -204,6 +211,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // Try to perform a send_join using the newly built event. respSendJoin, err := r.federation.SendJoin( context.Background(), + origin, serverName, event, ) @@ -246,7 +254,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( respMakeJoin.RoomVersion, r.keyRing, event, - federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), + federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName), ) if err != nil { return fmt.Errorf("respSendJoin.Check: %w", err) @@ -281,6 +289,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( if err = roomserverAPI.SendEventWithState( context.Background(), r.rsAPI, + origin, roomserverAPI.KindNew, respState, event.Headered(respMakeJoin.RoomVersion), @@ -427,6 +436,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // request. respPeek, err := r.federation.Peek( ctx, + r.cfg.Matrix.ServerName, serverName, roomID, peekID, @@ -453,7 +463,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) + authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName)) if err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) } @@ -475,7 +485,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // logrus.Warnf("got respPeek %#v", respPeek) // Send the newly returned state to the roomserver to update our local view. if err = roomserverAPI.SendEventWithState( - ctx, r.rsAPI, + ctx, r.rsAPI, r.cfg.Matrix.ServerName, roomserverAPI.KindNew, &respState, respPeek.LatestEvent.Headered(respPeek.RoomVersion), @@ -495,6 +505,11 @@ func (r *FederationInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) (err error) { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID) + if err != nil { + return err + } + // Deduplicate the server names we were provided. util.SortAndUnique(request.ServerNames) @@ -505,6 +520,7 @@ func (r *FederationInternalAPI) PerformLeave( // request. respMakeLeave, err := r.federation.MakeLeave( ctx, + origin, serverName, request.RoomID, request.UserID, @@ -546,7 +562,7 @@ func (r *FederationInternalAPI) PerformLeave( // Build the leave event. event, err := respMakeLeave.LeaveEvent.Build( time.Now(), - r.cfg.Matrix.ServerName, + origin, r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, respMakeLeave.RoomVersion, @@ -559,6 +575,7 @@ func (r *FederationInternalAPI) PerformLeave( // Try to perform a send_leave using the newly built event. err = r.federation.SendLeave( ctx, + origin, serverName, event, ) @@ -585,6 +602,11 @@ func (r *FederationInternalAPI) PerformInvite( request *api.PerformInviteRequest, response *api.PerformInviteResponse, ) (err error) { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender()) + if err != nil { + return err + } + if request.Event.StateKey() == nil { return errors.New("invite must be a state event") } @@ -607,7 +629,7 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) if err != nil { return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } @@ -708,7 +730,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder // FederatedAuthProvider is an auth chain provider which fetches events from the server provided func federatedAuthProvider( ctx context.Context, federation api.FederationClient, - keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, + keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName, ) gomatrixserverlib.AuthChainProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. @@ -738,7 +760,7 @@ func federatedAuthProvider( // Try to retrieve the event from the server that sent us the send // join response. - tx, txerr := federation.GetEvent(ctx, server, eventID) + tx, txerr := federation.GetEvent(ctx, origin, server, eventID) if txerr != nil { return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr) } diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 812d3c6da..6c37a1f5a 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -152,16 +152,18 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU( type getUserDevices struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName UserID string } func (h *httpFederationInternalAPI) GetUserDevices( - ctx context.Context, s gomatrixserverlib.ServerName, userID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, ctx, &getUserDevices{ S: s, + Origin: origin, UserID: userID, }, ) @@ -169,52 +171,58 @@ func (h *httpFederationInternalAPI) GetUserDevices( type claimKeys struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName OneTimeKeys map[string]map[string]string } func (h *httpFederationInternalAPI) ClaimKeys( - ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, ctx, &claimKeys{ S: s, + Origin: origin, OneTimeKeys: oneTimeKeys, }, ) } type queryKeys struct { - S gomatrixserverlib.ServerName - Keys map[string][]string + S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName + Keys map[string][]string } func (h *httpFederationInternalAPI) QueryKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, ctx, &queryKeys{ - S: s, - Keys: keys, + S: s, + Origin: origin, + Keys: keys, }, ) } type backfill struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string Limit int EventIDs []string } func (h *httpFederationInternalAPI) Backfill( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (gomatrixserverlib.Transaction, error) { return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, ctx, &backfill{ S: s, + Origin: origin, RoomID: roomID, Limit: limit, EventIDs: eventIDs, @@ -224,18 +232,20 @@ func (h *httpFederationInternalAPI) Backfill( type lookupState struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string EventID string RoomVersion gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) LookupState( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (gomatrixserverlib.RespState, error) { return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, ctx, &lookupState{ S: s, + Origin: origin, RoomID: roomID, EventID: eventID, RoomVersion: roomVersion, @@ -245,17 +255,19 @@ func (h *httpFederationInternalAPI) LookupState( type lookupStateIDs struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string EventID string } func (h *httpFederationInternalAPI) LookupStateIDs( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, ) (gomatrixserverlib.RespStateIDs, error) { return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, ctx, &lookupStateIDs{ S: s, + Origin: origin, RoomID: roomID, EventID: eventID, }, @@ -264,19 +276,21 @@ func (h *httpFederationInternalAPI) LookupStateIDs( type lookupMissingEvents struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string Missing gomatrixserverlib.MissingEvents RoomVersion gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) LookupMissingEvents( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, ctx, &lookupMissingEvents{ S: s, + Origin: origin, RoomID: roomID, Missing: missing, RoomVersion: roomVersion, @@ -286,16 +300,18 @@ func (h *httpFederationInternalAPI) LookupMissingEvents( type getEvent struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName EventID string } func (h *httpFederationInternalAPI) GetEvent( - ctx context.Context, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, ) (gomatrixserverlib.Transaction, error) { return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, ctx, &getEvent{ S: s, + Origin: origin, EventID: eventID, }, ) @@ -303,19 +319,21 @@ func (h *httpFederationInternalAPI) GetEvent( type getEventAuth struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomVersion gomatrixserverlib.RoomVersion RoomID string EventID string } func (h *httpFederationInternalAPI) GetEventAuth( - ctx context.Context, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (gomatrixserverlib.RespEventAuth, error) { return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, ctx, &getEventAuth{ S: s, + Origin: origin, RoomVersion: roomVersion, RoomID: roomID, EventID: eventID, @@ -351,18 +369,20 @@ func (h *httpFederationInternalAPI) LookupServerKeys( type eventRelationships struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName Req gomatrixserverlib.MSC2836EventRelationshipsRequest RoomVer gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, ctx, &eventRelationships{ S: s, + Origin: origin, Req: r, RoomVer: roomVersion, }, @@ -371,17 +391,19 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( type spacesReq struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName SuggestedOnly bool RoomID string } func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, ctx, &spacesReq{ S: dst, + Origin: origin, SuggestedOnly: suggestedOnly, RoomID: roomID, }, diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 7aa0e4801..7b3edb2a7 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -59,7 +59,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetUserDevices", func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { - res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) + res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID) return &res, federationClientError(err) }, ), @@ -70,7 +70,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIClaimKeys", func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { - res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) + res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys) return &res, federationClientError(err) }, ), @@ -81,7 +81,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIQueryKeys", func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { - res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) + res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys) return &res, federationClientError(err) }, ), @@ -92,7 +92,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIBackfill", func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) + res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs) return &res, federationClientError(err) }, ), @@ -103,7 +103,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupState", func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { - res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) + res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion) return &res, federationClientError(err) }, ), @@ -114,7 +114,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupStateIDs", func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { - res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) + res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID) return &res, federationClientError(err) }, ), @@ -125,7 +125,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupMissingEvents", func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { - res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) + res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion) return &res, federationClientError(err) }, ), @@ -136,7 +136,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetEvent", func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.GetEvent(ctx, req.S, req.EventID) + res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID) return &res, federationClientError(err) }, ), @@ -147,7 +147,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetEventAuth", func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { - res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) + res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID) return &res, federationClientError(err) }, ), @@ -174,7 +174,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIMSC2836EventRelationships", func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { - res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) + res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer) return &res, federationClientError(err) }, ), @@ -185,7 +185,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIMSC2946SpacesSummary", func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { - res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) + res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly) return &res, federationClientError(err) }, ), diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index bf04ee99a..63fc59c89 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -50,7 +50,7 @@ type destinationQueue struct { queues *OutgoingQueues db storage.Database process *process.ProcessContext - signing map[gomatrixserverlib.ServerName]*SigningInfo + signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity rsAPI api.FederationRoomserverAPI client fedapi.FederationClient // federation client origin gomatrixserverlib.ServerName // origin of requests diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 68f354993..31124e0c1 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -15,7 +15,6 @@ package queue import ( - "crypto/ed25519" "encoding/json" "fmt" "sync" @@ -46,7 +45,7 @@ type OutgoingQueues struct { origin gomatrixserverlib.ServerName client fedapi.FederationClient statistics *statistics.Statistics - signing map[gomatrixserverlib.ServerName]*SigningInfo + signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity queuesMutex sync.Mutex // protects the below queues map[gomatrixserverlib.ServerName]*destinationQueue } @@ -91,7 +90,7 @@ func NewOutgoingQueues( client fedapi.FederationClient, rsAPI api.FederationRoomserverAPI, statistics *statistics.Statistics, - signing map[gomatrixserverlib.ServerName]*SigningInfo, + signing []*gomatrixserverlib.SigningIdentity, ) *OutgoingQueues { queues := &OutgoingQueues{ disabled: disabled, @@ -101,9 +100,12 @@ func NewOutgoingQueues( origin: origin, client: client, statistics: statistics, - signing: signing, + signing: map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity{}, queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, } + for _, identity := range signing { + queues.signing[identity.ServerName] = identity + } // Look up which servers we have pending items for and then rehydrate those queues. if !disabled { serverNames := map[gomatrixserverlib.ServerName]struct{}{} @@ -135,14 +137,6 @@ func NewOutgoingQueues( return queues } -// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables -// around together -type SigningInfo struct { - ServerName gomatrixserverlib.ServerName - KeyID gomatrixserverlib.KeyID - PrivateKey ed25519.PrivateKey -} - type queuedPDU struct { receipt *shared.Receipt pdu *gomatrixserverlib.HeaderedEvent diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 58745c607..b2ec4b836 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -350,8 +350,8 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T } rs := &stubFederationRoomServerAPI{} stats := statistics.NewStatistics(db, failuresUntilBlacklist) - signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{ - "localhost": { + signingInfo := []*gomatrixserverlib.SigningIdentity{ + { KeyID: "ed21019:auto", PrivateKey: test.PrivateKeyA, ServerName: "localhost", diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 7b9ca66f6..272f5e9d8 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -82,7 +82,8 @@ func Backfill( BackwardsExtremities: map[string][]string{ "": eIDs, }, - ServerName: request.Origin(), + ServerName: request.Origin(), + VirtualHost: request.Destination(), } if req.Limit, err = strconv.Atoi(limit); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") @@ -123,7 +124,7 @@ func Backfill( } txn := gomatrixserverlib.Transaction{ - Origin: cfg.Matrix.ServerName, + Origin: request.Destination(), PDUs: eventJSONs, OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 504204504..f424fcacd 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -140,6 +140,21 @@ func processInvite( } } + if event.StateKey() == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The invite event has no state key"), + } + } + + _, domain, err := cfg.Matrix.SplitLocalID('@', *event.StateKey()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)), + } + } + // Check that the event is signed by the server sending the request. redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) if err != nil { @@ -175,7 +190,7 @@ func processInvite( // Sign the event so that other servers will know that we have received the invite. signedEvent := event.Sign( - string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, + string(domain), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, ) // Add the invite event to the roomserver. diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 74d065e59..03809df75 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -131,10 +131,20 @@ func MakeJoin( return jsonerror.InternalServerError() } + identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound( + fmt.Sprintf("Server name %q does not exist", request.Destination()), + ), + } + } + queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: verRes.RoomVersion, } - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 5650e3d53..4fd3720f1 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -16,7 +16,6 @@ package routing import ( "encoding/json" - "fmt" "net/http" "time" @@ -136,38 +135,52 @@ func ClaimOneTimeKeys( // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse { - keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) + keys, err := localKeys(cfg, serverName) if err != nil { - return util.ErrorResponse(err) + return util.MessageResponse(http.StatusNotFound, err.Error()) } return util.JSONResponse{Code: http.StatusOK, JSON: keys} } -func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { +func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys - if !cfg.Matrix.IsLocalServerName(serverName) { - return nil, fmt.Errorf("server name not known") - } - - keys.ServerName = serverName - keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) - - publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) - - keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ - cfg.Matrix.KeyID: { - Key: gomatrixserverlib.Base64Bytes(publicKey), - }, - } - - keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} - for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys { - keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{ - VerifyKey: gomatrixserverlib.VerifyKey{ - Key: oldVerifyKey.PublicKey, - }, - ExpiredTS: oldVerifyKey.ExpiredAt, + var virtualHost *config.VirtualHost + for _, v := range cfg.Matrix.VirtualHosts { + if v.ServerName != serverName { + continue } + virtualHost = v + } + + if virtualHost == nil { + publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) + keys.ServerName = cfg.Matrix.ServerName + keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) + keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ + cfg.Matrix.KeyID: { + Key: gomatrixserverlib.Base64Bytes(publicKey), + }, + } + keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} + for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys { + keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{ + VerifyKey: gomatrixserverlib.VerifyKey{ + Key: oldVerifyKey.PublicKey, + }, + ExpiredTS: oldVerifyKey.ExpiredAt, + } + } + } else { + publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey) + keys.ServerName = virtualHost.ServerName + keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod)) + keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ + virtualHost.KeyID: { + Key: gomatrixserverlib.Base64Bytes(publicKey), + }, + } + // TODO: Virtual hosts probably want to be able to specify old signing + // keys too, just in case } toSign, err := json.Marshal(keys.ServerKeyFields) @@ -213,7 +226,7 @@ func NotaryKeys( for serverName, kidToCriteria := range req.ServerKeys { var keyList []gomatrixserverlib.ServerKeys if serverName == cfg.Matrix.ServerName { - if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { + if k, err := localKeys(cfg, serverName); err == nil { keyList = append(keyList, *k) } else { return util.ErrorResponse(err) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index a67e4e28b..f1e9f49ba 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -13,6 +13,7 @@ package routing import ( + "fmt" "net/http" "time" @@ -60,8 +61,18 @@ func MakeLeave( return jsonerror.InternalServerError() } + identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound( + fmt.Sprintf("Server name %q does not exist", request.Destination()), + ), + } + } + var queryRes api.QueryLatestEventsAndStateResponse - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index f672811af..e4d2230ad 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -42,16 +41,9 @@ func GetProfile( } } - _, domain, err := gomatrixserverlib.SplitID('@', userID) + _, domain, err := cfg.Matrix.SplitLocalID('@', userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument(fmt.Sprintf("Format of user ID %q is invalid", userID)), - } - } - - if domain != cfg.Matrix.ServerName { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)), diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 316c61a14..e6dc52601 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -83,7 +83,7 @@ func RoomAliasToID( } } } else { - resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, roomAlias) + resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, cfg.Matrix.ServerName, roomAlias) if err != nil { switch x := err.(type) { case gomatrix.HTTPError: diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index b3bbaa394..a146d85bd 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -197,12 +197,12 @@ type txnReq struct { // A subset of FederationClient functionality that txn requires. Useful for testing. type txnFederationClient interface { - LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) - LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, + LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } @@ -287,6 +287,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res []*gomatrixserverlib.HeaderedEvent{ event.Headered(roomVersion), }, + t.Destination, t.Origin, api.DoNotSendToOtherServers, nil, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 1c796f542..b8bfe0221 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -147,7 +147,7 @@ type txnFedClient struct { getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) } -func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( +func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) { fmt.Println("testFederationClient.LookupState", eventID) @@ -159,7 +159,7 @@ func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.Serv res = r return } -func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { +func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { fmt.Println("testFederationClient.LookupStateIDs", eventID) r, ok := c.stateIDs[eventID] if !ok { @@ -169,7 +169,7 @@ func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.S res = r return } -func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { +func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { fmt.Println("testFederationClient.GetEvent", eventID) r, ok := c.getEvent[eventID] if !ok { @@ -179,7 +179,7 @@ func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerN res = r return } -func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, +func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { return c.getMissingEvents(missing) } diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index ccde9168e..d07faef39 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -90,7 +90,17 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil, false); err != nil { + if err := api.SendEvents( + req.Context(), + rsAPI, + api.KindNew, + evs, + cfg.Matrix.ServerName, // TODO: which virtual host? + "TODO", + cfg.Matrix.ServerName, + nil, + false, + ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -126,6 +136,14 @@ func ExchangeThirdPartyInvite( } } + _, senderDomain, err := cfg.Matrix.SplitLocalID('@', builder.Sender) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()), + } + } + // Check that the state key is correct. _, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey) if err != nil { @@ -171,7 +189,7 @@ func ExchangeThirdPartyInvite( util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") return jsonerror.InternalServerError() } - signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) + signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() @@ -189,6 +207,7 @@ func ExchangeThirdPartyInvite( []*gomatrixserverlib.HeaderedEvent{ inviteEvent.Headered(verRes.RoomVersion), }, + request.Destination(), request.Origin(), cfg.Matrix.ServerName, nil, @@ -341,7 +360,7 @@ func buildMembershipEvent( // them responded with an error. func sendToRemoteServer( ctx context.Context, inv invite, - federation federationAPI.FederationClient, _ *config.FederationAPI, + federation federationAPI.FederationClient, cfg *config.FederationAPI, builder gomatrixserverlib.EventBuilder, ) (err error) { remoteServers := make([]gomatrixserverlib.ServerName, 2) @@ -357,7 +376,7 @@ func sendToRemoteServer( } for _, server := range remoteServers { - err = federation.ExchangeThirdPartyInvite(ctx, server, builder) + err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, builder) if err == nil { return } diff --git a/go.mod b/go.mod index 47c43fbc7..fce538b4c 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index a805a5bf9..224d5aa86 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 h1:Bet5n+//Yh+A2SuPHD67N8jrOhC/EIKvEgisfsKhTss= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf h1:OPM8urHlGC8m1MfusZg6Oi9zxJiDNuSVh3GQuvZeOHw= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index d96231963..c572d8830 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -38,7 +38,8 @@ var ErrRoomNoExists = errors.New("room does not exist") // Returns an error if something else went wrong func QueryAndBuildEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, + builder *gomatrixserverlib.EventBuilder, cfg *config.Global, + identity *gomatrixserverlib.SigningIdentity, evTime time.Time, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.HeaderedEvent, error) { if queryRes == nil { @@ -50,24 +51,24 @@ func QueryAndBuildEvent( // This can pass through a ErrRoomNoExists to the caller return nil, err } - return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes) + return BuildEvent(ctx, builder, cfg, identity, evTime, eventsNeeded, queryRes) } // BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse // provided. func BuildEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, + builder *gomatrixserverlib.EventBuilder, cfg *config.Global, + identity *gomatrixserverlib.SigningIdentity, evTime time.Time, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.HeaderedEvent, error) { - err := addPrevEventsToEvent(builder, eventsNeeded, queryRes) - if err != nil { + if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil { return nil, err } event, err := builder.Build( - evTime, cfg.ServerName, cfg.KeyID, - cfg.PrivateKey, queryRes.RoomVersion, + evTime, identity.ServerName, identity.KeyID, + identity.PrivateKey, queryRes.RoomVersion, ) if err != nil { return nil, err diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go index 575e41281..cd911f8c6 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/keyserver/consumers/devicelistupdate.go @@ -30,12 +30,12 @@ import ( // DeviceListUpdateConsumer consumes device list updates that came in over federation. type DeviceListUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - updater *internal.DeviceListUpdater - serverName gomatrixserverlib.ServerName + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + updater *internal.DeviceListUpdater + isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. @@ -46,12 +46,12 @@ func NewDeviceListUpdateConsumer( updater *internal.DeviceListUpdater, ) *DeviceListUpdateConsumer { return &DeviceListUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - updater: updater, - serverName: cfg.Matrix.ServerName, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + updater: updater, + isLocalServerName: cfg.Matrix.IsLocalServerName, } } @@ -75,7 +75,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { return true - } else if serverName == t.serverName { + } else if t.isLocalServerName(serverName) { return true } else if serverName != origin { return true diff --git a/keyserver/consumers/signingkeyupdate.go b/keyserver/consumers/signingkeyupdate.go index 366e259b4..bcceaad15 100644 --- a/keyserver/consumers/signingkeyupdate.go +++ b/keyserver/consumers/signingkeyupdate.go @@ -31,12 +31,13 @@ import ( // SigningKeyUpdateConsumer consumes signing key updates that came in over federation. type SigningKeyUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - keyAPI *internal.KeyInternalAPI - cfg *config.KeyServer + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + keyAPI *internal.KeyInternalAPI + cfg *config.KeyServer + isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. @@ -47,12 +48,13 @@ func NewSigningKeyUpdateConsumer( keyAPI *internal.KeyInternalAPI, ) *SigningKeyUpdateConsumer { return &SigningKeyUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - keyAPI: keyAPI, - cfg: cfg, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + keyAPI: keyAPI, + cfg: cfg, + isLocalServerName: cfg.Matrix.IsLocalServerName, } } @@ -77,7 +79,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { logrus.WithError(err).Error("failed to split user id") return true - } else if serverName == t.cfg.Matrix.ServerName { + } else if t.isLocalServerName(serverName) { logrus.Warn("dropping device key update from ourself") return true } else if serverName != origin { diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 3f7c0d8b4..8ff9dfc31 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -96,6 +96,7 @@ type DeviceListUpdater struct { producer KeyChangeProducer fedClient fedsenderapi.KeyserverFederationAPI workerChans []chan gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // block on or timeout via a select. @@ -139,6 +140,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, + thisServer gomatrixserverlib.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -148,6 +150,7 @@ func NewDeviceListUpdater( api: api, producer: producer, fedClient: fedClient, + thisServer: thisServer, workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, @@ -435,8 +438,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go "server_name": serverName, "user_id": userID, }) - - res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) + res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return time.Minute * 10, err diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 28a13a0a0..a374c9516 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -129,7 +129,13 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { _, pkey, _ := ed25519.GenerateKey(nil) fedClient := gomatrixserverlib.NewFederationClient( - gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: gomatrixserverlib.ServerName("example.test"), + KeyID: gomatrixserverlib.KeyID("ed25519:test"), + PrivateKey: pkey, + }, + }, ) fedClient.Client = *gomatrixserverlib.NewClient( gomatrixserverlib.WithTransport(&roundTripper{tripper}), @@ -147,7 +153,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost") event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -219,7 +225,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -288,7 +294,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 37c55c8f8..9a08a0bb7 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -146,7 +146,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( defer cancel() defer wg.Done() - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) mu.Lock() defer mu.Unlock() @@ -559,7 +559,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( if len(devKeys) == 0 { return } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp return diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 0a4b8fde6..a86c2da4e 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -58,7 +58,7 @@ func NewInternalAPI( FedClient: fedClient, Producer: keyChangeProducer, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable ap.Updater = updater go func() { if err := updater.Start(); err != nil { diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 45a9ef497..88d523270 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -94,8 +94,9 @@ type TransactionID struct { // InputRoomEventsRequest is a request to InputRoomEvents type InputRoomEventsRequest struct { - InputRoomEvents []InputRoomEvent `json:"input_room_events"` - Asynchronous bool `json:"async"` + InputRoomEvents []InputRoomEvent `json:"input_room_events"` + Asynchronous bool `json:"async"` + VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` } // InputRoomEventsResponse is a response to InputRoomEvents diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 2536a4bb5..e70e5ea9c 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -148,6 +148,8 @@ type PerformBackfillRequest struct { Limit int `json:"limit"` // The server interested in the events. ServerName gomatrixserverlib.ServerName `json:"server_name"` + // Which virtual host are we doing this for? + VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` } // PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 8b031982c..252be557f 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -26,7 +26,7 @@ import ( func SendEvents( ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, events []*gomatrixserverlib.HeaderedEvent, - origin gomatrixserverlib.ServerName, + virtualHost, origin gomatrixserverlib.ServerName, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, async bool, ) error { @@ -40,14 +40,15 @@ func SendEvents( TransactionID: txnID, } } - return SendInputRoomEvents(ctx, rsAPI, ires, async) + return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async) } // SendEventWithState writes an event with the specified kind to the roomserver // with the state at the event as KindOutlier before it. Will not send any event that is // marked as `true` in haveEventIDs. func SendEventWithState( - ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, + ctx context.Context, rsAPI InputRoomEventsAPI, + virtualHost gomatrixserverlib.ServerName, kind Kind, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { @@ -85,17 +86,19 @@ func SendEventWithState( StateEventIDs: stateEventIDs, }) - return SendInputRoomEvents(ctx, rsAPI, ires, async) + return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async) } // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI InputRoomEventsAPI, + virtualHost gomatrixserverlib.ServerName, ires []InputRoomEvent, async bool, ) error { request := InputRoomEventsRequest{ InputRoomEvents: ires, Asynchronous: async, + VirtualHost: virtualHost, } var response InputRoomEventsResponse if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 175bb9310..329e6af7f 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -137,6 +137,11 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { + _, virtualHost, err := r.Cfg.Matrix.SplitLocalID('@', request.UserID) + if err != nil { + return err + } + roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) @@ -190,6 +195,16 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( sender = ev.Sender() } + _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', sender) + if err != nil { + return err + } + + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + return err + } + builder := &gomatrixserverlib.EventBuilder{ Sender: sender, RoomID: ev.RoomID(), @@ -211,12 +226,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, stateRes) + newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, stateRes) if err != nil { return err } - err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, r.ServerName, r.ServerName, nil, false) + err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false) if err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a11586a5..1a3626609 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -87,10 +87,10 @@ func NewRoomserverAPI( Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"), ServerACLs: serverACLs, Queryer: &query.Queryer{ - DB: roomserverDB, - Cache: base.Caches, - ServerName: base.Cfg.Global.ServerName, - ServerACLs: serverACLs, + DB: roomserverDB, + Cache: base.Caches, + IsLocalServerName: base.Cfg.Global.IsLocalServerName, + ServerACLs: serverACLs, }, // perform-er structs get initialised when we have a federation sender to use } @@ -127,13 +127,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Inputer: r.Inputer, } r.Joiner = &perform.Joiner{ - ServerName: r.Cfg.Matrix.ServerName, - Cfg: r.Cfg, - DB: r.DB, - FSAPI: r.fsAPI, - RSAPI: r, - Inputer: r.Inputer, - Queryer: r.Queryer, + Cfg: r.Cfg, + DB: r.DB, + FSAPI: r.fsAPI, + RSAPI: r, + Inputer: r.Inputer, + Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ ServerName: r.Cfg.Matrix.ServerName, @@ -163,10 +162,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio DB: r.DB, } r.Backfiller = &perform.Backfiller{ - ServerName: r.ServerName, - DB: r.DB, - FSAPI: r.fsAPI, - KeyRing: r.KeyRing, + IsLocalServerName: r.Cfg.Matrix.IsLocalServerName, + DB: r.DB, + FSAPI: r.fsAPI, + KeyRing: r.KeyRing, // Perspective servers are trusted to not lie about server keys, so we will also // prefer these servers when backfilling (assuming they are in the room) rather // than trying random servers diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index f5099ca11..e965691c9 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -278,7 +278,11 @@ func (w *worker) _next() { // a string, because we might want to return that to the caller if // it was a synchronous request. var errString string - if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { + if err = w.r.processRoomEvent( + w.r.ProcessContext.Context(), + gomatrixserverlib.ServerName(msg.Header.Get("virtual_host")), + &inputRoomEvent, + ); err != nil { switch err.(type) { case types.RejectedError: // Don't send events that were rejected to Sentry @@ -358,6 +362,7 @@ func (r *Inputer) queueInputRoomEvents( if replyTo != "" { msg.Header.Set("sync", replyTo) } + msg.Header.Set("virtual_host", string(request.VirtualHost)) msg.Data, err = json.Marshal(e) if err != nil { return nil, fmt.Errorf("json.Marshal: %w", err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 682aa2b12..72c285f8e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -69,6 +69,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, + virtualHost gomatrixserverlib.ServerName, input *api.InputRoomEvent, ) error { select { @@ -200,7 +201,7 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { return fmt.Errorf("r.fetchAuthEvents: %w", err) } @@ -265,16 +266,17 @@ func (r *Inputer) processRoomEvent( // processRoomEvent. if len(serverRes.ServerNames) > 0 { missingState := missingStateReq{ - origin: input.Origin, - inputer: r, - db: r.DB, - roomInfo: roomInfo, - federation: r.FSAPI, - keys: r.KeyRing, - roomsMu: internal.NewMutexByRoom(), - servers: serverRes.ServerNames, - hadEvents: map[string]bool{}, - haveEvents: map[string]*gomatrixserverlib.Event{}, + origin: input.Origin, + virtualHost: virtualHost, + inputer: r, + db: r.DB, + roomInfo: roomInfo, + federation: r.FSAPI, + keys: r.KeyRing, + roomsMu: internal.NewMutexByRoom(), + servers: serverRes.ServerNames, + hadEvents: map[string]bool{}, + haveEvents: map[string]*gomatrixserverlib.Event{}, } var stateSnapshot *parsedRespState if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { @@ -555,6 +557,7 @@ func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, roomInfo *types.RoomInfo, + virtualHost gomatrixserverlib.ServerName, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, known map[string]*types.Event, @@ -605,7 +608,7 @@ func (r *Inputer) fetchAuthEvents( // Request the entire auth chain for the event in question. This should // contain all of the auth events — including ones that we already know — // so we'll need to filter through those in the next section. - res, err = r.FSAPI.GetEventAuth(ctx, serverName, event.RoomVersion, event.RoomID(), event.EventID()) + res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.RoomVersion, event.RoomID(), event.EventID()) if err != nil { logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) continue diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index f788c5b3a..03ac2b38d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -41,6 +41,7 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event { type missingStateReq struct { log *logrus.Entry + virtualHost gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName db storage.Database roomInfo *types.RoomInfo @@ -101,7 +102,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -157,7 +158,7 @@ func (t *missingStateReq) processEventWithMissingState( }) } for _, ire := range outlierRoomEvents { - if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { + if err = t.inputer.processRoomEvent(ctx, t.virtualHost, &ire); err != nil { if _, ok := err.(types.RejectedError); !ok { return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) } @@ -180,7 +181,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -199,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -519,7 +520,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve var missingResp *gomatrixserverlib.RespMissingEvents for _, server := range t.servers { var m gomatrixserverlib.RespMissingEvents - if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ + if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. EarliestEvents: latestEvents, @@ -635,7 +636,7 @@ func (t *missingStateReq) lookupMissingStateViaState( span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") defer span.Finish() - state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) + state, err := t.federation.LookupState(ctx, t.virtualHost, t.origin, roomID, eventID, roomVersion) if err != nil { return nil, err } @@ -675,7 +676,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5) for _, serverName := range t.servers { reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20) - stateIDs, err = t.federation.LookupStateIDs(reqctx, serverName, roomID, eventID) + stateIDs, err = t.federation.LookupStateIDs(reqctx, t.virtualHost, serverName, roomID, eventID) reqcancel() if err == nil { break @@ -855,7 +856,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs for _, serverName := range t.servers { reqctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) + txn, err := t.federation.GetEvent(reqctx, t.virtualHost, serverName, missingEventID) if err != nil || len(txn.PDUs) == 0 { t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID") if errors.Is(err, context.DeadlineExceeded) { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 0406fc8f4..d42f4e45d 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -139,7 +139,12 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil } - event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + continue + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -242,6 +247,15 @@ func (r *Admin) PerformAdminDownloadState( req *api.PerformAdminDownloadStateRequest, res *api.PerformAdminDownloadStateResponse, ) error { + _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', req.UserID) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.Cfg.Matrix.SplitLocalID: %s", err), + } + return nil + } + roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) if err != nil { res.Error = &api.PerformError{ @@ -273,7 +287,7 @@ func (r *Admin) PerformAdminDownloadState( for _, fwdExtremity := range fwdExtremities { var state gomatrixserverlib.RespState - state, err = r.Inputer.FSAPI.LookupState(ctx, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) + state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -331,7 +345,12 @@ func (r *Admin) PerformAdminDownloadState( Depth: depth, } - ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, queryRes) + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + return err + } + + ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 57e121ea2..069f017a9 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -37,10 +37,10 @@ import ( const maxBackfillServers = 5 type Backfiller struct { - ServerName gomatrixserverlib.ServerName - DB storage.Database - FSAPI federationAPI.RoomserverFederationAPI - KeyRing gomatrixserverlib.JSONVerifier + IsLocalServerName func(gomatrixserverlib.ServerName) bool + DB storage.Database + FSAPI federationAPI.RoomserverFederationAPI + KeyRing gomatrixserverlib.JSONVerifier // The servers which should be preferred above other servers when backfilling PreferServers []gomatrixserverlib.ServerName @@ -55,7 +55,7 @@ func (r *Backfiller) PerformBackfill( // if we are requesting the backfill then we need to do a federation hit // TODO: we could be more sensible and fetch as many events we already have then request the rest // which is what the syncapi does already. - if request.ServerName == r.ServerName { + if r.IsLocalServerName(request.ServerName) { return r.backfillViaFederation(ctx, request, response) } // someone else is requesting the backfill, try to service their request. @@ -112,16 +112,18 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform if info == nil || info.IsStub() { return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) } - requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities, r.PreferServers) + requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) // Request 100 items regardless of what the query asks for. // We don't want to go much higher than this. // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass // (so we don't need to hit /state_ids which the test has no listener for) // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( - ctx, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100) + ctx, req.VirtualHost, requester, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, + ) if err != nil { + logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed") return err } logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) @@ -144,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform var entries []types.StateEntry if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil { // attempt to fetch the missing events - r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) + r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs, req.VirtualHost) // try again entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true) if err != nil { @@ -173,7 +175,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just // best effort. func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - backfillRequester *backfillRequester, stateIDs []string) { + backfillRequester *backfillRequester, stateIDs []string, virtualHost gomatrixserverlib.ServerName) { servers := backfillRequester.servers @@ -198,7 +200,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue // already found } logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) - res, err := r.FSAPI.GetEvent(ctx, srv, id) + res, err := r.FSAPI.GetEvent(ctx, virtualHost, srv, id) if err != nil { logger.WithError(err).Warn("failed to get event from server") continue @@ -241,11 +243,12 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { - db storage.Database - fsAPI federationAPI.RoomserverFederationAPI - thisServer gomatrixserverlib.ServerName - preferServer map[gomatrixserverlib.ServerName]bool - bwExtrems map[string][]string + db storage.Database + fsAPI federationAPI.RoomserverFederationAPI + virtualHost gomatrixserverlib.ServerName + isLocalServerName func(gomatrixserverlib.ServerName) bool + preferServer map[gomatrixserverlib.ServerName]bool + bwExtrems map[string][]string // per-request state servers []gomatrixserverlib.ServerName @@ -255,7 +258,9 @@ type backfillRequester struct { } func newBackfillRequester( - db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, thisServer gomatrixserverlib.ServerName, + db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, + virtualHost gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, ) *backfillRequester { preferServer := make(map[gomatrixserverlib.ServerName]bool) @@ -265,7 +270,8 @@ func newBackfillRequester( return &backfillRequester{ db: db, fsAPI: fsAPI, - thisServer: thisServer, + virtualHost: virtualHost, + isLocalServerName: isLocalServerName, eventIDToBeforeStateIDs: make(map[string][]string), eventIDMap: make(map[string]*gomatrixserverlib.Event), bwExtrems: bwExtrems, @@ -450,7 +456,7 @@ FindSuccessor: } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -477,7 +483,7 @@ FindSuccessor: } var servers []gomatrixserverlib.ServerName for server := range serverSet { - if server == b.thisServer { + if b.isLocalServerName(server) { continue } if b.preferServer[server] { // insert at the front @@ -496,10 +502,10 @@ FindSuccessor: // Backfill performs a backfill request to the given server. // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid -func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, +func (b *backfillRequester) Backfill(ctx context.Context, origin, server gomatrixserverlib.ServerName, roomID string, limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { - tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs) + tx, err := b.fsAPI.Backfill(ctx, origin, server, roomID, limit, fromEventIDs) return tx, err } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 9d596ab30..4de008c66 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -39,11 +39,10 @@ import ( ) type Joiner struct { - ServerName gomatrixserverlib.ServerName - Cfg *config.RoomServer - FSAPI fsAPI.RoomserverFederationAPI - RSAPI rsAPI.RoomserverInternalAPI - DB storage.Database + Cfg *config.RoomServer + FSAPI fsAPI.RoomserverFederationAPI + RSAPI rsAPI.RoomserverInternalAPI + DB storage.Database Inputer *input.Inputer Queryer *query.Queryer @@ -197,7 +196,7 @@ func (r *Joiner) performJoinRoomByID( // Prepare the template for the join event. userID := req.UserID - _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) if err != nil { return "", "", &rsAPI.PerformError{ Code: rsAPI.PerformErrorBadRequest, @@ -283,7 +282,7 @@ func (r *Joiner) performJoinRoomByID( // locally on the homeserver. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, userDomain, &eb) switch err { case nil: @@ -410,7 +409,9 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( } func buildEvent( - ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, + ctx context.Context, db storage.Database, cfg *config.Global, + senderDomain gomatrixserverlib.ServerName, + builder *gomatrixserverlib.EventBuilder, ) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) if err != nil { @@ -438,7 +439,12 @@ func buildEvent( } } - ev, err := eventutil.BuildEvent(ctx, builder, cfg, time.Now(), &eventsNeeded, &queryRes) + identity, err := cfg.SigningIdentityFor(senderDomain) + if err != nil { + return nil, nil, err + } + + ev, err := eventutil.BuildEvent(ctx, builder, cfg, identity, time.Now(), &eventsNeeded, &queryRes) if err != nil { return nil, nil, err } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 49e4b479a..fa998e3e1 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -162,21 +162,21 @@ func (r *Leaver) performLeaveRoomByID( return nil, fmt.Errorf("eb.SetUnsigned: %w", err) } + // Get the sender domain. + _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', eb.Sender) + if serr != nil { + return nil, fmt.Errorf("sender %q is invalid", eb.Sender) + } + // We know that the user is in the room at this point so let's build // a leave event. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, senderDomain, &eb) if err != nil { return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) } - // Get the sender domain. - _, senderDomain, serr := gomatrixserverlib.SplitID('@', event.Sender()) - if serr != nil { - return nil, fmt.Errorf("sender %q is invalid", event.Sender()) - } - // Give our leave event to the roomserver input stream. The // roomserver will process the membership change and notify // downstream automatically. diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 38abe323c..02a19911c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -60,7 +60,7 @@ func (r *Upgrader) performRoomUpgrade( ) (string, *api.PerformError) { roomID := req.RoomID userID := req.UserID - _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) if err != nil { return "", &api.PerformError{ Code: api.PerformErrorNotAllowed, @@ -558,7 +558,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user SendAsServer: api.DoNotSendToOtherServers, }) } - if err = api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { + if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err), } @@ -595,8 +595,21 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err), } } + // Get the sender domain. + _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender) + if serr != nil { + return nil, &api.PerformError{ + Msg: fmt.Sprintf("Failed to split user ID %q: %s", builder.Sender, err), + } + } + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + return nil, &api.PerformError{ + Msg: fmt.Sprintf("Failed to get signing identity for %q: %s", senderDomain, err), + } + } var queryRes api.QueryLatestEventsAndStateResponse - headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, evTime, r.URSAPI, &queryRes) + headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &api.PerformError{ Code: api.PerformErrorNoRoom, @@ -686,7 +699,7 @@ func (r *Upgrader) sendHeaderedEvent( Origin: serverName, SendAsServer: sendAsServer, }) - if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { + if err := api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false); err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err), } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 8850e5c46..d8456fb43 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -37,10 +37,10 @@ import ( ) type Queryer struct { - DB storage.Database - Cache caching.RoomServerCaches - ServerName gomatrixserverlib.ServerName - ServerACLs *acls.ServerACLs + DB storage.Database + Cache caching.RoomServerCaches + IsLocalServerName func(gomatrixserverlib.ServerName) bool + ServerACLs *acls.ServerACLs } // QueryLatestEventsAndState implements api.RoomserverInternalAPI @@ -392,7 +392,7 @@ func (r *Queryer) QueryServerJoinedToRoom( } response.RoomExists = true - if request.ServerName == r.ServerName || request.ServerName == "" { + if r.IsLocalServerName(request.ServerName) || request.ServerName == "" { response.IsInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) if err != nil { return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 4e98af853..24b5515e5 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -44,7 +44,7 @@ func Test_SharedUsers(t *testing.T) { // SetFederationAPI starts the room event input consumer rsAPI.SetFederationAPI(nil, nil) // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } diff --git a/setup/base/base.go b/setup/base/base.go index 2e3a3a195..14edadd96 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -364,10 +364,10 @@ func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client { // CreateFederationClient creates a new federation client. Should only be called // once per component. func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { + identities := b.Cfg.Global.SigningIdentities() if b.Cfg.Global.DisableFederation { return gomatrixserverlib.NewFederationClient( - b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey, - gomatrixserverlib.WithTransport(noOpHTTPTransport), + identities, gomatrixserverlib.WithTransport(noOpHTTPTransport), ) } opts := []gomatrixserverlib.ClientOption{ @@ -379,8 +379,7 @@ func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationCli opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache)) } client := gomatrixserverlib.NewFederationClient( - b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, - b.Cfg.Global.PrivateKey, opts..., + identities, opts..., ) client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) return client diff --git a/setup/config/config.go b/setup/config/config.go index e99852ec9..918bcbe3b 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -231,6 +231,21 @@ func loadConfig( return nil, err } + for _, v := range c.Global.VirtualHosts { + if v.KeyValidityPeriod == 0 { + v.KeyValidityPeriod = c.Global.KeyValidityPeriod + } + if v.PrivateKeyPath == "" { + v.KeyID = c.Global.KeyID + v.PrivateKey = c.Global.PrivateKey + continue + } + privateKeyPath := absPath(basePath, v.PrivateKeyPath) + if v.KeyID, v.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { + return nil, err + } + } + for _, key := range c.Global.OldVerifyKeys { switch { case key.PrivateKeyPath != "": diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 825772827..f2fdd021e 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "math/rand" "strconv" "strings" @@ -15,7 +16,7 @@ type Global struct { ServerName gomatrixserverlib.ServerName `yaml:"server_name"` // The secondary server names, used for virtual hosting. - SecondaryServerNames []gomatrixserverlib.ServerName `yaml:"-"` + VirtualHosts []*VirtualHost `yaml:"virtual_hosts"` // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` @@ -114,6 +115,10 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { checkNotEmpty(configErrs, "global.server_name", string(c.ServerName)) checkNotEmpty(configErrs, "global.private_key", string(c.PrivateKeyPath)) + for _, v := range c.VirtualHosts { + v.Verify(configErrs) + } + c.JetStream.Verify(configErrs, isMonolith) c.Metrics.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith) @@ -127,14 +132,85 @@ func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool if c.ServerName == serverName { return true } - for _, secondaryName := range c.SecondaryServerNames { - if secondaryName == serverName { + for _, v := range c.VirtualHosts { + if v.ServerName == serverName { return true } } return false } +func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib.ServerName, error) { + u, s, err := gomatrixserverlib.SplitID(sigil, id) + if err != nil { + return u, s, err + } + if !c.IsLocalServerName(s) { + return u, s, fmt.Errorf("server name %q not known", s) + } + return u, s, nil +} + +func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.SigningIdentity, error) { + for _, id := range c.SigningIdentities() { + if id.ServerName == serverName { + return id, nil + } + } + return nil, fmt.Errorf("no signing identity %q", serverName) +} + +func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { + identities := make([]*gomatrixserverlib.SigningIdentity, 0, len(c.VirtualHosts)+1) + identities = append(identities, &gomatrixserverlib.SigningIdentity{ + ServerName: c.ServerName, + KeyID: c.KeyID, + PrivateKey: c.PrivateKey, + }) + for _, v := range c.VirtualHosts { + identities = append(identities, v.SigningIdentity()) + } + return identities +} + +type VirtualHost struct { + // The server name of the virtual host. + ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + + // The key ID of the private key. If not specified, the default global key ID + // will be used instead. + KeyID gomatrixserverlib.KeyID `yaml:"key_id"` + + // Path to the private key. If not specified, the default global private key + // will be used instead. + PrivateKeyPath Path `yaml:"private_key"` + + // The private key itself. + PrivateKey ed25519.PrivateKey `yaml:"-"` + + // How long a remote server can cache our server key for before requesting it again. + // Increasing this number will reduce the number of requests made by remote servers + // for our key, but increases the period a compromised key will be considered valid + // by remote servers. + // Defaults to 24 hours. + KeyValidityPeriod time.Duration `yaml:"key_validity_period"` + + // Is registration enabled on this virtual host? + AllowRegistration bool `json:"allow_registration"` +} + +func (v *VirtualHost) Verify(configErrs *ConfigErrors) { + checkNotEmpty(configErrs, "virtual_host.*.server_name", string(v.ServerName)) +} + +func (v *VirtualHost) SigningIdentity() *gomatrixserverlib.SigningIdentity { + return &gomatrixserverlib.SigningIdentity{ + ServerName: v.ServerName, + KeyID: v.KeyID, + PrivateKey: v.PrivateKey, + } +} + type OldVerifyKeys struct { // Path to the private key. PrivateKeyPath Path `yaml:"private_key"` diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 98502f5cb..bc369c166 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -397,7 +397,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen serversToQuery := rc.getServersForEventID(parentID) var result *MSC2836EventRelationshipsResponse for _, srv := range serversToQuery { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: parentID, Direction: "down", Limit: 100, @@ -484,7 +484,7 @@ func walkThread( // MSC2836EventRelationships performs an /event_relationships request to a remote server func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: eventID, DepthFirst: rc.req.DepthFirst, Direction: rc.req.Direction, @@ -665,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo }) } // we've got the data by this point so use a background context - err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) + err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, rc.serverName, ires, false) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index d7c022544..56c063598 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -433,7 +433,7 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) *gomatrixserver if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) + res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 96ebba7ef..92f081500 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -28,22 +28,20 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - notifier *notifier.Notifier - stream streams.StreamProvider - serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.SyncRoomserverAPI + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + notifier *notifier.Notifier + stream streams.StreamProvider + rsAPI roomserverAPI.SyncRoomserverAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. @@ -59,15 +57,14 @@ func NewOutputKeyChangeEventConsumer( stream streams.StreamProvider, ) *OutputKeyChangeEventConsumer { s := &OutputKeyChangeEventConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), - topic: topic, - db: store, - serverName: cfg.Matrix.ServerName, - rsAPI: rsAPI, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), + topic: topic, + db: store, + rsAPI: rsAPI, + notifier: notifier, + stream: stream, } return s diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index 8aaa65730..e39d43f94 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -34,14 +34,13 @@ import ( // OutputReceiptEventConsumer consumes events that originated in the EDU server. type OutputReceiptEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream streams.StreamProvider - notifier *notifier.Notifier - serverName gomatrixserverlib.ServerName + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + stream streams.StreamProvider + notifier *notifier.Notifier } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -55,14 +54,13 @@ func NewOutputReceiptEventConsumer( stream streams.StreamProvider, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), - durable: cfg.Matrix.JetStream.Durable("SyncAPIReceiptConsumer"), - db: store, - notifier: notifier, - stream: stream, - serverName: cfg.Matrix.ServerName, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), + durable: cfg.Matrix.JetStream.Durable("SyncAPIReceiptConsumer"), + db: store, + notifier: notifier, + stream: stream, } } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index f767615c8..1b67f5684 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -364,11 +364,7 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom // TODO: check that it's a join and not a profile change (means unmarshalling prev_content) if membership == gomatrixserverlib.Join { // check it's a local join - _, domain, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) - if err != nil { - return sp, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if domain != s.cfg.Matrix.ServerName { + if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil { return sp, nil } @@ -390,9 +386,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( if msg.Event.StateKey() == nil { return } - if _, serverName, err := gomatrixserverlib.SplitID('@', *msg.Event.StateKey()); err != nil { - return - } else if serverName != s.cfg.Matrix.ServerName { + if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil { return } pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 49d84cca3..356e83263 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -37,15 +37,15 @@ import ( // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. type OutputSendToDeviceEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - keyAPI keyapi.SyncKeyAPI - serverName gomatrixserverlib.ServerName // our server name - stream streams.StreamProvider - notifier *notifier.Notifier + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + keyAPI keyapi.SyncKeyAPI + isLocalServerName func(gomatrixserverlib.ServerName) bool + stream streams.StreamProvider + notifier *notifier.Notifier } // NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer. @@ -60,15 +60,15 @@ func NewOutputSendToDeviceEventConsumer( stream streams.StreamProvider, ) *OutputSendToDeviceEventConsumer { return &OutputSendToDeviceEventConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), - durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), - db: store, - keyAPI: keyAPI, - serverName: cfg.Matrix.ServerName, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), + db: store, + keyAPI: keyAPI, + isLocalServerName: cfg.Matrix.IsLocalServerName, + notifier: notifier, + stream: stream, } } @@ -89,7 +89,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] log.WithError(err).Errorf("send-to-device: failed to split user id, dropping message") return true } - if domain != s.serverName { + if !s.isLocalServerName(domain) { log.Tracef("ignoring send-to-device event with destination %s", domain) return true } @@ -114,7 +114,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] if output.Type == "m.room_key_request" { requestingDeviceID := gjson.GetBytes(output.SendToDeviceEvent.Content, "requesting_device_id").Str _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) - if requestingDeviceID != "" && senderDomain != s.serverName { + if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) { // Mark the requesting device as stale, if we don't know about it. if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 86cf8e736..0d740ebfc 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -528,6 +528,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] BackwardsExtremities: backwardsExtremities, Limit: limit, ServerName: r.cfg.Matrix.ServerName, + VirtualHost: r.device.UserDomain(), }, &res) if err != nil { return nil, fmt.Errorf("PerformBackfill failed: %w", err) diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index a4985dbf4..483274481 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -433,7 +433,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { beforeJoinBody := fmt.Sprintf("Before invite in a %s room", tc.historyVisibility) beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": beforeJoinBody}) eventsToSend := append(room.Events(), beforeJoinEv) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } syncUntil(t, base, aliceDev.AccessToken, false, @@ -472,7 +472,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } syncUntil(t, base, aliceDev.AccessToken, false, From 5c9aed6af90dc77a7d82a4d16a165715cdd560ca Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Nov 2022 15:11:08 +0000 Subject: [PATCH 020/124] Update to matrix-org/gomatrixserverlib@900369e --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index fce538b4c..1889bf891 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf + github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 224d5aa86..321f00d4a 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf h1:OPM8urHlGC8m1MfusZg6Oi9zxJiDNuSVh3GQuvZeOHw= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 h1:VapUpY3oSbEGhfSpnnTKh7bz6AA5R4tVB5lwdlcA6jE= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From 9b8bb55430e658c8752c0b093416eed6d977cb0d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Nov 2022 17:21:16 +0000 Subject: [PATCH 021/124] Don't get blacklisted hosts when querying joined servers (#2880) Otherwise we just waste time/CPU. --- federationapi/api/api.go | 5 +-- federationapi/consumers/keychange.go | 4 +-- federationapi/consumers/presence.go | 2 +- federationapi/internal/query.go | 2 +- federationapi/storage/interface.go | 2 +- .../storage/postgres/joined_hosts_table.go | 31 +++++++++++++------ federationapi/storage/postgres/storage.go | 8 ++--- federationapi/storage/shared/storage.go | 4 +-- .../storage/sqlite3/joined_hosts_table.go | 15 +++++++-- federationapi/storage/sqlite3/storage.go | 8 ++--- federationapi/storage/tables/interface.go | 2 +- roomserver/internal/input/input_events.go | 5 +-- 12 files changed, 56 insertions(+), 32 deletions(-) diff --git a/federationapi/api/api.go b/federationapi/api/api.go index e34c9e8b6..f4be53b9a 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -198,8 +198,9 @@ type PerformInviteResponse struct { // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomRequest struct { - RoomID string `json:"room_id"` - ExcludeSelf bool `json:"exclude_self"` + RoomID string `json:"room_id"` + ExcludeSelf bool `json:"exclude_self"` + ExcludeBlacklisted bool `json:"exclude_blacklisted"` } // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 7d1ae0f81..601257d4b 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { sentry.CaptureException(err) logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") @@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { return true } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { sentry.CaptureException(err) logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index 153fc40b5..29b16f373 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg } // send this presence to all servers who share rooms with this user. - joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { log.WithError(err).Error("failed to get joined hosts") return true diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index b0a76eeb7..688afa8ea 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) (err error) { - joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) + joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted) if err != nil { return } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 09098cd1e..b15b8bfae 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -32,7 +32,7 @@ type Database interface { GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. - GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationapi/storage/postgres/joined_hosts_table.go b/federationapi/storage/postgres/joined_hosts_table.go index 5c95b72a8..9a3977560 100644 --- a/federationapi/storage/postgres/joined_hosts_table.go +++ b/federationapi/storage/postgres/joined_hosts_table.go @@ -66,14 +66,20 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" +const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" + + " SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" + + ");" + type joinedHostsStatements struct { - db *sql.DB - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - deleteJoinedHostsForRoomStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt - selectJoinedHostsForRoomsStmt *sql.Stmt + db *sql.DB + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + deleteJoinedHostsForRoomStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + selectJoinedHostsForRoomsStmt *sql.Stmt + selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt } func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { return } + if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil { + return + } return } @@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( } func (s *joinedHostsStatements) SelectJoinedHostsForRooms( - ctx context.Context, roomIDs []string, + ctx context.Context, roomIDs []string, excludingBlacklisted bool, ) ([]gomatrixserverlib.ServerName, error) { - rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + stmt := s.selectJoinedHostsForRoomsStmt + if excludingBlacklisted { + stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt + } + rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs)) if err != nil { return nil, err } diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index a33fa4a43..fe84e932e 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { return nil, err } + blacklist, err := NewPostgresBlacklistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { return nil, err @@ -58,10 +62,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - blacklist, err := NewPostgresBlacklistTable(d.db) - if err != nil { - return nil, err - } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 4fabff7d4..7d6e8a684 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -117,8 +117,8 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { - servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) +func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { + servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err } diff --git a/federationapi/storage/sqlite3/joined_hosts_table.go b/federationapi/storage/sqlite3/joined_hosts_table.go index e0e0f2873..83112c150 100644 --- a/federationapi/storage/sqlite3/joined_hosts_table.go +++ b/federationapi/storage/sqlite3/joined_hosts_table.go @@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" +const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" + + " SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" + + ");" + type joinedHostsStatements struct { db *sql.DB insertJoinedHostsStmt *sql.Stmt @@ -74,6 +79,7 @@ type joinedHostsStatements struct { selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic + // selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic } func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( } func (s *joinedHostsStatements) SelectJoinedHostsForRooms( - ctx context.Context, roomIDs []string, + ctx context.Context, roomIDs []string, excludingBlacklisted bool, ) ([]gomatrixserverlib.ServerName, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i := range roomIDs { iRoomIDs[i] = roomIDs[i] } - - sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + query := selectJoinedHostsForRoomsSQL + if excludingBlacklisted { + query = selectJoinedHostsForRoomsExcludingBlacklistedSQL + } + sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) if err != nil { return nil, err diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index e86ac817b..d13b5defc 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -41,6 +41,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } + blacklist, err := NewSQLiteBlacklistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { return nil, err @@ -57,10 +61,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - blacklist, err := NewSQLiteBlacklistTable(d.db) - if err != nil { - return nil, err - } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 3c116a1d0..9f4e86a6e 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -58,7 +58,7 @@ type FederationJoinedHosts interface { SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) + SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error) } type FederationBlacklist interface { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 72c285f8e..10b8ee27f 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -165,8 +165,9 @@ func (r *Inputer) processRoomEvent( if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: event.RoomID(), - ExcludeSelf: true, + RoomID: event.RoomID(), + ExcludeSelf: true, + ExcludeBlacklisted: true, } if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) From deddf686b9e9be236ca4e3cc265c7cfba38ed4ee Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 09:16:07 +0000 Subject: [PATCH 022/124] Tweak `/key/v2/server` --- federationapi/routing/keys.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 4fd3720f1..693f5b9fe 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -146,10 +146,10 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam var keys gomatrixserverlib.ServerKeys var virtualHost *config.VirtualHost for _, v := range cfg.Matrix.VirtualHosts { - if v.ServerName != serverName { - continue + if v.ServerName == serverName { + virtualHost = v + break } - virtualHost = v } if virtualHost == nil { @@ -188,14 +188,15 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam return nil, err } - keys.Raw, err = gomatrixserverlib.SignJSON( - string(serverName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, - ) + identity, err := cfg.Matrix.SigningIdentityFor(serverName) if err != nil { return nil, err } - return &keys, nil + keys.Raw, err = gomatrixserverlib.SignJSON( + string(identity.ServerName), identity.KeyID, identity.PrivateKey, toSign, + ) + return &keys, err } func NotaryKeys( From d558da1c875704b101dbb26d56e08cd0f6369d35 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 09:34:09 +0000 Subject: [PATCH 023/124] Virtual host server name workaround --- federationapi/routing/keys.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 693f5b9fe..8194c9905 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -16,6 +16,7 @@ package routing import ( "encoding/json" + "net" "net/http" "time" @@ -190,7 +191,16 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam identity, err := cfg.Matrix.SigningIdentityFor(serverName) if err != nil { - return nil, err + // TODO: This is a bit of a hack because the Host header can contain a port + // number if it's specified in the well-known file. Try getting a signing + // identity without it to see if that helps. + var h string + if h, _, err = net.SplitHostPort(string(serverName)); err == nil { + identity, err = cfg.Matrix.SigningIdentityFor(gomatrixserverlib.ServerName(h)) + } + if err != nil { + return nil, err + } } keys.Raw, err = gomatrixserverlib.SignJSON( From a2f72dd966fb53b265ed0df0d792027af704a08b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 09:39:19 +0000 Subject: [PATCH 024/124] Fix slice out of bounds in federation API --- federationapi/storage/shared/storage.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 7d6e8a684..f86ae2e0d 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -125,7 +125,9 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, if excludeSelf { for i, server := range servers { if d.IsLocalServerName(server) { - servers = append(servers[:i], servers[i+1:]...) + copy(servers[:i], servers[i+1:]) + servers = servers[:len(servers)-1] + break } } } From 1e714bc3b641c09d958e78060ff1a7b90c79b500 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 10:05:59 +0000 Subject: [PATCH 025/124] Update to NATS Server 2.9.6 and nats.go 1.20.0 --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 1889bf891..97dcdb293 100644 --- a/go.mod +++ b/go.mod @@ -26,8 +26,8 @@ require ( github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 - github.com/nats-io/nats-server/v2 v2.9.4 - github.com/nats-io/nats.go v1.19.0 + github.com/nats-io/nats-server/v2 v2.9.6 + github.com/nats-io/nats.go v1.20.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 diff --git a/go.sum b/go.sum index 321f00d4a..25331b2f6 100644 --- a/go.sum +++ b/go.sum @@ -385,10 +385,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.4 h1:GvRgv1936J/zYUwMg/cqtYaJ6L+bgeIOIvPslbesdow= -github.com/nats-io/nats-server/v2 v2.9.4/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= -github.com/nats-io/nats.go v1.19.0 h1:H6j8aBnTQFoVrTGB6Xjd903UMdE7jz6DS4YkmAqgZ9Q= -github.com/nats-io/nats.go v1.19.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= +github.com/nats-io/nats-server/v2 v2.9.6 h1:RTtK+rv/4CcliOuqGsy58g7MuWkBaWmF5TUNwuUo9Uw= +github.com/nats-io/nats-server/v2 v2.9.6/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= +github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM= +github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= From a916b041b16e3369c4784b648343c241f4fa6494 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 10:28:22 +0000 Subject: [PATCH 026/124] Detect consumer being deleted in `JetStreamConsumer` --- setup/jetstream/helpers.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 1ec860b04..c1ce9583f 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -2,6 +2,7 @@ package jetstream import ( "context" + "errors" "fmt" "github.com/getsentry/sentry-go" @@ -72,6 +73,9 @@ func JetStreamConsumer( // just timed out and we should try again. continue } + } else if errors.Is(err, nats.ErrConsumerDeleted) { + // The consumer was deleted so stop. + return } else { // Something else went wrong, so we'll panic. sentry.CaptureException(err) From 163dabc498a2b45e9b3f8fee6e24114247192bfa Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 15:10:33 +0000 Subject: [PATCH 027/124] Fix bug in a2f72dd9 --- federationapi/storage/shared/storage.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index f86ae2e0d..1e1ea9e17 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -125,7 +125,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, if excludeSelf { for i, server := range servers { if d.IsLocalServerName(server) { - copy(servers[:i], servers[i+1:]) + copy(servers[i:], servers[i+1:]) servers = servers[:len(servers)-1] break } From df76a172344facfa2d03a910fc4d5b1a7b02dd20 Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 16 Nov 2022 22:02:25 +0000 Subject: [PATCH 028/124] Add test code coverage reporting (#2871) --- .github/workflows/schedules.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index c07917248..e304c2a59 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -53,12 +53,14 @@ jobs: key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go${{ matrix.go }}-test-race- - - run: go test -race ./... + - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt env: POSTGRES_HOST: localhost POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres POSTGRES_DB: dendrite + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 # Dummy step to gate other tests on without repeating the whole list initial-tests-done: From 607819f42507d6a3b18ef7c44f98ed8f862a7f78 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Nov 2022 09:26:56 +0000 Subject: [PATCH 029/124] Fix `/key/v2/server`, add HTTP `Host` matching --- federationapi/routing/keys.go | 31 ++++++++++++++----------------- setup/config/config_global.go | 5 +++++ 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 8194c9905..b2ef1dba4 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -16,7 +16,6 @@ package routing import ( "encoding/json" - "net" "net/http" "time" @@ -146,14 +145,26 @@ func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys var virtualHost *config.VirtualHost +loop: for _, v := range cfg.Matrix.VirtualHosts { if v.ServerName == serverName { virtualHost = v - break + break loop + } + for _, httpHost := range v.MatchHTTPHosts { + if httpHost == serverName { + virtualHost = v + break loop + } } } - if virtualHost == nil { + identity, err := cfg.Matrix.SigningIdentityFor(serverName) + if err != nil { + identity, _ = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName) + } + + if identity.ServerName == serverName { publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = cfg.Matrix.ServerName keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) @@ -189,20 +200,6 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam return nil, err } - identity, err := cfg.Matrix.SigningIdentityFor(serverName) - if err != nil { - // TODO: This is a bit of a hack because the Host header can contain a port - // number if it's specified in the well-known file. Try getting a signing - // identity without it to see if that helps. - var h string - if h, _, err = net.SplitHostPort(string(serverName)); err == nil { - identity, err = cfg.Matrix.SigningIdentityFor(gomatrixserverlib.ServerName(h)) - } - if err != nil { - return nil, err - } - } - keys.Raw, err = gomatrixserverlib.SignJSON( string(identity.ServerName), identity.KeyID, identity.PrivateKey, toSign, ) diff --git a/setup/config/config_global.go b/setup/config/config_global.go index f2fdd021e..722230d9a 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -195,6 +195,11 @@ type VirtualHost struct { // Defaults to 24 hours. KeyValidityPeriod time.Duration `yaml:"key_validity_period"` + // Match these HTTP Host headers on the `/key/v2/server` endpoint, this needs + // to match all delegated names, likely including the port number too if + // the well-known delegation includes that also. + MatchHTTPHosts []gomatrixserverlib.ServerName `yaml:"match_http_hosts"` + // Is registration enabled on this virtual host? AllowRegistration bool `json:"allow_registration"` } From 16325203af267685e2d664524a3e2976ca1a4bbf Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Nov 2022 09:32:19 +0000 Subject: [PATCH 030/124] Try that again --- federationapi/routing/keys.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index b2ef1dba4..ee25ffbb2 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -159,12 +159,12 @@ loop: } } - identity, err := cfg.Matrix.SigningIdentityFor(serverName) - if err != nil { - identity, _ = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName) - } - - if identity.ServerName == serverName { + var identity *gomatrixserverlib.SigningIdentity + var err error + if virtualHost == nil { + if identity, err = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName); err != nil { + return nil, err + } publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = cfg.Matrix.ServerName keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) @@ -183,6 +183,9 @@ loop: } } } else { + if identity, err = cfg.Matrix.SigningIdentityFor(virtualHost.ServerName); err != nil { + return nil, err + } publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = virtualHost.ServerName keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod)) From ffd8e21ce52bbf542e0e2ed74f032af6a163c56c Mon Sep 17 00:00:00 2001 From: devonh Date: Thu, 17 Nov 2022 15:30:23 +0000 Subject: [PATCH 031/124] Fix nightly code coverage (#2881) --- .github/workflows/schedules.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index e304c2a59..ff4d47187 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -45,6 +45,11 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} + - name: Set up gotestfmt + uses: gotesttools/gotestfmt-action@v2 + with: + # Optional: pass GITHUB_TOKEN to avoid rate limiting. + token: ${{ secrets.GITHUB_TOKEN }} - uses: actions/cache@v3 with: path: | From a8e7ffc7ab147ebced766da8e0e1ebb1d75f846a Mon Sep 17 00:00:00 2001 From: devonh Date: Fri, 18 Nov 2022 00:29:23 +0000 Subject: [PATCH 032/124] Add p2p wakeup broadcast handling to pinecone demos (#2841) Adds wakeup broadcast handling to the pinecone demos. This will reset their blacklist status and interrupt any ongoing federation queue backoffs currently in progress for this peer. The end result is that any queued events will quickly be sent to the peer if they had disconnected while attempting to send events to them. --- build/gobind-pinecone/monolith.go | 35 +++++++++++++++++++++++++ cmd/dendrite-demo-pinecone/main.go | 35 +++++++++++++++++++++++++ federationapi/api/api.go | 13 +++++++++ federationapi/internal/perform.go | 16 ++++++++++- federationapi/inthttp/client.go | 13 +++++++++ federationapi/inthttp/server.go | 5 ++++ federationapi/queue/destinationqueue.go | 31 ++++++++++++++++++---- federationapi/queue/queue.go | 16 ++++++++--- go.mod | 4 ++- go.sum | 16 +++++++++-- 10 files changed, 172 insertions(+), 12 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index adb4e40a6..9100ebf0f 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -40,6 +40,7 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" @@ -58,6 +59,7 @@ import ( pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" + pineconeEvents "github.com/matrix-org/pinecone/router/events" pineconeSessions "github.com/matrix-org/pinecone/sessions" "github.com/matrix-org/pinecone/types" @@ -295,7 +297,12 @@ func (m *DendriteMonolith) Start() { m.logger.SetOutput(BindLogger{}) logrus.SetOutput(BindLogger{}) + pineconeEventChannel := make(chan pineconeEvents.Event) m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) + m.PineconeRouter.EnableHopLimiting() + m.PineconeRouter.EnableWakeupBroadcasts() + m.PineconeRouter.Subscribe(pineconeEventChannel) + m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil) @@ -423,6 +430,34 @@ func (m *DendriteMonolith) Start() { m.logger.Fatal(err) } }() + + go func(ch <-chan pineconeEvents.Event) { + eLog := logrus.WithField("pinecone", "events") + + for event := range ch { + switch e := event.(type) { + case pineconeEvents.PeerAdded: + case pineconeEvents.PeerRemoved: + case pineconeEvents.TreeParentUpdate: + case pineconeEvents.SnakeDescUpdate: + case pineconeEvents.TreeRootAnnUpdate: + case pineconeEvents.SnakeEntryAdded: + case pineconeEvents.SnakeEntryRemoved: + case pineconeEvents.BroadcastReceived: + eLog.Info("Broadcast received from: ", e.PeerID) + + req := &api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + } + res := &api.PerformWakeupServersResponse{} + if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + } + case pineconeEvents.BandwidthReport: + default: + } + } + }(pineconeEventChannel) } func (m *DendriteMonolith) Stop() { diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 6c719a1ee..421b17d56 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -37,6 +37,7 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" @@ -51,6 +52,7 @@ import ( pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" + pineconeEvents "github.com/matrix-org/pinecone/router/events" pineconeSessions "github.com/matrix-org/pinecone/sessions" "github.com/sirupsen/logrus" @@ -155,7 +157,12 @@ func main() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck + pineconeEventChannel := make(chan pineconeEvents.Event) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) + pRouter.EnableHopLimiting() + pRouter.EnableWakeupBroadcasts() + pRouter.Subscribe(pineconeEventChannel) + pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) pManager := pineconeConnections.NewConnectionManager(pRouter, nil) @@ -293,5 +300,33 @@ func main() { logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) }() + go func(ch <-chan pineconeEvents.Event) { + eLog := logrus.WithField("pinecone", "events") + + for event := range ch { + switch e := event.(type) { + case pineconeEvents.PeerAdded: + case pineconeEvents.PeerRemoved: + case pineconeEvents.TreeParentUpdate: + case pineconeEvents.SnakeDescUpdate: + case pineconeEvents.TreeRootAnnUpdate: + case pineconeEvents.SnakeEntryAdded: + case pineconeEvents.SnakeEntryRemoved: + case pineconeEvents.BroadcastReceived: + eLog.Info("Broadcast received from: ", e.PeerID) + + req := &api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + } + res := &api.PerformWakeupServersResponse{} + if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + } + case pineconeEvents.BandwidthReport: + default: + } + } + }(pineconeEventChannel) + base.WaitForShutdown() } diff --git a/federationapi/api/api.go b/federationapi/api/api.go index f4be53b9a..50d0339e4 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -30,6 +30,12 @@ type FederationInternalAPI interface { request *PerformBroadcastEDURequest, response *PerformBroadcastEDUResponse, ) error + + PerformWakeupServers( + ctx context.Context, + request *PerformWakeupServersRequest, + response *PerformWakeupServersResponse, + ) error } type ClientFederationAPI interface { @@ -214,6 +220,13 @@ type PerformBroadcastEDURequest struct { type PerformBroadcastEDUResponse struct { } +type PerformWakeupServersRequest struct { + ServerNames []gomatrixserverlib.ServerName `json:"server_names"` +} + +type PerformWakeupServersResponse struct { +} + type InputPublicKeysRequest struct { Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"` } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 603e2a9c9..d86d07e03 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -670,9 +670,23 @@ func (r *FederationInternalAPI) PerformBroadcastEDU( return nil } +// PerformWakeupServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) PerformWakeupServers( + ctx context.Context, + request *api.PerformWakeupServersRequest, + response *api.PerformWakeupServersResponse, +) (err error) { + r.MarkServersAlive(request.ServerNames) + return nil +} + func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { for _, srv := range destinations { - _ = r.db.RemoveServerFromBlacklist(srv) + // Check the statistics cache for the blacklist status to prevent hitting + // the database unnecessarily. + if r.queues.IsServerBlacklisted(srv) { + _ = r.db.RemoveServerFromBlacklist(srv) + } r.queues.RetryServer(srv) } } diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 6c37a1f5a..6eefdc7cd 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -23,6 +23,7 @@ const ( FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest" FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" + FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -150,6 +151,18 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU( ) } +// Handle an instruction to remove the respective servers from being blacklisted. +func (h *httpFederationInternalAPI) PerformWakeupServers( + ctx context.Context, + request *api.PerformWakeupServersRequest, + response *api.PerformWakeupServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformWakeupServers", h.federationAPIURL+FederationAPIPerformWakeupServers, + h.httpClient, ctx, request, response, + ) +} + type getUserDevices struct { S gomatrixserverlib.ServerName Origin gomatrixserverlib.ServerName diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 7b3edb2a7..21a070392 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -43,6 +43,11 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), ) + internalAPIMux.Handle( + FederationAPIPerformWakeupServers, + httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers), + ) + internalAPIMux.Handle( FederationAPIPerformJoinRequestPath, httputil.MakeInternalRPCAPI( diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 63fc59c89..a4a87fe99 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -141,23 +141,44 @@ func (oq *destinationQueue) handleBackoffNotifier() { } } +// wakeQueueIfEventsPending calls wakeQueueAndNotify only if there are +// pending events or if forceWakeup is true. This prevents starting the +// queue unnecessarily. +func (oq *destinationQueue) wakeQueueIfEventsPending(forceWakeup bool) { + eventsPending := func() bool { + oq.pendingMutex.Lock() + defer oq.pendingMutex.Unlock() + return len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 + } + + // NOTE : Only wakeup and notify the queue if there are pending events + // or if forceWakeup is true. Otherwise there is no reason to start the + // queue goroutine and waste resources. + if forceWakeup || eventsPending() { + logrus.Info("Starting queue due to pending events or forceWakeup") + oq.wakeQueueAndNotify() + } +} + // wakeQueueAndNotify ensures the destination queue is running and notifies it // that there is pending work. func (oq *destinationQueue) wakeQueueAndNotify() { - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() + // NOTE : Send notification before waking queue to prevent a race + // where the queue was running and stops due to a timeout in between + // checking it and sending the notification. // Notify the queue that there are events ready to send. select { case oq.notify <- struct{}{}: default: } + + // Wake up the queue if it's asleep. + oq.wakeQueueIfNeeded() } // wakeQueueIfNeeded will wake up the destination queue if it is -// not already running. If it is running but it is backing off -// then we will interrupt the backoff, causing any federation -// requests to retry. +// not already running. func (oq *destinationQueue) wakeQueueIfNeeded() { // Clear the backingOff flag and update the backoff metrics if it was set. if oq.backingOff.CompareAndSwap(true, false) { diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 31124e0c1..75b1b36be 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -374,14 +374,24 @@ func (oqs *OutgoingQueues) SendEDU( return nil } +// IsServerBlacklisted returns whether or not the provided server is currently +// blacklisted. +func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool { + return oqs.statistics.ForServer(srv).Blacklisted() +} + // RetryServer attempts to resend events to the given server if we had given up. func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { if oqs.disabled { return } - oqs.statistics.ForServer(srv).RemoveBlacklist() + + serverStatistics := oqs.statistics.ForServer(srv) + forceWakeup := serverStatistics.Blacklisted() + serverStatistics.RemoveBlacklist() + serverStatistics.ClearBackoff() + if queue := oqs.getQueue(srv); queue != nil { - queue.statistics.ClearBackoff() - queue.wakeQueueIfNeeded() + queue.wakeQueueIfEventsPending(forceWakeup) } } diff --git a/go.mod b/go.mod index 97dcdb293..adf319127 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 - github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 + github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 github.com/nats-io/nats-server/v2 v2.9.6 @@ -109,6 +109,8 @@ require ( github.com/nats-io/jwt/v2 v2.3.0 // indirect github.com/nats-io/nkeys v0.3.0 // indirect github.com/nats-io/nuid v1.0.1 // indirect + github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/ginkgo/v2 v2.3.0 // indirect github.com/onsi/gomega v1.22.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect diff --git a/go.sum b/go.sum index 25331b2f6..73741136d 100644 --- a/go.sum +++ b/go.sum @@ -335,8 +335,12 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= +github.com/lucas-clemente/quic-go v0.29.2 h1:O8Mt0O6LpvEW+wfC40vZdcw0DngwYzoxq5xULZNzSI8= +github.com/lucas-clemente/quic-go v0.29.2/go.mod h1:g6/h9YMmLuU54tL1gW25uIi3VlBp3uv+sBihplIuskE= github.com/lucas-clemente/quic-go v0.30.0 h1:nwLW0h8ahVQ5EPTIM7uhl/stHqQDea15oRlYKZmw2O0= github.com/lucas-clemente/quic-go v0.30.0/go.mod h1:ssOrRsOmdxa768Wr78vnh2B8JozgLsMzG/g+0qEC7uk= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= @@ -350,8 +354,10 @@ github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 h1:VapUpY3oSbEGhfSpnnTKh7bz6AA5R4tVB5lwdlcA6jE= github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= -github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= +github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6 h1:nAT5w41Q9uWTSnpKW55/hBwP91j2IFYPDRs0jJ8TyFI= +github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= +github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d h1:RRJg8TP3bYYtVyFIMv/ebJR+ZTWv7Xtci5zk1D3GD10= +github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= @@ -402,6 +408,12 @@ github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79/go.mod h1:E26fwEtRNigB github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.3.0 h1:kUMoxMoQG3ogk/QWyKh3zibV7BKZ+xBpWil1cTylVqc= github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= From 8299da590542a982437ad9dd30115d23c3d9d075 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 18 Nov 2022 13:24:02 +0000 Subject: [PATCH 033/124] Fix registration for virtual hosting --- clientapi/auth/login_test.go | 9 +++- clientapi/auth/user_interactive_test.go | 4 +- clientapi/routing/register.go | 52 ++++++++++++++---- clientapi/userutil/userutil_test.go | 16 ++++-- federationapi/routing/keys.go | 17 +----- go.mod | 4 +- go.sum | 16 +----- setup/config/config.go | 2 +- setup/config/config_global.go | 72 ++++++++++++++----------- userapi/userapi_test.go | 4 +- 10 files changed, 113 insertions(+), 83 deletions(-) diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index 5085f0170..b79c573aa 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -66,7 +67,9 @@ func TestLoginFromJSONReader(t *testing.T) { var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) @@ -144,7 +147,9 @@ func TestBadLoginFromJSONReader(t *testing.T) { var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 001b1a6d4..5d97b31ce 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -47,7 +47,9 @@ func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *a func setup() *UserInteractive { cfg := &config.ClientAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } return NewUserInteractive(&fakeAccountDatabase{}, cfg) diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index a92513b8b..801000f61 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -551,6 +551,12 @@ func Register( } var r registerRequest + host := gomatrixserverlib.ServerName(req.Host) + if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { + r.ServerName = v.ServerName + } else { + r.ServerName = cfg.Matrix.ServerName + } sessionID := gjson.GetBytes(reqBody, "auth.session").String() if sessionID == "" { // Generate a new, random session ID @@ -560,6 +566,7 @@ func Register( // Some of these might end up being overwritten if the // values are specified again in the request body. r.Username = data.Username + r.ServerName = data.ServerName r.Password = data.Password r.DeviceID = data.DeviceID r.InitialDisplayName = data.InitialDisplayName @@ -575,7 +582,6 @@ func Register( if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } - r.ServerName = cfg.Matrix.ServerName if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { r.Username, r.ServerName = l, d } @@ -650,16 +656,25 @@ func handleGuestRegistration( cfg *config.ClientAPI, userAPI userapi.ClientUserAPI, ) util.JSONResponse { - if cfg.RegistrationDisabled || cfg.GuestsDisabled { + registrationEnabled := !cfg.RegistrationDisabled + guestsEnabled := !cfg.GuestsDisabled + if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil { + registrationEnabled, guestsEnabled = v.RegistrationAllowed() + } + + if !registrationEnabled || !guestsEnabled { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Guest registration is disabled"), + JSON: jsonerror.Forbidden( + fmt.Sprintf("Guest registration is disabled on %q", r.ServerName), + ), } } var res userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeGuest, + ServerName: r.ServerName, }, &res) if err != nil { return util.JSONResponse{ @@ -736,10 +751,16 @@ func handleRegistrationFlow( ) } - if cfg.RegistrationDisabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { + registrationEnabled := !cfg.RegistrationDisabled + if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil { + registrationEnabled, _ = v.RegistrationAllowed() + } + if !registrationEnabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Registration is disabled"), + JSON: jsonerror.Forbidden( + fmt.Sprintf("Registration is disabled on %q", r.ServerName), + ), } } @@ -827,8 +848,9 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, + req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr, + req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + userapi.AccountTypeAppService, ) } @@ -846,8 +868,9 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, + req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr, + req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + userapi.AccountTypeUser, ) } sessions.addParams(sessionID, r) @@ -869,7 +892,8 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username, password, appserviceID, ipAddr, userAgent, sessionID string, + username string, serverName gomatrixserverlib.ServerName, + password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, @@ -891,6 +915,7 @@ func completeRegistration( err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ AppServiceID: appserviceID, Localpart: username, + ServerName: serverName, Password: password, AccountType: accType, OnConflict: userapi.ConflictAbort, @@ -934,6 +959,7 @@ func completeRegistration( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: username, + ServerName: serverName, AccessToken: token, DeviceDisplayName: displayName, DeviceID: deviceID, @@ -1028,6 +1054,10 @@ func RegisterAvailable( // Squash username to all lowercase letters username = strings.ToLower(username) domain := cfg.Matrix.ServerName + host := gomatrixserverlib.ServerName(req.Host) + if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { + domain = v.ServerName + } if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil { username, domain = u, l } @@ -1117,5 +1147,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/userutil/userutil_test.go b/clientapi/userutil/userutil_test.go index ccd6647b2..ee6bf8a01 100644 --- a/clientapi/userutil/userutil_test.go +++ b/clientapi/userutil/userutil_test.go @@ -30,7 +30,9 @@ var ( // TestGoodUserID checks that correct localpart is returned for a valid user ID. func TestGoodUserID(t *testing.T) { cfg := &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, } lp, _, err := ParseUsernameParam(goodUserID, cfg) @@ -47,7 +49,9 @@ func TestGoodUserID(t *testing.T) { // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. func TestWithLocalpartOnly(t *testing.T) { cfg := &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, } lp, _, err := ParseUsernameParam(localpart, cfg) @@ -64,7 +68,9 @@ func TestWithLocalpartOnly(t *testing.T) { // TestIncorrectDomain checks for error when there's server name mismatch. func TestIncorrectDomain(t *testing.T) { cfg := &config.Global{ - ServerName: invalidServerName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: invalidServerName, + }, } _, _, err := ParseUsernameParam(goodUserID, cfg) @@ -77,7 +83,9 @@ func TestIncorrectDomain(t *testing.T) { // TestBadUserID checks that ParseUsernameParam fails for invalid user ID func TestBadUserID(t *testing.T) { cfg := &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, } _, _, err := ParseUsernameParam(badUserID, cfg) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index ee25ffbb2..dc262cfde 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -144,24 +144,9 @@ func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys - var virtualHost *config.VirtualHost -loop: - for _, v := range cfg.Matrix.VirtualHosts { - if v.ServerName == serverName { - virtualHost = v - break loop - } - for _, httpHost := range v.MatchHTTPHosts { - if httpHost == serverName { - virtualHost = v - break loop - } - } - } - var identity *gomatrixserverlib.SigningIdentity var err error - if virtualHost == nil { + if virtualHost := cfg.Matrix.VirtualHostForHTTPHost(serverName); virtualHost == nil { if identity, err = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName); err != nil { return nil, err } diff --git a/go.mod b/go.mod index adf319127..7961a4b08 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 @@ -109,8 +109,6 @@ require ( github.com/nats-io/jwt/v2 v2.3.0 // indirect github.com/nats-io/nkeys v0.3.0 // indirect github.com/nats-io/nuid v1.0.1 // indirect - github.com/nxadm/tail v1.4.8 // indirect - github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/ginkgo/v2 v2.3.0 // indirect github.com/onsi/gomega v1.22.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect diff --git a/go.sum b/go.sum index 73741136d..c5184ff06 100644 --- a/go.sum +++ b/go.sum @@ -335,12 +335,8 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= -github.com/lucas-clemente/quic-go v0.29.2 h1:O8Mt0O6LpvEW+wfC40vZdcw0DngwYzoxq5xULZNzSI8= -github.com/lucas-clemente/quic-go v0.29.2/go.mod h1:g6/h9YMmLuU54tL1gW25uIi3VlBp3uv+sBihplIuskE= github.com/lucas-clemente/quic-go v0.30.0 h1:nwLW0h8ahVQ5EPTIM7uhl/stHqQDea15oRlYKZmw2O0= github.com/lucas-clemente/quic-go v0.30.0/go.mod h1:ssOrRsOmdxa768Wr78vnh2B8JozgLsMzG/g+0qEC7uk= -github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= @@ -352,10 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 h1:VapUpY3oSbEGhfSpnnTKh7bz6AA5R4tVB5lwdlcA6jE= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6 h1:nAT5w41Q9uWTSnpKW55/hBwP91j2IFYPDRs0jJ8TyFI= -github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 h1:S2TNN7C00CZlE1Af31LzxkOsAEkFt0RYZ7/3VdR1D5U= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d h1:RRJg8TP3bYYtVyFIMv/ebJR+ZTWv7Xtci5zk1D3GD10= github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= @@ -408,12 +402,6 @@ github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79/go.mod h1:E26fwEtRNigB github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= -github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= -github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= -github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.3.0 h1:kUMoxMoQG3ogk/QWyKh3zibV7BKZ+xBpWil1cTylVqc= github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= diff --git a/setup/config/config.go b/setup/config/config.go index 918bcbe3b..7e7ed1aa1 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -235,7 +235,7 @@ func loadConfig( if v.KeyValidityPeriod == 0 { v.KeyValidityPeriod = c.Global.KeyValidityPeriod } - if v.PrivateKeyPath == "" { + if v.PrivateKeyPath == "" || v.PrivateKey == nil || v.KeyID == "" { v.KeyID = c.Global.KeyID v.PrivateKey = c.Global.PrivateKey continue diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 722230d9a..801c68450 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -12,8 +12,9 @@ import ( ) type Global struct { - // The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. - ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + // Signing identity contains the server name, private key and key ID of + // the deployment. + gomatrixserverlib.SigningIdentity `yaml:",inline"` // The secondary server names, used for virtual hosting. VirtualHosts []*VirtualHost `yaml:"virtual_hosts"` @@ -21,13 +22,6 @@ type Global struct { // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` - // The private key which will be used to sign requests and events. - PrivateKey ed25519.PrivateKey `yaml:"-"` - - // An arbitrary string used to uniquely identify the PrivateKey. Must start with the - // prefix "ed25519:". - KeyID gomatrixserverlib.KeyID `yaml:"-"` - // Information about old private keys that used to be used to sign requests and // events on this domain. They will not be used but will be advertised to other // servers that ask for them to help verify old events. @@ -151,6 +145,29 @@ func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib. return u, s, nil } +func (c *Global) VirtualHost(serverName gomatrixserverlib.ServerName) *VirtualHost { + for _, v := range c.VirtualHosts { + if v.ServerName == serverName { + return v + } + } + return nil +} + +func (c *Global) VirtualHostForHTTPHost(serverName gomatrixserverlib.ServerName) *VirtualHost { + for _, v := range c.VirtualHosts { + if v.ServerName == serverName { + return v + } + for _, h := range v.MatchHTTPHosts { + if h == serverName { + return v + } + } + } + return nil +} + func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.SigningIdentity, error) { for _, id := range c.SigningIdentities() { if id.ServerName == serverName { @@ -162,32 +179,22 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { identities := make([]*gomatrixserverlib.SigningIdentity, 0, len(c.VirtualHosts)+1) - identities = append(identities, &gomatrixserverlib.SigningIdentity{ - ServerName: c.ServerName, - KeyID: c.KeyID, - PrivateKey: c.PrivateKey, - }) + identities = append(identities, &c.SigningIdentity) for _, v := range c.VirtualHosts { - identities = append(identities, v.SigningIdentity()) + identities = append(identities, &v.SigningIdentity) } return identities } type VirtualHost struct { - // The server name of the virtual host. - ServerName gomatrixserverlib.ServerName `yaml:"server_name"` - - // The key ID of the private key. If not specified, the default global key ID - // will be used instead. - KeyID gomatrixserverlib.KeyID `yaml:"key_id"` + // Signing identity contains the server name, private key and key ID of + // the virtual host. + gomatrixserverlib.SigningIdentity `yaml:",inline"` // Path to the private key. If not specified, the default global private key // will be used instead. PrivateKeyPath Path `yaml:"private_key"` - // The private key itself. - PrivateKey ed25519.PrivateKey `yaml:"-"` - // How long a remote server can cache our server key for before requesting it again. // Increasing this number will reduce the number of requests made by remote servers // for our key, but increases the period a compromised key will be considered valid @@ -201,19 +208,24 @@ type VirtualHost struct { MatchHTTPHosts []gomatrixserverlib.ServerName `yaml:"match_http_hosts"` // Is registration enabled on this virtual host? - AllowRegistration bool `json:"allow_registration"` + AllowRegistration bool `yaml:"allow_registration"` + + // Is guest registration enabled on this virtual host? + AllowGuests bool `yaml:"allow_guests"` } func (v *VirtualHost) Verify(configErrs *ConfigErrors) { checkNotEmpty(configErrs, "virtual_host.*.server_name", string(v.ServerName)) } -func (v *VirtualHost) SigningIdentity() *gomatrixserverlib.SigningIdentity { - return &gomatrixserverlib.SigningIdentity{ - ServerName: v.ServerName, - KeyID: v.KeyID, - PrivateKey: v.PrivateKey, +// RegistrationAllowed returns two bools, the first states whether registration +// is allowed for this virtual host and the second states whether guests are +// allowed for this virtual host. +func (v *VirtualHost) RegistrationAllowed() (bool, bool) { + if v == nil { + return false, false } + return v.AllowRegistration, v.AllowGuests } type OldVerifyKeys struct { diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 23acfa6f3..25fa75ee2 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -61,7 +61,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap cfg := &config.UserAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } From 7ad87eace3b88b3aaa88dc8a0acec9b64ffcc52c Mon Sep 17 00:00:00 2001 From: devonh Date: Fri, 18 Nov 2022 19:37:13 +0000 Subject: [PATCH 034/124] Update pinecone version (#2884) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 7961a4b08..7b3804a6e 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 - github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d + github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 github.com/nats-io/nats-server/v2 v2.9.6 diff --git a/go.sum b/go.sum index c5184ff06..14a8b0e88 100644 --- a/go.sum +++ b/go.sum @@ -350,8 +350,8 @@ github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 h1:S2TNN7C00CZlE1Af31LzxkOsAEkFt0RYZ7/3VdR1D5U= github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d h1:RRJg8TP3bYYtVyFIMv/ebJR+ZTWv7Xtci5zk1D3GD10= -github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= +github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= +github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= From 31f56ac3f43b692a56718fd1acfcb5a46e485c57 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Nov 2022 21:38:27 +0000 Subject: [PATCH 035/124] Never filter out a user's own membership when using LL (#2887) --- syncapi/streams/stream_pdu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 65ca8e2a3..dd7845574 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -588,7 +588,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list stateKey := *event.StateKey() - if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental { + if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID { newStateEvents = append(newStateEvents, event) if !stateFilter.IncludeRedundantMembers { p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) From 5e4b461e0158daa76872956729b73ccdccefbf10 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 28 Nov 2022 11:26:03 +0100 Subject: [PATCH 036/124] Return empty JSON if we don't have any protocols to return (#2892) This should help with Element reporting `The homeserver may be too old to support third party networks.` --- clientapi/routing/thirdparty.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/clientapi/routing/thirdparty.go b/clientapi/routing/thirdparty.go index e757cd411..7a62da449 100644 --- a/clientapi/routing/thirdparty.go +++ b/clientapi/routing/thirdparty.go @@ -36,9 +36,15 @@ func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, dev return jsonerror.InternalServerError() } if !resp.Exists { + if protocol != "" { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The protocol is unknown."), + } + } return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The protocol is unknown."), + Code: http.StatusOK, + JSON: struct{}{}, } } if protocol != "" { From f6f1445cfaccc7c4540775d9cb7083522d7c27d2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 09:58:22 +0000 Subject: [PATCH 037/124] Tweak event auth logging and cases (update to matrix-org/gomatrixserverlib@8835f6d) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 7b3804a6e..fd662f6e5 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 14a8b0e88..f72aa739e 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 h1:S2TNN7C00CZlE1Af31LzxkOsAEkFt0RYZ7/3VdR1D5U= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From 1ed5fb5e98220c2a96f1743490d7fef821bd9114 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 10:37:57 +0000 Subject: [PATCH 038/124] Update NATS Server to 2.9.8 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index fd662f6e5..d3eb4890a 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 - github.com/nats-io/nats-server/v2 v2.9.6 + github.com/nats-io/nats-server/v2 v2.9.8 github.com/nats-io/nats.go v1.20.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index f72aa739e..ad9372c84 100644 --- a/go.sum +++ b/go.sum @@ -385,8 +385,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.6 h1:RTtK+rv/4CcliOuqGsy58g7MuWkBaWmF5TUNwuUo9Uw= -github.com/nats-io/nats-server/v2 v2.9.6/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= +github.com/nats-io/nats-server/v2 v2.9.8 h1:jgxZsv+A3Reb3MgwxaINcNq/za8xZInKhDg9Q0cGN1o= +github.com/nats-io/nats-server/v2 v2.9.8/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM= github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= From 1990c154e920a350654d0ae6e02071950e595a01 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 11:11:08 +0000 Subject: [PATCH 039/124] Update configuration --- setup/config/config_global.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 801c68450..511951fe6 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -17,7 +17,7 @@ type Global struct { gomatrixserverlib.SigningIdentity `yaml:",inline"` // The secondary server names, used for virtual hosting. - VirtualHosts []*VirtualHost `yaml:"virtual_hosts"` + VirtualHosts []*VirtualHost `yaml:"-"` // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` From f8d1dc521d401b78163840d8ff978cb2cd2718d7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 29 Nov 2022 15:46:28 +0100 Subject: [PATCH 040/124] Fix `m.receipt`s causing notifications (#2893) Fixes https://github.com/matrix-org/dendrite/issues/2353 --- syncapi/streams/stream_receipt.go | 3 +- syncapi/types/types.go | 7 +++ syncapi/types/types_test.go | 100 ++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 2 deletions(-) diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index 977815078..16a81e833 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -87,8 +87,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( } ev := gomatrixserverlib.ClientEvent{ - Type: gomatrixserverlib.MReceipt, - RoomID: roomID, + Type: gomatrixserverlib.MReceipt, } content := make(map[string]ReceiptMRead) for _, receipt := range receipts { diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 295187acc..9fbadc06c 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -480,6 +480,13 @@ func (jr JoinResponse) MarshalJSON() ([]byte, error) { if jr.Ephemeral != nil && len(jr.Ephemeral.Events) == 0 { a.Ephemeral = nil } + if jr.Ephemeral != nil { + // Remove the room_id from EDUs, as this seems to cause Element Web + // to trigger notifications - https://github.com/vector-im/element-web/issues/17263 + for i := range jr.Ephemeral.Events { + jr.Ephemeral.Events[i].RoomID = "" + } + } if jr.AccountData != nil && len(jr.AccountData.Events) == 0 { a.AccountData = nil } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 19fcfc150..74246d964 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -2,6 +2,7 @@ package types import ( "encoding/json" + "reflect" "testing" "github.com/matrix-org/gomatrixserverlib" @@ -63,3 +64,102 @@ func TestNewInviteResponse(t *testing.T) { t.Fatalf("Invite response didn't contain correct info") } } + +func TestJoinResponse_MarshalJSON(t *testing.T) { + type fields struct { + Summary *Summary + State *ClientEvents + Timeline *Timeline + Ephemeral *ClientEvents + AccountData *ClientEvents + UnreadNotifications *UnreadNotifications + } + tests := []struct { + name string + fields fields + want []byte + wantErr bool + }{ + { + name: "empty state is removed", + fields: fields{ + State: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty accountdata is removed", + fields: fields{ + AccountData: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty ephemeral is removed", + fields: fields{ + Ephemeral: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty timeline is removed", + fields: fields{ + Timeline: &Timeline{}, + }, + want: []byte("{}"), + }, + { + name: "empty summary is removed", + fields: fields{ + Summary: &Summary{}, + }, + want: []byte("{}"), + }, + { + name: "unread notifications are removed, if everything else is empty", + fields: fields{ + UnreadNotifications: &UnreadNotifications{}, + }, + want: []byte("{}"), + }, + { + name: "unread notifications are NOT removed, if state is set", + fields: fields{ + State: &ClientEvents{Events: []gomatrixserverlib.ClientEvent{{Content: []byte("{}")}}}, + UnreadNotifications: &UnreadNotifications{NotificationCount: 1}, + }, + want: []byte(`{"state":{"events":[{"content":{},"type":""}]},"unread_notifications":{"highlight_count":0,"notification_count":1}}`), + }, + { + name: "roomID is removed from EDUs", + fields: fields{ + Ephemeral: &ClientEvents{ + Events: []gomatrixserverlib.ClientEvent{ + {RoomID: "!someRandomRoomID:test", Content: []byte("{}")}, + }, + }, + }, + want: []byte(`{"ephemeral":{"events":[{"content":{},"type":""}]}}`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jr := JoinResponse{ + Summary: tt.fields.Summary, + State: tt.fields.State, + Timeline: tt.fields.Timeline, + Ephemeral: tt.fields.Ephemeral, + AccountData: tt.fields.AccountData, + UnreadNotifications: tt.fields.UnreadNotifications, + } + got, err := jr.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", string(got), string(tt.want)) + } + }) + } +} From ed497aa8b2d24b5d5ac00df773dab5087ddcf5ca Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 16:26:33 +0000 Subject: [PATCH 041/124] Version 0.10.8 --- .github/workflows/docker.yml | 48 ++++++++++++++++++------------------ CHANGES.md | 23 +++++++++++++++++ internal/version.go | 2 +- 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 846844173..2e17539d8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -68,18 +68,6 @@ jobs: ${{ env.DOCKER_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@master - with: - image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} - format: "sarif" - output: "trivy-results.sarif" - - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@v2 - with: - sarif_file: "trivy-results.sarif" - - name: Build release monolith image if: github.event_name == 'release' # Only for GitHub releases id: docker_build_monolith_release @@ -98,6 +86,18 @@ jobs: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ env.RELEASE_VERSION }} + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@master + with: + image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} + format: "sarif" + output: "trivy-results.sarif" + + - name: Upload Trivy scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: "trivy-results.sarif" + polylith: name: Polylith image runs-on: ubuntu-latest @@ -148,18 +148,6 @@ jobs: ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@master - with: - image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - format: "sarif" - output: "trivy-results.sarif" - - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@v2 - with: - sarif_file: "trivy-results.sarif" - - name: Build release polylith image if: github.event_name == 'release' # Only for GitHub releases id: docker_build_polylith_release @@ -178,6 +166,18 @@ jobs: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@master + with: + image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} + format: "sarif" + output: "trivy-results.sarif" + + - name: Upload Trivy scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: "trivy-results.sarif" + demo-pinecone: name: Pinecone demo image runs-on: ubuntu-latest diff --git a/CHANGES.md b/CHANGES.md index cdeb1dea3..f5a82cfe2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,28 @@ # Changelog +## Dendrite 0.10.8 (2022-11-29) + +### Features + +* The built-in NATS Server has been updated to version 2.9.8 +* A number of under-the-hood changes have been merged for future virtual hosting support in Dendrite (running multiple domain names on the same Dendrite deployment) + +### Fixes + +* Event auth handling of invites has been refactored, which should fix some edge cases being handled incorrectly +* Fix a bug when returning an empty protocol list, which could cause Element to display "The homeserver may be too old to support third party networks" when opening the public room directory +* The sync API will no longer filter out the user's own membership when using lazy-loading +* Dendrite will now correctly detect JetStream consumers being deleted, stopping the consumer goroutine as needed +* A panic in the federation API where the server list could go out of bounds has been fixed +* Blacklisted servers will now be excluded when querying joined servers, which improves CPU usage and performs less unnecessary outbound requests +* A database writer will now be used to assign state key NIDs when requesting NIDs that may not exist yet +* Dendrite will now correctly move local aliases for an upgraded room when the room is upgraded remotely +* Dendrite will now correctly move account data for an upgraded room when the room is upgraded remotely +* Missing state key NIDs will now be allocated on request rather than returning an error +* Guest access is now correctly denied on a number of endpoints +* Presence information will now be correctly sent for new private chats +* A number of unspecced fields have been removed from outbound `/send` transactions + ## Dendrite 0.10.7 (2022-11-04) ### Features diff --git a/internal/version.go b/internal/version.go index 85b19046e..685237b9e 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 7 + VersionPatch = 8 VersionTag = "" // example: "rc1" ) From ac5f3f025eaf157d2a2bb15adcd19a5732325608 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 30 Nov 2022 12:40:36 +0100 Subject: [PATCH 042/124] Calculate correct room member count for push rule evaluation (#2894) Fixes a bug where we would return only the local member count, which could result in wrongly calculated push rules. --- userapi/consumers/roomserver.go | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 5d8924dda..3ce5af621 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -385,7 +385,6 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s req := &rsapi.QueryMembershipsForRoomRequest{ RoomID: roomID, JoinedOnly: true, - LocalOnly: true, } var res rsapi.QueryMembershipsForRoomResponse @@ -396,8 +395,23 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s } var members []*localMembership - var ntotal int for _, event := range res.JoinEvents { + // Filter out invalid join events + if event.StateKey == nil { + continue + } + if *event.StateKey == "" { + continue + } + _, serverName, err := gomatrixserverlib.SplitID('@', *event.StateKey) + if err != nil { + log.WithError(err).Error("failed to get servername from statekey") + continue + } + // Only get memberships for our server + if serverName != s.serverName { + continue + } member, err := newLocalMembership(&event) if err != nil { log.WithError(err).Errorf("Parsing MemberContent") @@ -410,11 +424,10 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s continue } - ntotal++ members = append(members, member) } - return members, ntotal, nil + return members, len(res.JoinEvents), nil } // roomName returns the name in the event (if type==m.room.name), or @@ -641,7 +654,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * if rule == nil { // SPEC: If no rules match an event, the homeserver MUST NOT // notify the Push Gateway for that event. - return nil, err + return nil, nil } log.WithFields(log.Fields{ From f009e541816cbae1519fede2b40b0af68020b3ab Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 30 Nov 2022 12:54:37 +0000 Subject: [PATCH 043/124] Push rule evaluation tweaks (#2897) This tweaks push rule evaluation: 1. to be more strict around pattern matching and to not match empty patterns 3. to bail if we come across a `dont_notify`, since cycles after that are wasted 4. refactors `ActionsToTweaks` to make a bit more sense --- internal/pushrules/evaluate.go | 13 ++++++++++++ internal/pushrules/evaluate_test.go | 13 ++++++------ internal/pushrules/util.go | 31 +++++++++++++++++------------ internal/pushrules/util_test.go | 1 + 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index df22cb042..4ff9939a6 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -145,6 +145,11 @@ func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec Evalua } func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { + // It doesn't make sense for an empty pattern to match anything. + if pattern == "" { + return false, nil + } + re, err := globToRegexp(pattern) if err != nil { return false, err @@ -154,12 +159,20 @@ func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, if err = json.Unmarshal(event.JSON(), &eventMap); err != nil { return false, fmt.Errorf("parsing event: %w", err) } + // From the spec: + // "If the property specified by key is completely absent from + // the event, or does not have a string value, then the condition + // will not match, even if pattern is *." v, err := lookupMapPath(strings.Split(key, "."), eventMap) if err != nil { // An unknown path is a benign error that shouldn't stop rule // processing. It's just a non-match. return false, nil } + if _, ok := v.(string); !ok { + // A non-string never matches. + return false, nil + } return re.MatchString(fmt.Sprint(v)), nil } diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index eabd02415..c5d5abd2a 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -111,7 +111,10 @@ func TestConditionMatches(t *testing.T) { {"empty", Condition{}, `{}`, false}, {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, - {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true}, + // Neither of these should match because `content` is not a full string match, + // and `content.body` is not a string value. + {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, false}, + {"eventBodyMatch", Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3"}, `{"content":{"body": 3}}`, false}, {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, @@ -137,7 +140,7 @@ func TestConditionMatches(t *testing.T) { t.Fatalf("conditionMatches failed: %v", err) } if got != tst.Want { - t.Errorf("conditionMatches: got %v, want %v", got, tst.Want) + t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.Want, tst.Name) } }) } @@ -161,9 +164,7 @@ func TestPatternMatches(t *testing.T) { }{ {"empty", "", "", `{}`, false}, - // Note that an empty pattern contains no wildcard characters, - // which implicitly means "*". - {"patternEmpty", "content", "", `{"content":{}}`, true}, + {"patternEmpty", "content", "", `{"content":{}}`, false}, {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true}, {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true}, @@ -178,7 +179,7 @@ func TestPatternMatches(t *testing.T) { t.Fatalf("patternMatches failed: %v", err) } if got != tst.Want { - t.Errorf("patternMatches: got %v, want %v", got, tst.Want) + t.Errorf("patternMatches: got %v, want %v on %s", got, tst.Want, tst.Name) } }) } diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go index 8ab4eab94..fb9c05be2 100644 --- a/internal/pushrules/util.go +++ b/internal/pushrules/util.go @@ -11,22 +11,27 @@ import ( // kind and a tweaks map. Returns a nil map if it would have been // empty. func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) { - kind := UnknownAction - tweaks := map[string]interface{}{} + var kind ActionKind + var tweaks map[string]interface{} for _, a := range as { - if a.Kind == SetTweakAction { - tweaks[string(a.Tweak)] = a.Value - continue - } - if kind != UnknownAction { - return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) - } - kind = a.Kind - } + switch a.Kind { + case DontNotifyAction: + // Don't bother processing any further + return DontNotifyAction, nil, nil - if len(tweaks) == 0 { - tweaks = nil + case SetTweakAction: + if tweaks == nil { + tweaks = map[string]interface{}{} + } + tweaks[string(a.Tweak)] = a.Value + + default: + if kind != UnknownAction { + return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) + } + kind = a.Kind + } } return kind, tweaks, nil diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go index a951c55a2..89f8243d9 100644 --- a/internal/pushrules/util_test.go +++ b/internal/pushrules/util_test.go @@ -17,6 +17,7 @@ func TestActionsToTweaks(t *testing.T) { {"empty", nil, UnknownAction, nil}, {"zero", []*Action{{}}, UnknownAction, nil}, {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil}, + {"onlyPrimaryDontNotify", []*Action{{Kind: DontNotifyAction}}, DontNotifyAction, nil}, {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}}, {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}}, { From 6f000e980155d6651b01c9b4dbdb316e9f501a45 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 1 Dec 2022 10:14:26 +0000 Subject: [PATCH 044/124] Make `create-account` more verbose --- cmd/create-account/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index c8e239f29..15b043ed5 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -177,7 +177,7 @@ func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, a defer regResp.Body.Close() // nolint: errcheck if regResp.StatusCode < 200 || regResp.StatusCode >= 300 { body, _ = io.ReadAll(regResp.Body) - return "", fmt.Errorf(gjson.GetBytes(body, "error").Str) + return "", fmt.Errorf("got HTTP %d error from server: %s", regResp.StatusCode, string(body)) } r, err := io.ReadAll(regResp.Body) if err != nil { From 1be0afa1810ecbfb8348a81c8b49ef2a15114055 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 1 Dec 2022 10:24:17 +0000 Subject: [PATCH 045/124] Expose `/_dendrite` and `/_synapse` on the P2P demo HTTP muxes --- build/gobind-pinecone/monolith.go | 2 ++ build/gobind-yggdrasil/monolith.go | 2 ++ cmd/dendrite-demo-pinecone/main.go | 2 ++ cmd/dendrite-demo-yggdrasil/main.go | 2 ++ 4 files changed, 8 insertions(+) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 9100ebf0f..b2fb70ce4 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -382,6 +382,8 @@ func (m *DendriteMonolith) Start() { httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) httpRouter.HandleFunc("/pinecone", m.PineconeRouter.ManholeHandler) pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 248b6c324..4cbe983ec 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -196,6 +196,8 @@ func (m *DendriteMonolith) Start() { httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) yggRouter := mux.NewRouter() yggRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 421b17d56..2ceb7c17b 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -270,6 +270,8 @@ func main() { pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + pMux.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + pMux.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) pHTTP := pQUIC.Protocol("matrix").HTTP() pHTTP.Mux().Handle(users.PublicURL, pMux) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 1226496c3..e1a04abf5 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -198,6 +198,8 @@ func main() { httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) embed.Embed(httpRouter, *instancePort, "Yggdrasil Demo") yggRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() From 934056f21fb91b59c9a9c4413318a9387abf658e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 1 Dec 2022 10:45:15 +0000 Subject: [PATCH 046/124] Fix `dendrite-demo-pinecone`, `/_dendrite` namespace setup --- build/gobind-pinecone/monolith.go | 1 + build/gobind-yggdrasil/monolith.go | 1 + cmd/dendrite-demo-pinecone/main.go | 5 +++-- cmd/dendrite-demo-yggdrasil/main.go | 1 + setup/base/base.go | 34 ++++++++++++++++------------- 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index b2fb70ce4..e8ed8fe85 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -336,6 +336,7 @@ func (m *DendriteMonolith) Start() { } base := base.NewBaseDendrite(cfg, "Monolith") + base.ConfigureAdminEndpoints() defer base.Close() // nolint: errcheck federation := conn.CreateFederationClient(base, m.PineconeQUIC) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 4cbe983ec..9a3ac5d7b 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -150,6 +150,7 @@ func (m *DendriteMonolith) Start() { } base := base.NewBaseDendrite(cfg, "Monolith") + base.ConfigureAdminEndpoints() m.processContext = base.ProcessContext defer base.Close() // nolint: errcheck diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 2ceb7c17b..2f647a41b 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -155,6 +155,7 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) base := base.NewBaseDendrite(cfg, "Monolith") + base.ConfigureAdminEndpoints() defer base.Close() // nolint: errcheck pineconeEventChannel := make(chan pineconeEvents.Event) @@ -248,6 +249,8 @@ func main() { httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) httpRouter.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { c, err := wsUpgrader.Upgrade(w, r, nil) if err != nil { @@ -270,8 +273,6 @@ func main() { pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - pMux.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) - pMux.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) pHTTP := pQUIC.Protocol("matrix").HTTP() pHTTP.Mux().Handle(users.PublicURL, pMux) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index e1a04abf5..5dd61b1b7 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -144,6 +144,7 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) base := base.NewBaseDendrite(cfg, "Monolith") + base.ConfigureAdminEndpoints() defer base.Close() // nolint: errcheck ygg, err := yggconn.Setup(sk, *instanceName, ".", *instancePeer, *instanceListen) diff --git a/setup/base/base.go b/setup/base/base.go index 14edadd96..d3adbf53f 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -413,6 +413,24 @@ func (b *BaseDendrite) configureHTTPErrors() { b.PublicClientAPIMux.MethodNotAllowedHandler = http.HandlerFunc(clientNotFoundHandler) } +func (b *BaseDendrite) ConfigureAdminEndpoints() { + b.DendriteAdminMux.HandleFunc("/monitor/up", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + }) + b.DendriteAdminMux.HandleFunc("/monitor/health", func(w http.ResponseWriter, r *http.Request) { + if isDegraded, reasons := b.ProcessContext.IsDegraded(); isDegraded { + w.WriteHeader(503) + _ = json.NewEncoder(w).Encode(struct { + Warnings []string `json:"warnings"` + }{ + Warnings: reasons, + }) + return + } + w.WriteHeader(200) + }) +} + // SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on // ApiMux under /api/ and adds a prometheus handler under /metrics. func (b *BaseDendrite) SetupAndServeHTTP( @@ -463,21 +481,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( internalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth)) } - b.DendriteAdminMux.HandleFunc("/monitor/up", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - }) - b.DendriteAdminMux.HandleFunc("/monitor/health", func(w http.ResponseWriter, r *http.Request) { - if isDegraded, reasons := b.ProcessContext.IsDegraded(); isDegraded { - w.WriteHeader(503) - _ = json.NewEncoder(w).Encode(struct { - Warnings []string `json:"warnings"` - }{ - Warnings: reasons, - }) - return - } - w.WriteHeader(200) - }) + b.ConfigureAdminEndpoints() var clientHandler http.Handler clientHandler = b.PublicClientAPIMux From 9a46d8d95c937bdb356f9e373424e1a37aa37f04 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 2 Dec 2022 11:44:20 +0100 Subject: [PATCH 047/124] Test and CI related changes (#2896) In an attempt to: - make on-boarding a bit easier (`go test ./...` should now not need additional postgres setup) - get code coverage faster, not only scheduled at night - test the `create-account` binary --- .github/workflows/dendrite.yml | 75 ++++++++++++++++++++++++-- .github/workflows/schedules.yaml | 82 ++++------------------------- cmd/dendrite-upgrade-tests/main.go | 43 +++++++++++++++ docs/CONTRIBUTING.md | 15 +++++- internal/pushgateway/client.go | 5 +- internal/pushgateway/client_test.go | 54 +++++++++++++++++++ 6 files changed, 197 insertions(+), 77 deletions(-) create mode 100644 internal/pushgateway/client_test.go diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index fa4282384..593012ef3 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -68,7 +68,7 @@ jobs: # run go test with different go versions test: - timeout-minutes: 5 + timeout-minutes: 10 name: Unit tests (Go ${{ matrix.go }}) runs-on: ubuntu-latest # Service containers to run with `container-job` @@ -94,14 +94,22 @@ jobs: strategy: fail-fast: false matrix: - go: ["1.18", "1.19"] + go: ["1.19"] steps: - uses: actions/checkout@v3 - name: Setup go uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - cache: true + - uses: actions/cache@v3 + # manually set up caches, as they otherwise clash with different steps using setup-go with cache=true + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go${{ matrix.go }}-unit-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go${{ matrix.go }}-unit- - name: Set up gotestfmt uses: gotesttools/gotestfmt-action@v2 with: @@ -194,6 +202,66 @@ jobs: with: jobs: ${{ toJSON(needs) }} + # run go test with different go versions + integration: + timeout-minutes: 20 + needs: initial-tests-done + name: Integration tests (Go ${{ matrix.go }}) + runs-on: ubuntu-latest + # Service containers to run with `container-job` + services: + # Label used to access the service container + postgres: + # Docker Hub image + image: postgres:13-alpine + # Provide the password for postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dendrite + ports: + # Maps tcp port 5432 on service container to the host + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + strategy: + fail-fast: false + matrix: + go: ["1.19"] + steps: + - uses: actions/checkout@v3 + - name: Setup go + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Set up gotestfmt + uses: gotesttools/gotestfmt-action@v2 + with: + # Optional: pass GITHUB_TOKEN to avoid rate limiting. + token: ${{ secrets.GITHUB_TOKEN }} + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go${{ matrix.go }}-test-race- + - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt + env: + POSTGRES_HOST: localhost + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dendrite + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + flags: unittests + # run database upgrade tests upgrade_test: name: Upgrade tests @@ -404,6 +472,7 @@ jobs: upgrade_test_direct, sytest, complement, + integration ] runs-on: ubuntu-latest if: ${{ !cancelled() }} # Run this even if prior jobs were skipped diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index ff4d47187..d2a1f6e1f 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -10,79 +10,9 @@ concurrency: cancel-in-progress: true jobs: - # run go test with different go versions - test: - timeout-minutes: 20 - name: Unit tests (Go ${{ matrix.go }}) - runs-on: ubuntu-latest - # Service containers to run with `container-job` - services: - # Label used to access the service container - postgres: - # Docker Hub image - image: postgres:13-alpine - # Provide the password for postgres - env: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: dendrite - ports: - # Maps tcp port 5432 on service container to the host - - 5432:5432 - # Set health checks to wait until postgres has started - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - strategy: - fail-fast: false - matrix: - go: ["1.18", "1.19"] - steps: - - uses: actions/checkout@v3 - - name: Setup go - uses: actions/setup-go@v3 - with: - go-version: ${{ matrix.go }} - - name: Set up gotestfmt - uses: gotesttools/gotestfmt-action@v2 - with: - # Optional: pass GITHUB_TOKEN to avoid rate limiting. - token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-test-race- - - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt - env: - POSTGRES_HOST: localhost - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: dendrite - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 - - # Dummy step to gate other tests on without repeating the whole list - initial-tests-done: - name: Initial tests passed - needs: [test] - runs-on: ubuntu-latest - if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - steps: - - name: Check initial tests passed - uses: re-actors/alls-green@release/v1 - with: - jobs: ${{ toJSON(needs) }} - # run Sytest in different variations sytest: timeout-minutes: 60 - needs: initial-tests-done name: "Sytest (${{ matrix.label }})" runs-on: ubuntu-latest strategy: @@ -104,13 +34,23 @@ jobs: image: matrixdotorg/sytest-dendrite:latest volumes: - ${{ github.workspace }}:/src + - /root/.cache/go-build:/github/home/.cache/go-build + - /root/.cache/go-mod:/gopath/pkg/mod env: POSTGRES: ${{ matrix.postgres && 1}} API: ${{ matrix.api && 1 }} SYTEST_BRANCH: ${{ github.head_ref }} RACE_DETECTION: 1 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + /gopath/pkg/mod + key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go-sytest- - name: Run Sytest run: /bootstrap.sh dendrite working-directory: /src diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 75446d18c..39b9320cb 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "io" + "io/ioutil" "log" "net/http" "os" @@ -61,6 +62,7 @@ COPY . . RUN go build ./cmd/dendrite-monolith-server RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config +RUN go build ./cmd/create-account RUN ./generate-config --ci > dendrite.yaml RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key @@ -104,6 +106,7 @@ COPY . . RUN go build ./cmd/dendrite-monolith-server RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config +RUN go build ./cmd/create-account RUN ./generate-config --ci > dendrite.yaml RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key @@ -458,6 +461,46 @@ func loadAndRunTests(dockerClient *client.Client, volumeName, v string, branchTo if err = runTests(csAPIURL, v); err != nil { return fmt.Errorf("failed to run tests on version %s: %s", v, err) } + + err = testCreateAccount(dockerClient, v, containerID) + if err != nil { + return err + } + return nil +} + +// test that create-account is working +func testCreateAccount(dockerClient *client.Client, v string, containerID string) error { + createUser := strings.ToLower("createaccountuser-" + v) + log.Printf("%s: Creating account %s with create-account\n", v, createUser) + + respID, err := dockerClient.ContainerExecCreate(context.Background(), containerID, types.ExecConfig{ + AttachStderr: true, + AttachStdout: true, + Cmd: []string{ + "/build/create-account", + "-username", createUser, + "-password", "someRandomPassword", + }, + }) + if err != nil { + return fmt.Errorf("failed to ContainerExecCreate: %w", err) + } + + response, err := dockerClient.ContainerExecAttach(context.Background(), respID.ID, types.ExecStartCheck{}) + if err != nil { + return fmt.Errorf("failed to attach to container: %w", err) + } + defer response.Close() + + data, err := ioutil.ReadAll(response.Reader) + if err != nil { + return err + } + + if !bytes.Contains(data, []byte("AccessToken")) { + return fmt.Errorf("failed to create-account: %s", string(data)) + } return nil } diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 6ba05f46f..262a93a7c 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -75,7 +75,20 @@ comment. Please avoid doing this if you can. We also have unit tests which we run via: ```bash -go test --race ./... +DENDRITE_TEST_SKIP_NODB=1 go test --race ./... +``` + +This only runs SQLite database tests. If you wish to execute Postgres tests as well, you'll either need to +have Postgres installed locally (`createdb` will be used) or have a remote/containerized Postgres instance +available. + +To configure the connection to a remote Postgres, you can use the following enviroment variables: + +```bash +POSTGRES_USER=postgres +POSTGERS_PASSWORD=yourPostgresPassword +POSTGRES_HOST=localhost +POSTGRES_DB=postgres # the superuser database to use ``` In general, we like submissions that come with tests. Anything that proves that the diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go index 95f5afd90..259239b87 100644 --- a/internal/pushgateway/client.go +++ b/internal/pushgateway/client.go @@ -9,6 +9,8 @@ import ( "net/http" "time" + "github.com/matrix-org/dendrite/internal" + "github.com/opentracing/opentracing-go" ) @@ -50,8 +52,7 @@ func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, return err } - //nolint:errcheck - defer hresp.Body.Close() + defer internal.CloseAndLogIfError(ctx, hresp.Body, "failed to close response body") if hresp.StatusCode == http.StatusOK { return json.NewDecoder(hresp.Body).Decode(resp) diff --git a/internal/pushgateway/client_test.go b/internal/pushgateway/client_test.go new file mode 100644 index 000000000..bd0dca470 --- /dev/null +++ b/internal/pushgateway/client_test.go @@ -0,0 +1,54 @@ +package pushgateway + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestNotify(t *testing.T) { + wantResponse := NotifyResponse{ + Rejected: []string{"testing"}, + } + + var i = 0 + + svr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // /notify only accepts POST requests + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusNotImplemented) + return + } + + if i != 0 { // error path + w.WriteHeader(http.StatusBadRequest) + return + } + + // happy path + json.NewEncoder(w).Encode(wantResponse) + })) + defer svr.Close() + + cl := NewHTTPClient(true) + gotResponse := NotifyResponse{} + + // Test happy path + err := cl.Notify(context.Background(), svr.URL, &NotifyRequest{}, &gotResponse) + if err != nil { + t.Errorf("failed to notify client") + } + if !reflect.DeepEqual(gotResponse, wantResponse) { + t.Errorf("expected response %+v, got %+v", wantResponse, gotResponse) + } + + // Test error path + i++ + err = cl.Notify(context.Background(), svr.URL, &NotifyRequest{}, &gotResponse) + if err == nil { + t.Errorf("expected notifying the pushgateway to fail, but it succeeded") + } +} From b65f89e61e95b295e46ac3ade3c860b56126fa90 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 2 Dec 2022 16:42:23 +0100 Subject: [PATCH 048/124] Add tests for the AS internal API (#2898) --- appservice/appservice_test.go | 224 ++++++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 appservice/appservice_test.go diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go new file mode 100644 index 000000000..5a3a9aef7 --- /dev/null +++ b/appservice/appservice_test.go @@ -0,0 +1,224 @@ +package appservice_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "regexp" + "strings" + "testing" + + "github.com/gorilla/mux" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/appservice/inthttp" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi" + + "github.com/matrix-org/dendrite/test/testrig" +) + +func TestAppserviceInternalAPI(t *testing.T) { + + // Set expected results + existingProtocol := "irc" + wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}} + wantProtocolResult := map[string]api.ASProtocolResponse{ + existingProtocol: wantProtocolResponse, + } + + // create a dummy AS url, handling some cases + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "location"): + // Check if we've got an existing protocol, if so, return a proper response. + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "user"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "protocol"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode(nil); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + default: + t.Logf("hit location: %s", r.URL.Path) + } + })) + + // TODO: use test.WithAllDatabases + // only one DBType, since appservice.AddInternalRoutes complains about multiple prometheus counters added + base, closeBase := testrig.CreateBaseDendrite(t, test.DBTypeSQLite) + defer closeBase() + + // Create a dummy application service + base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ + { + ID: "someID", + URL: srv.URL, + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, + "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, + }, + Protocols: []string{existingProtocol}, + }, + } + + // Create required internal APIs + rsAPI := roomserver.NewInternalAPI(base) + usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) + + // The test cases to run + runCases := func(t *testing.T, testAPI api.AppServiceInternalAPI) { + t.Run("UserIDExists", func(t *testing.T) { + testUserIDExists(t, testAPI, "@as-testing:test", true) + testUserIDExists(t, testAPI, "@as1-testing:test", false) + }) + + t.Run("AliasExists", func(t *testing.T) { + testAliasExists(t, testAPI, "@asroom-testing:test", true) + testAliasExists(t, testAPI, "@asroom1-testing:test", false) + }) + + t.Run("Locations", func(t *testing.T) { + testLocations(t, testAPI, existingProtocol, wantLocationResponse) + testLocations(t, testAPI, "abc", nil) + }) + + t.Run("User", func(t *testing.T) { + testUser(t, testAPI, existingProtocol, wantUserResponse) + testUser(t, testAPI, "abc", nil) + }) + + t.Run("Protocols", func(t *testing.T) { + testProtocol(t, testAPI, existingProtocol, wantProtocolResult) + testProtocol(t, testAPI, existingProtocol, wantProtocolResult) // tests the cache + testProtocol(t, testAPI, "", wantProtocolResult) // tests getting all protocols + testProtocol(t, testAPI, "abc", nil) + }) + } + + // Finally execute the tests + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + appservice.AddInternalRoutes(router, asAPI) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + asHTTPApi, err := inthttp.NewAppserviceClient(apiURL, &http.Client{}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + runCases(t, asHTTPApi) + }) + + t.Run("Monolith", func(t *testing.T) { + runCases(t, asAPI) + }) + +} + +func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) { + ctx := context.Background() + userResp := &api.UserIDExistsResponse{} + + if err := asAPI.UserIDExists(ctx, &api.UserIDExistsRequest{ + UserID: userID, + }, userResp); err != nil { + t.Errorf("failed to get userID: %s", err) + } + if userResp.UserIDExists != wantExists { + t.Errorf("unexpected result for UserIDExists(%s): %v, expected %v", userID, userResp.UserIDExists, wantExists) + } +} + +func testAliasExists(t *testing.T, asAPI api.AppServiceInternalAPI, alias string, wantExists bool) { + ctx := context.Background() + aliasResp := &api.RoomAliasExistsResponse{} + + if err := asAPI.RoomAliasExists(ctx, &api.RoomAliasExistsRequest{ + Alias: alias, + }, aliasResp); err != nil { + t.Errorf("failed to get alias: %s", err) + } + if aliasResp.AliasExists != wantExists { + t.Errorf("unexpected result for RoomAliasExists(%s): %v, expected %v", alias, aliasResp.AliasExists, wantExists) + } +} + +func testLocations(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASLocationResponse) { + ctx := context.Background() + locationResp := &api.LocationResponse{} + + if err := asAPI.Locations(ctx, &api.LocationRequest{ + Protocol: proto, + }, locationResp); err != nil { + t.Errorf("failed to get locations: %s", err) + } + if !reflect.DeepEqual(locationResp.Locations, wantResult) { + t.Errorf("unexpected result for Locations(%s): %+v, expected %+v", proto, locationResp.Locations, wantResult) + } +} + +func testUser(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASUserResponse) { + ctx := context.Background() + userResp := &api.UserResponse{} + + if err := asAPI.User(ctx, &api.UserRequest{ + Protocol: proto, + }, userResp); err != nil { + t.Errorf("failed to get user: %s", err) + } + if !reflect.DeepEqual(userResp.Users, wantResult) { + t.Errorf("unexpected result for User(%s): %+v, expected %+v", proto, userResp.Users, wantResult) + } +} + +func testProtocol(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult map[string]api.ASProtocolResponse) { + ctx := context.Background() + protoResp := &api.ProtocolResponse{} + + if err := asAPI.Protocols(ctx, &api.ProtocolRequest{ + Protocol: proto, + }, protoResp); err != nil { + t.Errorf("failed to get Protocols: %s", err) + } + if !reflect.DeepEqual(protoResp.Protocols, wantResult) { + t.Errorf("unexpected result for Protocols(%s): %+v, expected %+v", proto, protoResp.Protocols[proto], wantResult) + } +} From e245a26f6bcb4d134015f49f621b6d639a78707f Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 13:53:36 +0100 Subject: [PATCH 049/124] Enable/Disable internal metrics (#2899) Basically enables us to use `test.WithAllDatabases` when testing internal HTTP APIs, as this would otherwise result in Prometheus complaining about already registered metric names. --- appservice/appservice.go | 7 +- appservice/inthttp/server.go | 12 +-- cmd/dendrite-monolith-server/main.go | 13 +-- .../personalities/appservice.go | 2 +- .../personalities/federationapi.go | 2 +- .../personalities/keyserver.go | 2 +- .../personalities/roomserver.go | 2 +- .../personalities/userapi.go | 2 +- federationapi/federationapi.go | 4 +- federationapi/inthttp/server.go | 46 +++++------ internal/httputil/httpapi.go | 15 ++-- internal/httputil/internalapi.go | 10 +-- keyserver/inthttp/server.go | 24 +++--- keyserver/keyserver.go | 4 +- roomserver/inthttp/server.go | 80 +++++++++---------- roomserver/roomserver.go | 7 +- userapi/inthttp/server.go | 71 ++++++++-------- userapi/inthttp/server_logintoken.go | 9 ++- userapi/userapi.go | 4 +- userapi/userapi_test.go | 2 +- 20 files changed, 164 insertions(+), 154 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index b3c28dbde..753850de7 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -24,6 +24,8 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/inthttp" @@ -32,12 +34,11 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // AddInternalRoutes registers HTTP handlers for internal API calls -func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceInternalAPI) { - inthttp.AddRoutes(queryAPI, router) +func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(queryAPI, router, enableMetrics) } // NewInternalAPI returns a concerete implementation of the internal API. Callers diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go index ccf5c83d8..b70fad673 100644 --- a/appservice/inthttp/server.go +++ b/appservice/inthttp/server.go @@ -8,29 +8,29 @@ import ( ) // AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux. -func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) { +func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { internalAPIMux.Handle( AppServiceRoomAliasExistsPath, - httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", a.RoomAliasExists), + httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", enableMetrics, a.RoomAliasExists), ) internalAPIMux.Handle( AppServiceUserIDExistsPath, - httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists), + httputil.MakeInternalRPCAPI("AppserviceUserIDExists", enableMetrics, a.UserIDExists), ) internalAPIMux.Handle( AppServiceProtocolsPath, - httputil.MakeInternalRPCAPI("AppserviceProtocols", a.Protocols), + httputil.MakeInternalRPCAPI("AppserviceProtocols", enableMetrics, a.Protocols), ) internalAPIMux.Handle( AppServiceLocationsPath, - httputil.MakeInternalRPCAPI("AppserviceLocations", a.Locations), + httputil.MakeInternalRPCAPI("AppserviceLocations", enableMetrics, a.Locations), ) internalAPIMux.Handle( AppServiceUserPath, - httputil.MakeInternalRPCAPI("AppserviceUser", a.User), + httputil.MakeInternalRPCAPI("AppserviceUser", enableMetrics, a.User), ) } diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index ff980dc1c..2d2f32b00 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -18,6 +18,8 @@ import ( "flag" "os" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/keyserver" @@ -29,7 +31,6 @@ import ( "github.com/matrix-org/dendrite/setup/mscs" "github.com/matrix-org/dendrite/userapi" uapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/sirupsen/logrus" ) var ( @@ -75,7 +76,7 @@ func main() { // call functions directly on the impl unless running in HTTP mode rsAPI := rsImpl if base.UseHTTPAPIs { - roomserver.AddInternalRoutes(base.InternalAPIMux, rsImpl) + roomserver.AddInternalRoutes(base.InternalAPIMux, rsImpl, base.EnableMetrics) rsAPI = base.RoomserverHTTPClient() } if traceInternal { @@ -89,7 +90,7 @@ func main() { ) fsImplAPI := fsAPI if base.UseHTTPAPIs { - federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI) + federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI, base.EnableMetrics) fsAPI = base.FederationAPIHTTPClient() } keyRing := fsAPI.KeyRing() @@ -97,7 +98,7 @@ func main() { keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) keyAPI := keyImpl if base.UseHTTPAPIs { - keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI) + keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics) keyAPI = base.KeyServerHTTPClient() } @@ -105,7 +106,7 @@ func main() { userImpl := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) userAPI := userImpl if base.UseHTTPAPIs { - userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) + userapi.AddInternalRoutes(base.InternalAPIMux, userAPI, base.EnableMetrics) userAPI = base.UserAPIClient() } if traceInternal { @@ -119,7 +120,7 @@ func main() { // before the listeners are up. asAPI := appservice.NewInternalAPI(base, userImpl, rsAPI) if base.UseHTTPAPIs { - appservice.AddInternalRoutes(base.InternalAPIMux, asAPI) + appservice.AddInternalRoutes(base.InternalAPIMux, asAPI, base.EnableMetrics) asAPI = base.AppserviceHTTPClient() } diff --git a/cmd/dendrite-polylith-multi/personalities/appservice.go b/cmd/dendrite-polylith-multi/personalities/appservice.go index 4f74434a4..0547d57f0 100644 --- a/cmd/dendrite-polylith-multi/personalities/appservice.go +++ b/cmd/dendrite-polylith-multi/personalities/appservice.go @@ -26,7 +26,7 @@ func Appservice(base *base.BaseDendrite, cfg *config.Dendrite) { rsAPI := base.RoomserverHTTPClient() intAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - appservice.AddInternalRoutes(base.InternalAPIMux, intAPI) + appservice.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.AppServiceAPI.InternalAPI.Listen, // internal listener diff --git a/cmd/dendrite-polylith-multi/personalities/federationapi.go b/cmd/dendrite-polylith-multi/personalities/federationapi.go index 6377ce9e3..48da42fbf 100644 --- a/cmd/dendrite-polylith-multi/personalities/federationapi.go +++ b/cmd/dendrite-polylith-multi/personalities/federationapi.go @@ -34,7 +34,7 @@ func FederationAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { rsAPI, fsAPI, keyAPI, nil, ) - federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI) + federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.FederationAPI.InternalAPI.Listen, diff --git a/cmd/dendrite-polylith-multi/personalities/keyserver.go b/cmd/dendrite-polylith-multi/personalities/keyserver.go index f8aa57b86..d2924b892 100644 --- a/cmd/dendrite-polylith-multi/personalities/keyserver.go +++ b/cmd/dendrite-polylith-multi/personalities/keyserver.go @@ -25,7 +25,7 @@ func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) intAPI.SetUserAPI(base.UserAPIClient()) - keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) + keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.KeyServer.InternalAPI.Listen, // internal listener diff --git a/cmd/dendrite-polylith-multi/personalities/roomserver.go b/cmd/dendrite-polylith-multi/personalities/roomserver.go index 1deb51ce0..974559bd2 100644 --- a/cmd/dendrite-polylith-multi/personalities/roomserver.go +++ b/cmd/dendrite-polylith-multi/personalities/roomserver.go @@ -26,7 +26,7 @@ func RoomServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(fsAPI, fsAPI.KeyRing()) rsAPI.SetAppserviceAPI(asAPI) - roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI) + roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.RoomServer.InternalAPI.Listen, // internal listener diff --git a/cmd/dendrite-polylith-multi/personalities/userapi.go b/cmd/dendrite-polylith-multi/personalities/userapi.go index 3fe5a43d7..1bc88cb5f 100644 --- a/cmd/dendrite-polylith-multi/personalities/userapi.go +++ b/cmd/dendrite-polylith-multi/personalities/userapi.go @@ -27,7 +27,7 @@ func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { base.PushGatewayHTTPClient(), ) - userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) + userapi.AddInternalRoutes(base.InternalAPIMux, userAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.UserAPI.InternalAPI.Listen, // internal listener diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 854251220..87eb751f5 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -43,8 +43,8 @@ import ( // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.FederationInternalAPI) { - inthttp.AddRoutes(intAPI, router) +func AddInternalRoutes(router *mux.Router, intAPI api.FederationInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(intAPI, router, enableMetrics) } // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 21a070392..9068dc400 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -17,41 +17,41 @@ import ( // AddRoutes adds the FederationInternalAPI handlers to the http.ServeMux. // nolint:gocyclo -func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { +func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { internalAPIMux.Handle( FederationAPIQueryJoinedHostServerNamesInRoomPath, - httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", intAPI.QueryJoinedHostServerNamesInRoom), + httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", enableMetrics, intAPI.QueryJoinedHostServerNamesInRoom), ) internalAPIMux.Handle( FederationAPIPerformInviteRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", intAPI.PerformInvite), + httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", enableMetrics, intAPI.PerformInvite), ) internalAPIMux.Handle( FederationAPIPerformLeaveRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", intAPI.PerformLeave), + httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", enableMetrics, intAPI.PerformLeave), ) internalAPIMux.Handle( FederationAPIPerformDirectoryLookupRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", intAPI.PerformDirectoryLookup), + httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", enableMetrics, intAPI.PerformDirectoryLookup), ) internalAPIMux.Handle( FederationAPIPerformBroadcastEDUPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), + httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", enableMetrics, intAPI.PerformBroadcastEDU), ) internalAPIMux.Handle( FederationAPIPerformWakeupServers, - httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers), + httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", enableMetrics, intAPI.PerformWakeupServers), ) internalAPIMux.Handle( FederationAPIPerformJoinRequestPath, httputil.MakeInternalRPCAPI( - "FederationAPIPerformJoinRequest", + "FederationAPIPerformJoinRequest", enableMetrics, func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error { intAPI.PerformJoin(ctx, req, res) return nil @@ -62,7 +62,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIGetUserDevicesPath, httputil.MakeInternalProxyAPI( - "FederationAPIGetUserDevices", + "FederationAPIGetUserDevices", enableMetrics, func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID) return &res, federationClientError(err) @@ -73,7 +73,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIClaimKeysPath, httputil.MakeInternalProxyAPI( - "FederationAPIClaimKeys", + "FederationAPIClaimKeys", enableMetrics, func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys) return &res, federationClientError(err) @@ -84,7 +84,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIQueryKeysPath, httputil.MakeInternalProxyAPI( - "FederationAPIQueryKeys", + "FederationAPIQueryKeys", enableMetrics, func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys) return &res, federationClientError(err) @@ -95,7 +95,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIBackfillPath, httputil.MakeInternalProxyAPI( - "FederationAPIBackfill", + "FederationAPIBackfill", enableMetrics, func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs) return &res, federationClientError(err) @@ -106,7 +106,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPILookupStatePath, httputil.MakeInternalProxyAPI( - "FederationAPILookupState", + "FederationAPILookupState", enableMetrics, func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion) return &res, federationClientError(err) @@ -117,7 +117,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPILookupStateIDsPath, httputil.MakeInternalProxyAPI( - "FederationAPILookupStateIDs", + "FederationAPILookupStateIDs", enableMetrics, func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID) return &res, federationClientError(err) @@ -128,7 +128,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPILookupMissingEventsPath, httputil.MakeInternalProxyAPI( - "FederationAPILookupMissingEvents", + "FederationAPILookupMissingEvents", enableMetrics, func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion) return &res, federationClientError(err) @@ -139,7 +139,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIGetEventPath, httputil.MakeInternalProxyAPI( - "FederationAPIGetEvent", + "FederationAPIGetEvent", enableMetrics, func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID) return &res, federationClientError(err) @@ -150,7 +150,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIGetEventAuthPath, httputil.MakeInternalProxyAPI( - "FederationAPIGetEventAuth", + "FederationAPIGetEventAuth", enableMetrics, func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID) return &res, federationClientError(err) @@ -160,13 +160,13 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIQueryServerKeysPath, - httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", intAPI.QueryServerKeys), + httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", enableMetrics, intAPI.QueryServerKeys), ) internalAPIMux.Handle( FederationAPILookupServerKeysPath, httputil.MakeInternalProxyAPI( - "FederationAPILookupServerKeys", + "FederationAPILookupServerKeys", enableMetrics, func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) { res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests) return &res, federationClientError(err) @@ -177,7 +177,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIEventRelationshipsPath, httputil.MakeInternalProxyAPI( - "FederationAPIMSC2836EventRelationships", + "FederationAPIMSC2836EventRelationships", enableMetrics, func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer) return &res, federationClientError(err) @@ -188,7 +188,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPISpacesSummaryPath, httputil.MakeInternalProxyAPI( - "FederationAPIMSC2946SpacesSummary", + "FederationAPIMSC2946SpacesSummary", enableMetrics, func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly) return &res, federationClientError(err) @@ -198,7 +198,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { // TODO: Look at this shape internalAPIMux.Handle(FederationAPIQueryPublicKeyPath, - httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", enableMetrics, func(req *http.Request) util.JSONResponse { request := api.QueryPublicKeysRequest{} response := api.QueryPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -215,7 +215,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { // TODO: Look at this shape internalAPIMux.Handle(FederationAPIInputPublicKeyPath, - httputil.MakeInternalAPI("FederationAPIInputPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("FederationAPIInputPublicKeys", enableMetrics, func(req *http.Request) util.JSONResponse { request := api.InputPublicKeysRequest{} response := api.InputPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 4f33a3f79..127d1fac7 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -24,16 +24,17 @@ import ( "strings" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" - opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // BasicAuth is used for authorization on /metrics handlers @@ -227,7 +228,7 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) // This is used for APIs that are internal to dendrite. // If we are passed a tracing context in the request headers then we use that // as the parent of any tracing spans we create. -func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { +func MakeInternalAPI(metricsName string, enableMetrics bool, f func(*http.Request) util.JSONResponse) http.Handler { h := util.MakeJSONAPI(util.NewJSONRequestHandler(f)) withSpan := func(w http.ResponseWriter, req *http.Request) { carrier := opentracing.HTTPHeadersCarrier(req.Header) @@ -246,6 +247,10 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse h.ServeHTTP(w, req) } + if !enableMetrics { + return http.HandlerFunc(withSpan) + } + return promhttp.InstrumentHandlerCounter( promauto.NewCounterVec( prometheus.CounterOpts{ diff --git a/internal/httputil/internalapi.go b/internal/httputil/internalapi.go index 385092d9c..22f436e38 100644 --- a/internal/httputil/internalapi.go +++ b/internal/httputil/internalapi.go @@ -22,7 +22,7 @@ import ( "reflect" "github.com/matrix-org/util" - opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go" ) type InternalAPIError struct { @@ -34,8 +34,8 @@ func (e InternalAPIError) Error() string { return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message) } -func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype, *restype) error) http.Handler { - return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { +func MakeInternalRPCAPI[reqtype, restype any](metricsName string, enableMetrics bool, f func(context.Context, *reqtype, *restype) error) http.Handler { + return MakeInternalAPI(metricsName, enableMetrics, func(req *http.Request) util.JSONResponse { var request reqtype var response restype if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -57,8 +57,8 @@ func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context }) } -func MakeInternalProxyAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype) (*restype, error)) http.Handler { - return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { +func MakeInternalProxyAPI[reqtype, restype any](metricsName string, enableMetrics bool, f func(context.Context, *reqtype) (*restype, error)) http.Handler { + return MakeInternalAPI(metricsName, enableMetrics, func(req *http.Request) util.JSONResponse { var request reqtype if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index 7af0ff6e5..443269f73 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -21,59 +21,59 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" ) -func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { +func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI, enableMetrics bool) { internalAPIMux.Handle( PerformClaimKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", s.PerformClaimKeys), + httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", enableMetrics, s.PerformClaimKeys), ) internalAPIMux.Handle( PerformDeleteKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", s.PerformDeleteKeys), + httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", enableMetrics, s.PerformDeleteKeys), ) internalAPIMux.Handle( PerformUploadKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", s.PerformUploadKeys), + httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", enableMetrics, s.PerformUploadKeys), ) internalAPIMux.Handle( PerformUploadDeviceKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", s.PerformUploadDeviceKeys), + httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", enableMetrics, s.PerformUploadDeviceKeys), ) internalAPIMux.Handle( PerformUploadDeviceSignaturesPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", s.PerformUploadDeviceSignatures), + httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", enableMetrics, s.PerformUploadDeviceSignatures), ) internalAPIMux.Handle( QueryKeysPath, - httputil.MakeInternalRPCAPI("KeyserverQueryKeys", s.QueryKeys), + httputil.MakeInternalRPCAPI("KeyserverQueryKeys", enableMetrics, s.QueryKeys), ) internalAPIMux.Handle( QueryOneTimeKeysPath, - httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", s.QueryOneTimeKeys), + httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", enableMetrics, s.QueryOneTimeKeys), ) internalAPIMux.Handle( QueryDeviceMessagesPath, - httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", s.QueryDeviceMessages), + httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", enableMetrics, s.QueryDeviceMessages), ) internalAPIMux.Handle( QueryKeyChangesPath, - httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", s.QueryKeyChanges), + httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", enableMetrics, s.QueryKeyChanges), ) internalAPIMux.Handle( QuerySignaturesPath, - httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures), + httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", enableMetrics, s.QuerySignatures), ) internalAPIMux.Handle( PerformMarkAsStalePath, - httputil.MakeInternalRPCAPI("KeyserverMarkAsStale", s.PerformMarkAsStaleIfNeeded), + httputil.MakeInternalRPCAPI("KeyserverMarkAsStale", enableMetrics, s.PerformMarkAsStaleIfNeeded), ) } diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index a86c2da4e..5360c06fd 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -32,8 +32,8 @@ import ( // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { - inthttp.AddRoutes(router, intAPI) +func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(router, intAPI, enableMetrics) } // NewInternalAPI returns a concerete implementation of the internal API. Callers diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 4d37e90b5..6e7c2d985 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -9,198 +9,198 @@ import ( // AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux. // nolint: gocyclo -func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { +func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { internalAPIMux.Handle( RoomserverInputRoomEventsPath, - httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", r.InputRoomEvents), + httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", enableMetrics, r.InputRoomEvents), ) internalAPIMux.Handle( RoomserverPerformInvitePath, - httputil.MakeInternalRPCAPI("RoomserverPerformInvite", r.PerformInvite), + httputil.MakeInternalRPCAPI("RoomserverPerformInvite", enableMetrics, r.PerformInvite), ) internalAPIMux.Handle( RoomserverPerformJoinPath, - httputil.MakeInternalRPCAPI("RoomserverPerformJoin", r.PerformJoin), + httputil.MakeInternalRPCAPI("RoomserverPerformJoin", enableMetrics, r.PerformJoin), ) internalAPIMux.Handle( RoomserverPerformLeavePath, - httputil.MakeInternalRPCAPI("RoomserverPerformLeave", r.PerformLeave), + httputil.MakeInternalRPCAPI("RoomserverPerformLeave", enableMetrics, r.PerformLeave), ) internalAPIMux.Handle( RoomserverPerformPeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformPeek", r.PerformPeek), + httputil.MakeInternalRPCAPI("RoomserverPerformPeek", enableMetrics, r.PerformPeek), ) internalAPIMux.Handle( RoomserverPerformInboundPeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", r.PerformInboundPeek), + httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", enableMetrics, r.PerformInboundPeek), ) internalAPIMux.Handle( RoomserverPerformUnpeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", r.PerformUnpeek), + httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", enableMetrics, r.PerformUnpeek), ) internalAPIMux.Handle( RoomserverPerformRoomUpgradePath, - httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", r.PerformRoomUpgrade), + httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", enableMetrics, r.PerformRoomUpgrade), ) internalAPIMux.Handle( RoomserverPerformPublishPath, - httputil.MakeInternalRPCAPI("RoomserverPerformPublish", r.PerformPublish), + httputil.MakeInternalRPCAPI("RoomserverPerformPublish", enableMetrics, r.PerformPublish), ) internalAPIMux.Handle( RoomserverPerformAdminEvacuateRoomPath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", r.PerformAdminEvacuateRoom), + httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", enableMetrics, r.PerformAdminEvacuateRoom), ) internalAPIMux.Handle( RoomserverPerformAdminEvacuateUserPath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser), + httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", enableMetrics, r.PerformAdminEvacuateUser), ) internalAPIMux.Handle( RoomserverPerformAdminDownloadStatePath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", r.PerformAdminDownloadState), + httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", enableMetrics, r.PerformAdminDownloadState), ) internalAPIMux.Handle( RoomserverQueryPublishedRoomsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms), + httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", enableMetrics, r.QueryPublishedRooms), ) internalAPIMux.Handle( RoomserverQueryLatestEventsAndStatePath, - httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", r.QueryLatestEventsAndState), + httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", enableMetrics, r.QueryLatestEventsAndState), ) internalAPIMux.Handle( RoomserverQueryStateAfterEventsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", r.QueryStateAfterEvents), + httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", enableMetrics, r.QueryStateAfterEvents), ) internalAPIMux.Handle( RoomserverQueryEventsByIDPath, - httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", r.QueryEventsByID), + httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", enableMetrics, r.QueryEventsByID), ) internalAPIMux.Handle( RoomserverQueryMembershipForUserPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", r.QueryMembershipForUser), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", enableMetrics, r.QueryMembershipForUser), ) internalAPIMux.Handle( RoomserverQueryMembershipsForRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", r.QueryMembershipsForRoom), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", enableMetrics, r.QueryMembershipsForRoom), ) internalAPIMux.Handle( RoomserverQueryServerJoinedToRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", r.QueryServerJoinedToRoom), + httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", enableMetrics, r.QueryServerJoinedToRoom), ) internalAPIMux.Handle( RoomserverQueryServerAllowedToSeeEventPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", r.QueryServerAllowedToSeeEvent), + httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", enableMetrics, r.QueryServerAllowedToSeeEvent), ) internalAPIMux.Handle( RoomserverQueryMissingEventsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", r.QueryMissingEvents), + httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", enableMetrics, r.QueryMissingEvents), ) internalAPIMux.Handle( RoomserverQueryStateAndAuthChainPath, - httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", r.QueryStateAndAuthChain), + httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", enableMetrics, r.QueryStateAndAuthChain), ) internalAPIMux.Handle( RoomserverPerformBackfillPath, - httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", r.PerformBackfill), + httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", enableMetrics, r.PerformBackfill), ) internalAPIMux.Handle( RoomserverPerformForgetPath, - httputil.MakeInternalRPCAPI("RoomserverPerformForget", r.PerformForget), + httputil.MakeInternalRPCAPI("RoomserverPerformForget", enableMetrics, r.PerformForget), ) internalAPIMux.Handle( RoomserverQueryRoomVersionCapabilitiesPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", r.QueryRoomVersionCapabilities), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", enableMetrics, r.QueryRoomVersionCapabilities), ) internalAPIMux.Handle( RoomserverQueryRoomVersionForRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", r.QueryRoomVersionForRoom), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", enableMetrics, r.QueryRoomVersionForRoom), ) internalAPIMux.Handle( RoomserverSetRoomAliasPath, - httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", r.SetRoomAlias), + httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", enableMetrics, r.SetRoomAlias), ) internalAPIMux.Handle( RoomserverGetRoomIDForAliasPath, - httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", r.GetRoomIDForAlias), + httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", enableMetrics, r.GetRoomIDForAlias), ) internalAPIMux.Handle( RoomserverGetAliasesForRoomIDPath, - httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", r.GetAliasesForRoomID), + httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", enableMetrics, r.GetAliasesForRoomID), ) internalAPIMux.Handle( RoomserverRemoveRoomAliasPath, - httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", r.RemoveRoomAlias), + httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", enableMetrics, r.RemoveRoomAlias), ) internalAPIMux.Handle( RoomserverQueryCurrentStatePath, - httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", r.QueryCurrentState), + httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", enableMetrics, r.QueryCurrentState), ) internalAPIMux.Handle( RoomserverQueryRoomsForUserPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", r.QueryRoomsForUser), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", enableMetrics, r.QueryRoomsForUser), ) internalAPIMux.Handle( RoomserverQueryBulkStateContentPath, - httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", r.QueryBulkStateContent), + httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", enableMetrics, r.QueryBulkStateContent), ) internalAPIMux.Handle( RoomserverQuerySharedUsersPath, - httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", r.QuerySharedUsers), + httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", enableMetrics, r.QuerySharedUsers), ) internalAPIMux.Handle( RoomserverQueryKnownUsersPath, - httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", r.QueryKnownUsers), + httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", enableMetrics, r.QueryKnownUsers), ) internalAPIMux.Handle( RoomserverQueryServerBannedFromRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", r.QueryServerBannedFromRoom), + httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", enableMetrics, r.QueryServerBannedFromRoom), ) internalAPIMux.Handle( RoomserverQueryAuthChainPath, - httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", r.QueryAuthChain), + httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", enableMetrics, r.QueryAuthChain), ) internalAPIMux.Handle( RoomserverQueryRestrictedJoinAllowed, - httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed), + httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", enableMetrics, r.QueryRestrictedJoinAllowed), ) internalAPIMux.Handle( RoomserverQueryMembershipAtEventPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent), ) } diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 1f707735b..0f6b48bf9 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -16,18 +16,19 @@ package roomserver import ( "github.com/gorilla/mux" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal" "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" - "github.com/sirupsen/logrus" ) // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI) { - inthttp.AddRoutes(intAPI, router) +func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(intAPI, router, enableMetrics) } // NewInternalAPI returns a concerete implementation of the internal API. Callers diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 661fecfae..f0579079f 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -16,176 +16,177 @@ package inthttp import ( "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" ) // nolint: gocyclo -func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { - addRoutesLoginToken(internalAPIMux, s) +func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics bool) { + addRoutesLoginToken(internalAPIMux, s, enableMetrics) internalAPIMux.Handle( PerformAccountCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", s.PerformAccountCreation), + httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", enableMetrics, s.PerformAccountCreation), ) internalAPIMux.Handle( PerformPasswordUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", s.PerformPasswordUpdate), + httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", enableMetrics, s.PerformPasswordUpdate), ) internalAPIMux.Handle( PerformDeviceCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", s.PerformDeviceCreation), + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", enableMetrics, s.PerformDeviceCreation), ) internalAPIMux.Handle( PerformLastSeenUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", s.PerformLastSeenUpdate), + httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", enableMetrics, s.PerformLastSeenUpdate), ) internalAPIMux.Handle( PerformDeviceUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", s.PerformDeviceUpdate), + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", enableMetrics, s.PerformDeviceUpdate), ) internalAPIMux.Handle( PerformDeviceDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", s.PerformDeviceDeletion), + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", enableMetrics, s.PerformDeviceDeletion), ) internalAPIMux.Handle( PerformAccountDeactivationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", s.PerformAccountDeactivation), + httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", enableMetrics, s.PerformAccountDeactivation), ) internalAPIMux.Handle( PerformOpenIDTokenCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", s.PerformOpenIDTokenCreation), + httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", enableMetrics, s.PerformOpenIDTokenCreation), ) internalAPIMux.Handle( QueryProfilePath, - httputil.MakeInternalRPCAPI("UserAPIQueryProfile", s.QueryProfile), + httputil.MakeInternalRPCAPI("UserAPIQueryProfile", enableMetrics, s.QueryProfile), ) internalAPIMux.Handle( QueryAccessTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", s.QueryAccessToken), + httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", enableMetrics, s.QueryAccessToken), ) internalAPIMux.Handle( QueryDevicesPath, - httputil.MakeInternalRPCAPI("UserAPIQueryDevices", s.QueryDevices), + httputil.MakeInternalRPCAPI("UserAPIQueryDevices", enableMetrics, s.QueryDevices), ) internalAPIMux.Handle( QueryAccountDataPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", s.QueryAccountData), + httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", enableMetrics, s.QueryAccountData), ) internalAPIMux.Handle( QueryDeviceInfosPath, - httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", s.QueryDeviceInfos), + httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", enableMetrics, s.QueryDeviceInfos), ) internalAPIMux.Handle( QuerySearchProfilesPath, - httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", s.QuerySearchProfiles), + httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", enableMetrics, s.QuerySearchProfiles), ) internalAPIMux.Handle( QueryOpenIDTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", s.QueryOpenIDToken), + httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", enableMetrics, s.QueryOpenIDToken), ) internalAPIMux.Handle( InputAccountDataPath, - httputil.MakeInternalRPCAPI("UserAPIInputAccountData", s.InputAccountData), + httputil.MakeInternalRPCAPI("UserAPIInputAccountData", enableMetrics, s.InputAccountData), ) internalAPIMux.Handle( QueryKeyBackupPath, - httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", s.QueryKeyBackup), + httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", enableMetrics, s.QueryKeyBackup), ) internalAPIMux.Handle( PerformKeyBackupPath, - httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", s.PerformKeyBackup), + httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", enableMetrics, s.PerformKeyBackup), ) internalAPIMux.Handle( QueryNotificationsPath, - httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", s.QueryNotifications), + httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", enableMetrics, s.QueryNotifications), ) internalAPIMux.Handle( PerformPusherSetPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", s.PerformPusherSet), + httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", enableMetrics, s.PerformPusherSet), ) internalAPIMux.Handle( PerformPusherDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", s.PerformPusherDeletion), + httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", enableMetrics, s.PerformPusherDeletion), ) internalAPIMux.Handle( QueryPushersPath, - httputil.MakeInternalRPCAPI("UserAPIQueryPushers", s.QueryPushers), + httputil.MakeInternalRPCAPI("UserAPIQueryPushers", enableMetrics, s.QueryPushers), ) internalAPIMux.Handle( PerformPushRulesPutPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", s.PerformPushRulesPut), + httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", enableMetrics, s.PerformPushRulesPut), ) internalAPIMux.Handle( QueryPushRulesPath, - httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", s.QueryPushRules), + httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", enableMetrics, s.QueryPushRules), ) internalAPIMux.Handle( PerformSetAvatarURLPath, - httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), + httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", enableMetrics, s.SetAvatarURL), ) internalAPIMux.Handle( QueryNumericLocalpartPath, - httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart), + httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", enableMetrics, s.QueryNumericLocalpart), ) internalAPIMux.Handle( QueryAccountAvailabilityPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", s.QueryAccountAvailability), + httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", enableMetrics, s.QueryAccountAvailability), ) internalAPIMux.Handle( QueryAccountByPasswordPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", s.QueryAccountByPassword), + httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", enableMetrics, s.QueryAccountByPassword), ) internalAPIMux.Handle( PerformSetDisplayNamePath, - httputil.MakeInternalRPCAPI("UserAPISetDisplayName", s.SetDisplayName), + httputil.MakeInternalRPCAPI("UserAPISetDisplayName", enableMetrics, s.SetDisplayName), ) internalAPIMux.Handle( QueryLocalpartForThreePIDPath, - httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", s.QueryLocalpartForThreePID), + httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", enableMetrics, s.QueryLocalpartForThreePID), ) internalAPIMux.Handle( QueryThreePIDsForLocalpartPath, - httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", s.QueryThreePIDsForLocalpart), + httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", enableMetrics, s.QueryThreePIDsForLocalpart), ) internalAPIMux.Handle( PerformForgetThreePIDPath, - httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", s.PerformForgetThreePID), + httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", enableMetrics, s.PerformForgetThreePID), ) internalAPIMux.Handle( PerformSaveThreePIDAssociationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", s.PerformSaveThreePIDAssociation), + httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), ) } diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go index b57348413..dc116428b 100644 --- a/userapi/inthttp/server_logintoken.go +++ b/userapi/inthttp/server_logintoken.go @@ -16,24 +16,25 @@ package inthttp import ( "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" ) // addRoutesLoginToken adds routes for all login token API calls. -func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { +func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics bool) { internalAPIMux.Handle( PerformLoginTokenCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", s.PerformLoginTokenCreation), + httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", enableMetrics, s.PerformLoginTokenCreation), ) internalAPIMux.Handle( PerformLoginTokenDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", s.PerformLoginTokenDeletion), + httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", enableMetrics, s.PerformLoginTokenDeletion), ) internalAPIMux.Handle( QueryLoginTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", s.QueryLoginToken), + httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", enableMetrics, s.QueryLoginToken), ) } diff --git a/userapi/userapi.go b/userapi/userapi.go index e46a8e76e..183ca3123 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -37,8 +37,8 @@ import ( // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { - inthttp.AddRoutes(router, intAPI) +func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(router, intAPI, enableMetrics) } // NewInternalAPI returns a concerete implementation of the internal API. Callers diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 25fa75ee2..60dd730fd 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -144,7 +144,7 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI) + userapi.AddInternalRoutes(router, userAPI, false) apiURL, cancel := test.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) From 07e8ed13f61fb04f6cac5a6d42263a9ddba49d32 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 15:09:59 +0100 Subject: [PATCH 050/124] Fix CI and test.WithAllDatabases --- appservice/appservice_test.go | 79 +++++++++++++++++------------------ 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 5a3a9aef7..72910d8d1 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -77,32 +77,6 @@ func TestAppserviceInternalAPI(t *testing.T) { } })) - // TODO: use test.WithAllDatabases - // only one DBType, since appservice.AddInternalRoutes complains about multiple prometheus counters added - base, closeBase := testrig.CreateBaseDendrite(t, test.DBTypeSQLite) - defer closeBase() - - // Create a dummy application service - base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ - { - ID: "someID", - URL: srv.URL, - ASToken: "", - HSToken: "", - SenderLocalpart: "senderLocalPart", - NamespaceMap: map[string][]config.ApplicationServiceNamespace{ - "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, - "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, - }, - Protocols: []string{existingProtocol}, - }, - } - - // Create required internal APIs - rsAPI := roomserver.NewInternalAPI(base) - usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) - asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) - // The test cases to run runCases := func(t *testing.T, testAPI api.AppServiceInternalAPI) { t.Run("UserIDExists", func(t *testing.T) { @@ -133,24 +107,49 @@ func TestAppserviceInternalAPI(t *testing.T) { }) } - // Finally execute the tests - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - appservice.AddInternalRoutes(router, asAPI) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, closeBase := testrig.CreateBaseDendrite(t, test.DBTypeSQLite) + defer closeBase() - asHTTPApi, err := inthttp.NewAppserviceClient(apiURL, &http.Client{}) - if err != nil { - t.Fatalf("failed to create HTTP client: %s", err) + // Create a dummy application service + base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ + { + ID: "someID", + URL: srv.URL, + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, + "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, + }, + Protocols: []string{existingProtocol}, + }, } - runCases(t, asHTTPApi) - }) - t.Run("Monolith", func(t *testing.T) { - runCases(t, asAPI) - }) + // Create required internal APIs + rsAPI := roomserver.NewInternalAPI(base) + usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) + // Finally execute the tests + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + appservice.AddInternalRoutes(router, asAPI, base.EnableMetrics) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + asHTTPApi, err := inthttp.NewAppserviceClient(apiURL, &http.Client{}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + runCases(t, asHTTPApi) + }) + + t.Run("Monolith", func(t *testing.T) { + runCases(t, asAPI) + }) + }) } func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) { From 0e6d94757b81b8e52479706ee5f857fc34023a5b Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 15:24:36 +0100 Subject: [PATCH 051/124] Enforce coverage --- .github/codecov.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 .github/codecov.yaml diff --git a/.github/codecov.yaml b/.github/codecov.yaml new file mode 100644 index 000000000..e6a38a8be --- /dev/null +++ b/.github/codecov.yaml @@ -0,0 +1,13 @@ +flag_management: + default_rules: + carryforward: true + +coverage: + status: + project: + default: + target: 75% + threshold: 0% + base: auto + flags: + - unittests \ No newline at end of file From 3dc06bea81d58d26a2d52ffb04ceb04a8c70e928 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 15:49:11 +0100 Subject: [PATCH 052/124] Differentiate between project and patch --- .github/codecov.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/codecov.yaml b/.github/codecov.yaml index e6a38a8be..78122c990 100644 --- a/.github/codecov.yaml +++ b/.github/codecov.yaml @@ -5,6 +5,13 @@ flag_management: coverage: status: project: + default: + target: auto + threshold: 0% + base: auto + flags: + - unittests + patch: default: target: 75% threshold: 0% From b99349b18c28a1c27b5bd5df30853a3b7c689d02 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 16:00:02 +0100 Subject: [PATCH 053/124] Use test.WithAllDatabases --- appservice/appservice_test.go | 2 +- userapi/userapi_test.go | 55 ++++++++++++++++++----------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 72910d8d1..83c551fea 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -108,7 +108,7 @@ func TestAppserviceInternalAPI(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, closeBase := testrig.CreateBaseDendrite(t, test.DBTypeSQLite) + base, closeBase := testrig.CreateBaseDendrite(t, dbType) defer closeBase() // Create a dummy application service diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 60dd730fd..8a19af195 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -27,14 +27,13 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" - "github.com/matrix-org/dendrite/userapi/inthttp" - - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -79,19 +78,6 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) - defer close() - _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) - if err != nil { - t.Fatalf("failed to make account: %s", err) - } - if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { - t.Fatalf("failed to set avatar url: %s", err) - } - if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { - t.Fatalf("failed to set display name: %s", err) - } testCases := []struct { req api.QueryProfileRequest @@ -142,19 +128,34 @@ func TestQueryProfile(t *testing.T) { } } - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI, false) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() - httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) if err != nil { - t.Fatalf("failed to create HTTP client") + t.Fatalf("failed to make account: %s", err) } - runCases(httpAPI, true) - }) - t.Run("Monolith", func(t *testing.T) { - runCases(userAPI, false) + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { + t.Fatalf("failed to set avatar url: %s", err) + } + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { + t.Fatalf("failed to set display name: %s", err) + } + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, userAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) + if err != nil { + t.Fatalf("failed to create HTTP client") + } + runCases(httpAPI, true) + }) + t.Run("Monolith", func(t *testing.T) { + runCases(userAPI, false) + }) }) } From 75834783055b0c70f8b411d9c3741e57461832f0 Mon Sep 17 00:00:00 2001 From: kegsay Date: Mon, 5 Dec 2022 16:54:01 +0000 Subject: [PATCH 054/124] Update contributing guidelines (#2904) --- docs/CONTRIBUTING.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 262a93a7c..21b0e7ab1 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -9,6 +9,28 @@ permalink: /development/contributing Everyone is welcome to contribute to Dendrite! We aim to make it as easy as possible to get started. + ## Contribution types + +We are a small team maintaining a large project. As a result, we cannot merge every feature, even if it +is bug-free and useful, because we then commit to maintaining it indefinitely. We will always accept: + - bug fixes + - security fixes (please responsibly disclose via security@matrix.org *before* creating pull requests) + +We will accept the following with caveats: + - documentation fixes, provided they do not add additional instructions which can end up going out-of-date, + e.g example configs, shell commands. + - performance fixes, provided they do not add significantly more maintenance burden. + - additional functionality on existing features, provided the functionality is small and maintainable. + - additional functionality that, in its absence, would impact the ecosystem e.g spam and abuse mitigations + - test-only changes, provided they help improve coverage or test tricky code. + +The following items are at risk of not being accepted: + - Configuration or CLI changes, particularly ones which increase the overall configuration surface. + +The following items are unlikely to be accepted into a main Dendrite release for now: + - New MSC implementations. + - New features which are not in the specification. + ## Sign off We require that everyone who contributes to the project signs off their contributions From ded43e0f2d07adc399da5962454d9873021b59ac Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 6 Dec 2022 13:27:33 +0100 Subject: [PATCH 055/124] Fix issue with sending presence events to invalid servers --- federationapi/consumers/roomserver.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index d16af6626..0c1080afa 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -232,7 +232,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) { - joined := make([]gomatrixserverlib.ServerName, len(addedJoined)) + joined := make([]gomatrixserverlib.ServerName, 0, len(addedJoined)) for _, added := range addedJoined { joined = append(joined, added.ServerName) } From ba2ffb7da9b86b64dc8091d90645ab80cf2831db Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 6 Dec 2022 18:16:17 +0000 Subject: [PATCH 056/124] Repeatable reads for `/sync` (#2783) This puts repeatable reads into all sync streams. Co-authored-by: kegsay --- syncapi/storage/shared/storage_consumer.go | 42 +++++++++------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 23f53d11f..f2064fb89 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -57,31 +57,23 @@ type Database struct { } func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) { - return d.NewDatabaseTransaction(ctx) - - /* - TODO: Repeatable read is probably the right thing to do here, - but it seems to cause some problems with the invite tests, so - need to investigate that further. - - txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ - // Set the isolation level so that we see a snapshot of the database. - // In PostgreSQL repeatable read transactions will see a snapshot taken - // at the first query, and since the transaction is read-only it can't - // run into any serialisation errors. - // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ - Isolation: sql.LevelRepeatableRead, - ReadOnly: true, - }) - if err != nil { - return nil, err - } - return &DatabaseTransaction{ - Database: d, - ctx: ctx, - txn: txn, - }, nil - */ + txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ + // Set the isolation level so that we see a snapshot of the database. + // In PostgreSQL repeatable read transactions will see a snapshot taken + // at the first query, and since the transaction is read-only it can't + // run into any serialisation errors. + // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }) + if err != nil { + return nil, err + } + return &DatabaseTransaction{ + Database: d, + ctx: ctx, + txn: txn, + }, nil } func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) { From 27a1dea5225c828ad787098aadafa83ed73c4940 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 8 Dec 2022 08:24:06 +0100 Subject: [PATCH 057/124] Fix issue with multiple/duplicate log entries during tests (#2906) --- internal/log.go | 5 +++++ internal/log_unix.go | 13 +++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/internal/log.go b/internal/log.go index a171555ab..d7e852c81 100644 --- a/internal/log.go +++ b/internal/log.go @@ -33,6 +33,11 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) +// logrus is using a global variable when we're using `logrus.AddHook` +// this unfortunately results in us adding the same hook multiple times. +// This map ensures we only ever add one level hook. +var stdLevelLogAdded = make(map[logrus.Level]bool) + type utcFormatter struct { logrus.Formatter } diff --git a/internal/log_unix.go b/internal/log_unix.go index 75332af73..b38e7c2e8 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -22,16 +22,16 @@ import ( "log/syslog" "github.com/MFAshby/stdemuxerhook" - "github.com/matrix-org/dendrite/setup/config" "github.com/sirupsen/logrus" lSyslog "github.com/sirupsen/logrus/hooks/syslog" + + "github.com/matrix-org/dendrite/setup/config" ) // SetupHookLogging configures the logging hooks defined in the configuration. // If something fails here it means that the logging was improperly configured, // so we just exit with the error func SetupHookLogging(hooks []config.LogrusHook, componentName string) { - stdLogAdded := false for _, hook := range hooks { // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) @@ -54,14 +54,11 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) { setupSyslogHook(hook, level, componentName) case "std": setupStdLogHook(level) - stdLogAdded = true default: logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type) } } - if !stdLogAdded { - setupStdLogHook(logrus.InfoLevel) - } + setupStdLogHook(logrus.InfoLevel) // Hooks are now configured for stdout/err, so throw away the default logger output logrus.SetOutput(io.Discard) } @@ -88,7 +85,11 @@ func checkSyslogHookParams(params map[string]interface{}) { } func setupStdLogHook(level logrus.Level) { + if stdLevelLogAdded[level] { + return + } logrus.AddHook(&logLevelHook{level, stdemuxerhook.New(logrus.StandardLogger())}) + stdLevelLogAdded[level] = true } func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) { From 0351618ff4e7d569e14a165be59a1a7e9e979684 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 8 Dec 2022 08:24:24 +0100 Subject: [PATCH 058/124] Add UserAPI util tests (#2907) This adds some `userapi/util` tests. --- test/testrig/base.go | 2 +- userapi/util/notify_test.go | 119 ++++++++++++++++++++++++++++ userapi/util/phonehomestats.go | 10 +-- userapi/util/phonehomestats_test.go | 84 ++++++++++++++++++++ 4 files changed, 208 insertions(+), 7 deletions(-) create mode 100644 userapi/util/notify_test.go create mode 100644 userapi/util/phonehomestats_test.go diff --git a/test/testrig/base.go b/test/testrig/base.go index 15fb5c370..7bc26a5c5 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -108,7 +108,7 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat cfg.Global.JetStream.InMemory = true cfg.SyncAPI.Fulltext.InMemory = true cfg.FederationAPI.KeyPerspectives = nil - base := base.NewBaseDendrite(cfg, "Tests") + base := base.NewBaseDendrite(cfg, "Tests", base.DisableMetrics) js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) return base, js, jc } diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go new file mode 100644 index 000000000..f1d20259c --- /dev/null +++ b/userapi/util/notify_test.go @@ -0,0 +1,119 @@ +package util_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + userUtil "github.com/matrix-org/dendrite/userapi/util" +) + +func TestNotifyUserCountsAsync(t *testing.T) { + alice := test.NewUser(t) + aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) + if err != nil { + t.Error(err) + } + ctx := context.Background() + + // Create a test room, just used to provide events + room := test.NewRoom(t, alice) + dummyEvent := room.Events()[len(room.Events())-1] + + appID := util.RandomString(8) + pushKey := util.RandomString(8) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + receivedRequest := make(chan bool, 1) + // create a test server which responds to our /notify call + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data pushgateway.NotifyRequest + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Error(err) + } + notification := data.Notification + // Validate the request + if notification.Counts == nil { + t.Fatal("no unread notification counts in request") + } + if unread := notification.Counts.Unread; unread != 1 { + t.Errorf("expected one unread notification, got %d", unread) + } + + if len(notification.Devices) == 0 { + t.Fatal("expected devices in request") + } + + // We only created one push device, so access it directly + device := notification.Devices[0] + if device.AppID != appID { + t.Errorf("unexpected app_id: %s, want %s", device.AppID, appID) + } + if device.PushKey != pushKey { + t.Errorf("unexpected push_key: %s, want %s", device.PushKey, pushKey) + } + + // Return empty result, otherwise the call is handled as failed + if _, err := w.Write([]byte("{}")); err != nil { + t.Error(err) + } + close(receivedRequest) + })) + defer srv.Close() + + // Create DB and Dendrite base + connStr, close := test.PrepareDBConnectionString(t, dbType) + defer close() + base, _, _ := testrig.Base(nil) + defer base.Close() + db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "test", bcrypt.MinCost, 0, 0, "") + if err != nil { + t.Error(err) + } + + // Prepare pusher with our test server URL + if err := db.UpsertPusher(ctx, api.Pusher{ + Kind: api.HTTPKind, + AppID: appID, + PushKey: pushKey, + Data: map[string]interface{}{ + "url": srv.URL, + }, + }, aliceLocalpart, serverName); err != nil { + t.Error(err) + } + + // Insert a dummy event + if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ + Event: gomatrixserverlib.HeaderedToClientEvent(dummyEvent, gomatrixserverlib.FormatAll), + }); err != nil { + t.Error(err) + } + + // Notify the user about a new notification + if err := userUtil.NotifyUserCountsAsync(ctx, pushgateway.NewHTTPClient(true), aliceLocalpart, serverName, db); err != nil { + t.Error(err) + } + select { + case <-time.After(time.Second * 5): + t.Error("timed out waiting for response") + case <-receivedRequest: + } + }) + +} diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go index 6f36568c9..42c8f5d7c 100644 --- a/userapi/util/phonehomestats.go +++ b/userapi/util/phonehomestats.go @@ -97,12 +97,10 @@ func (p *phoneHomeStats) collect() { // configuration information p.stats["federation_disabled"] = p.cfg.Global.DisableFederation - p.stats["nats_embedded"] = true - p.stats["nats_in_memory"] = p.cfg.Global.JetStream.InMemory - if len(p.cfg.Global.JetStream.Addresses) > 0 { - p.stats["nats_embedded"] = false - p.stats["nats_in_memory"] = false // probably - } + natsEmbedded := len(p.cfg.Global.JetStream.Addresses) == 0 + p.stats["nats_embedded"] = natsEmbedded + p.stats["nats_in_memory"] = p.cfg.Global.JetStream.InMemory && natsEmbedded + if len(p.cfg.Logging) > 0 { p.stats["log_level"] = p.cfg.Logging[0].Level } else { diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go new file mode 100644 index 000000000..6e62210e8 --- /dev/null +++ b/userapi/util/phonehomestats_test.go @@ -0,0 +1,84 @@ +package util + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/storage" +) + +func TestCollect(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + b, _, _ := testrig.Base(nil) + connStr, closeDB := test.PrepareDBConnectionString(t, dbType) + defer closeDB() + db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "localhost", bcrypt.MinCost, 1000, 1000, "") + if err != nil { + t.Error(err) + } + + receivedRequest := make(chan struct{}, 1) + // create a test server which responds to our call + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Error(err) + } + defer r.Body.Close() + if _, err := w.Write([]byte("{}")); err != nil { + t.Error(err) + } + + // verify the received data matches our expectations + dbEngine, ok := data["database_engine"] + if !ok { + t.Errorf("missing database_engine in JSON request: %+v", data) + } + version, ok := data["version"] + if !ok { + t.Errorf("missing version in JSON request: %+v", data) + } + if version != internal.VersionString() { + t.Errorf("unexpected version: %q, expected %q", version, internal.VersionString()) + } + switch { + case dbType == test.DBTypeSQLite && dbEngine != "SQLite": + t.Errorf("unexpected database_engine: %s", dbEngine) + case dbType == test.DBTypePostgres && dbEngine != "Postgres": + t.Errorf("unexpected database_engine: %s", dbEngine) + } + close(receivedRequest) + })) + defer srv.Close() + + b.Cfg.Global.ReportStats.Endpoint = srv.URL + stats := phoneHomeStats{ + prevData: timestampToRUUsage{}, + serverName: "localhost", + startTime: time.Now(), + cfg: b.Cfg, + db: db, + isMonolith: false, + client: &http.Client{Timeout: time.Second}, + } + + stats.collect() + + select { + case <-time.After(time.Second * 5): + t.Error("timed out waiting for response") + case <-receivedRequest: + } + }) +} From c136a450d5196cf22a91419f493bb73c29481122 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 8 Dec 2022 08:25:03 +0100 Subject: [PATCH 059/124] Fix newly joined users presence (#2854) Fixes #2803 Also refactors the presence stream to not hit the database for every user, instead queries all users at once now. --- syncapi/consumers/presence.go | 12 +- syncapi/storage/interface.go | 4 +- syncapi/storage/postgres/presence_table.go | 38 +++-- syncapi/storage/shared/storage_consumer.go | 4 +- syncapi/storage/shared/storage_sync.go | 4 +- syncapi/storage/sqlite3/presence_table.go | 48 +++++-- syncapi/storage/tables/interface.go | 2 +- syncapi/storage/tables/presence_table_test.go | 136 ++++++++++++++++++ syncapi/streams/stream_presence.go | 80 +++++++---- syncapi/sync/requestpool.go | 6 +- syncapi/sync/requestpool_test.go | 4 +- 11 files changed, 263 insertions(+), 75 deletions(-) create mode 100644 syncapi/storage/tables/presence_table_test.go diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 145059c2d..6e3150c29 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error { // Normal NATS subscription, used by Request/Reply _, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) { userID := msg.Header.Get(jetstream.UserID) - presence, err := s.db.GetPresence(context.Background(), userID) + presences, err := s.db.GetPresences(context.Background(), []string{userID}) m := &nats.Msg{ Header: nats.Header{}, } @@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error { } return } - if presence == nil { - presence = &types.PresenceInternal{ - UserID: userID, - } + + presence := &types.PresenceInternal{ + UserID: userID, + } + if len(presences) > 0 { + presence = presences[0] } deviceRes := api.QueryDevicesResponse{} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 97c2ced49..75afbce15 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -106,7 +106,7 @@ type DatabaseTransaction interface { SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) } @@ -186,7 +186,7 @@ type Database interface { } type Presence interface { - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) } diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index 7194afea6..a3f7c5213 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -19,10 +19,12 @@ import ( "database/sql" "time" + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -63,9 +65,9 @@ const upsertPresenceFromSyncSQL = "" + " RETURNING id" const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id = ANY($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -119,20 +121,28 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, - } + userIDs []string, +) ([]*types.PresenceInternal, error) { + result := make([]*types.PresenceInternal, 0, len(userIDs)) stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + rows, err := stmt.QueryContext(ctx, pq.Array(userIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index f2064fb89..df2338cf8 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -564,8 +564,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t return pos, err } -func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, nil, userID) +func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, nil, userIDs) } func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index c3763521c..77afa0290 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -596,8 +596,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) } -func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, d.txn, userID) +func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs) } func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index b61a825df..7641de92f 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -17,12 +17,14 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" + " RETURNING id" const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id IN ($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, + userIDs []string, +) ([]*types.PresenceInternal, error) { + qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + prepStmt, err := p.db.Prepare(qry) + if err != nil { + return nil, err } - stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed") + + params := make([]interface{}, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + result := make([]*types.PresenceInternal, 0, len(userIDs)) + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 2c4f04ec2..a0574b257 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -207,7 +207,7 @@ type Ignores interface { type Presence interface { UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) - GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) + GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) } diff --git a/syncapi/storage/tables/presence_table_test.go b/syncapi/storage/tables/presence_table_test.go new file mode 100644 index 000000000..dce0c695a --- /dev/null +++ b/syncapi/storage/tables/presence_table_test.go @@ -0,0 +1,136 @@ +package tables_test + +import ( + "context" + "database/sql" + "reflect" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Presence + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresPresenceTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + tab, err = sqlite3.NewSqlitePresenceTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, close +} + +func TestPresence(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + ctx := context.Background() + + statusMsg := "Hello World!" + timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + + var txn *sql.Tx + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustPresenceTable(t, dbType) + defer closeDB() + + // Insert some presences + pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos := types.StreamPosition(1) + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos = 2 + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + + // verify the expected max presence ID + maxPos, err := tab.GetMaxPresenceID(ctx, txn) + if err != nil { + t.Error(err) + } + if maxPos != wantPos { + t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos) + } + + // This should increment the position + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true) + if err != nil { + t.Error(err) + } + wantPos = pos + if wantPos <= maxPos { + t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos) + } + + // This should return only Bobs status + presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10}) + if err != nil { + t.Error(err) + } + + if c := len(presences); c > 1 { + t.Errorf("expected only one presence, got %d", c) + } + + // Validate the response + wantPresence := &types.PresenceInternal{ + UserID: bob.ID, + Presence: types.PresenceOnline, + StreamPos: wantPos, + LastActiveTS: timestamp, + ClientFields: types.PresenceClientResponse{ + LastActiveAgo: 0, + Presence: types.PresenceOnline.String(), + StatusMsg: &statusMsg, + }, + } + if !reflect.DeepEqual(wantPresence, presences[bob.ID]) { + t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence) + } + + // Try getting presences for existing and non-existing users + getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"} + presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers) + if err != nil { + t.Error(err) + } + + if len(presencesForUsers) >= len(getUsers) { + t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers)) + } + }) + +} diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 030b7c5d5..445e46b3a 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -17,6 +17,7 @@ package streams import ( "context" "encoding/json" + "fmt" "sync" "github.com/matrix-org/gomatrixserverlib" @@ -70,39 +71,25 @@ func (p *PresenceStreamProvider) IncrementalSync( return from } - if len(presences) == 0 { + getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences) + if err != nil { + req.Log.WithError(err).Error("getNeededUsersFromRequest failed") + return from + } + + // Got no presence between range and no presence to get from the database + if len(getPresenceForUsers) == 0 && len(presences) == 0 { return to } - // add newly joined rooms user presences - newlyJoined := joinedRooms(req.Response, req.Device.UserID) - if len(newlyJoined) > 0 { - // TODO: Check if this is working better than before. - if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { - req.Log.WithError(err).Error("unable to refresh notifier lists") - return from - } - NewlyJoinedLoop: - for _, roomID := range newlyJoined { - roomUsers := p.notifier.JoinedUsers(roomID) - for i := range roomUsers { - // we already got a presence from this user - if _, ok := presences[roomUsers[i]]; ok { - continue - } - // Bear in mind that this might return nil, but at least populating - // a nil means that there's a map entry so we won't repeat this call. - presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i]) - if err != nil { - req.Log.WithError(err).Error("unable to query presence for user") - _ = snapshot.Rollback() - return from - } - if len(presences) > req.Filter.Presence.Limit { - break NewlyJoinedLoop - } - } - } + dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers) + if err != nil { + req.Log.WithError(err).Error("unable to query presence for user") + _ = snapshot.Rollback() + return from + } + for _, presence := range dbPresences { + presences[presence.UserID] = presence } lastPos := from @@ -164,6 +151,39 @@ func (p *PresenceStreamProvider) IncrementalSync( return lastPos } +func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) { + getPresenceForUsers := []string{} + // Add presence for users which newly joined a room + for userID := range req.MembershipChanges { + if _, ok := presences[userID]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, userID) + } + + // add newly joined rooms user presences + newlyJoined := joinedRooms(req.Response, req.Device.UserID) + if len(newlyJoined) == 0 { + return getPresenceForUsers, nil + } + + // TODO: Check if this is working better than before. + if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { + return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err) + } + for _, roomID := range newlyJoined { + roomUsers := p.notifier.JoinedUsers(roomID) + for i := range roomUsers { + // we already got a presence from this user + if _, ok := presences[roomUsers[i]]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, roomUsers[i]) + } + } + return getPresenceForUsers, nil +} + func joinedRooms(res *types.Response, userID string) []string { var roomIDs []string for roomID, join := range res.Rooms.Join { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 29d92b293..b086567b8 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -145,12 +145,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user } // ensure we also send the current status_msg to federated servers and not nil - dbPresence, err := db.GetPresence(context.Background(), userID) + dbPresence, err := db.GetPresences(context.Background(), []string{userID}) if err != nil && err != sql.ErrNoRows { return } - if dbPresence != nil { - newPresence.ClientFields = dbPresence.ClientFields + if len(dbPresence) > 0 && dbPresence[0] != nil { + newPresence.ClientFields = dbPresence[0].ClientFields } newPresence.ClientFields.Presence = presenceID.String() diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index 3e5769d8c..faa0b49c6 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -29,8 +29,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ return 0, nil } -func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return &types.PresenceInternal{}, nil +func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) { + return []*types.PresenceInternal{}, nil } func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { From 8846de7312d447cd2866236493b6ae172fbebfa9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 8 Dec 2022 10:19:55 +0000 Subject: [PATCH 060/124] Bump nokogiri from 1.13.9 to 1.13.10 in /docs (#2909) Bumps [nokogiri](https://github.com/sparklemotion/nokogiri) from 1.13.9 to 1.13.10.
Release notes

Sourced from nokogiri's releases.

1.13.10 / 2022-12-07

Security

  • [CRuby] Address CVE-2022-23476, unchecked return value from xmlTextReaderExpand. See GHSA-qv4q-mr5r-qprj for more information.

Improvements

  • [CRuby] XML::Reader#attribute_hash now returns nil on parse errors. This restores the behavior of #attributes from v1.13.7 and earlier. [#2715]

sha256 checksums:

777ce2e80f64772e91459b943e531dfef387e768f2255f9bc7a1655f254bbaa1
nokogiri-1.13.10-aarch64-linux.gem
b432ff47c51386e07f7e275374fe031c1349e37eaef2216759063bc5fa5624aa
nokogiri-1.13.10-arm64-darwin.gem
73ac581ddcb680a912e92da928ffdbac7b36afd3368418f2cee861b96e8c830b
nokogiri-1.13.10-java.gem
916aa17e624611dddbf2976ecce1b4a80633c6378f8465cff0efab022ebc2900
nokogiri-1.13.10-x64-mingw-ucrt.gem
0f85a1ad8c2b02c166a6637237133505b71a05f1bb41b91447005449769bced0
nokogiri-1.13.10-x64-mingw32.gem
91fa3a8724a1ce20fccbd718dafd9acbde099258183ac486992a61b00bb17020
nokogiri-1.13.10-x86-linux.gem
d6663f5900ccd8f72d43660d7f082565b7ffcaade0b9a59a74b3ef8791034168
nokogiri-1.13.10-x86-mingw32.gem
81755fc4b8130ef9678c76a2e5af3db7a0a6664b3cba7d9fe8ef75e7d979e91b
nokogiri-1.13.10-x86_64-darwin.gem
51d5246705dedad0a09b374d09cc193e7383a5dd32136a690a3cd56e95adf0a3
nokogiri-1.13.10-x86_64-linux.gem
d3ee00f26c151763da1691c7fc6871ddd03e532f74f85101f5acedc2d099e958
nokogiri-1.13.10.gem
Changelog

Sourced from nokogiri's changelog.

1.13.10 / 2022-12-07

Security

  • [CRuby] Address CVE-2022-23476, unchecked return value from xmlTextReaderExpand. See GHSA-qv4q-mr5r-qprj for more information.

Improvements

  • [CRuby] XML::Reader#attribute_hash now returns nil on parse errors. This restores the behavior of #attributes from v1.13.7 and earlier. [#2715]
Commits
  • 4c80121 version bump to v1.13.10
  • 85410e3 Merge pull request #2715 from sparklemotion/flavorjones-fix-reader-error-hand...
  • 9fe0761 fix(cruby): XML::Reader#attribute_hash returns nil on error
  • 3b9c736 Merge pull request #2717 from sparklemotion/flavorjones-lock-psych-to-fix-bui...
  • 2efa87b test: skip large cdata test on system libxml2
  • 3187d67 dep(dev): pin psych to v4 until v5 builds in CI
  • a16b4bf style(rubocop): disable Minitest/EmptyLineBeforeAssertionMethods
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=nokogiri&package-manager=bundler&previous-version=1.13.9&new-version=1.13.10)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/Gemfile.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index c7ba43711..509a8cbcf 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -231,9 +231,9 @@ GEM jekyll-seo-tag (~> 2.1) minitest (5.15.0) multipart-post (2.1.1) - nokogiri (1.13.9-arm64-darwin) + nokogiri (1.13.10-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.9-x86_64-linux) + nokogiri (1.13.10-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) @@ -241,7 +241,7 @@ GEM pathutil (0.16.2) forwardable-extended (~> 2.6) public_suffix (4.0.7) - racc (1.6.0) + racc (1.6.1) rb-fsevent (0.11.1) rb-inotify (0.10.1) ffi (~> 1.0) From aaf4e5c8654463cc5431d57db9163ad9ed558f53 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 9 Dec 2022 18:45:42 +0100 Subject: [PATCH 061/124] Use older sytest-dendrite image --- .github/workflows/dendrite.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 593012ef3..2c04005d2 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -331,7 +331,8 @@ jobs: postgres: postgres api: full-http container: - image: matrixdotorg/sytest-dendrite:latest + # Temporary for debugging to see if this image is working better. + image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73 volumes: - ${{ github.workspace }}:/src - /root/.cache/go-build:/github/home/.cache/go-build From 7d2344049d0780e92b071939addc41219ce94f5a Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 12 Dec 2022 08:20:59 +0100 Subject: [PATCH 062/124] Cleanup stale device lists for users we don't share a room with anymore (#2857) The stale device lists table might contain entries for users we don't share a room with anymore. This now asks the roomserver about left users and removes those entries from the table. Co-authored-by: Neil Alexander --- build/dendritejs-pinecone/main.go | 6 +- build/gobind-pinecone/monolith.go | 2 +- build/gobind-yggdrasil/monolith.go | 2 +- cmd/dendrite-demo-pinecone/main.go | 2 +- cmd/dendrite-demo-yggdrasil/main.go | 5 +- cmd/dendrite-monolith-server/main.go | 2 +- .../personalities/keyserver.go | 3 +- keyserver/internal/device_list_update.go | 29 +++++- keyserver/internal/device_list_update_test.go | 86 ++++++++++++++++- keyserver/keyserver.go | 12 ++- keyserver/keyserver_test.go | 29 ++++++ keyserver/storage/interface.go | 5 + .../storage/postgres/stale_device_lists.go | 33 +++++-- keyserver/storage/shared/storage.go | 10 ++ .../storage/sqlite3/stale_device_lists.go | 44 +++++++-- keyserver/storage/tables/interface.go | 1 + .../storage/tables/stale_device_lists_test.go | 94 ++++++++++++++++++ roomserver/api/api.go | 5 + roomserver/api/api_trace.go | 6 ++ roomserver/api/query.go | 12 +++ roomserver/internal/query/query.go | 6 ++ roomserver/inthttp/client.go | 8 ++ roomserver/inthttp/server.go | 5 + roomserver/roomserver_test.go | 77 ++++++++++++++- roomserver/storage/interface.go | 1 + .../storage/postgres/membership_table.go | 34 ++++++- roomserver/storage/shared/storage.go | 37 +++++++ roomserver/storage/shared/storage_test.go | 96 +++++++++++++++++++ .../storage/sqlite3/membership_table.go | 47 ++++++++- roomserver/storage/tables/interface.go | 1 + .../storage/tables/membership_table_test.go | 6 ++ 31 files changed, 666 insertions(+), 40 deletions(-) create mode 100644 keyserver/keyserver_test.go create mode 100644 keyserver/storage/tables/stale_device_lists_test.go create mode 100644 roomserver/storage/shared/storage_test.go diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go index e070173aa..f44a77488 100644 --- a/build/dendritejs-pinecone/main.go +++ b/build/dendritejs-pinecone/main.go @@ -180,14 +180,14 @@ func startup() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck + rsAPI := roomserver.NewInternalAPI(base) + federation := conn.CreateFederationClient(base, pSessions) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index e8ed8fe85..b8f8111d2 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -350,7 +350,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 9a3ac5d7b..8c2d0a006 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -165,7 +165,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 2f647a41b..3f627b41d 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -213,7 +213,7 @@ func main() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsComponent) userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 5dd61b1b7..3ea4a08b0 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -157,11 +157,12 @@ func main() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - rsComponent := roomserver.NewInternalAPI( base, ) + + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsComponent) + rsAPI := rsComponent userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 2d2f32b00..6836b6426 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -95,7 +95,7 @@ func main() { } keyRing := fsAPI.KeyRing() - keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) keyAPI := keyImpl if base.UseHTTPAPIs { keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics) diff --git a/cmd/dendrite-polylith-multi/personalities/keyserver.go b/cmd/dendrite-polylith-multi/personalities/keyserver.go index d2924b892..ad0bd0e54 100644 --- a/cmd/dendrite-polylith-multi/personalities/keyserver.go +++ b/cmd/dendrite-polylith-multi/personalities/keyserver.go @@ -22,7 +22,8 @@ import ( func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { fsAPI := base.FederationAPIHTTPClient() - intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + rsAPI := base.RoomserverHTTPClient() + intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) intAPI.SetUserAPI(base.UserAPIClient()) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics) diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 8ff9dfc31..c7bf8da53 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -24,6 +24,8 @@ import ( "sync" "time" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -102,6 +104,7 @@ type DeviceListUpdater struct { // block on or timeout via a select. userIDToChan map[string]chan bool userIDToChanMu *sync.Mutex + rsAPI rsapi.KeyserverRoomserverAPI } // DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. @@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface { // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error } type DeviceListUpdaterAPI interface { @@ -140,7 +145,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - thisServer gomatrixserverlib.ServerName, + rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -154,6 +159,7 @@ func NewDeviceListUpdater( workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, + rsAPI: rsAPI, } } @@ -168,7 +174,7 @@ func (u *DeviceListUpdater) Start() error { go u.worker(ch) } - staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{}) + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) if err != nil { return err } @@ -186,6 +192,25 @@ func (u *DeviceListUpdater) Start() error { return nil } +// CleanUp removes stale device entries for users we don't share a room with anymore +func (u *DeviceListUpdater) CleanUp() error { + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + + res := rsapi.QueryLeftUsersResponse{} + if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { + return err + } + + if len(res.LeftUsers) == 0 { + return nil + } + logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers)) + return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers) +} + func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { u.mu.Lock() defer u.mu.Unlock() diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index a374c9516..60a2c2f30 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -30,7 +30,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage" + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" ) var ( @@ -53,6 +58,10 @@ type mockDeviceListUpdaterDatabase struct { mu sync.Mutex // protect staleUsers } +func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error { + return nil +} + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { @@ -153,7 +162,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost") event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -225,7 +234,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -239,6 +248,7 @@ func TestUpdateNoPrevID(t *testing.T) { UserID: remoteUserID, } err := updater.Update(ctx, event) + if err != nil { t.Fatalf("Update returned an error: %s", err) } @@ -294,7 +304,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -349,3 +359,73 @@ func TestDebounce(t *testing.T) { t.Errorf("user %s is marked as stale", userID) } } + +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { + t.Helper() + + base, _, _ := testrig.Base(nil) + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + if err != nil { + t.Fatal(err) + } + + return db, clearDB +} + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +func TestDeviceListUpdater_CleanUp(t *testing.T) { + processCtx := process.NewProcessContext() + + alice := test.NewUser(t) + bob := test.NewUser(t) + + // Bob is not joined to any of our rooms + rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}} + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clearDB := mustCreateKeyserverDB(t, dbType) + defer clearDB() + + // This should not get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil { + t.Error(err) + } + + // this one should get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil { + t.Error(err) + } + + updater := NewDeviceListUpdater(processCtx, db, nil, + nil, nil, + 0, rsAPI, "test") + if err := updater.CleanUp(); err != nil { + t.Error(err) + } + + // check that we still have Alice in our stale list + staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Error(err) + } + + // There should only be Alice + wantCount := 1 + if count := len(staleUsers); count != wantCount { + t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count) + } + + if staleUsers[0] != alice.ID { + t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID) + } + }) +} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 5360c06fd..275576773 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -18,6 +18,8 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/consumers" @@ -40,6 +42,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetr // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, + rsAPI rsapi.KeyserverRoomserverAPI, ) api.KeyInternalAPI { js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) @@ -47,6 +50,7 @@ func NewInternalAPI( if err != nil { logrus.WithError(err).Panicf("failed to connect to key server database") } + keyChangeProducer := &producers.KeyChange{ Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), JetStream: js, @@ -58,8 +62,14 @@ func NewInternalAPI( FedClient: fedClient, Producer: keyChangeProducer, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable ap.Updater = updater + + // Remove users which we don't share a room with anymore + if err := updater.CleanUp(); err != nil { + logrus.WithError(err).Error("failed to cleanup stale device lists") + } + go func() { if err := updater.Start(); err != nil { logrus.WithError(err).Panicf("failed to start device list updater") diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go new file mode 100644 index 000000000..159b280f5 --- /dev/null +++ b/keyserver/keyserver_test.go @@ -0,0 +1,29 @@ +package keyserver + +import ( + "context" + "testing" + + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +// Merely tests that we can create an internal keyserver API +func Test_NewInternalAPI(t *testing.T) { + rsAPI := &mockKeyserverRoomserverAPI{} + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, closeBase := testrig.CreateBaseDendrite(t, dbType) + defer closeBase() + _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + }) +} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 242e16a06..c6a8f44cd 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -85,4 +85,9 @@ type Database interface { StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error } diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go index d0fe50d00..248ddfb45 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -19,6 +19,10 @@ import ( "database/sql" "time" + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" @@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" + type staleDeviceListsStatements struct { upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + deleteStaleDeviceListsStmt *sql.Stmt } func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) + _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5beeed0f1..54dd6ddc9 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget( return nil }) } + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *Database) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index 1e08b266c..fd76a6e3b 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -17,8 +17,11 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" @@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" + type staleDeviceListsStatements struct { db *sql.DB upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime } func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") + stmt = sqlutil.TxStmt(txn, stmt) + + params := make([]any, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 37a010a7c..24da1125e 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -56,6 +56,7 @@ type KeyChanges interface { type StaleDeviceLists interface { InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error } type CrossSigningKeys interface { diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/keyserver/storage/tables/stale_device_lists_test.go new file mode 100644 index 000000000..76d3baddd --- /dev/null +++ b/keyserver/storage/tables/stale_device_lists_test.go @@ -0,0 +1,94 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/keyserver/storage/postgres" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/test" +) + +func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, nil) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresStaleDeviceListsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) + } + if err != nil { + t.Fatalf("failed to create new table: %s", err) + } + return tab, close +} + +func TestStaleDeviceLists(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := "@charlie:localhost" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateTable(t, dbType) + defer closeDB() + + if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + + // Query one server + wantStaleUsers := []string{alice.ID, bob.ID} + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Query all servers + wantStaleUsers = []string{alice.ID, bob.ID, charlie} + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Delete stale devices + deleteUsers := []string{alice.ID, bob.ID} + if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { + t.Fatalf("failed to delete stale device lists: %s", err) + } + + // Verify we don't get anything back after deleting + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + + if gotCount := len(gotStaleUsers); gotCount > 0 { + t.Fatalf("expected no stale users, got %d", gotCount) + } + }) +} diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 01e87ec8a..420ef278a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -17,6 +17,7 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI + KeyserverRoomserverAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -199,3 +200,7 @@ type FederationRoomserverAPI interface { // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error } + +type KeyserverRoomserverAPI interface { + QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error +} diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 342a3904c..b23263d17 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -19,6 +19,12 @@ type RoomserverInternalAPITrace struct { Impl RoomserverInternalAPI } +func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error { + err := t.Impl.QueryLeftUsers(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryLeftUsers req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { t.Impl.SetFederationAPI(fsAPI, keyRing) } diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b62907f3c..76f8298ca 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -447,3 +447,15 @@ type QueryMembershipAtEventResponse struct { // do not have known state will return an empty array here. Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` } + +// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a +// a room with anymore. This is used to cleanup stale device list entries, where we would +// otherwise keep on trying to get device lists. +type QueryLeftUsersRequest struct { + StaleDeviceListUsers []string `json:"user_ids"` +} + +// QueryLeftUsersResponse is the response to QueryLeftUsersRequest. +type QueryLeftUsersResponse struct { + LeftUsers []string `json:"user_ids"` +} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index d8456fb43..69d841dda 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS return nil } +func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error { + var err error + res.LeftUsers, err = r.DB.GetLeftUsers(ctx, req.StaleDeviceListUsers) + return err +} + func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") if err != nil { diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 1bd1b3fb7..8a2e0a03c 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -63,6 +63,7 @@ const ( RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" + RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers" ) type httpRoomserverInternalAPI struct { @@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, h.httpClient, ctx, request, response, ) } + +func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error { + return httputil.CallInternalRPCAPI( + "RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath, + h.httpClient, ctx, request, response, + ) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 6e7c2d985..4d21909b7 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe RoomserverQueryMembershipAtEventPath, httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent), ) + + internalAPIMux.Handle( + RoomserverQueryLeftMembersPath, + httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", enableMetrics, r.QueryLeftUsers), + ) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 24b5515e5..518bb3722 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -2,20 +2,27 @@ package roomserver_test import ( "context" + "net/http" "testing" + "time" + "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/gomatrixserverlib" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { + t.Helper() base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches) + db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches) if err != nil { t.Fatalf("failed to create Database: %v", err) } @@ -67,3 +74,69 @@ func Test_SharedUsers(t *testing.T) { } }) } + +func Test_QueryLeftUsers(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite and join Bob + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, _, close := mustCreateDatabase(t, dbType) + defer close() + + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // Query the left users, there should only be "@idontexist:test", + // as Alice and Bob are still joined. + res := &api.QueryLeftUsersResponse{} + leftUserID := "@idontexist:test" + getLeftUsersList := []string{alice.ID, bob.ID, leftUserID} + + testCase := func(rsAPI api.RoomserverInternalAPI) { + if err := rsAPI.QueryLeftUsers(ctx, &api.QueryLeftUsersRequest{StaleDeviceListUsers: getLeftUsersList}, res); err != nil { + t.Fatalf("unable to query left users: %v", err) + } + wantCount := 1 + if count := len(res.LeftUsers); count > wantCount { + t.Fatalf("unexpected left users count: want %d, got %d", wantCount, count) + } + if res.LeftUsers[0] != leftUserID { + t.Fatalf("unexpected left users : want %s, got %s", leftUserID, res.LeftUsers[0]) + } + } + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + roomserver.AddInternalRoutes(router, rsAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + httpAPI, err := inthttp.NewRoomserverClient(apiURL, &http.Client{Timeout: time.Second * 5}, nil) + if err != nil { + t.Fatalf("failed to create HTTP client") + } + testCase(httpAPI) + }) + t.Run("Monolith", func(t *testing.T) { + testCase(rsAPI) + // also test tracing + traceAPI := &api.RoomserverInternalAPITrace{Impl: rsAPI} + testCase(traceAPI) + }) + + }) +} diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 06db4b2d8..92bc2e66f 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -172,5 +172,6 @@ type Database interface { ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 0150534e1..d774b7892 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -21,12 +21,13 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -157,6 +158,12 @@ const selectServerInRoomSQL = "" + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid = ANY($2) +` + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -174,6 +181,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt } func CreateMembershipTable(db *sql.DB) error { @@ -209,9 +217,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.deleteMembershipStmt, deleteMembershipSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, }.Prepare(db) } +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt) + rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed") + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} + func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 16898bcb1..725cc5bc7 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1365,6 +1365,43 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ return result, nil } +// GetLeftUsers calculates users we (the server) don't share a room with anymore. +func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) { + // Get the userNID for all users with a stale device list + stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs) + if err != nil { + return nil, err + } + + userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)) + userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap)) + // Create a map from userNID -> userID + for userID, nid := range stateKeyNIDMap { + userNIDs = append(userNIDs, nid) + userNIDtoUserID[nid] = userID + } + + // Get all users whose membership is still join, knock or invite. + stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs) + if err != nil { + return nil, err + } + + // Remove joined users from the "user with stale devices" list, which contains left AND joined users + for _, joinedUser := range stillJoinedUsersNIDs { + delete(userNIDtoUserID, joinedUser) + } + + // The users still in our userNIDtoUserID map are the users we don't share a room with anymore, + // and the return value we are looking for. + leftUsers := make([]string, 0, len(userNIDtoUserID)) + for _, userID := range userNIDtoUserID { + leftUsers = append(leftUsers, userID) + } + + return leftUsers, nil +} + // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go new file mode 100644 index 000000000..58724340c --- /dev/null +++ b/roomserver/storage/shared/storage_test.go @@ -0,0 +1,96 @@ +package shared_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Database, func()) { + t.Helper() + + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + base, _, _ := testrig.Base(nil) + dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)} + + db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + + var membershipTable tables.Membership + var stateKeyTable tables.EventStateKeys + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventStateKeysTable(db) + assert.NoError(t, err) + err = postgres.CreateMembershipTable(db) + assert.NoError(t, err) + membershipTable, err = postgres.PrepareMembershipTable(db) + assert.NoError(t, err) + stateKeyTable, err = postgres.PrepareEventStateKeysTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventStateKeysTable(db) + assert.NoError(t, err) + err = sqlite3.CreateMembershipTable(db) + assert.NoError(t, err) + membershipTable, err = sqlite3.PrepareMembershipTable(db) + assert.NoError(t, err) + stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db) + } + assert.NoError(t, err) + + return &shared.Database{ + DB: db, + EventStateKeysTable: stateKeyTable, + MembershipTable: membershipTable, + Writer: sqlutil.NewExclusiveWriter(), + }, func() { + err := base.Close() + assert.NoError(t, err) + clearDB() + err = db.Close() + assert.NoError(t, err) + } +} + +func Test_GetLeftUsers(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + // Create dummy entries + for _, user := range []*test.User{alice, bob, charlie} { + nid, err := db.EventStateKeysTable.InsertEventStateKeyNID(ctx, nil, user.ID) + assert.NoError(t, err) + err = db.MembershipTable.InsertMembership(ctx, nil, 1, nid, true) + assert.NoError(t, err) + // We must update the membership with a non-zero event NID or it will get filtered out in later queries + membershipNID := tables.MembershipStateLeaveOrBan + if user == alice { + membershipNID = tables.MembershipStateJoin + } + _, err = db.MembershipTable.UpdateMembership(ctx, nil, 1, nid, nid, membershipNID, 1, false) + assert.NoError(t, err) + } + + // Now try to get the left users, this should be Bob and Charlie, since they have a "leave" membership + expectedUserIDs := []string{bob.ID, charlie.ID} + leftUsers, err := db.GetLeftUsers(context.Background(), []string{alice.ID, bob.ID, charlie.ID}) + assert.NoError(t, err) + assert.ElementsMatch(t, expectedUserIDs, leftUsers) + }) +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index cd149f0ed..8a60b359f 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -21,12 +21,13 @@ import ( "fmt" "strings" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -133,6 +134,12 @@ const selectServerInRoomSQL = "" + const deleteMembershipSQL = "" + "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid IN ($2) +` + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -149,6 +156,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + // selectJoinedUsersStmt *sql.Stmt // Prepared at runtime } func CreateMembershipTable(db *sql.DB) error { @@ -412,3 +420,40 @@ func (s *membershipStatements) DeleteMembership( ) return err } + +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1) + + stmt, err := s.db.Prepare(qry) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed") + + params := make([]any, len(targetUserNIDs)+1) + params[0] = tables.MembershipStateLeaveOrBan + for i := range targetUserNIDs { + params[i+1] = targetUserNIDs[i] + } + + stmt = sqlutil.TxStmt(txn, stmt) + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed") + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 50d27c756..80fcf72dd 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -144,6 +144,7 @@ type Membership interface { SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error + SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error) } type Published interface { diff --git a/roomserver/storage/tables/membership_table_test.go b/roomserver/storage/tables/membership_table_test.go index c9541d9d2..c4524ee44 100644 --- a/roomserver/storage/tables/membership_table_test.go +++ b/roomserver/storage/tables/membership_table_test.go @@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) { knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2) assert.NoError(t, err) assert.Equal(t, 1, len(knownUsers)) + + // get users we share a room with, given their userNID + joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs) + assert.NoError(t, err) + // Only userNIDs[0] is actually joined, so we only expect this userNID + assert.Equal(t, userNIDs[:1], joinedUsers) }) } From 76db8e90defdfb9e61f6caea8a312c5d60bcc005 Mon Sep 17 00:00:00 2001 From: Kento Okamoto Date: Mon, 12 Dec 2022 08:46:37 -0800 Subject: [PATCH 063/124] Dendrite Documentation Fix (#2913) ### Pull Request Checklist I was reading through the Dendrite documentation on https://matrix-org.github.io/dendrite/development/contributing and noticed the installation link leads to a 404 error. This link works fine if it is viewed directly from [docs/CONTRIBUTING.md](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md) but this might not be very obvious to new contributors who are reading through the [contribution page](https://matrix-org.github.io/dendrite/development/contributing) directly. This PR is mainly a small re-organization of the online documentation mainly in the [Development](https://matrix-org.github.io/dendrite/development) tab along with any links throughout the doc that may be impacted by the change. This does not contain any Go unit tests as this does not actually touch core dendrite functionality. * [ ] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Kento Okamoto ` --- docs/FAQ.md | 4 ++-- docs/{ => development}/CONTRIBUTING.md | 6 +++--- docs/{ => development}/PROFILING.md | 0 docs/{ => development}/coverage.md | 0 docs/{ => development}/sytest.md | 0 docs/{ => development}/tracing/opentracing.md | 0 docs/{ => development}/tracing/setup.md | 0 7 files changed, 5 insertions(+), 5 deletions(-) rename docs/{ => development}/CONTRIBUTING.md (97%) rename docs/{ => development}/PROFILING.md (100%) rename docs/{ => development}/coverage.md (100%) rename docs/{ => development}/sytest.md (100%) rename docs/{ => development}/tracing/opentracing.md (100%) rename docs/{ => development}/tracing/setup.md (100%) diff --git a/docs/FAQ.md b/docs/FAQ.md index ca72b151d..816130515 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -91,7 +91,7 @@ Please use PostgreSQL wherever possible, especially if you are planning to run a ## Dendrite is using a lot of CPU Generally speaking, you should expect to see some CPU spikes, particularly if you are joining or participating in large rooms. However, constant/sustained high CPU usage is not expected - if you are experiencing that, please join `#dendrite-dev:matrix.org` and let us know what you were doing when the -CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](PROFILING.md) then that would +CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the CPU time is going. ## Dendrite is using a lot of RAM @@ -99,7 +99,7 @@ be a huge help too, as that will help us to understand where the CPU time is goi As above with CPU usage, some memory spikes are expected if Dendrite is doing particularly heavy work at a given instant. However, if it is using more RAM than you expect for a long time, that's probably not expected. Join `#dendrite-dev:matrix.org` and let us know what you were doing when the memory usage -ballooned, or file a GitHub issue if you can. If you can take a [memory profile](PROFILING.md) then that +ballooned, or file a GitHub issue if you can. If you can take a [memory profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the memory usage is happening. ## Dendrite is running out of PostgreSQL database connections diff --git a/docs/CONTRIBUTING.md b/docs/development/CONTRIBUTING.md similarity index 97% rename from docs/CONTRIBUTING.md rename to docs/development/CONTRIBUTING.md index 21b0e7ab1..2aec4c363 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/development/CONTRIBUTING.md @@ -9,7 +9,7 @@ permalink: /development/contributing Everyone is welcome to contribute to Dendrite! We aim to make it as easy as possible to get started. - ## Contribution types +## Contribution types We are a small team maintaining a large project. As a result, we cannot merge every feature, even if it is bug-free and useful, because we then commit to maintaining it indefinitely. We will always accept: @@ -57,7 +57,7 @@ to do so for future contributions. ## Getting up and running -See the [Installation](installation) section for information on how to build an +See the [Installation](../installation) section for information on how to build an instance of Dendrite. You will likely need this in order to test your changes. ## Code style @@ -151,7 +151,7 @@ significant amount of CPU and RAM. Once the code builds, run [Sytest](https://github.com/matrix-org/sytest) according to the guide in -[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/sytest.md#using-a-sytest-docker-image) +[docs/development/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/development/sytest.md#using-a-sytest-docker-image) so you can see whether something is being broken and whether there are newly passing tests. diff --git a/docs/PROFILING.md b/docs/development/PROFILING.md similarity index 100% rename from docs/PROFILING.md rename to docs/development/PROFILING.md diff --git a/docs/coverage.md b/docs/development/coverage.md similarity index 100% rename from docs/coverage.md rename to docs/development/coverage.md diff --git a/docs/sytest.md b/docs/development/sytest.md similarity index 100% rename from docs/sytest.md rename to docs/development/sytest.md diff --git a/docs/tracing/opentracing.md b/docs/development/tracing/opentracing.md similarity index 100% rename from docs/tracing/opentracing.md rename to docs/development/tracing/opentracing.md diff --git a/docs/tracing/setup.md b/docs/development/tracing/setup.md similarity index 100% rename from docs/tracing/setup.md rename to docs/development/tracing/setup.md From d3db542fbf5b35377586567b21bc5c28872167a1 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 10:56:20 +0100 Subject: [PATCH 064/124] Add federation peeking table tests (#2920) As the title says, adds tests for inbound/outbound peeking federation table tests. Also removes some unused code --- federationapi/queue/queue_test.go | 22 --- federationapi/storage/interface.go | 3 - .../storage/postgres/inbound_peeks_table.go | 32 ++-- .../storage/postgres/outbound_peeks_table.go | 31 ++-- .../storage/postgres/queue_edus_table.go | 21 --- .../storage/postgres/queue_pdus_table.go | 23 --- federationapi/storage/shared/storage_edus.go | 9 - federationapi/storage/shared/storage_pdus.go | 9 - .../storage/sqlite3/inbound_peeks_table.go | 32 ++-- .../storage/sqlite3/outbound_peeks_table.go | 31 ++-- .../storage/sqlite3/queue_edus_table.go | 21 --- .../storage/sqlite3/queue_pdus_table.go | 23 --- federationapi/storage/storage_test.go | 166 ++++++++++++++++++ .../tables/inbound_peeks_table_test.go | 148 ++++++++++++++++ federationapi/storage/tables/interface.go | 2 - .../tables/outbound_peeks_table_test.go | 147 ++++++++++++++++ 16 files changed, 503 insertions(+), 217 deletions(-) create mode 100644 federationapi/storage/tables/inbound_peeks_table_test.go create mode 100644 federationapi/storage/tables/outbound_peeks_table_test.go diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index b2ec4b836..c317edc21 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -221,28 +221,6 @@ func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverl return nil } -func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if pdus, ok := d.associatedPDUs[serverName]; ok { - count = int64(len(pdus)) - } - return count, nil -} - -func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if edus, ok := d.associatedEDUs[serverName]; ok { - count = int64(len(edus)) - } - return count, nil -} - func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index b15b8bfae..276cd9a50 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -45,9 +45,6 @@ type Database interface { CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) diff --git a/federationapi/storage/postgres/inbound_peeks_table.go b/federationapi/storage/postgres/inbound_peeks_table.go index df5c60761..ad2afcb15 100644 --- a/federationapi/storage/postgres/inbound_peeks_table.go +++ b/federationapi/storage/postgres/inbound_peeks_table.go @@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectInboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER by creation_ts" const renewInboundPeekSQL = "" + "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteInboundPeekSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteInboundPeeksSQL = "" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" @@ -74,25 +74,15 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er return } - if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { - return - } - if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { - return - } - if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { - return - } - if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertInboundPeekStmt, insertInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeeksStmt, selectInboundPeeksSQL}, + {&s.renewInboundPeekStmt, renewInboundPeekSQL}, + {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL}, + {&s.deleteInboundPeekStmt, deleteInboundPeekSQL}, + }.Prepare(db) } func (s *inboundPeeksStatements) InsertInboundPeek( diff --git a/federationapi/storage/postgres/outbound_peeks_table.go b/federationapi/storage/postgres/outbound_peeks_table.go index c22d893f7..5df684318 100644 --- a/federationapi/storage/postgres/outbound_peeks_table.go +++ b/federationapi/storage/postgres/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index d6507e13b..8870dc88d 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -62,10 +62,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_edus" + " WHERE json_nid = $1" -const selectQueueEDUCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE server_name = $1" - const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" @@ -81,7 +77,6 @@ type queueEDUsStatements struct { deleteQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt - selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt @@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error { {&s.deleteQueueEDUStmt, deleteQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, - {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, @@ -186,21 +180,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( return count, err } -func (s *queueEDUsStatements) SelectQueueEDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationapi/storage/postgres/queue_pdus_table.go b/federationapi/storage/postgres/queue_pdus_table.go index 38ac5a6eb..3b0bef9af 100644 --- a/federationapi/storage/postgres/queue_pdus_table.go +++ b/federationapi/storage/postgres/queue_pdus_table.go @@ -58,10 +58,6 @@ const selectQueuePDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" -const selectQueuePDUsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE server_name = $1" - const selectQueuePDUServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" @@ -71,7 +67,6 @@ type queuePDUsStatements struct { deleteQueuePDUsStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt selectQueuePDUReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt selectQueuePDUServerNamesStmt *sql.Stmt } @@ -95,9 +90,6 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil { - return - } if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil { return } @@ -146,21 +138,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) SelectQueuePDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index c796d2f8f..be8355f31 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -162,15 +162,6 @@ func (d *Database) CleanEDUs( }) } -// GetPendingEDUCount returns the number of EDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingEDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName) -} - // GetPendingServerNames returns the server names that have EDUs // waiting to be sent. func (d *Database) GetPendingEDUServerNames( diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index dc37d7507..da4cb979d 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -141,15 +141,6 @@ func (d *Database) CleanPDUs( }) } -// GetPendingPDUCount returns the number of PDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingPDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName) -} - // GetPendingServerNames returns the server names that have PDUs // waiting to be sent. func (d *Database) GetPendingPDUServerNames( diff --git a/federationapi/storage/sqlite3/inbound_peeks_table.go b/federationapi/storage/sqlite3/inbound_peeks_table.go index ad3c4a6dd..8c3567934 100644 --- a/federationapi/storage/sqlite3/inbound_peeks_table.go +++ b/federationapi/storage/sqlite3/inbound_peeks_table.go @@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectInboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewInboundPeekSQL = "" + "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteInboundPeekSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteInboundPeeksSQL = "" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" @@ -74,25 +74,15 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro return } - if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { - return - } - if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { - return - } - if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { - return - } - if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertInboundPeekStmt, insertInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeeksStmt, selectInboundPeeksSQL}, + {&s.renewInboundPeekStmt, renewInboundPeekSQL}, + {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL}, + {&s.deleteInboundPeekStmt, deleteInboundPeekSQL}, + }.Prepare(db) } func (s *inboundPeeksStatements) InsertInboundPeek( diff --git a/federationapi/storage/sqlite3/outbound_peeks_table.go b/federationapi/storage/sqlite3/outbound_peeks_table.go index e29026fab..33f452b68 100644 --- a/federationapi/storage/sqlite3/outbound_peeks_table.go +++ b/federationapi/storage/sqlite3/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index 8e7e7901f..0dc914328 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -63,10 +63,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_edus" + " WHERE json_nid = $1" -const selectQueueEDUCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE server_name = $1" - const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" @@ -82,7 +78,6 @@ type queueEDUsStatements struct { // deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt - selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt @@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error { {&s.insertQueueEDUStmt, insertQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, - {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, @@ -198,21 +192,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( return count, err } -func (s *queueEDUsStatements) SelectQueueEDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationapi/storage/sqlite3/queue_pdus_table.go b/federationapi/storage/sqlite3/queue_pdus_table.go index e818585a5..aee8b03d6 100644 --- a/federationapi/storage/sqlite3/queue_pdus_table.go +++ b/federationapi/storage/sqlite3/queue_pdus_table.go @@ -66,10 +66,6 @@ const selectQueuePDUsReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" -const selectQueuePDUsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE server_name = $1" - const selectQueuePDUsServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" @@ -79,7 +75,6 @@ type queuePDUsStatements struct { selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt selectQueueServerNamesStmt *sql.Stmt // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic } @@ -107,9 +102,6 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { - return - } if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { return } @@ -179,21 +171,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) SelectQueuePDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index f7408fa9f..14efa2655 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -2,10 +2,12 @@ package storage_test import ( "context" + "reflect" "testing" "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/federationapi/storage" @@ -80,3 +82,167 @@ func TestExpireEDUs(t *testing.T) { assert.Equal(t, 2, len(data)) }) } + +func TestOutboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add outbound peek + if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + for i := range outboundPeeks { + if outboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i]) + } + } + }) +} + +func TestInboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add inbound peek + if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if inboundPeek1.PeekID != peekID { + t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) + } + if inboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) + } + if inboundPeek1.ServerName != serverName { + t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) + } + if inboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(inboundPeek1, inboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if inboundPeek1.ServerName != inboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) + } + if inboundPeek1.RoomID != inboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) + } + for i := range inboundPeeks { + if inboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i]) + } + } + }) +} diff --git a/federationapi/storage/tables/inbound_peeks_table_test.go b/federationapi/storage/tables/inbound_peeks_table_test.go new file mode 100644 index 000000000..3a76a8576 --- /dev/null +++ b/federationapi/storage/tables/inbound_peeks_table_test.go @@ -0,0 +1,148 @@ +package tables_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + var tab tables.FederationInboundPeeks + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresInboundPeeksTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteInboundPeeksTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + return tab, close +} + +func TestInboundPeeksTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateInboundpeeksTable(t, dbType) + defer closeDB() + + // Insert a peek + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + if err := tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + inboundPeek1, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if inboundPeek1.PeekID != peekID { + t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) + } + if inboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) + } + if inboundPeek1.ServerName != serverName { + t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) + } + if inboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) + } + + // Renew the peek + if err = tab.RenewInboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + inboundPeek2, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(inboundPeek1, inboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if inboundPeek1.ServerName != inboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) + } + if inboundPeek1.RoomID != inboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) + } + + // delete the peek + if err = tab.DeleteInboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil { + t.Fatal(err) + } + + // There should be no peek anymore + peek, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if peek != nil { + t.Fatalf("got a peek which should be deleted: %+v", peek) + } + + // insert some peeks + var peekIDs []string + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + inboundPeeks, err := tab.SelectInboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) + } + for i := range inboundPeeks { + if inboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("") + } + } + + // And delete them again + if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil { + t.Fatal(err) + } + + // they should be gone now + inboundPeeks, err = tab.SelectInboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) > 0 { + t.Fatal("got inbound peeks which should be deleted") + } + + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 9f4e86a6e..2b36edb46 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -28,7 +28,6 @@ type FederationQueuePDUs interface { InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) } @@ -38,7 +37,6 @@ type FederationQueueEDUs interface { DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error diff --git a/federationapi/storage/tables/outbound_peeks_table_test.go b/federationapi/storage/tables/outbound_peeks_table_test.go new file mode 100644 index 000000000..dad6b9825 --- /dev/null +++ b/federationapi/storage/tables/outbound_peeks_table_test.go @@ -0,0 +1,147 @@ +package tables_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + var tab tables.FederationOutboundPeeks + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresOutboundPeeksTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + return tab, close +} + +func TestOutboundPeeksTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateOutboundpeeksTable(t, dbType) + defer closeDB() + + // Insert a peek + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + + // Renew the peek + if err = tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // delete the peek + if err = tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil { + t.Fatal(err) + } + + // There should be no peek anymore + peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if peek != nil { + t.Fatalf("got a peek which should be deleted: %+v", peek) + } + + // insert some peeks + var peekIDs []string + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + for i := range outboundPeeks { + if outboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("") + } + } + + // And delete them again + if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil { + t.Fatal(err) + } + + // they should be gone now + outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) > 0 { + t.Fatal("got outbound peeks which should be deleted") + } + }) +} From beea2432e6144a98370138f8d3f6334c19a044bb Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 11:31:54 +0100 Subject: [PATCH 065/124] Fix flakey test --- federationapi/storage/tables/inbound_peeks_table_test.go | 9 +++++---- .../storage/tables/outbound_peeks_table_test.go | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/federationapi/storage/tables/inbound_peeks_table_test.go b/federationapi/storage/tables/inbound_peeks_table_test.go index 3a76a8576..e5d898b3a 100644 --- a/federationapi/storage/tables/inbound_peeks_table_test.go +++ b/federationapi/storage/tables/inbound_peeks_table_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) { @@ -124,11 +125,11 @@ func TestInboundPeeksTable(t *testing.T) { if len(inboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) } - for i := range inboundPeeks { - if inboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("") - } + gotPeekIDs := make([]string, 0, len(inboundPeeks)) + for _, p := range inboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) // And delete them again if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil { diff --git a/federationapi/storage/tables/outbound_peeks_table_test.go b/federationapi/storage/tables/outbound_peeks_table_test.go index dad6b9825..a460af09d 100644 --- a/federationapi/storage/tables/outbound_peeks_table_test.go +++ b/federationapi/storage/tables/outbound_peeks_table_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) { @@ -124,11 +125,11 @@ func TestOutboundPeeksTable(t *testing.T) { if len(outboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) } - for i := range outboundPeeks { - if outboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("") - } + gotPeekIDs := make([]string, 0, len(outboundPeeks)) + for _, p := range outboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) // And delete them again if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil { From d1d2d16738a248846ea4367fe2b33485d56db6cd Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 11:54:03 +0100 Subject: [PATCH 066/124] Fix reset password endpoint (#2921) Fixes the admin password reset endpoint. It was using a wrong variable, so could not detect the user. Adds some more checks to validate we can actually change the password. --- clientapi/admin_test.go | 134 ++++++++++++++++++++++++++++++ clientapi/clientapi.go | 5 +- clientapi/routing/admin.go | 38 +++++++-- clientapi/routing/password.go | 3 +- clientapi/routing/register.go | 24 +----- clientapi/routing/routing.go | 14 +++- docs/administration/4_adminapi.md | 7 +- internal/httputil/httpapi.go | 6 +- internal/validate.go | 44 ++++++++++ test/user.go | 4 +- 10 files changed, 237 insertions(+), 42 deletions(-) create mode 100644 clientapi/admin_test.go create mode 100644 internal/validate.go diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go new file mode 100644 index 000000000..0d973f350 --- /dev/null +++ b/clientapi/admin_test.go @@ -0,0 +1,134 @@ +package clientapi + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestAdminResetPassword(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + vhUser := &test.User{ID: "@vhuser:vh1"} + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + // add a vhost + base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + }) + + rsAPI := roomserver.NewInternalAPI(base) + // Needed for changing the password/login + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil) + + // Create the users in the userapi and login + accessTokens := map[*test.User]string{ + aliceAdmin: "", + bob: "", + vhUser: "", + } + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + } + + testCases := []struct { + name string + requestingUser *test.User + userID string + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + {name: "Missing auth", requestingUser: bob, wantOK: false, userID: bob.ID}, + {name: "Bob is denied access", requestingUser: bob, wantOK: false, withHeader: true, userID: bob.ID}, + {name: "Alice is allowed access", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(8), + })}, + {name: "missing userID does not call function", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: ""}, // this 404s + {name: "rejects empty password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": "", + })}, + {name: "rejects unknown server name", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:localhost", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "rejects unknown user", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "allows changing password for different vhost", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: vhUser.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(8), + })}, + {name: "rejects existing user, missing body", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID}, + {name: "rejects invalid userID", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "!notauserid:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "rejects invalid json", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, `{invalidJSON}`)}, + {name: "rejects too weak password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(6), + })}, + {name: "rejects too long password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(513), + })}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID) + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID, tc.requestOpt) + } + + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser]) + } + + rec := httptest.NewRecorder() + base.DendriteAdminMux.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 080d4d9fa..62ffa6155 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -57,10 +57,7 @@ func AddPublicRoutes( } routing.Setup( - base.PublicClientAPIMux, - base.PublicWellKnownAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, + base, cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, syncProducer, transactionsCache, fsAPI, keyAPI, diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index be8073c33..8419622df 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -7,6 +7,7 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -98,20 +99,40 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { + if req.Body == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Unknown("Missing request body"), + } + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - serverName := cfg.Matrix.ServerName - localpart, ok := vars["localpart"] - if !ok { + var localpart string + userID := vars["userID"] + localpart, serverName, err := cfg.Matrix.SplitLocalID('@', userID) + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting user localpart."), + JSON: jsonerror.InvalidArgumentValue(err.Error()), } } - if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil { - localpart, serverName = l, s + accAvailableResp := &userapi.QueryAccountAvailabilityResponse{} + if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ + Localpart: localpart, + ServerName: serverName, + }, accAvailableResp); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.InternalAPIError(req.Context(), err), + } + } + if accAvailableResp.Available { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.Unknown("User does not exist"), + } } request := struct { Password string `json:"password"` @@ -128,6 +149,11 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.MissingArgument("Expecting non-empty password."), } } + + if resErr := internal.ValidatePassword(request.Password); resErr != nil { + return *resErr + } + updateReq := &userapi.PerformPasswordUpdateRequest{ Localpart: localpart, ServerName: serverName, diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 9772f669a..cd88b025a 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -7,6 +7,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -81,7 +82,7 @@ func Password( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. - if resErr = validatePassword(r.NewPassword); resErr != nil { + if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil { return *resErr } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 801000f61..4abbcdf9e 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -30,6 +30,7 @@ import ( "sync" "time" + "github.com/matrix-org/dendrite/internal" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" @@ -60,8 +61,6 @@ var ( ) const ( - minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based - maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain sessionIDLength = 24 ) @@ -315,23 +314,6 @@ func validateApplicationServiceUsername(localpart string, domain gomatrixserverl return nil } -// validatePassword returns an error response if the password is invalid -func validatePassword(password string) *util.JSONResponse { - // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if len(password) > maxPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)), - } - } else if len(password) > 0 && len(password) < minPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), - } - } - return nil -} - // validateRecaptcha returns an error response if the captcha response is invalid func validateRecaptcha( cfg *config.ClientAPI, @@ -636,7 +618,7 @@ func Register( return *resErr } } - if resErr := validatePassword(r.Password); resErr != nil { + if resErr := internal.ValidatePassword(r.Password); resErr != nil { return *resErr } @@ -1138,7 +1120,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { return *resErr } - if resErr := validatePassword(ssrr.Password); resErr != nil { + if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil { return *resErr } deviceID := "shared_secret_registration" diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index a510761eb..69b46214c 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -49,7 +50,7 @@ import ( // applied: // nolint: gocyclo func Setup( - publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router, + base *base.BaseDendrite, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, @@ -63,7 +64,14 @@ func Setup( extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, natsClient *nats.Conn, ) { - prometheus.MustRegister(amtRegUsers, sendEventDuration) + publicAPIMux := base.PublicClientAPIMux + wkMux := base.PublicWellKnownAPIMux + synapseAdminRouter := base.SynapseAdminMux + dendriteAdminRouter := base.DendriteAdminMux + + if base.EnableMetrics { + prometheus.MustRegister(amtRegUsers, sendEventDuration) + } rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) @@ -631,7 +639,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) return AuthFallback(w, req, vars["authType"], cfg) }), diff --git a/docs/administration/4_adminapi.md b/docs/administration/4_adminapi.md index 56e19a8b4..c521cbc90 100644 --- a/docs/administration/4_adminapi.md +++ b/docs/administration/4_adminapi.md @@ -44,7 +44,9 @@ This endpoint will instruct Dendrite to part the given local `userID` in the URL all rooms which they are currently joined. A JSON body will be returned containing the room IDs of all affected rooms. -## POST `/_dendrite/admin/resetPassword/{localpart}` +## POST `/_dendrite/admin/resetPassword/{userID}` + +Reset the password of a local user. Request body format: @@ -54,9 +56,6 @@ Request body format: } ``` -Reset the password of a local user. The `localpart` is the username only, i.e. if -the full user ID is `@alice:domain.com` then the local part is `alice`. - ## GET `/_dendrite/admin/fulltext/reindex` This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately. diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 127d1fac7..383913c60 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -198,7 +198,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTMLAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { +func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { span := opentracing.StartSpan(metricsName) defer span.Finish() @@ -211,6 +211,10 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) } } + if !enableMetrics { + return http.HandlerFunc(withSpan) + } + return promhttp.InstrumentHandlerCounter( promauto.NewCounterVec( prometheus.CounterOpts{ diff --git a/internal/validate.go b/internal/validate.go new file mode 100644 index 000000000..fc685ad50 --- /dev/null +++ b/internal/validate.go @@ -0,0 +1,44 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "fmt" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/util" +) + +const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based + +const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + +// ValidatePassword returns an error response if the password is invalid +func ValidatePassword(password string) *util.JSONResponse { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if len(password) > maxPasswordLength { + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)), + } + } else if len(password) > 0 && len(password) < minPasswordLength { + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + } + } + return nil +} diff --git a/test/user.go b/test/user.go index 692eae351..95a8f83e6 100644 --- a/test/user.go +++ b/test/user.go @@ -47,7 +47,7 @@ var ( type User struct { ID string - accountType api.AccountType + AccountType api.AccountType // key ID and private key of the server who has this user, if known. keyID gomatrixserverlib.KeyID privKey ed25519.PrivateKey @@ -66,7 +66,7 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve func WithAccountType(accountType api.AccountType) UserOpt { return func(u *User) { - u.accountType = accountType + u.AccountType = accountType } } From 09dff951d6be1fee1cc7c6872e98eb27e81fc778 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 13:04:32 +0100 Subject: [PATCH 067/124] More flakey tests --- federationapi/storage/storage_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 14efa2655..5b57d40d4 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -157,11 +157,11 @@ func TestOutboundPeeking(t *testing.T) { if len(outboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) } - for i := range outboundPeeks { - if outboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i]) - } + gotPeekIDs := make([]string, 0, len(outboundPeeks)) + for _, p := range outboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } @@ -239,10 +239,10 @@ func TestInboundPeeking(t *testing.T) { if len(inboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) } - for i := range inboundPeeks { - if inboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i]) - } + gotPeekIDs := make([]string, 0, len(inboundPeeks)) + for _, p := range inboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } From 5eed31fea330f5f0500384c98272b9a75a44fba4 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 13:05:59 +0100 Subject: [PATCH 068/124] Handle guest access [1/2?] (#2872) Needs https://github.com/matrix-org/sytest/pull/1315, as otherwise the membership events aren't persisted yet when hitting `/state` after kicking guest users. Makes the following tests pass: ``` Guest users denied access over federation if guest access prohibited Guest users are kicked from guest_access rooms on revocation of guest_access Guest users are kicked from guest_access rooms on revocation of guest_access over federation ``` Todo (in a follow up PR): - Restrict access to CS API Endpoints as per https://spec.matrix.org/v1.4/client-server-api/#client-behaviour-14 Co-authored-by: kegsay --- clientapi/clientapi.go | 3 +- clientapi/routing/joinroom.go | 10 +- clientapi/routing/joinroom_test.go | 158 ++++++++++++++++++++ roomserver/api/perform.go | 1 + roomserver/internal/api.go | 20 ++- roomserver/internal/input/input.go | 4 + roomserver/internal/input/input_events.go | 105 +++++++++++++ roomserver/internal/perform/perform_join.go | 23 +++ roomserver/roomserver_test.go | 141 +++++++++++++---- setup/config/config_global.go | 2 +- setup/config/config_test.go | 54 +++++++ sytest-blacklist | 5 +- sytest-whitelist | 5 +- test/room.go | 22 ++- userapi/api/api.go | 10 ++ userapi/api/api_trace.go | 6 + userapi/internal/api.go | 5 + userapi/inthttp/client.go | 12 ++ userapi/inthttp/server.go | 5 + userapi/userapi_test.go | 61 ++++++++ 20 files changed, 607 insertions(+), 45 deletions(-) create mode 100644 clientapi/routing/joinroom_test.go diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 62ffa6155..2d17e0928 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,6 +15,8 @@ package clientapi import ( + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/producers" @@ -26,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index c50e552bd..e371d9214 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias( joinReq := roomserverAPI.PerformJoinRequest{ RoomIDOrAlias: roomIDOrAlias, UserID: device.UserID, + IsGuest: device.AccountType == api.AccountTypeGuest, Content: map[string]interface{}{}, } joinRes := roomserverAPI.PerformJoinResponse{} @@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias( if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { done <- jsonerror.InternalAPIError(req.Context(), err) } else if joinRes.Error != nil { - done <- joinRes.Error.JSONResponse() + if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { + done <- util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg), + } + } else { + done <- joinRes.Error.JSONResponse() + } } else { done <- util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go new file mode 100644 index 000000000..9e8208e6d --- /dev/null +++ b/clientapi/routing/joinroom_test.go @@ -0,0 +1,158 @@ +package routing + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestJoinRoomByIDOrAlias(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc + + // Create the users in the userapi + for _, u := range []*test.User{alice, bob, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + } + + aliceDev := &uapi.Device{UserID: alice.ID} + bobDev := &uapi.Device{UserID: bob.ID} + charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest} + + // create a room with disabled guest access and invite Bob + resp := createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + RoomAliasName: "alias", + Invite: []string{bob.ID}, + GuestCanJoin: false, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crResp, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // create a room with guest access enabled and invite Charlie + resp = createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + Invite: []string{charlie.ID}, + GuestCanJoin: true, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // Dummy request + body := &bytes.Buffer{} + req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + device *uapi.Device + roomID string + wantHTTP200 bool + }{ + { + name: "User can join successfully by alias", + device: bobDev, + roomID: crResp.RoomAlias, + wantHTTP200: true, + }, + { + name: "User can join successfully by roomID", + device: bobDev, + roomID: crResp.RoomID, + wantHTTP200: true, + }, + { + name: "join is forbidden if user is guest", + device: charlieDev, + roomID: crResp.RoomID, + }, + { + name: "room does not exist", + device: aliceDev, + roomID: "!doesnotexist:test", + }, + { + name: "user from different server", + device: &uapi.Device{UserID: "@wrong:server"}, + roomID: crResp.RoomAlias, + }, + { + name: "user doesn't exist locally", + device: &uapi.Device{UserID: "@doesnotexist:test"}, + roomID: crResp.RoomAlias, + }, + { + name: "invalid room ID", + device: aliceDev, + roomID: "invalidRoomID", + }, + { + name: "roomAlias does not exist", + device: aliceDev, + roomID: "#doesnotexist:test", + }, + { + name: "room with guest_access event", + device: charlieDev, + roomID: crRespWithGuestAccess.RoomID, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID) + if tc.wantHTTP200 && !joinResp.Is2xx() { + t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp) + } + }) + } + }) +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e70e5ea9c..e789b9568 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -78,6 +78,7 @@ const ( type PerformJoinRequest struct { RoomIDOrAlias string `json:"room_id_or_alias"` UserID string `json:"user_id"` + IsGuest bool `json:"is_guest"` Content map[string]interface{} `json:"content"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"` Unsigned map[string]interface{} `json:"unsigned"` diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a3626609..451b37696 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -4,6 +4,10 @@ import ( "context" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" @@ -19,9 +23,6 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" ) // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI @@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.fsAPI = fsAPI r.KeyRing = keyRing + identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName) + if err != nil { + logrus.Panic(err) + } + r.Inputer = &input.Inputer{ Cfg: &r.Base.Cfg.RoomServer, Base: r.Base, @@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio JetStream: r.JetStream, NATSClient: r.NATSClient, Durable: nats.Durable(r.Durable), - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, + SigningIdentity: identity, FSAPI: fsAPI, KeyRing: keyRing, ACLs: r.ServerACLs, @@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Inputer: r.Inputer, } r.Unpeeker = &perform.Unpeeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { r.Leaver.UserAPI = userAPI + r.Inputer.UserAPI = userAPI } func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index e965691c9..941311030 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -23,6 +23,8 @@ import ( "sync" "time" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -79,6 +81,7 @@ type Inputer struct { JetStream nats.JetStreamContext Durable nats.SubOpt ServerName gomatrixserverlib.ServerName + SigningIdentity *gomatrixserverlib.SigningIdentity FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs @@ -87,6 +90,7 @@ type Inputer struct { workers sync.Map // room ID -> *worker Queryer *query.Queryer + UserAPI userapi.RoomserverUserAPI } // If a room consumer is inactive for a while then we will allow NATS diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 10b8ee27f..4179fc1ef 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -19,6 +19,7 @@ package input import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "time" @@ -31,6 +32,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + userAPI "github.com/matrix-org/dendrite/userapi/api" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" @@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent( } } + // If guest_access changed and is not can_join, kick all guest users. + if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { + if err = r.kickGuests(ctx, event, roomInfo); err != nil { + logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation") + } + } + // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) @@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState( succeeded = true return nil } + +// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited. +func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error { + membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) + if err != nil { + return err + } + + memberEvents, err := r.DB.Events(ctx, membershipNIDs) + if err != nil { + return err + } + + inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) + latestReq := &api.QueryLatestEventsAndStateRequest{ + RoomID: event.RoomID(), + } + latestRes := &api.QueryLatestEventsAndStateResponse{} + if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { + return err + } + + prevEvents := latestRes.LatestEvents + for _, memberEvent := range memberEvents { + if memberEvent.StateKey() == nil { + continue + } + + localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) + if err != nil { + continue + } + + accountRes := &userAPI.QueryAccountByLocalpartResponse{} + if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: senderDomain, + }, accountRes); err != nil { + return err + } + if accountRes.Account == nil { + continue + } + + if accountRes.Account.AccountType != userAPI.AccountTypeGuest { + continue + } + + var memberContent gomatrixserverlib.MemberContent + if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { + return err + } + memberContent.Membership = gomatrixserverlib.Leave + + stateKey := *memberEvent.StateKey() + fledglingEvent := &gomatrixserverlib.EventBuilder{ + RoomID: event.RoomID(), + Type: gomatrixserverlib.MRoomMember, + StateKey: &stateKey, + Sender: stateKey, + PrevEvents: prevEvents, + } + + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { + return err + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + if err != nil { + return err + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) + if err != nil { + return err + } + + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: senderDomain, + SendAsServer: string(senderDomain), + }) + prevEvents = []gomatrixserverlib.EventReference{ + event.EventReference(), + } + } + + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: inputEvents, + Asynchronous: true, // Needs to be async, as we otherwise create a deadlock + } + inputRes := &api.InputRoomEventsResponse{} + return r.InputRoomEvents(ctx, inputReq, inputRes) +} diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 4de008c66..fc7ba940c 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -16,6 +16,7 @@ package perform import ( "context" + "database/sql" "errors" "fmt" "strings" @@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID( } } + // If a guest is trying to join a room, check that the room has a m.room.guest_access event + if req.IsGuest { + var guestAccessEvent *gomatrixserverlib.HeaderedEvent + guestAccess := "forbidden" + guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "") + if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil { + logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'") + } + if guestAccessEvent != nil { + guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String() + } + + // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event + // is present on the room and has the guest_access value can_join. + if guestAccess != "can_join" { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorNotAllowed, + Msg: "Guest access is forbidden", + } + } + } + // If we should do a forced federated join then do that. var joinedVia gomatrixserverlib.ServerName if forceFederatedJoin { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 518bb3722..595ceb526 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -3,18 +3,23 @@ package roomserver_test import ( "context" "net/http" + "reflect" "testing" "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/userapi" + + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" ) @@ -29,7 +34,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s return base, db, close } -func Test_SharedUsers(t *testing.T) { +func TestUsers(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + + t.Run("shared users", func(t *testing.T) { + testSharedUsers(t, rsAPI) + }) + + t.Run("kick users", func(t *testing.T) { + usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + rsAPI.SetUserAPI(usrAPI) + testKickUsers(t, rsAPI, usrAPI) + }) + }) + +} + +func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) { alice := test.NewUser(t) bob := test.NewUser(t) room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) @@ -43,36 +69,93 @@ func Test_SharedUsers(t *testing.T) { }, test.WithStateKey(bob.ID)) ctx := context.Background() - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, _, close := mustCreateDatabase(t, dbType) - defer close() - rsAPI := roomserver.NewInternalAPI(base) - // SetFederationAPI starts the room event input consumer - rsAPI.SetFederationAPI(nil, nil) - // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { - t.Fatalf("failed to send events: %v", err) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Query the shared users for Alice, there should only be Bob. + // This is used by the SyncAPI keychange consumer. + res := &api.QuerySharedUsersResponse{} + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + // Also verify that we get the expected result when specifying OtherUserIDs. + // This is used by the SyncAPI when getting device list changes. + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } +} + +func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) { + // Create users and room; Bob is going to be the guest and kicked on revocation of guest access + alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest)) + + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true)) + + // Join with the guest user + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + // Create the users in the userapi, so the RSAPI can query the account type later + for _, u := range []*test.User{alice, bob} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &userAPI.PerformAccountCreationResponse{} + if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + // Create the room in the database + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Get the membership events BEFORE revoking guest access + membershipRes := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil { + t.Errorf("failed to query membership for room: %s", err) + } + + // revoke guest access + revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need + // to loop and wait for the events to be processed by the roomserver. + for i := 0; i <= 20; i++ { + // Get the membership events AFTER revoking guest access + membershipRes2 := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil { + t.Errorf("failed to query membership for room: %s", err) } - // Query the shared users for Alice, there should only be Bob. - // This is used by the SyncAPI keychange consumer. - res := &api.QuerySharedUsersResponse{} - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) + // The membership events should NOT match, as Bob (guest user) should now be kicked from the room + if !reflect.DeepEqual(membershipRes, membershipRes2) { + return } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - // Also verify that we get the expected result when specifying OtherUserIDs. - // This is used by the SyncAPI when getting device list changes. - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) - } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - }) + time.Sleep(time.Millisecond * 10) + } + + t.Errorf("memberships didn't change in time") } func Test_QueryLeftUsers(t *testing.T) { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 511951fe6..804eb1a2d 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g return id, nil } } - return nil, fmt.Errorf("no signing identity %q", serverName) + return nil, fmt.Errorf("no signing identity for %q", serverName) } func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index ee7e7389c..3408bf46d 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -16,8 +16,10 @@ package config import ( "fmt" + "reflect" "testing" + "github.com/matrix-org/gomatrixserverlib" "gopkg.in/yaml.v2" ) @@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) { } } } + +func Test_SigningIdentityFor(t *testing.T) { + tests := []struct { + name string + virtualHosts []*VirtualHost + serverName gomatrixserverlib.ServerName + want *gomatrixserverlib.SigningIdentity + wantErr bool + }{ + { + name: "no virtual hosts defined", + wantErr: true, + }, + { + name: "no identity found", + serverName: gomatrixserverlib.ServerName("doesnotexist"), + wantErr: true, + }, + { + name: "found identity", + serverName: gomatrixserverlib.ServerName("main"), + want: &gomatrixserverlib.SigningIdentity{ServerName: "main"}, + }, + { + name: "identity found on virtual hosts", + serverName: gomatrixserverlib.ServerName("vh2"), + virtualHosts: []*VirtualHost{ + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}}, + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}}, + }, + want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Global{ + VirtualHosts: tt.virtualHosts, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "main", + }, + } + got, err := c.SigningIdentityFor(tt.serverName) + if (err != nil) != tt.wantErr { + t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sytest-blacklist b/sytest-blacklist index c35b03bd7..99cfbabc8 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one Leaves are present in non-gapped incremental syncs # Below test was passing for the wrong reason, failing correctly since #2858 -New federated private chats get full presence information (SYN-115) \ No newline at end of file +New federated private chats get full presence information (SYN-115) + +# We don't have any state to calculate m.room.guest_access when accepting invites +Guest users can accept invites to private rooms over federation \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 49ffb8fe8..215889a49 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -763,4 +763,7 @@ AS and main public room lists are separate local user has tags copied to the new room remote user has tags copied to the new room /upgrade moves remote aliases to the new room -Local and remote users' homeservers remove a room from their public directory on upgrade \ No newline at end of file +Local and remote users' homeservers remove a room from their public directory on upgrade +Guest users denied access over federation if guest access prohibited +Guest users are kicked from guest_access rooms on revocation of guest_access +Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file diff --git a/test/room.go b/test/room.go index 4328bf84f..685876cb0 100644 --- a/test/room.go +++ b/test/room.go @@ -38,11 +38,12 @@ var ( ) type Room struct { - ID string - Version gomatrixserverlib.RoomVersion - preset Preset - visibility gomatrixserverlib.HistoryVisibility - creator *User + ID string + Version gomatrixserverlib.RoomVersion + preset Preset + guestCanJoin bool + visibility gomatrixserverlib.HistoryVisibility + creator *User authEvents gomatrixserverlib.AuthEvents currentState map[string]*gomatrixserverlib.HeaderedEvent @@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) { r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) + if r.guestCanJoin { + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{ + "guest_access": "can_join", + }, WithStateKey("")) + } } // Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. @@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { r.Version = ver } } + +func GuestsCanJoin(canJoin bool) roomModifier { + return func(t *testing.T, r *Room) { + r.guestCanJoin = canJoin + } +} diff --git a/userapi/api/api.go b/userapi/api/api.go index d3f5aefc8..4ea2e91c3 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -50,6 +50,7 @@ type KeyserverUserAPI interface { type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the media api @@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct { ServerName gomatrixserverlib.ServerName Medium string } + +type QueryAccountByLocalpartRequest struct { + Localpart string + ServerName gomatrixserverlib.ServerName +} + +type QueryAccountByLocalpartResponse struct { + Account *Account +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index ce661770f..d10b5767b 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex return err } +func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error { + err := t.Impl.QueryAccountByLocalpart(ctx, req, res) + util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 3f256457e..0bb480da6 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } +func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { + res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) + return +} + // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 87ae058c2..51b0fe3ef 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -60,6 +60,7 @@ const ( QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" + QueryAccountByLocalpartPath = "/userapi/queryAccountType" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( h.httpClient, ctx, request, response, ) } + +func (h *httpUserInternalAPI) QueryAccountByLocalpart( + ctx context.Context, + req *api.QueryAccountByLocalpartRequest, + res *api.QueryAccountByLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath, + h.httpClient, ctx, req, res, + ) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index f0579079f..b40b507c2 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics PerformSaveThreePIDAssociationPath, httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), ) + + internalAPIMux.Handle( + QueryAccountByLocalpartPath, + httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart), + ) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 8a19af195..dada56de4 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) { }) }) } + +func TestQueryAccountByLocalpart(t *testing.T) { + alice := test.NewUser(t) + + localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + + createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType) + if err != nil { + t.Error(err) + } + + testCases := func(t *testing.T, internalAPI api.UserInternalAPI) { + // Query existing account + queryAccResp := &api.QueryAccountByLocalpartResponse{} + if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: userServername, + }, queryAccResp); err != nil { + t.Error(err) + } + if !reflect.DeepEqual(createdAcc, queryAccResp.Account) { + t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account) + } + + // Query non-existent account, this should result in an error + err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: "doesnotexist", + ServerName: userServername, + }, queryAccResp) + + if err == nil { + t.Fatalf("expected an error, but got none: %+v", queryAccResp) + } + } + + t.Run("Monolith", func(t *testing.T) { + testCases(t, intAPI) + // also test tracing + testCases(t, &api.UserInternalAPITrace{Impl: intAPI}) + }) + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, intAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + testCases(t, userHTTPApi) + + }) + }) +} From f47515e38b0bbf734bf977daedd836bf85465272 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 12:52:47 +0100 Subject: [PATCH 069/124] Pushrule tweaks, make `pattern` non-optional on `EventMatchCondition` (#2918) This should fix https://github.com/matrix-org/dendrite/issues/2882 (Tested with FluffyChat 1.7.1) Also adds tests that the predefined push rules (as per the spec) is what we have in Dendrite. --- internal/pushrules/condition.go | 2 +- internal/pushrules/default_content.go | 9 +- internal/pushrules/default_override.go | 54 +++++---- internal/pushrules/default_pushrules_test.go | 111 +++++++++++++++++++ internal/pushrules/default_underride.go | 39 ++----- internal/pushrules/evaluate.go | 10 +- internal/pushrules/evaluate_test.go | 51 +++++---- internal/pushrules/pushrules.go | 10 +- internal/pushrules/util.go | 4 + internal/pushrules/validate.go | 5 +- internal/pushrules/validate_test.go | 19 ++-- userapi/consumers/roomserver_test.go | 6 - 12 files changed, 210 insertions(+), 110 deletions(-) create mode 100644 internal/pushrules/default_pushrules_test.go diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go index 2d9773c0f..c7b30da8e 100644 --- a/internal/pushrules/condition.go +++ b/internal/pushrules/condition.go @@ -14,7 +14,7 @@ type Condition struct { // Pattern indicates the value pattern that must match. Required // for EventMatchCondition. - Pattern string `json:"pattern,omitempty"` + Pattern *string `json:"pattern,omitempty"` // Is indicates the condition that must be fulfilled. Required for // RoomMemberCountCondition. diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go index 8982dd587..a055ba03c 100644 --- a/internal/pushrules/default_content.go +++ b/internal/pushrules/default_content.go @@ -15,13 +15,7 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule { RuleID: MRuleContainsUserName, Default: true, Enabled: true, - Pattern: localpart, - Conditions: []*Condition{ - { - Kind: EventMatchCondition, - Key: "content.body", - }, - }, + Pattern: &localpart, Actions: []*Action{ {Kind: NotifyAction}, { @@ -32,7 +26,6 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule { { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go index a9788df2f..f97427b71 100644 --- a/internal/pushrules/default_override.go +++ b/internal/pushrules/default_override.go @@ -22,15 +22,15 @@ const ( MRuleTombstone = ".m.rule.tombstone" MRuleRoomNotif = ".m.rule.roomnotif" MRuleReaction = ".m.rule.reaction" + MRuleRoomACLs = ".m.rule.room.server_acl" ) var ( mRuleMasterDefinition = Rule{ - RuleID: MRuleMaster, - Default: true, - Enabled: false, - Conditions: []*Condition{}, - Actions: []*Action{{Kind: DontNotifyAction}}, + RuleID: MRuleMaster, + Default: true, + Enabled: false, + Actions: []*Action{{Kind: DontNotifyAction}}, } mRuleSuppressNoticesDefinition = Rule{ RuleID: MRuleSuppressNotices, @@ -40,7 +40,7 @@ var ( { Kind: EventMatchCondition, Key: "content.msgtype", - Pattern: "m.notice", + Pattern: pointer("m.notice"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -53,7 +53,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -73,7 +73,6 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } @@ -85,12 +84,12 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.tombstone", + Pattern: pointer("m.room.tombstone"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: "", + Pattern: pointer(""), }, }, Actions: []*Action{ @@ -98,10 +97,27 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } + mRuleACLsDefinition = Rule{ + RuleID: MRuleRoomACLs, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: pointer("m.room.server_acl"), + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: pointer(""), + }, + }, + Actions: []*Action{}, + } mRuleRoomNotifDefinition = Rule{ RuleID: MRuleRoomNotif, Default: true, @@ -110,7 +126,7 @@ var ( { Kind: EventMatchCondition, Key: "content.body", - Pattern: "@room", + Pattern: pointer("@room"), }, { Kind: SenderNotificationPermissionCondition, @@ -122,7 +138,6 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } @@ -134,7 +149,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.reaction", + Pattern: pointer("m.reaction"), }, }, Actions: []*Action{ @@ -152,17 +167,17 @@ func mRuleInviteForMeDefinition(userID string) *Rule { { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, { Kind: EventMatchCondition, Key: "content.membership", - Pattern: "invite", + Pattern: pointer("invite"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: userID, + Pattern: pointer(userID), }, }, Actions: []*Action{ @@ -172,11 +187,6 @@ func mRuleInviteForMeDefinition(userID string) *Rule { Tweak: SoundTweak, Value: "default", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } } diff --git a/internal/pushrules/default_pushrules_test.go b/internal/pushrules/default_pushrules_test.go new file mode 100644 index 000000000..dea829842 --- /dev/null +++ b/internal/pushrules/default_pushrules_test.go @@ -0,0 +1,111 @@ +package pushrules + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Tests that the pre-defined rules as of +// https://spec.matrix.org/v1.4/client-server-api/#predefined-rules +// are correct +func TestDefaultRules(t *testing.T) { + type testCase struct { + name string + inputBytes []byte + want Rule + } + + testCases := []testCase{ + // Default override rules + { + name: ".m.rule.master", + inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"actions":["dont_notify"]}`), + want: mRuleMasterDefinition, + }, + { + name: ".m.rule.suppress_notices", + inputBytes: []byte(`{"rule_id":".m.rule.suppress_notices","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.msgtype","pattern":"m.notice"}],"actions":["dont_notify"]}`), + want: mRuleSuppressNoticesDefinition, + }, + { + name: ".m.rule.invite_for_me", + inputBytes: []byte(`{"rule_id":".m.rule.invite_for_me","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"},{"kind":"event_match","key":"content.membership","pattern":"invite"},{"kind":"event_match","key":"state_key","pattern":"@test:localhost"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: *mRuleInviteForMeDefinition("@test:localhost"), + }, + { + name: ".m.rule.member_event", + inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"}],"actions":["dont_notify"]}`), + want: mRuleMemberEventDefinition, + }, + { + name: ".m.rule.contains_display_name", + inputBytes: []byte(`{"rule_id":".m.rule.contains_display_name","default":true,"enabled":true,"conditions":[{"kind":"contains_display_name"}],"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}]}`), + want: mRuleContainsDisplayNameDefinition, + }, + { + name: ".m.rule.tombstone", + inputBytes: []byte(`{"rule_id":".m.rule.tombstone","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.tombstone"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":["notify",{"set_tweak":"highlight"}]}`), + want: mRuleTombstoneDefinition, + }, + { + name: ".m.rule.room.server_acl", + inputBytes: []byte(`{"rule_id":".m.rule.room.server_acl","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.server_acl"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":[]}`), + want: mRuleACLsDefinition, + }, + { + name: ".m.rule.roomnotif", + inputBytes: []byte(`{"rule_id":".m.rule.roomnotif","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.body","pattern":"@room"},{"kind":"sender_notification_permission","key":"room"}],"actions":["notify",{"set_tweak":"highlight"}]}`), + want: mRuleRoomNotifDefinition, + }, + // Default content rules + { + name: ".m.rule.contains_user_name", + inputBytes: []byte(`{"rule_id":".m.rule.contains_user_name","default":true,"enabled":true,"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}],"pattern":"myLocalUser"}`), + want: *mRuleContainsUserNameDefinition("myLocalUser"), + }, + // default underride rules + { + name: ".m.rule.call", + inputBytes: []byte(`{"rule_id":".m.rule.call","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.call.invite"}],"actions":["notify",{"set_tweak":"sound","value":"ring"}]}`), + want: mRuleCallDefinition, + }, + { + name: ".m.rule.encrypted_room_one_to_one", + inputBytes: []byte(`{"rule_id":".m.rule.encrypted_room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: mRuleEncryptedRoomOneToOneDefinition, + }, + { + name: ".m.rule.room_one_to_one", + inputBytes: []byte(`{"rule_id":".m.rule.room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: mRuleRoomOneToOneDefinition, + }, + { + name: ".m.rule.message", + inputBytes: []byte(`{"rule_id":".m.rule.message","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify"]}`), + want: mRuleMessageDefinition, + }, + { + name: ".m.rule.encrypted", + inputBytes: []byte(`{"rule_id":".m.rule.encrypted","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify"]}`), + want: mRuleEncryptedDefinition, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := Rule{} + // unmarshal predefined push rules + err := json.Unmarshal(tc.inputBytes, &r) + assert.NoError(t, err) + assert.Equal(t, tc.want, r) + + // and reverse it to check we get the expected result + got, err := json.Marshal(r) + assert.NoError(t, err) + assert.Equal(t, string(got), string(tc.inputBytes)) + }) + + } +} diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go index 8da449a19..118bfae59 100644 --- a/internal/pushrules/default_underride.go +++ b/internal/pushrules/default_underride.go @@ -25,7 +25,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.call.invite", + Pattern: pointer("m.call.invite"), }, }, Actions: []*Action{ @@ -35,11 +35,6 @@ var ( Tweak: SoundTweak, Value: "ring", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleEncryptedRoomOneToOneDefinition = Rule{ @@ -54,7 +49,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ @@ -64,11 +59,6 @@ var ( Tweak: SoundTweak, Value: "default", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleRoomOneToOneDefinition = Rule{ @@ -83,20 +73,15 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ {Kind: NotifyAction}, { Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, + Tweak: SoundTweak, + Value: "default", }, }, } @@ -108,16 +93,11 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ {Kind: NotifyAction}, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleEncryptedDefinition = Rule{ @@ -128,16 +108,11 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ {Kind: NotifyAction}, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } ) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 4ff9939a6..fc8e0f174 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -104,7 +104,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu case ContentKind: // TODO: "These configure behaviour for (unencrypted) messages // that match certain patterns." - Does that mean "content.body"? - return patternMatches("content.body", rule.Pattern, event) + if rule.Pattern == nil { + return false, nil + } + return patternMatches("content.body", *rule.Pattern, event) case RoomKind: return rule.RuleID == event.RoomID(), nil @@ -120,7 +123,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { switch cond.Kind { case EventMatchCondition: - return patternMatches(cond.Key, cond.Pattern, event) + if cond.Pattern == nil { + return false, fmt.Errorf("missing condition pattern") + } + return patternMatches(cond.Key, *cond.Pattern, event) case ContainsDisplayNameCondition: return patternMatches("content.body", ec.UserDisplayName(), event) diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index c5d5abd2a..ca8ae5519 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -79,8 +79,8 @@ func TestRuleMatches(t *testing.T) { {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, - {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, - {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, @@ -106,41 +106,44 @@ func TestConditionMatches(t *testing.T) { Name string Cond Condition EventJSON string - Want bool + WantMatch bool + WantErr bool }{ - {"empty", Condition{}, `{}`, false}, - {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, + {Name: "empty", Cond: Condition{}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "empty", Cond: Condition{Kind: "unknownstring"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, // Neither of these should match because `content` is not a full string match, // and `content.body` is not a string value. - {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, false}, - {"eventBodyMatch", Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3"}, `{"content":{"body": 3}}`, false}, + {Name: "eventMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content", Pattern: pointer("")}, EventJSON: `{"content":{}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3", Pattern: pointer("")}, EventJSON: `{"content":{"body": "3"}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch matches", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Pattern: pointer("world")}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: true, WantErr: false}, + {Name: "EventMatch missing pattern", Cond: Condition{Kind: EventMatchCondition, Key: "content.body"}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: false, WantErr: true}, - {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, - {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, + {Name: "displayNameNoMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"something without displayname"}}`, WantMatch: false, WantErr: false}, + {Name: "displayNameMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"hello Dear User, how are you?"}}`, WantMatch: true, WantErr: false}, - {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false}, - {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true}, - {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false}, - {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true}, - {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false}, - {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true}, - {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false}, - {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true}, - {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false}, - {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true}, + {Name: "roomMemberCountLessNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<3"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountLessEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">1"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, - {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true}, - {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false}, + {Name: "senderNotificationPermissionMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@poweruser:example.com"}`, WantMatch: true, WantErr: false}, + {Name: "senderNotificationPermissionNoMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@nobody:example.com"}`, WantMatch: false, WantErr: false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{2}) - if err != nil { + if err != nil && !tst.WantErr { t.Fatalf("conditionMatches failed: %v", err) } - if got != tst.Want { - t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.Want, tst.Name) + if got != tst.WantMatch { + t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.WantMatch, tst.Name) } }) } diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go index bbed1f95f..98deaf132 100644 --- a/internal/pushrules/pushrules.go +++ b/internal/pushrules/pushrules.go @@ -36,18 +36,18 @@ type Rule struct { // around. Required. Enabled bool `json:"enabled"` + // Conditions provide the rule's conditions for OverrideKind and + // UnderrideKind. Not allowed for other kinds. + Conditions []*Condition `json:"conditions,omitempty"` + // Actions describe the desired outcome, should the rule // match. Required. Actions []*Action `json:"actions"` - // Conditions provide the rule's conditions for OverrideKind and - // UnderrideKind. Not allowed for other kinds. - Conditions []*Condition `json:"conditions"` - // Pattern is the body pattern to match for ContentKind. Required // for that kind. The interpretation is the same as that of // Condition.Pattern. - Pattern string `json:"pattern"` + Pattern *string `json:"pattern,omitempty"` } // Scope only has one valid value. See also AccountRuleSets. diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go index fb9c05be2..de8fe5cd0 100644 --- a/internal/pushrules/util.go +++ b/internal/pushrules/util.go @@ -128,3 +128,7 @@ func parseRoomMemberCountCondition(s string) (func(int) bool, error) { b = int(v) return cmp, nil } + +func pointer[t any](s t) *t { + return &s +} diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go index 5d260f0b9..f50c51bd7 100644 --- a/internal/pushrules/validate.go +++ b/internal/pushrules/validate.go @@ -34,7 +34,10 @@ func ValidateRule(kind Kind, rule *Rule) []error { } case ContentKind: - if rule.Pattern == "" { + if rule.Pattern == nil { + errs = append(errs, fmt.Errorf("missing content rule pattern")) + } + if rule.Pattern != nil && *rule.Pattern == "" { errs = append(errs, fmt.Errorf("missing content rule pattern")) } diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go index b276eb551..966e46259 100644 --- a/internal/pushrules/validate_test.go +++ b/internal/pushrules/validate_test.go @@ -12,15 +12,16 @@ func TestValidateRuleNegatives(t *testing.T) { Rule Rule WantErrString string }{ - {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"}, - {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"}, - {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"}, - {"noActions", OverrideKind, Rule{}, "missing actions"}, - {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"}, - {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"}, - {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"}, - {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"}, - {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"}, + {Name: "emptyRuleID", Kind: OverrideKind, Rule: Rule{}, WantErrString: "invalid rule ID"}, + {Name: "invalidKind", Kind: Kind("something else"), Rule: Rule{}, WantErrString: "invalid rule kind"}, + {Name: "ruleIDBackslash", Kind: OverrideKind, Rule: Rule{RuleID: "#foo\\:example.com"}, WantErrString: "invalid rule ID"}, + {Name: "noActions", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing actions"}, + {Name: "invalidAction", Kind: OverrideKind, Rule: Rule{Actions: []*Action{{}}}, WantErrString: "invalid rule action kind"}, + {Name: "invalidCondition", Kind: OverrideKind, Rule: Rule{Conditions: []*Condition{{}}}, WantErrString: "invalid rule condition kind"}, + {Name: "overrideNoCondition", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"}, + {Name: "underrideNoCondition", Kind: UnderrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"}, + {Name: "contentNoPattern", Kind: ContentKind, Rule: Rule{}, WantErrString: "missing content rule pattern"}, + {Name: "contentEmptyPattern", Kind: ContentKind, Rule: Rule{Pattern: pointer("")}, WantErrString: "missing content rule pattern"}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 265e3a3aa..39f4aab4a 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -81,11 +81,6 @@ func Test_evaluatePushRules(t *testing.T) { wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ {Kind: pushrules.NotifyAction}, - { - Kind: pushrules.SetTweakAction, - Tweak: pushrules.HighlightTweak, - Value: false, - }, }, }, { @@ -103,7 +98,6 @@ func Test_evaluatePushRules(t *testing.T) { { Kind: pushrules.SetTweakAction, Tweak: pushrules.HighlightTweak, - Value: true, }, }, }, From f762ce1050f2add409a83b1eeb6da5940177cfa7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 14:11:11 +0100 Subject: [PATCH 070/124] Add clientapi tests (#2916) This PR - adds several tests for the clientapi, mostly around `/register` and auth fallback. - removes the now deprecated `homeserver` field from responses to `/register` and `/login` - slightly refactors auth fallback handling --- .github/workflows/dendrite.yml | 3 +- clientapi/routing/admin.go | 6 +- clientapi/routing/auth_fallback.go | 115 ++++----- clientapi/routing/auth_fallback_test.go | 149 ++++++++++++ clientapi/routing/login.go | 9 +- clientapi/routing/password.go | 4 +- clientapi/routing/register.go | 148 ++++-------- clientapi/routing/register_test.go | 306 ++++++++++++++++++++++++ clientapi/routing/routing.go | 4 +- cmd/create-account/main.go | 32 +-- internal/httputil/httpapi.go | 9 +- internal/validate.go | 84 ++++++- internal/validate_test.go | 170 +++++++++++++ setup/config/config.go | 12 +- setup/config/config_clientapi.go | 7 +- 15 files changed, 838 insertions(+), 220 deletions(-) create mode 100644 clientapi/routing/auth_fallback_test.go create mode 100644 internal/validate_test.go diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 2c04005d2..1de39850d 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -331,8 +331,7 @@ jobs: postgres: postgres api: full-http container: - # Temporary for debugging to see if this image is working better. - image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73 + image: matrixdotorg/sytest-dendrite volumes: - ${{ github.workspace }}:/src - /root/.cache/go-build:/github/home/.cache/go-build diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 8419622df..dbd913376 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap request := struct { Password string `json:"password"` }{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()), @@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } } - if resErr := internal.ValidatePassword(request.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(request.Password); err != nil { + return *internal.PasswordResponse(err) } updateReq := &userapi.PerformPasswordUpdateRequest{ diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index ad870993e..f8d3684fe 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -15,11 +15,11 @@ package routing import ( + "fmt" "html/template" "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/util" ) @@ -101,14 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s func AuthFallback( w http.ResponseWriter, req *http.Request, authType string, cfg *config.ClientAPI, -) *util.JSONResponse { - sessionID := req.URL.Query().Get("session") +) { + // We currently only support "m.login.recaptcha", so fail early if that's not requested + if authType == authtypes.LoginTypeRecaptcha { + if !cfg.RecaptchaEnabled { + writeHTTPMessage(w, req, + "Recaptcha login is disabled on this Homeserver", + http.StatusBadRequest, + ) + return + } + } else { + writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented) + return + } + sessionID := req.URL.Query().Get("session") if sessionID == "" { - return writeHTTPMessage(w, req, + writeHTTPMessage(w, req, "Session ID not provided", http.StatusBadRequest, ) + return } serveRecaptcha := func() { @@ -130,70 +144,44 @@ func AuthFallback( if req.Method == http.MethodGet { // Handle Recaptcha - if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(cfg, w, req); err != nil { - return err - } - - serveRecaptcha() - return nil - } - return &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown auth stage type"), - } + serveRecaptcha() + return } else if req.Method == http.MethodPost { // Handle Recaptcha - if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(cfg, w, req); err != nil { - return err - } - - clientIP := req.RemoteAddr - err := req.ParseForm() - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") - res := jsonerror.InternalServerError() - return &res - } - - response := req.Form.Get(cfg.RecaptchaFormField) - if err := validateRecaptcha(cfg, response, clientIP); err != nil { - util.GetLogger(req.Context()).Error(err) - return err - } - - // Success. Add recaptcha as a completed login flow - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) - - serveSuccess() - return nil + clientIP := req.RemoteAddr + err := req.ParseForm() + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") + w.WriteHeader(http.StatusBadRequest) + serveRecaptcha() + return } - return &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown auth stage type"), + response := req.Form.Get(cfg.RecaptchaFormField) + err = validateRecaptcha(cfg, response, clientIP) + switch err { + case ErrMissingResponse: + w.WriteHeader(http.StatusBadRequest) + serveRecaptcha() // serve the initial page again, instead of nothing + return + case ErrInvalidCaptcha: + w.WriteHeader(http.StatusUnauthorized) + serveRecaptcha() + return + case nil: + default: // something else failed + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + serveRecaptcha() + return } - } - return &util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), - } -} -// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver. -func checkRecaptchaEnabled( - cfg *config.ClientAPI, - w http.ResponseWriter, - req *http.Request, -) *util.JSONResponse { - if !cfg.RecaptchaEnabled { - return writeHTTPMessage(w, req, - "Recaptcha login is disabled on this Homeserver", - http.StatusBadRequest, - ) + // Success. Add recaptcha as a completed login flow + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + + serveSuccess() + return } - return nil + writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed) } // writeHTTPMessage writes the given header and message to the HTTP response writer. @@ -201,13 +189,10 @@ func checkRecaptchaEnabled( func writeHTTPMessage( w http.ResponseWriter, req *http.Request, message string, header int, -) *util.JSONResponse { +) { w.WriteHeader(header) _, err := w.Write([]byte(message)) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("w.Write failed") - res := jsonerror.InternalServerError() - return &res } - return nil } diff --git a/clientapi/routing/auth_fallback_test.go b/clientapi/routing/auth_fallback_test.go new file mode 100644 index 000000000..0d77f9a01 --- /dev/null +++ b/clientapi/routing/auth_fallback_test.go @@ -0,0 +1,149 @@ +package routing + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test/testrig" +) + +func Test_AuthFallback(t *testing.T) { + base, _, _ := testrig.Base(nil) + defer base.Close() + + for _, useHCaptcha := range []bool{false, true} { + for _, recaptchaEnabled := range []bool{false, true} { + for _, wantErr := range []bool{false, true} { + t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) { + // Set the defaults for each test + base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, Monolithic: true}) + base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled + base.Cfg.ClientAPI.RecaptchaPublicKey = "pub" + base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv" + if useHCaptcha { + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify" + base.Cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js" + base.Cfg.ClientAPI.RecaptchaFormField = "h-captcha-response" + base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha" + } + cfgErrs := &config.ConfigErrors{} + base.Cfg.ClientAPI.Verify(cfgErrs, true) + if len(*cfgErrs) > 0 { + t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error()) + } + + req := httptest.NewRequest(http.MethodGet, "/?session=1337", nil) + rec := httptest.NewRecorder() + + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if !recaptchaEnabled { + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest) + } + if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" { + t.Fatalf("unexpected response body: %s", rec.Body.String()) + } + } else { + if !strings.Contains(rec.Body.String(), base.Cfg.ClientAPI.RecaptchaSitekeyClass) { + t.Fatalf("body does not contain %s: %s", base.Cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String()) + } + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if wantErr { + _, _ = w.Write([]byte(`{"success":false}`)) + return + } + _, _ = w.Write([]byte(`{"success":true}`)) + })) + defer srv.Close() // nolint: errcheck + + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + + // check the result after sending the captcha + req = httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + req.Form = url.Values{} + req.Form.Add(base.Cfg.ClientAPI.RecaptchaFormField, "someRandomValue") + rec = httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if recaptchaEnabled { + if !wantErr { + if rec.Code != http.StatusOK { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusOK) + } + if rec.Body.String() != successTemplate { + t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), successTemplate) + } + } else { + if rec.Code != http.StatusUnauthorized { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusUnauthorized) + } + wantString := "Authentication" + if !strings.Contains(rec.Body.String(), wantString) { + t.Fatalf("expected response to contain '%s', but didn't: %s", wantString, rec.Body.String()) + } + } + } else { + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest) + } + if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" { + t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), "successTemplate") + } + } + }) + } + } + } + + t.Run("unknown fallbacks are handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, "DoesNotExist", &base.Cfg.ClientAPI) + if rec.Code != http.StatusNotImplemented { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusNotImplemented) + } + }) + + t.Run("unknown methods are handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodDelete, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusMethodNotAllowed) + } + }) + + t.Run("missing session parameter is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing session parameter is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing 'response' is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 0de324da1..778c8c0c3 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -23,15 +23,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) type loginResponse struct { - UserID string `json:"user_id"` - AccessToken string `json:"access_token"` - HomeServer gomatrixserverlib.ServerName `json:"home_server"` - DeviceID string `json:"device_id"` + UserID string `json:"user_id"` + AccessToken string `json:"access_token"` + DeviceID string `json:"device_id"` } type flows struct { @@ -116,7 +114,6 @@ func completeAuth( JSON: loginResponse{ UserID: performRes.Device.UserID, AccessToken: performRes.Device.AccessToken, - HomeServer: serverName, DeviceID: performRes.Device.ID, }, } diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index cd88b025a..f7f9da622 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -82,8 +82,8 @@ func Password( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. - if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil { - return *resErr + if err := internal.ValidatePassword(r.NewPassword); err != nil { + return *internal.PasswordResponse(err) } // Get the local part. diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 4abbcdf9e..6087bda0c 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -18,12 +18,12 @@ package routing import ( "context" "encoding/json" + "errors" "fmt" "io" "net" "net/http" "net/url" - "regexp" "sort" "strconv" "strings" @@ -60,10 +60,7 @@ var ( ) ) -const ( - maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain - sessionIDLength = 24 -) +const sessionIDLength = 24 // sessionsDict keeps track of completed auth stages for each session. // It shouldn't be passed by value because it contains a mutex. @@ -198,8 +195,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) { } var ( - sessions = newSessionsDict() - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) + sessions = newSessionsDict() ) // registerRequest represents the submitted registration request. @@ -262,10 +258,9 @@ func newUserInteractiveResponse( // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register type registerResponse struct { - UserID string `json:"user_id"` - AccessToken string `json:"access_token,omitempty"` - HomeServer gomatrixserverlib.ServerName `json:"home_server"` - DeviceID string `json:"device_id,omitempty"` + UserID string `json:"user_id"` + AccessToken string `json:"access_token,omitempty"` + DeviceID string `json:"device_id,omitempty"` } // recaptchaResponse represents the HTTP response from a Google Recaptcha server @@ -276,66 +271,28 @@ type recaptchaResponse struct { ErrorCodes []int `json:"error-codes"` } -// validateUsername returns an error response if the username is invalid -func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } else if localpart[0] == '_' { // Regex checks its not a zero length string - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"), - } - } - return nil -} - -// validateApplicationServiceUsername returns an error response if the username is invalid for an application service -func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } - return nil -} +var ( + ErrInvalidCaptcha = errors.New("invalid captcha response") + ErrMissingResponse = errors.New("captcha response is required") + ErrCaptchaDisabled = errors.New("captcha registration is disabled") +) // validateRecaptcha returns an error response if the captcha response is invalid func validateRecaptcha( cfg *config.ClientAPI, response string, clientip string, -) *util.JSONResponse { +) error { ip, _, _ := net.SplitHostPort(clientip) if !cfg.RecaptchaEnabled { - return &util.JSONResponse{ - Code: http.StatusConflict, - JSON: jsonerror.Unknown("Captcha registration is disabled"), - } + return ErrCaptchaDisabled } if response == "" { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Captcha response is required"), - } + return ErrMissingResponse } - // Make a POST request to Google's API to check the captcha response + // Make a POST request to the captcha provider API to check the captcha response resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI, url.Values{ "secret": {cfg.RecaptchaPrivateKey}, @@ -345,10 +302,7 @@ func validateRecaptcha( ) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"), - } + return err } // Close the request once we're finishing reading from it @@ -358,25 +312,16 @@ func validateRecaptcha( var r recaptchaResponse body, err := io.ReadAll(resp.Body) if err != nil { - return &util.JSONResponse{ - Code: http.StatusGatewayTimeout, - JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()), - } + return err } err = json.Unmarshal(body, &r) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()), - } + return err } // Check that we received a "success" if !r.Success { - return &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."), - } + return ErrInvalidCaptcha } return nil } @@ -508,8 +453,8 @@ func validateApplicationService( } // Check username application service is trying to register is valid - if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { - return "", err + if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { + return "", internal.UsernameResponse(err) } // No errors, registration valid @@ -564,15 +509,12 @@ func Register( if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } - if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { - r.Username, r.ServerName = l, d - } if req.URL.Query().Get("kind") == "guest" { return handleGuestRegistration(req, r, cfg, userAPI) } // Don't allow numeric usernames less than MAX_INT64. - if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil { + if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), @@ -584,7 +526,7 @@ func Register( ServerName: r.ServerName, } nres := &userapi.QueryNumericLocalpartResponse{} - if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { + if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") return jsonerror.InternalServerError() } @@ -601,8 +543,8 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } case accessTokenErr == nil: // Non-spec-compliant case (the access_token is specified but the login @@ -614,12 +556,12 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } } - if resErr := internal.ValidatePassword(r.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(r.Password); err != nil { + return *internal.PasswordResponse(err) } logger := util.GetLogger(req.Context()) @@ -697,7 +639,6 @@ func handleGuestRegistration( JSON: registerResponse{ UserID: devRes.Device.UserID, AccessToken: devRes.Device.AccessToken, - HomeServer: res.Account.ServerName, DeviceID: devRes.Device.ID, }, } @@ -761,9 +702,18 @@ func handleRegistrationFlow( switch r.Auth.Type { case authtypes.LoginTypeRecaptcha: // Check given captcha response - resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) - if resErr != nil { - return *resErr + err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) + switch err { + case ErrCaptchaDisabled: + return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())} + case ErrMissingResponse: + return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())} + case ErrInvalidCaptcha: + return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())} + case nil: + default: + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()} } // Add Recaptcha to the list of completed registration stages @@ -924,8 +874,7 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: userutil.MakeUserID(username, accRes.Account.ServerName), - HomeServer: accRes.Account.ServerName, + UserID: userutil.MakeUserID(username, accRes.Account.ServerName), }, } } @@ -958,7 +907,6 @@ func completeRegistration( result := registerResponse{ UserID: devRes.Device.UserID, AccessToken: devRes.Device.AccessToken, - HomeServer: accRes.Account.ServerName, DeviceID: devRes.Device.ID, } sessions.addCompletedRegistration(sessionID, result) @@ -1054,8 +1002,8 @@ func RegisterAvailable( } } - if err := validateUsername(username, domain); err != nil { - return *err + if err := internal.ValidateUsername(username, domain); err != nil { + return *internal.UsernameResponse(err) } // Check if this username is reserved by an application service @@ -1117,11 +1065,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien // downcase capitals ssrr.User = strings.ToLower(ssrr.User) - if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil { + return *internal.UsernameResponse(err) } - if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(ssrr.Password); err != nil { + return *internal.PasswordResponse(err) } deviceID := "shared_secret_registration" diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 85846c7d6..b8fd19e90 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -15,12 +15,27 @@ package routing import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "reflect" "regexp" + "strings" "testing" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/util" ) var ( @@ -264,3 +279,294 @@ func TestSessionCleanUp(t *testing.T) { } }) } + +func Test_register(t *testing.T) { + testCases := []struct { + name string + kind string + password string + username string + loginType string + forceEmpty bool + registrationDisabled bool + guestsDisabled bool + enableRecaptcha bool + captchaBody string + wantResponse util.JSONResponse + }{ + { + name: "disallow guests", + kind: "guest", + guestsDisabled: true, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`), + }, + }, + { + name: "allow guests", + kind: "guest", + }, + { + name: "unknown login type", + loginType: "im.not.known", + wantResponse: util.JSONResponse{ + Code: http.StatusNotImplemented, + JSON: jsonerror.Unknown("unknown/unimplemented auth type"), + }, + }, + { + name: "disabled registration", + registrationDisabled: true, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(`Registration is disabled on "test"`), + }, + }, + { + name: "successful registration, numeric ID", + username: "", + password: "someRandomPassword", + forceEmpty: true, + }, + { + name: "successful registration", + username: "success", + }, + { + name: "failing registration - user already exists", + username: "success", + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UserInUse("Desired user ID is already taken."), + }, + }, + { + name: "successful registration uppercase username", + username: "LOWERCASED", // this is going to be lower-cased + }, + { + name: "invalid username", + username: "#totalyNotValid", + wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid), + }, + { + name: "numeric username is forbidden", + username: "1337", + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), + }, + }, + { + name: "disabled recaptcha login", + loginType: authtypes.LoginTypeRecaptcha, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()), + }, + }, + { + name: "enabled recaptcha, no response defined", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrMissingResponse.Error()), + }, + }, + { + name: "invalid captcha response", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `notvalid`, + wantResponse: util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()), + }, + }, + { + name: "valid captcha response", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `success`, + }, + { + name: "captcha invalid from remote", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `i should fail for other reasons`, + wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.enableRecaptcha { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatal(err) + } + response := r.Form.Get("response") + + // Respond with valid JSON or no JSON at all to test happy/error cases + switch response { + case "success": + json.NewEncoder(w).Encode(recaptchaResponse{Success: true}) + case "notvalid": + json.NewEncoder(w).Encode(recaptchaResponse{Success: false}) + default: + + } + })) + defer srv.Close() + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + } + + if err := base.Cfg.Derive(); err != nil { + t.Fatalf("failed to derive config: %s", err) + } + + base.Cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha + base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled + base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled + + if tc.kind == "" { + tc.kind = "user" + } + if tc.password == "" && !tc.forceEmpty { + tc.password = "someRandomPassword" + } + if tc.username == "" && !tc.forceEmpty { + tc.username = "valid" + } + if tc.loginType == "" { + tc.loginType = "m.login.dummy" + } + + reg := registerRequest{ + Password: tc.password, + Username: tc.username, + } + + body := &bytes.Buffer{} + err := json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body) + + resp := Register(req, userAPI, &base.Cfg.ClientAPI) + t.Logf("Resp: %+v", resp) + + // The first request should return a userInteractiveResponse + switch r := resp.JSON.(type) { + case userInteractiveResponse: + // Check that the flows are the ones we configured + if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) { + t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows) + } + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) + } + return + case registerResponse: + // this should only be possible on guest user registration, never for normal users + if tc.kind != "guest" { + t.Fatalf("got register response on first request: %+v", r) + } + // assert we've got a UserID, AccessToken and DeviceID + if r.UserID == "" { + t.Fatalf("missing userID in response") + } + if r.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + if r.DeviceID == "" { + t.Fatalf("missing deviceID in response") + } + return + default: + t.Logf("Got response: %T", resp.JSON) + } + + // If we reached this, we should have received a UIA response + uia, ok := resp.JSON.(userInteractiveResponse) + if !ok { + t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON) + } + t.Logf("%+v", uia) + + // Register the user + reg.Auth = authDict{ + Type: authtypes.LoginType(tc.loginType), + Session: uia.Session, + } + + if tc.captchaBody != "" { + reg.Auth.Response = tc.captchaBody + } + + dummy := "dummy" + reg.DeviceID = &dummy + reg.InitialDisplayName = &dummy + reg.Type = authtypes.LoginType(tc.loginType) + + err = json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req = httptest.NewRequest(http.MethodPost, "/", body) + + resp = Register(req, userAPI, &base.Cfg.ClientAPI) + + switch resp.JSON.(type) { + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + case util.JSONResponse: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + } + + rr, ok := resp.JSON.(registerResponse) + if !ok { + t.Fatalf("expected a registerresponse, got %T", resp.JSON) + } + + // validate the response + if tc.forceEmpty { + // when not supplying a username, one will be generated. Given this _SHOULD_ be + // the second user, set the username accordingly + reg.Username = "2" + } + wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test")) + if wantUserID != rr.UserID { + t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID) + } + if rr.DeviceID != *reg.DeviceID { + t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID) + } + if rr.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + }) + } + }) +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 69b46214c..09c2cd02f 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -639,9 +639,9 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - return AuthFallback(w, req, vars["authType"], cfg) + AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 15b043ed5..772778680 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -25,10 +25,10 @@ import ( "io" "net/http" "os" - "regexp" "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/tidwall/gjson" "github.com/sirupsen/logrus" @@ -58,15 +58,14 @@ Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - isAdmin = flag.Bool("admin", false, "Create an admin account") - resetPassword = flag.Bool("reset-password", false, "Deprecated") - serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) - timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Deprecated") + serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") + timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") ) var cl = http.Client{ @@ -95,20 +94,21 @@ func main() { os.Exit(1) } - if !validUsernameRegex.MatchString(*username) { - logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") + if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil { + logrus.WithError(err).Error("Specified username is invalid") os.Exit(1) } - if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 { - logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) - } - pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin) if err != nil { logrus.Fatalln(err) } + if err = internal.ValidatePassword(pass); err != nil { + logrus.WithError(err).Error("Specified password is invalid") + os.Exit(1) + } + cl.Timeout = *timeout accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 383913c60..37d144f4e 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -198,17 +198,12 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTMLAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { +func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request)) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { span := opentracing.StartSpan(metricsName) defer span.Finish() req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) - if err := f(w, req); err != nil { - h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { - return *err - })) - h.ServeHTTP(w, req) - } + f(w, req) } if !enableMetrics { diff --git a/internal/validate.go b/internal/validate.go index fc685ad50..0461b897e 100644 --- a/internal/validate.go +++ b/internal/validate.go @@ -15,30 +15,96 @@ package internal import ( + "errors" "fmt" "net/http" + "regexp" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based +const ( + maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain -const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based + maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 +) -// ValidatePassword returns an error response if the password is invalid -func ValidatePassword(password string) *util.JSONResponse { +var ( + ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength) + ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength) + ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength) + ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='") + ErrUsernameUnderscore = errors.New("username cannot start with a '_'") + validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) +) + +// ValidatePassword returns an error if the password is invalid +func ValidatePassword(password string) error { // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 if len(password) > maxPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)), - } + return ErrPasswordTooLong } else if len(password) > 0 && len(password) < minPasswordLength { + return ErrPasswordWeak + } + return nil +} + +// PasswordResponse returns a util.JSONResponse for a given error, if any. +func PasswordResponse(err error) *util.JSONResponse { + switch err { + case ErrPasswordWeak: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()), + } + case ErrPasswordTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()), } } return nil } + +// ValidateUsername returns an error if the username is invalid +func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } else if localpart[0] == '_' { // Regex checks its not a zero length string + return ErrUsernameUnderscore + } + return nil +} + +// UsernameResponse returns a util.JSONResponse for the given error, if any. +func UsernameResponse(err error) *util.JSONResponse { + switch err { + case ErrUsernameTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } + case ErrUsernameInvalid, ErrUsernameUnderscore: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(err.Error()), + } + } + return nil +} + +// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service +func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error { + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } + return nil +} diff --git a/internal/validate_test.go b/internal/validate_test.go new file mode 100644 index 000000000..d0ad04707 --- /dev/null +++ b/internal/validate_test.go @@ -0,0 +1,170 @@ +package internal + +import ( + "net/http" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func Test_validatePassword(t *testing.T) { + tests := []struct { + name string + password string + wantError error + wantJSON *util.JSONResponse + }{ + { + name: "password too short", + password: "shortpw", + wantError: ErrPasswordWeak, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())}, + }, + { + name: "password too long", + password: strings.Repeat("a", maxPasswordLength+1), + wantError: ErrPasswordTooLong, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())}, + }, + { + name: "password OK", + password: util.RandomString(10), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidatePassword(tt.password) + if !reflect.DeepEqual(gotErr, tt.wantError) { + t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError) + } + + if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) { + t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON) + } + }) + } +} + +func Test_validateUsername(t *testing.T) { + tooLongUsername := strings.Repeat("a", maxUsernameLength) + tests := []struct { + name string + localpart string + domain gomatrixserverlib.ServerName + wantErr error + wantJSON *util.JSONResponse + }{ + { + name: "empty username", + localpart: "", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "invalid username", + localpart: "INVALIDUSERNAME", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username too long", + localpart: tooLongUsername, + domain: "localhost", + wantErr: ErrUsernameTooLong, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()), + }, + }, + { + name: "localpart starting with an underscore", + localpart: "_notvalid", + domain: "localhost", + wantErr: ErrUsernameUnderscore, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()), + }, + }, + { + name: "valid username", + localpart: "valid", + domain: "localhost", + }, + { + name: "complex username", + localpart: "f00_bar-baz.=40/", + domain: "localhost", + }, + { + name: "rejects emoji username 💥", + localpart: "💥", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "special characters are allowed", + localpart: "/dev/null", + domain: "localhost", + }, + { + name: "special characters are allowed 2", + localpart: "i_am_allowed=1", + domain: "localhost", + }, + { + name: "not all special characters are allowed", + localpart: "notallowed#", // contains # + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username containing numbers", + localpart: "hello1337", + domain: "localhost", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidateUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + + // Application services are allowed usernames starting with an underscore + if tt.wantErr == ErrUsernameUnderscore { + return + } + gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + }) + } +} diff --git a/setup/config/config.go b/setup/config/config.go index 7e7ed1aa1..6523a2452 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -29,7 +29,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "golang.org/x/crypto/ed25519" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" jaegerconfig "github.com/uber/jaeger-client-go/config" jaegermetrics "github.com/uber/jaeger-lib/metrics" @@ -314,11 +314,13 @@ func (config *Dendrite) Derive() error { if config.ClientAPI.RecaptchaEnabled { config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey} - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}) + config.Derived.Registration.Flows = []authtypes.Flow{ + {Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}, + } } else { - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) + config.Derived.Registration.Flows = []authtypes.Flow{ + {Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}, + } } // Load application service configuration files diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 0a871da18..11628b1b0 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -78,9 +78,6 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) if c.RecaptchaEnabled { - checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) - checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) - checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) if c.RecaptchaSiteVerifyAPI == "" { c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify" } @@ -93,6 +90,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { if c.RecaptchaSitekeyClass == "" { c.RecaptchaSitekeyClass = "g-recaptcha-response" } + checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) + checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) + checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) + checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass) } // Ensure there is any spam counter measure when enabling registration if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled { From e449d174ccf7569b2536289f3c8145298e80bc90 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 14:28:15 +0100 Subject: [PATCH 071/124] Add possibility to run complement with coverage enabled (#2901) This adds the possibility to run Complement with coverage enabled. In combination with https://github.com/matrix-org/complement/pull/566 we should then be able to extract the coverage logs, combine them with https://github.com/wadey/gocovmerge (or similar) and upload them to Codecov (with different flags, depending on SQLite, HTTP etc.) --- Dockerfile | 27 --------------------- build/scripts/Complement.Dockerfile | 7 ++++-- build/scripts/ComplementLocal.Dockerfile | 5 +++- build/scripts/ComplementPostgres.Dockerfile | 7 ++++-- build/scripts/complement-cmd.sh | 22 +++++++++++++++++ 5 files changed, 36 insertions(+), 32 deletions(-) create mode 100755 build/scripts/complement-cmd.sh diff --git a/Dockerfile b/Dockerfile index a9bbce925..ede33e635 100644 --- a/Dockerfile +++ b/Dockerfile @@ -63,30 +63,3 @@ WORKDIR /etc/dendrite ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] EXPOSE 8008 8448 -# -# Builds the Complement image, used for integration tests -# -FROM base AS complement -LABEL org.opencontainers.image.title="Dendrite (Complement)" -RUN apk add --no-cache sqlite openssl ca-certificates - -COPY --from=build /out/generate-config /usr/bin/generate-config -COPY --from=build /out/generate-keys /usr/bin/generate-keys -COPY --from=build /out/dendrite-monolith-server /usr/bin/dendrite-monolith-server - -WORKDIR /dendrite -RUN /usr/bin/generate-keys --private-key matrix_key.pem && \ - mkdir /ca && \ - openssl genrsa -out /ca/ca.key 2048 && \ - openssl req -new -x509 -key /ca/ca.key -days 3650 -subj "/C=GB/ST=London/O=matrix.org/CN=Complement CA" -out /ca/ca.crt - -ENV SERVER_NAME=localhost -ENV API=0 -EXPOSE 8008 8448 - -# At runtime, generate TLS cert based on the CA now mounted at /ca -# At runtime, replace the SERVER_NAME with what we are told -CMD /usr/bin/generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \ - /usr/bin/generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ - cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - /usr/bin/dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 79422e645..3a00fbdf0 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -16,13 +16,16 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \ + cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost ENV API=0 +ENV COVER=0 EXPOSE 8008 8448 # At runtime, generate TLS cert based on the CA now mounted at /ca @@ -30,4 +33,4 @@ EXPOSE 8008 8448 CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} + exec /complement-cmd.sh diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile index 3a019fc20..e3fbe1aa8 100644 --- a/build/scripts/ComplementLocal.Dockerfile +++ b/build/scripts/ComplementLocal.Dockerfile @@ -12,18 +12,20 @@ FROM golang:1.18-stretch RUN apt-get update && apt-get install -y sqlite3 ENV SERVER_NAME=localhost +ENV COVER=0 EXPOSE 8008 8448 WORKDIR /runtime # This script compiles Dendrite for us. RUN echo '\ #!/bin/bash -eux \n\ - if test -f "/runtime/dendrite-monolith-server"; then \n\ + if test -f "/runtime/dendrite-monolith-server" && test -f "/runtime/dendrite-monolith-server-cover"; then \n\ echo "Skipping compilation; binaries exist" \n\ exit 0 \n\ fi \n\ cd /dendrite \n\ go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\ + go test -c -cover -covermode=atomic -o /runtime/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." /dendrite/cmd/dendrite-monolith-server \n\ ' > compile.sh && chmod +x compile.sh # This script runs Dendrite for us. Must be run in the /runtime directory. @@ -33,6 +35,7 @@ RUN echo '\ ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ + [ ${COVER} -eq 1 ] && exec ./dendrite-monolith-server-cover --test.coverprofile=integrationcover.log --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ ' > run.sh && chmod +x run.sh diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 3faf43cc7..444cb947d 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -34,13 +34,16 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \ + cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost ENV API=0 +ENV COVER=0 EXPOSE 8008 8448 @@ -51,4 +54,4 @@ CMD /build/run_postgres.sh && ./generate-keys --keysize 1024 --server $SERVER_NA # Bump max_open_conns up here in the global database config sed -i 's/max_open_conns:.*$/max_open_conns: 1990/g' dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} \ No newline at end of file + exec /complement-cmd.sh \ No newline at end of file diff --git a/build/scripts/complement-cmd.sh b/build/scripts/complement-cmd.sh new file mode 100755 index 000000000..061bd18eb --- /dev/null +++ b/build/scripts/complement-cmd.sh @@ -0,0 +1,22 @@ +#!/bin/bash -e + +# This script is intended to be used inside a docker container for Complement + +if [[ "${COVER}" -eq 1 ]]; then + echo "Running with coverage" + exec /dendrite/dendrite-monolith-server-cover \ + --really-enable-open-registration \ + --tls-cert server.crt \ + --tls-key server.key \ + --config dendrite.yaml \ + -api=${API:-0} \ + --test.coverprofile=integrationcover.log +else + echo "Not running with coverage" + exec /dendrite/dendrite-monolith-server \ + --really-enable-open-registration \ + --tls-cert server.crt \ + --tls-key server.key \ + --config dendrite.yaml \ + -api=${API:-0} +fi From 2e1fe589375b650f9b2d9a09e1fcffb3ab6fe5b6 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 5 Jan 2023 09:24:00 +0100 Subject: [PATCH 072/124] Fix backfilling (#2926) This should fix https://github.com/matrix-org/dendrite/issues/2923 --- go.mod | 2 +- go.sum | 4 ++-- roomserver/internal/perform/perform_backfill.go | 9 +++++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index d3eb4890a..2d7174150 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index ad9372c84..b12f65eab 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 069f017a9..d9214fdc6 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -122,11 +122,14 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform ctx, req.VirtualHost, requester, r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, ) - if err != nil { + // Only return an error if we really couldn't get any events. + if err != nil && len(events) == 0 { logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed") return err } - logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) + // If we got an error but still got events, that's fine, because a server might have returned a 404 (or something) + // but other servers could provide the missing event. + logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) // persist these new events - auth checks have already been done roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) @@ -319,6 +322,7 @@ FederationHit: FedClient: b.fsAPI, RememberAuthEvents: false, Server: srv, + Origin: b.virtualHost, } res, err := c.StateIDsBeforeEvent(ctx, targetEvent) if err != nil { @@ -394,6 +398,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr FedClient: b.fsAPI, RememberAuthEvents: false, Server: srv, + Origin: b.virtualHost, } result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs) if err != nil { From d579ddb8e7c1a7e118797bcef08113379535e6fb Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:44:10 +0100 Subject: [PATCH 073/124] Add simplified helm chart (#2905) As discussed yesterday, a simplified version of [my helm](https://github.com/S7evinK/dendrite-helm) which deploys a monolith with internal NATS and an optionally enabled PostgreSQL server. If the PostgreSQL dependency is not enabled, a user specified connection string is constructed. Co-authored-by: kegsay --- .github/workflows/gh-pages.yml | 52 +++ .github/workflows/helm.yml | 39 ++ .github/workflows/k8s.yml | 90 +++++ helm/cr.yaml | 1 + helm/ct.yaml | 7 + helm/dendrite/.helm-docs/about.gotmpl | 5 + helm/dendrite/.helm-docs/appservices.gotmpl | 5 + helm/dendrite/.helm-docs/database.gotmpl | 18 + helm/dendrite/.helm-docs/state.gotmpl | 3 + helm/dendrite/Chart.yaml | 19 + helm/dendrite/README.md | 147 ++++++++ helm/dendrite/README.md.gotmpl | 13 + helm/dendrite/ci/ct-ingress-values.yaml | 13 + .../ci/ct-postgres-sharedsecret-values.yaml | 16 + helm/dendrite/templates/_helpers.tpl | 72 ++++ helm/dendrite/templates/_overrides.yaml | 16 + helm/dendrite/templates/deployment.yaml | 103 ++++++ helm/dendrite/templates/ingress.yaml | 55 +++ helm/dendrite/templates/jobs.yaml | 99 +++++ helm/dendrite/templates/pvc.yaml | 48 +++ helm/dendrite/templates/secrets.yaml | 33 ++ helm/dendrite/templates/service.yaml | 17 + .../templates/tests/test-version.yaml | 17 + helm/dendrite/values.yaml | 348 ++++++++++++++++++ setup/config/config.go | 4 +- 25 files changed, 1238 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/gh-pages.yml create mode 100644 .github/workflows/helm.yml create mode 100644 .github/workflows/k8s.yml create mode 100644 helm/cr.yaml create mode 100644 helm/ct.yaml create mode 100644 helm/dendrite/.helm-docs/about.gotmpl create mode 100644 helm/dendrite/.helm-docs/appservices.gotmpl create mode 100644 helm/dendrite/.helm-docs/database.gotmpl create mode 100644 helm/dendrite/.helm-docs/state.gotmpl create mode 100644 helm/dendrite/Chart.yaml create mode 100644 helm/dendrite/README.md create mode 100644 helm/dendrite/README.md.gotmpl create mode 100644 helm/dendrite/ci/ct-ingress-values.yaml create mode 100644 helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml create mode 100644 helm/dendrite/templates/_helpers.tpl create mode 100644 helm/dendrite/templates/_overrides.yaml create mode 100644 helm/dendrite/templates/deployment.yaml create mode 100644 helm/dendrite/templates/ingress.yaml create mode 100644 helm/dendrite/templates/jobs.yaml create mode 100644 helm/dendrite/templates/pvc.yaml create mode 100644 helm/dendrite/templates/secrets.yaml create mode 100644 helm/dendrite/templates/service.yaml create mode 100644 helm/dendrite/templates/tests/test-version.yaml create mode 100644 helm/dendrite/values.yaml diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml new file mode 100644 index 000000000..b5a8f0bbd --- /dev/null +++ b/.github/workflows/gh-pages.yml @@ -0,0 +1,52 @@ +# Sample workflow for building and deploying a Jekyll site to GitHub Pages +name: Deploy GitHub Pages dependencies preinstalled + +on: + # Runs on pushes targeting the default branch + push: + branches: ["main"] + paths: + - 'docs/**' # only execute if we have docs changes + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow one concurrent deployment +concurrency: + group: "pages" + cancel-in-progress: true + +jobs: + # Build job + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Pages + uses: actions/configure-pages@v2 + - name: Build with Jekyll + uses: actions/jekyll-build-pages@v1 + with: + source: ./docs + destination: ./_site + - name: Upload artifact + uses: actions/upload-pages-artifact@v1 + + # Deployment job + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v1 diff --git a/.github/workflows/helm.yml b/.github/workflows/helm.yml new file mode 100644 index 000000000..7cdc369ba --- /dev/null +++ b/.github/workflows/helm.yml @@ -0,0 +1,39 @@ +name: Release Charts + +on: + push: + branches: + - main + paths: + - 'helm/**' # only execute if we have helm chart changes + +jobs: + release: + # depending on default permission settings for your org (contents being read-only or read-write for workloads), you will have to add permissions + # see: https://docs.github.com/en/actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token + permissions: + contents: write + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Configure Git + run: | + git config user.name "$GITHUB_ACTOR" + git config user.email "$GITHUB_ACTOR@users.noreply.github.com" + + - name: Install Helm + uses: azure/setup-helm@v3 + with: + version: v3.10.0 + + - name: Run chart-releaser + uses: helm/chart-releaser-action@v1.4.1 + env: + CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}" + with: + config: helm/cr.yaml + charts_dir: helm/ diff --git a/.github/workflows/k8s.yml b/.github/workflows/k8s.yml new file mode 100644 index 000000000..fc5e8c906 --- /dev/null +++ b/.github/workflows/k8s.yml @@ -0,0 +1,90 @@ +name: k8s + +on: + push: + branches: ["main"] + paths: + - 'helm/**' # only execute if we have helm chart changes + pull_request: + branches: ["main"] + paths: + - 'helm/**' + +jobs: + lint: + name: Lint Helm chart + runs-on: ubuntu-latest + outputs: + changed: ${{ steps.list-changed.outputs.changed }} + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: azure/setup-helm@v3 + with: + version: v3.10.0 + - uses: actions/setup-python@v4 + with: + python-version: 3.11 + check-latest: true + - uses: helm/chart-testing-action@v2.3.1 + - name: Get changed status + id: list-changed + run: | + changed=$(ct list-changed --config helm/ct.yaml --target-branch ${{ github.event.repository.default_branch }}) + if [[ -n "$changed" ]]; then + echo "::set-output name=changed::true" + fi + + - name: Run lint + run: ct lint --config helm/ct.yaml + + # only bother to run if lint step reports a change to the helm chart + install: + needs: + - lint + if: ${{ needs.lint.outputs.changed == 'true' }} + name: Install Helm charts + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ inputs.checkoutCommit }} + - name: Install Kubernetes tools + uses: yokawasa/action-setup-kube-tools@v0.8.2 + with: + setup-tools: | + helmv3 + helm: "3.10.3" + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Set up chart-testing + uses: helm/chart-testing-action@v2.3.1 + - name: Create k3d cluster + uses: nolar/setup-k3d-k3s@v1 + with: + version: v1.21 + - name: Remove node taints + run: | + kubectl taint --all=true nodes node.cloudprovider.kubernetes.io/uninitialized- || true + - name: Run chart-testing (install) + run: ct install --config helm/ct.yaml + + # Install the chart using helm directly and test with create-account + - name: Install chart + run: | + helm install --values helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml dendrite helm/dendrite + - name: Wait for Postgres and Dendrite to be up + run: | + kubectl wait --for=condition=ready --timeout=90s pod -l app.kubernetes.io/name=postgresql || kubectl get pods -A + kubectl wait --for=condition=ready --timeout=90s pod -l app.kubernetes.io/name=dendrite || kubectl get pods -A + kubectl get pods -A + kubectl get services + kubectl get ingress + - name: Run create account + run: | + podName=$(kubectl get pods -l app.kubernetes.io/name=dendrite -o name) + kubectl exec "${podName}" -- /usr/bin/create-account -username alice -password somerandompassword \ No newline at end of file diff --git a/helm/cr.yaml b/helm/cr.yaml new file mode 100644 index 000000000..f895ab8d6 --- /dev/null +++ b/helm/cr.yaml @@ -0,0 +1 @@ +release-name-template: "helm-{{ .Name }}-{{ .Version }}" \ No newline at end of file diff --git a/helm/ct.yaml b/helm/ct.yaml new file mode 100644 index 000000000..af706fa3d --- /dev/null +++ b/helm/ct.yaml @@ -0,0 +1,7 @@ +remote: origin +target-branch: main +chart-repos: + - bitnami=https://charts.bitnami.com/bitnami +chart-dirs: + - helm +validate-maintainers: false \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/about.gotmpl b/helm/dendrite/.helm-docs/about.gotmpl new file mode 100644 index 000000000..a92c6be42 --- /dev/null +++ b/helm/dendrite/.helm-docs/about.gotmpl @@ -0,0 +1,5 @@ +{{ define "chart.about" }} +## About + +This chart creates a monolith deployment, including an optionally enabled PostgreSQL dependency to connect to. +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/appservices.gotmpl b/helm/dendrite/.helm-docs/appservices.gotmpl new file mode 100644 index 000000000..8a79a0780 --- /dev/null +++ b/helm/dendrite/.helm-docs/appservices.gotmpl @@ -0,0 +1,5 @@ +{{ define "chart.appservices" }} +## Usage with appservices + +Create a folder `appservices` and place your configurations in there. The configurations will be read and placed in a secret `dendrite-appservices-conf`. +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/database.gotmpl b/helm/dendrite/.helm-docs/database.gotmpl new file mode 100644 index 000000000..85ef01ecc --- /dev/null +++ b/helm/dendrite/.helm-docs/database.gotmpl @@ -0,0 +1,18 @@ +{{ define "chart.dbCreation" }} +## Manual database creation + +(You can skip this, if you're deploying the PostgreSQL dependency) + +You'll need to create the following database before starting Dendrite (see [installation](https://matrix-org.github.io/dendrite/installation/database#single-database-creation)): + +```postgres +create database dendrite +``` + +or + +```bash +sudo -u postgres createdb -O dendrite -E UTF-8 dendrite +``` + +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/state.gotmpl b/helm/dendrite/.helm-docs/state.gotmpl new file mode 100644 index 000000000..2fe987ddd --- /dev/null +++ b/helm/dendrite/.helm-docs/state.gotmpl @@ -0,0 +1,3 @@ +{{ define "chart.state" }} +Status: **NOT PRODUCTION READY** +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml new file mode 100644 index 000000000..15d1e6d19 --- /dev/null +++ b/helm/dendrite/Chart.yaml @@ -0,0 +1,19 @@ +apiVersion: v2 +name: dendrite +version: "0.10.8" +appVersion: "0.10.8" +description: Dendrite Matrix Homeserver +type: application +keywords: + - matrix + - chat + - homeserver + - dendrite +home: https://github.com/matrix-org/dendrite +sources: + - https://github.com/matrix-org/dendrite +dependencies: +- name: postgresql + version: 12.1.7 + repository: https://charts.bitnami.com/bitnami + condition: postgresql.enabled diff --git a/helm/dendrite/README.md b/helm/dendrite/README.md new file mode 100644 index 000000000..cb850d655 --- /dev/null +++ b/helm/dendrite/README.md @@ -0,0 +1,147 @@ +# dendrite + +![Version: 0.10.8](https://img.shields.io/badge/Version-0.10.8-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.10.8](https://img.shields.io/badge/AppVersion-0.10.8-informational?style=flat-square) +Dendrite Matrix Homeserver + +Status: **NOT PRODUCTION READY** + +## About + +This chart creates a monolith deployment, including an optionally enabled PostgreSQL dependency to connect to. + +## Manual database creation + +(You can skip this, if you're deploying the PostgreSQL dependency) + +You'll need to create the following database before starting Dendrite (see [installation](https://matrix-org.github.io/dendrite/installation/database#single-database-creation)): + +```postgres +create database dendrite +``` + +or + +```bash +sudo -u postgres createdb -O dendrite -E UTF-8 dendrite +``` + +## Usage with appservices + +Create a folder `appservices` and place your configurations in there. The configurations will be read and placed in a secret `dendrite-appservices-conf`. + +## Source Code + +* +## Requirements + +| Repository | Name | Version | +|------------|------|---------| +| https://charts.bitnami.com/bitnami | postgresql | 12.1.7 | +## Values + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| image.name | string | `"ghcr.io/matrix-org/dendrite-monolith:v0.10.8"` | Docker repository/image to use | +| image.pullPolicy | string | `"IfNotPresent"` | Kubernetes pullPolicy | +| signing_key.create | bool | `true` | Create a new signing key, if not exists | +| signing_key.existingSecret | string | `""` | Use an existing secret | +| resources | object | sets some sane default values | Default resource requests/limits. | +| persistence.storageClass | string | `""` | The storage class to use for volume claims. Defaults to the cluster default storage class. | +| persistence.jetstream.existingClaim | string | `""` | Use an existing volume claim for jetstream | +| persistence.jetstream.capacity | string | `"1Gi"` | PVC Storage Request for the jetstream volume | +| persistence.media.existingClaim | string | `""` | Use an existing volume claim for media files | +| persistence.media.capacity | string | `"1Gi"` | PVC Storage Request for the media volume | +| persistence.search.existingClaim | string | `""` | Use an existing volume claim for the fulltext search index | +| persistence.search.capacity | string | `"1Gi"` | PVC Storage Request for the search volume | +| dendrite_config.version | int | `2` | | +| dendrite_config.global.server_name | string | `""` | **REQUIRED** Servername for this Dendrite deployment. | +| dendrite_config.global.private_key | string | `"/etc/dendrite/secrets/signing.key"` | The private key to use. (**NOTE**: This is overriden in Helm) | +| dendrite_config.global.well_known_server_name | string | `""` | The server name to delegate server-server communications to, with optional port e.g. localhost:443 | +| dendrite_config.global.well_known_client_name | string | `""` | The server name to delegate client-server communications to, with optional port e.g. localhost:443 | +| dendrite_config.global.trusted_third_party_id_servers | list | `["matrix.org","vector.im"]` | Lists of domains that the server will trust as identity servers to verify third party identifiers such as phone numbers and email addresses. | +| dendrite_config.global.old_private_keys | string | `nil` | The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) to old signing keys that were formerly in use on this domain name. These keys will not be used for federation request or event signing, but will be provided to any other homeserver that asks when trying to verify old events. | +| dendrite_config.global.disable_federation | bool | `false` | Disable federation. Dendrite will not be able to make any outbound HTTP requests to other servers and the federation API will not be exposed. | +| dendrite_config.global.key_validity_period | string | `"168h0m0s"` | | +| dendrite_config.global.database.connection_string | string | `""` | The connection string for connections to Postgres. This will be set automatically if using the Postgres dependency | +| dendrite_config.global.database.max_open_conns | int | `90` | Default database maximum open connections | +| dendrite_config.global.database.max_idle_conns | int | `5` | Default database maximum idle connections | +| dendrite_config.global.database.conn_max_lifetime | int | `-1` | Default database maximum lifetime | +| dendrite_config.global.jetstream.storage_path | string | `"/data/jetstream"` | Persistent directory to store JetStream streams in. | +| dendrite_config.global.jetstream.addresses | list | `[]` | NATS JetStream server addresses if not using internal NATS. | +| dendrite_config.global.jetstream.topic_prefix | string | `"Dendrite"` | The prefix for JetStream streams | +| dendrite_config.global.jetstream.in_memory | bool | `false` | Keep all data in memory. (**NOTE**: This is overriden in Helm to `false`) | +| dendrite_config.global.jetstream.disable_tls_validation | bool | `true` | Disables TLS validation. This should **NOT** be used in production. | +| dendrite_config.global.cache.max_size_estimated | string | `"1gb"` | The estimated maximum size for the global cache in bytes, or in terabytes, gigabytes, megabytes or kilobytes when the appropriate 'tb', 'gb', 'mb' or 'kb' suffix is specified. Note that this is not a hard limit, nor is it a memory limit for the entire process. A cache that is too small may ultimately provide little or no benefit. | +| dendrite_config.global.cache.max_age | string | `"1h"` | The maximum amount of time that a cache entry can live for in memory before it will be evicted and/or refreshed from the database. Lower values result in easier admission of new cache entries but may also increase database load in comparison to higher values, so adjust conservatively. Higher values may make it harder for new items to make it into the cache, e.g. if new rooms suddenly become popular. | +| dendrite_config.global.report_stats.enabled | bool | `false` | Configures phone-home statistics reporting. These statistics contain the server name, number of active users and some information on your deployment config. We use this information to understand how Dendrite is being used in the wild. | +| dendrite_config.global.report_stats.endpoint | string | `"https://matrix.org/report-usage-stats/push"` | Endpoint to report statistics to. | +| dendrite_config.global.presence.enable_inbound | bool | `false` | Controls whether we receive presence events from other servers | +| dendrite_config.global.presence.enable_outbound | bool | `false` | Controls whether we send presence events for our local users to other servers. (_May increase CPU/memory usage_) | +| dendrite_config.global.server_notices.enabled | bool | `false` | Server notices allows server admins to send messages to all users on the server. | +| dendrite_config.global.server_notices.local_part | string | `"_server"` | The local part for the user sending server notices. | +| dendrite_config.global.server_notices.display_name | string | `"Server Alerts"` | The display name for the user sending server notices. | +| dendrite_config.global.server_notices.avatar_url | string | `""` | The avatar URL (as a mxc:// URL) name for the user sending server notices. | +| dendrite_config.global.server_notices.room_name | string | `"Server Alerts"` | | +| dendrite_config.global.metrics.enabled | bool | `false` | Whether or not Prometheus metrics are enabled. | +| dendrite_config.global.metrics.basic_auth.user | string | `"metrics"` | HTTP basic authentication username | +| dendrite_config.global.metrics.basic_auth.password | string | `"metrics"` | HTTP basic authentication password | +| dendrite_config.global.dns_cache.enabled | bool | `false` | Whether or not the DNS cache is enabled. | +| dendrite_config.global.dns_cache.cache_size | int | `256` | Maximum number of entries to hold in the DNS cache | +| dendrite_config.global.dns_cache.cache_lifetime | string | `"10m"` | Duration for how long DNS cache items should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) | +| dendrite_config.global.profiling.enabled | bool | `false` | Enable pprof. You will need to manually create a port forwarding to the deployment to access PPROF, as it will only listen on localhost and the defined port. e.g. `kubectl port-forward deployments/dendrite 65432:65432` | +| dendrite_config.global.profiling.port | int | `65432` | pprof port, if enabled | +| dendrite_config.mscs | object | `{"mscs":["msc2946"]}` | Configuration for experimental MSC's. (Valid values are: msc2836 and msc2946) | +| dendrite_config.app_service_api.disable_tls_validation | bool | `false` | Disable the validation of TLS certificates of appservices. This is not recommended in production since it may allow appservice traffic to be sent to an insecure endpoint. | +| dendrite_config.app_service_api.config_files | list | `[]` | Appservice config files to load on startup. (**NOTE**: This is overriden by Helm, if a folder `./appservices/` exists) | +| dendrite_config.client_api.registration_disabled | bool | `true` | Prevents new users from being able to register on this homeserver, except when using the registration shared secret below. | +| dendrite_config.client_api.guests_disabled | bool | `true` | | +| dendrite_config.client_api.registration_shared_secret | string | `""` | If set, allows registration by anyone who knows the shared secret, regardless of whether registration is otherwise disabled. | +| dendrite_config.client_api.enable_registration_captcha | bool | `false` | enable reCAPTCHA registration | +| dendrite_config.client_api.recaptcha_public_key | string | `""` | reCAPTCHA public key | +| dendrite_config.client_api.recaptcha_private_key | string | `""` | reCAPTCHA private key | +| dendrite_config.client_api.recaptcha_bypass_secret | string | `""` | reCAPTCHA bypass secret | +| dendrite_config.client_api.recaptcha_siteverify_api | string | `""` | | +| dendrite_config.client_api.turn.turn_user_lifetime | string | `"24h"` | Duration for how long users should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) | +| dendrite_config.client_api.turn.turn_uris | list | `[]` | | +| dendrite_config.client_api.turn.turn_shared_secret | string | `""` | | +| dendrite_config.client_api.turn.turn_username | string | `""` | The TURN username | +| dendrite_config.client_api.turn.turn_password | string | `""` | The TURN password | +| dendrite_config.client_api.rate_limiting.enabled | bool | `true` | Enable rate limiting | +| dendrite_config.client_api.rate_limiting.threshold | int | `20` | After how many requests a rate limit should be activated | +| dendrite_config.client_api.rate_limiting.cooloff_ms | int | `500` | Cooloff time in milliseconds | +| dendrite_config.client_api.rate_limiting.exempt_user_ids | string | `nil` | Users which should be exempt from rate limiting | +| dendrite_config.federation_api.send_max_retries | int | `16` | Federation failure threshold. How many consecutive failures that we should tolerate when sending federation requests to a specific server. The backoff is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds, etc. The default value is 16 if not specified, which is circa 18 hours. | +| dendrite_config.federation_api.disable_tls_validation | bool | `false` | Disable TLS validation. This should **NOT** be used in production. | +| dendrite_config.federation_api.prefer_direct_fetch | bool | `false` | | +| dendrite_config.federation_api.disable_http_keepalives | bool | `false` | Prevents Dendrite from keeping HTTP connections open for reuse for future requests. Connections will be closed quicker but we may spend more time on TLS handshakes instead. | +| dendrite_config.federation_api.key_perspectives | list | See value.yaml | Perspective keyservers, to use as a backup when direct key fetch requests don't succeed. | +| dendrite_config.media_api.base_path | string | `"/data/media_store"` | The path to store media files (e.g. avatars) in | +| dendrite_config.media_api.max_file_size_bytes | int | `10485760` | The max file size for uploaded media files | +| dendrite_config.media_api.dynamic_thumbnails | bool | `false` | | +| dendrite_config.media_api.max_thumbnail_generators | int | `10` | The maximum number of simultaneous thumbnail generators to run. | +| dendrite_config.media_api.thumbnail_sizes | list | See value.yaml | A list of thumbnail sizes to be generated for media content. | +| dendrite_config.sync_api.real_ip_header | string | `"X-Real-IP"` | This option controls which HTTP header to inspect to find the real remote IP address of the client. This is likely required if Dendrite is running behind a reverse proxy server. | +| dendrite_config.sync_api.search | object | `{"enabled":true,"index_path":"/data/search","language":"en"}` | Configuration for the full-text search engine. | +| dendrite_config.sync_api.search.enabled | bool | `true` | Whether fulltext search is enabled. | +| dendrite_config.sync_api.search.index_path | string | `"/data/search"` | The path to store the search index in. | +| dendrite_config.sync_api.search.language | string | `"en"` | The language most likely to be used on the server - used when indexing, to ensure the returned results match expectations. A full list of possible languages can be found [here](https://github.com/matrix-org/dendrite/blob/76db8e90defdfb9e61f6caea8a312c5d60bcc005/internal/fulltext/bleve.go#L25-L46) | +| dendrite_config.user_api.bcrypt_cost | int | `10` | bcrypt cost to use when hashing passwords. (ranges from 4-31; 4 being least secure, 31 being most secure; _NOTE: Using a too high value can cause clients to timeout and uses more CPU._) | +| dendrite_config.user_api.openid_token_lifetime_ms | int | `3600000` | OpenID Token lifetime in milliseconds. | +| dendrite_config.user_api.push_gateway_disable_tls_validation | bool | `false` | | +| dendrite_config.user_api.auto_join_rooms | list | `[]` | Rooms to join users to after registration | +| dendrite_config.logging | list | `[{"level":"info","type":"std"}]` | Default logging configuration | +| postgresql.enabled | bool | See value.yaml | Enable and configure postgres as the database for dendrite. | +| postgresql.image.repository | string | `"bitnami/postgresql"` | | +| postgresql.image.tag | string | `"15.1.0"` | | +| postgresql.auth.username | string | `"dendrite"` | | +| postgresql.auth.password | string | `"changeme"` | | +| postgresql.auth.database | string | `"dendrite"` | | +| postgresql.persistence.enabled | bool | `false` | | +| ingress.enabled | bool | `false` | Create an ingress for a monolith deployment | +| ingress.hosts | list | `[]` | | +| ingress.className | string | `""` | | +| ingress.hostName | string | `""` | | +| ingress.annotations | object | `{}` | Extra, custom annotations | +| ingress.tls | list | `[]` | | +| service.type | string | `"ClusterIP"` | | +| service.port | int | `80` | | diff --git a/helm/dendrite/README.md.gotmpl b/helm/dendrite/README.md.gotmpl new file mode 100644 index 000000000..7c32f7b02 --- /dev/null +++ b/helm/dendrite/README.md.gotmpl @@ -0,0 +1,13 @@ +{{ template "chart.header" . }} +{{ template "chart.deprecationWarning" . }} +{{ template "chart.badgesSection" . }} +{{ template "chart.description" . }} +{{ template "chart.state" . }} +{{ template "chart.about" . }} +{{ template "chart.dbCreation" . }} +{{ template "chart.appservices" . }} +{{ template "chart.maintainersSection" . }} +{{ template "chart.sourcesSection" . }} +{{ template "chart.requirementsSection" . }} +{{ template "chart.valuesSection" . }} +{{ template "helm-docs.versionFooter" . }} \ No newline at end of file diff --git a/helm/dendrite/ci/ct-ingress-values.yaml b/helm/dendrite/ci/ct-ingress-values.yaml new file mode 100644 index 000000000..28311d33e --- /dev/null +++ b/helm/dendrite/ci/ct-ingress-values.yaml @@ -0,0 +1,13 @@ +--- +postgresql: + enabled: true + primary: + persistence: + size: 1Gi + +dendrite_config: + global: + server_name: "localhost" + +ingress: + enabled: true diff --git a/helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml b/helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml new file mode 100644 index 000000000..55e652c63 --- /dev/null +++ b/helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml @@ -0,0 +1,16 @@ +--- +postgresql: + enabled: true + primary: + persistence: + size: 1Gi + +dendrite_config: + global: + server_name: "localhost" + + client_api: + registration_shared_secret: "d233f2fcb0470845a8e150a20ef594ddbe0b4cf7fe482fb9d5120c198557acbf" # echo "dendrite" | sha256sum + +ingress: + enabled: true diff --git a/helm/dendrite/templates/_helpers.tpl b/helm/dendrite/templates/_helpers.tpl new file mode 100644 index 000000000..291f351bc --- /dev/null +++ b/helm/dendrite/templates/_helpers.tpl @@ -0,0 +1,72 @@ +{{- define "validate.config" }} +{{- if not .Values.signing_key.create -}} +{{- fail "You must create a signing key for configuration.signing_key. (see https://github.com/matrix-org/dendrite/blob/master/docs/INSTALL.md#server-key-generation)" -}} +{{- end -}} +{{- if not (or .Values.dendrite_config.global.database.host .Values.postgresql.enabled) -}} +{{- fail "Database server must be set." -}} +{{- end -}} +{{- if not (or .Values.dendrite_config.global.database.user .Values.postgresql.enabled) -}} +{{- fail "Database user must be set." -}} +{{- end -}} +{{- if not (or .Values.dendrite_config.global.database.password .Values.postgresql.enabled) -}} +{{- fail "Database password must be set." -}} +{{- end -}} +{{- end -}} + + +{{- define "image.name" -}} +image: {{ .name }} +imagePullPolicy: {{ .pullPolicy }} +{{- end -}} + +{{/* +Expand the name of the chart. +*/}} +{{- define "dendrite.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). +If release name contains chart name it will be used as a full name. +*/}} +{{- define "dendrite.fullname" -}} +{{- if .Values.fullnameOverride }} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- $name := default .Chart.Name .Values.nameOverride }} +{{- if contains $name .Release.Name }} +{{- .Release.Name | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} +{{- end }} +{{- end }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "dendrite.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "dendrite.labels" -}} +helm.sh/chart: {{ include "dendrite.chart" . }} +{{ include "dendrite.selectorLabels" . }} +{{- if .Chart.AppVersion }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels +*/}} +{{- define "dendrite.selectorLabels" -}} +app.kubernetes.io/name: {{ include "dendrite.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} \ No newline at end of file diff --git a/helm/dendrite/templates/_overrides.yaml b/helm/dendrite/templates/_overrides.yaml new file mode 100644 index 000000000..edb8ba83a --- /dev/null +++ b/helm/dendrite/templates/_overrides.yaml @@ -0,0 +1,16 @@ +{{- define "override.config" }} +{{- if .Values.postgresql.enabled }} +{{- $_ := set .Values.dendrite_config.global.database "connection_string" (print "postgresql://" .Values.postgresql.auth.username ":" .Values.postgresql.auth.password "@" .Release.Name "-postgresql/dendrite?sslmode=disable") -}} +{{ end }} +global: + private_key: /etc/dendrite/secrets/signing.key + jetstream: + in_memory: false +{{ if (gt (len (.Files.Glob "appservices/*")) 0) }} +app_service_api: + config_files: + {{- range $x, $y := .Files.Glob "appservices/*" }} + - /etc/dendrite/appservices/{{ base $x }} + {{ end }} +{{ end }} +{{ end }} diff --git a/helm/dendrite/templates/deployment.yaml b/helm/dendrite/templates/deployment.yaml new file mode 100644 index 000000000..629ffe528 --- /dev/null +++ b/helm/dendrite/templates/deployment.yaml @@ -0,0 +1,103 @@ +{{ template "validate.config" . }} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: {{ $.Release.Namespace }} + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} +spec: + selector: + matchLabels: + {{- include "dendrite.selectorLabels" . | nindent 6 }} + replicas: 1 + template: + metadata: + labels: + {{- include "dendrite.selectorLabels" . | nindent 8 }} + annotations: + confighash-global: secret-{{ .Values.global | toYaml | sha256sum | trunc 32 }} + confighash-clientapi: clientapi-{{ .Values.clientapi | toYaml | sha256sum | trunc 32 }} + confighash-federationapi: federationapi-{{ .Values.federationapi | toYaml | sha256sum | trunc 32 }} + confighash-mediaapi: mediaapi-{{ .Values.mediaapi | toYaml | sha256sum | trunc 32 }} + confighash-syncapi: syncapi-{{ .Values.syncapi | toYaml | sha256sum | trunc 32 }} + spec: + volumes: + - name: {{ include "dendrite.fullname" . }}-conf-vol + secret: + secretName: {{ include "dendrite.fullname" . }}-conf + - name: {{ include "dendrite.fullname" . }}-signing-key + secret: + secretName: {{ default (print ( include "dendrite.fullname" . ) "-signing-key") $.Values.signing_key.existingSecret | quote }} + {{- if (gt (len ($.Files.Glob "appservices/*")) 0) }} + - name: {{ include "dendrite.fullname" . }}-appservices + secret: + secretName: {{ include "dendrite.fullname" . }}-appservices-conf + {{- end }} + - name: {{ include "dendrite.fullname" . }}-jetstream + persistentVolumeClaim: + claimName: {{ default (print ( include "dendrite.fullname" . ) "-jetstream-pvc") $.Values.persistence.jetstream.existingClaim | quote }} + - name: {{ include "dendrite.fullname" . }}-media + persistentVolumeClaim: + claimName: {{ default (print ( include "dendrite.fullname" . ) "-media-pvc") $.Values.persistence.media.existingClaim | quote }} + - name: {{ include "dendrite.fullname" . }}-search + persistentVolumeClaim: + claimName: {{ default (print ( include "dendrite.fullname" . ) "-search-pvc") $.Values.persistence.search.existingClaim | quote }} + containers: + - name: {{ $.Chart.Name }} + {{- include "image.name" $.Values.image | nindent 8 }} + args: + - '--config' + - '/etc/dendrite/dendrite.yaml' + ports: + - name: http + containerPort: 8008 + protocol: TCP + {{- if $.Values.dendrite_config.global.profiling.enabled }} + env: + - name: PPROFLISTEN + value: "localhost:{{- $.Values.global.profiling.port -}}" + {{- end }} + resources: + {{- toYaml $.Values.resources | nindent 10 }} + volumeMounts: + - mountPath: /etc/dendrite/ + name: {{ include "dendrite.fullname" . }}-conf-vol + - mountPath: /etc/dendrite/secrets/ + name: {{ include "dendrite.fullname" . }}-signing-key + {{- if (gt (len ($.Files.Glob "appservices/*")) 0) }} + - mountPath: /etc/dendrite/appservices + name: {{ include "dendrite.fullname" . }}-appservices + readOnly: true + {{ end }} + - mountPath: {{ .Values.dendrite_config.media_api.base_path }} + name: {{ include "dendrite.fullname" . }}-media + - mountPath: {{ .Values.dendrite_config.global.jetstream.storage_path }} + name: {{ include "dendrite.fullname" . }}-jetstream + - mountPath: {{ .Values.dendrite_config.sync_api.search.index_path }} + name: {{ include "dendrite.fullname" . }}-search + livenessProbe: + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + httpGet: + path: /_dendrite/monitor/health + port: http + readinessProbe: + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + httpGet: + path: /_dendrite/monitor/health + port: http + startupProbe: + initialDelaySeconds: 5 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + httpGet: + path: /_dendrite/monitor/up + port: http \ No newline at end of file diff --git a/helm/dendrite/templates/ingress.yaml b/helm/dendrite/templates/ingress.yaml new file mode 100644 index 000000000..8f86ad723 --- /dev/null +++ b/helm/dendrite/templates/ingress.yaml @@ -0,0 +1,55 @@ +{{- if .Values.ingress.enabled -}} + {{- $fullName := include "dendrite.fullname" . -}} + {{- $svcPort := .Values.service.port -}} + {{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }} + {{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }} + {{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}} + {{- end }} + {{- end }} + {{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}} +apiVersion: networking.k8s.io/v1 + {{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}} +apiVersion: networking.k8s.io/v1beta1 + {{- else -}} +apiVersion: extensions/v1beta1 + {{- end }} +kind: Ingress +metadata: + name: {{ $fullName }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} + annotations: + {{- with .Values.ingress.annotations }} + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + {{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }} + ingressClassName: {{ .Values.ingress.className }} + {{- end }} + {{- if .Values.ingress.tls }} + tls: + {{- range .Values.ingress.tls }} + - hosts: + {{- range .hosts }} + - {{ . | quote }} + {{- end }} + secretName: {{ .secretName }} + {{- end }} + {{- end }} + rules: + - host: {{ .Values.ingress.hostName | quote }} + http: + paths: + - path: / + pathType: ImplementationSpecific + backend: + {{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }} + service: + name: {{ $fullName }} + port: + number: {{ $svcPort }} + {{- else }} + serviceName: {{ $fullName }} + servicePort: {{ $svcPort }} + {{- end }} + {{- end }} \ No newline at end of file diff --git a/helm/dendrite/templates/jobs.yaml b/helm/dendrite/templates/jobs.yaml new file mode 100644 index 000000000..76915694d --- /dev/null +++ b/helm/dendrite/templates/jobs.yaml @@ -0,0 +1,99 @@ +{{ if and .Values.signing_key.create (not .Values.signing_key.existingSecret ) }} +{{ $name := (print ( include "dendrite.fullname" . ) "-signing-key") }} +{{ $secretName := (print ( include "dendrite.fullname" . ) "-signing-key") }} +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ $name }} + labels: + app.kubernetes.io/component: signingkey-job +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: {{ $name }} + labels: + app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} +rules: + - apiGroups: + - "" + resources: + - secrets + resourceNames: + - {{ $secretName }} + verbs: + - get + - update + - patch +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: {{ $name }} + labels: + app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: {{ $name }} +subjects: + - kind: ServiceAccount + name: {{ $name }} + namespace: {{ .Release.Namespace }} +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: generate-signing-key + labels: + {{- include "dendrite.labels" . | nindent 4 }} +spec: + template: + spec: + restartPolicy: "Never" + serviceAccount: {{ $name }} + containers: + - name: upload-key + image: bitnami/kubectl + command: + - sh + - -c + - | + # check if key already exists + key=$(kubectl get secret {{ $secretName }} -o jsonpath="{.data['signing\.key']}" 2> /dev/null) + [ $? -ne 0 ] && echo "Failed to get existing secret" && exit 1 + [ -n "$key" ] && echo "Key already created, exiting." && exit 0 + # wait for signing key + while [ ! -f /etc/dendrite/signing-key.pem ]; do + echo "Waiting for signing key.." + sleep 5; + done + # update secret + kubectl patch secret {{ $secretName }} -p "{\"data\":{\"signing.key\":\"$(base64 /etc/dendrite/signing-key.pem | tr -d '\n')\"}}" + [ $? -ne 0 ] && echo "Failed to update secret." && exit 1 + echo "Signing key successfully created." + volumeMounts: + - mountPath: /etc/dendrite/ + name: signing-key + readOnly: true + - name: generate-key + {{- include "image.name" $.Values.image | nindent 8 }} + command: + - sh + - -c + - | + /usr/bin/generate-keys -private-key /etc/dendrite/signing-key.pem + chown 1001:1001 /etc/dendrite/signing-key.pem + volumeMounts: + - mountPath: /etc/dendrite/ + name: signing-key + volumes: + - name: signing-key + emptyDir: {} + parallelism: 1 + completions: 1 + backoffLimit: 1 +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/templates/pvc.yaml b/helm/dendrite/templates/pvc.yaml new file mode 100644 index 000000000..897957e60 --- /dev/null +++ b/helm/dendrite/templates/pvc.yaml @@ -0,0 +1,48 @@ +{{ if not .Values.persistence.media.existingClaim }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-media-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.persistence.media.capacity }} + storageClassName: {{ .Values.persistence.storageClass }} +{{ end }} +{{ if not .Values.persistence.jetstream.existingClaim }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-jetstream-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.persistence.jetstream.capacity }} + storageClassName: {{ .Values.persistence.storageClass }} +{{ end }} +{{ if not .Values.persistence.search.existingClaim }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-search-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.persistence.search.capacity }} + storageClassName: {{ .Values.persistence.storageClass }} +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/templates/secrets.yaml b/helm/dendrite/templates/secrets.yaml new file mode 100644 index 000000000..d4b8ecbf2 --- /dev/null +++ b/helm/dendrite/templates/secrets.yaml @@ -0,0 +1,33 @@ +{{ if (gt (len (.Files.Glob "appservices/*")) 0) }} +--- +apiVersion: v1 +kind: Secret +metadata: + name: {{ include "dendrite.fullname" . }}-appservices-conf + namespace: {{ .Release.Namespace }} +type: Opaque +data: +{{ (.Files.Glob "appservices/*").AsSecrets | indent 2 }} +{{ end }} +{{ if and .Values.signing_key.create (not .Values.signing_key.existingSecret) }} +--- +apiVersion: v1 +kind: Secret +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-signing-key + namespace: {{ .Release.Namespace }} +type: Opaque +{{ end }} + +--- +apiVersion: v1 +kind: Secret +type: Opaque +metadata: + name: {{ include "dendrite.fullname" . }}-conf + namespace: {{ .Release.Namespace }} +stringData: + dendrite.yaml: | + {{ toYaml ( mustMergeOverwrite .Values.dendrite_config ( fromYaml (include "override.config" .) ) .Values.dendrite_config ) | nindent 4 }} \ No newline at end of file diff --git a/helm/dendrite/templates/service.yaml b/helm/dendrite/templates/service.yaml new file mode 100644 index 000000000..365a43f04 --- /dev/null +++ b/helm/dendrite/templates/service.yaml @@ -0,0 +1,17 @@ +{{ template "validate.config" . }} +--- +apiVersion: v1 +kind: Service +metadata: + namespace: {{ $.Release.Namespace }} + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} +spec: + selector: + {{- include "dendrite.selectorLabels" . | nindent 4 }} + ports: + - name: http + protocol: TCP + port: 8008 + targetPort: 8008 \ No newline at end of file diff --git a/helm/dendrite/templates/tests/test-version.yaml b/helm/dendrite/templates/tests/test-version.yaml new file mode 100644 index 000000000..d88751325 --- /dev/null +++ b/helm/dendrite/templates/tests/test-version.yaml @@ -0,0 +1,17 @@ +--- +apiVersion: v1 +kind: Pod +metadata: + name: "{{ include "dendrite.fullname" . }}-test-version" + labels: + {{- include "dendrite.selectorLabels" . | nindent 4 }} + annotations: + "helm.sh/hook": test +spec: + containers: + - name: curl + image: curlimages/curl + imagePullPolicy: IfNotPresent + args: + - 'http://{{- include "dendrite.fullname" . -}}:8008/_matrix/client/versions' + restartPolicy: Never diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml new file mode 100644 index 000000000..2c6e80942 --- /dev/null +++ b/helm/dendrite/values.yaml @@ -0,0 +1,348 @@ +image: + # -- Docker repository/image to use + name: "ghcr.io/matrix-org/dendrite-monolith:v0.10.8" + # -- Kubernetes pullPolicy + pullPolicy: IfNotPresent + + +# signing key to use +signing_key: + # -- Create a new signing key, if not exists + create: true + # -- Use an existing secret + existingSecret: "" + +# -- Default resource requests/limits. +# @default -- sets some sane default values +resources: + requests: + memory: "512Mi" + + limits: + memory: "4096Mi" + +persistence: + # -- The storage class to use for volume claims. Defaults to the + # cluster default storage class. + storageClass: "" + jetstream: + # -- Use an existing volume claim for jetstream + existingClaim: "" + # -- PVC Storage Request for the jetstream volume + capacity: "1Gi" + media: + # -- Use an existing volume claim for media files + existingClaim: "" + # -- PVC Storage Request for the media volume + capacity: "1Gi" + search: + # -- Use an existing volume claim for the fulltext search index + existingClaim: "" + # -- PVC Storage Request for the search volume + capacity: "1Gi" + +dendrite_config: + version: 2 + global: + # -- **REQUIRED** Servername for this Dendrite deployment. + server_name: "" + + # -- The private key to use. (**NOTE**: This is overriden in Helm) + private_key: /etc/dendrite/secrets/signing.key + + # -- The server name to delegate server-server communications to, with optional port + # e.g. localhost:443 + well_known_server_name: "" + + # -- The server name to delegate client-server communications to, with optional port + # e.g. localhost:443 + well_known_client_name: "" + + # -- Lists of domains that the server will trust as identity servers to verify third + # party identifiers such as phone numbers and email addresses. + trusted_third_party_id_servers: + - matrix.org + - vector.im + + # -- The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) + # to old signing keys that were formerly in use on this domain name. These + # keys will not be used for federation request or event signing, but will be + # provided to any other homeserver that asks when trying to verify old events. + old_private_keys: + # If the old private key file is available: + # - private_key: old_matrix_key.pem + # expired_at: 1601024554498 + # If only the public key (in base64 format) and key ID are known: + # - public_key: mn59Kxfdq9VziYHSBzI7+EDPDcBS2Xl7jeUdiiQcOnM= + # key_id: ed25519:mykeyid + # expired_at: 1601024554498 + + # -- Disable federation. Dendrite will not be able to make any outbound HTTP requests + # to other servers and the federation API will not be exposed. + disable_federation: false + + key_validity_period: 168h0m0s + + database: + # -- The connection string for connections to Postgres. + # This will be set automatically if using the Postgres dependency + connection_string: "" + + # -- Default database maximum open connections + max_open_conns: 90 + # -- Default database maximum idle connections + max_idle_conns: 5 + # -- Default database maximum lifetime + conn_max_lifetime: -1 + + jetstream: + # -- Persistent directory to store JetStream streams in. + storage_path: "/data/jetstream" + # -- NATS JetStream server addresses if not using internal NATS. + addresses: [] + # -- The prefix for JetStream streams + topic_prefix: "Dendrite" + # -- Keep all data in memory. (**NOTE**: This is overriden in Helm to `false`) + in_memory: false + # -- Disables TLS validation. This should **NOT** be used in production. + disable_tls_validation: true + + cache: + # -- The estimated maximum size for the global cache in bytes, or in terabytes, + # gigabytes, megabytes or kilobytes when the appropriate 'tb', 'gb', 'mb' or + # 'kb' suffix is specified. Note that this is not a hard limit, nor is it a + # memory limit for the entire process. A cache that is too small may ultimately + # provide little or no benefit. + max_size_estimated: 1gb + # -- The maximum amount of time that a cache entry can live for in memory before + # it will be evicted and/or refreshed from the database. Lower values result in + # easier admission of new cache entries but may also increase database load in + # comparison to higher values, so adjust conservatively. Higher values may make + # it harder for new items to make it into the cache, e.g. if new rooms suddenly + # become popular. + max_age: 1h + + report_stats: + # -- Configures phone-home statistics reporting. These statistics contain the server + # name, number of active users and some information on your deployment config. + # We use this information to understand how Dendrite is being used in the wild. + enabled: false + # -- Endpoint to report statistics to. + endpoint: https://matrix.org/report-usage-stats/push + + presence: + # -- Controls whether we receive presence events from other servers + enable_inbound: false + # -- Controls whether we send presence events for our local users to other servers. + # (_May increase CPU/memory usage_) + enable_outbound: false + + server_notices: + # -- Server notices allows server admins to send messages to all users on the server. + enabled: false + # -- The local part for the user sending server notices. + local_part: "_server" + # -- The display name for the user sending server notices. + display_name: "Server Alerts" + # -- The avatar URL (as a mxc:// URL) name for the user sending server notices. + avatar_url: "" + # The room name to be used when sending server notices. This room name will + # appear in user clients. + room_name: "Server Alerts" + + # prometheus metrics + metrics: + # -- Whether or not Prometheus metrics are enabled. + enabled: false + # HTTP basic authentication to protect access to monitoring. + basic_auth: + # -- HTTP basic authentication username + user: "metrics" + # -- HTTP basic authentication password + password: metrics + + dns_cache: + # -- Whether or not the DNS cache is enabled. + enabled: false + # -- Maximum number of entries to hold in the DNS cache + cache_size: 256 + # -- Duration for how long DNS cache items should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) + cache_lifetime: "10m" + + profiling: + # -- Enable pprof. You will need to manually create a port forwarding to the deployment to access PPROF, + # as it will only listen on localhost and the defined port. + # e.g. `kubectl port-forward deployments/dendrite 65432:65432` + enabled: false + # -- pprof port, if enabled + port: 65432 + + # -- Configuration for experimental MSC's. (Valid values are: msc2836 and msc2946) + mscs: + mscs: + - msc2946 + # A list of enabled MSC's + # Currently valid values are: + # - msc2836 (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) + # - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) + + + app_service_api: + # -- Disable the validation of TLS certificates of appservices. This is + # not recommended in production since it may allow appservice traffic + # to be sent to an insecure endpoint. + disable_tls_validation: false + # -- Appservice config files to load on startup. (**NOTE**: This is overriden by Helm, if a folder `./appservices/` exists) + config_files: [] + + client_api: + # -- Prevents new users from being able to register on this homeserver, except when + # using the registration shared secret below. + registration_disabled: true + + # Prevents new guest accounts from being created. Guest registration is also + # disabled implicitly by setting 'registration_disabled' above. + guests_disabled: true + + # -- If set, allows registration by anyone who knows the shared secret, regardless of + # whether registration is otherwise disabled. + registration_shared_secret: "" + + # -- enable reCAPTCHA registration + enable_registration_captcha: false + # -- reCAPTCHA public key + recaptcha_public_key: "" + # -- reCAPTCHA private key + recaptcha_private_key: "" + # -- reCAPTCHA bypass secret + recaptcha_bypass_secret: "" + recaptcha_siteverify_api: "" + + # TURN server information that this homeserver should send to clients. + turn: + # -- Duration for how long users should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) + turn_user_lifetime: "24h" + turn_uris: [] + turn_shared_secret: "" + # -- The TURN username + turn_username: "" + # -- The TURN password + turn_password: "" + + rate_limiting: + # -- Enable rate limiting + enabled: true + # -- After how many requests a rate limit should be activated + threshold: 20 + # -- Cooloff time in milliseconds + cooloff_ms: 500 + # -- Users which should be exempt from rate limiting + exempt_user_ids: + + federation_api: + # -- Federation failure threshold. How many consecutive failures that we should + # tolerate when sending federation requests to a specific server. The backoff + # is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds, etc. + # The default value is 16 if not specified, which is circa 18 hours. + send_max_retries: 16 + # -- Disable TLS validation. This should **NOT** be used in production. + disable_tls_validation: false + prefer_direct_fetch: false + # -- Prevents Dendrite from keeping HTTP connections + # open for reuse for future requests. Connections will be closed quicker + # but we may spend more time on TLS handshakes instead. + disable_http_keepalives: false + # -- Perspective keyservers, to use as a backup when direct key fetch + # requests don't succeed. + # @default -- See value.yaml + key_perspectives: + - server_name: matrix.org + keys: + - key_id: ed25519:auto + public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw + - key_id: ed25519:a_RXGa + public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ + + media_api: + # -- The path to store media files (e.g. avatars) in + base_path: "/data/media_store" + # -- The max file size for uploaded media files + max_file_size_bytes: 10485760 + # Whether to dynamically generate thumbnails if needed. + dynamic_thumbnails: false + # -- The maximum number of simultaneous thumbnail generators to run. + max_thumbnail_generators: 10 + # -- A list of thumbnail sizes to be generated for media content. + # @default -- See value.yaml + thumbnail_sizes: + - width: 32 + height: 32 + method: crop + - width: 96 + height: 96 + method: crop + - width: 640 + height: 480 + method: scale + + sync_api: + # -- This option controls which HTTP header to inspect to find the real remote IP + # address of the client. This is likely required if Dendrite is running behind + # a reverse proxy server. + real_ip_header: X-Real-IP + # -- Configuration for the full-text search engine. + search: + # -- Whether fulltext search is enabled. + enabled: true + # -- The path to store the search index in. + index_path: "/data/search" + # -- The language most likely to be used on the server - used when indexing, to + # ensure the returned results match expectations. A full list of possible languages + # can be found [here](https://github.com/matrix-org/dendrite/blob/76db8e90defdfb9e61f6caea8a312c5d60bcc005/internal/fulltext/bleve.go#L25-L46) + language: "en" + + user_api: + # -- bcrypt cost to use when hashing passwords. + # (ranges from 4-31; 4 being least secure, 31 being most secure; _NOTE: Using a too high value can cause clients to timeout and uses more CPU._) + bcrypt_cost: 10 + # -- OpenID Token lifetime in milliseconds. + openid_token_lifetime_ms: 3600000 + # - Disable TLS validation when hitting push gateways. This should **NOT** be used in production. + push_gateway_disable_tls_validation: false + # -- Rooms to join users to after registration + auto_join_rooms: [] + + # -- Default logging configuration + logging: + - type: std + level: info + +postgresql: + # -- Enable and configure postgres as the database for dendrite. + # @default -- See value.yaml + enabled: false + image: + repository: bitnami/postgresql + tag: "15.1.0" + auth: + username: dendrite + password: changeme + database: dendrite + + persistence: + enabled: false + +ingress: + # -- Create an ingress for a monolith deployment + enabled: false + hosts: [] + className: "" + hostName: "" + # -- Extra, custom annotations + annotations: {} + + tls: [] + +service: + type: ClusterIP + port: 80 diff --git a/setup/config/config.go b/setup/config/config.go index 6523a2452..41d2b6674 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -228,7 +228,7 @@ func loadConfig( privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath) if c.Global.KeyID, c.Global.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { - return nil, err + return nil, fmt.Errorf("failed to load private_key: %w", err) } for _, v := range c.Global.VirtualHosts { @@ -242,7 +242,7 @@ func loadConfig( } privateKeyPath := absPath(basePath, v.PrivateKeyPath) if v.KeyID, v.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { - return nil, err + return nil, fmt.Errorf("failed to load private_key for virtualhost %s: %w", v.ServerName, err) } } From 002310390f874f69e52be1a8d20b17a3cb11d126 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:51:07 +0100 Subject: [PATCH 074/124] Output to docs folder, hopefully --- helm/cr.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/helm/cr.yaml b/helm/cr.yaml index f895ab8d6..014803cf1 100644 --- a/helm/cr.yaml +++ b/helm/cr.yaml @@ -1 +1,2 @@ -release-name-template: "helm-{{ .Name }}-{{ .Version }}" \ No newline at end of file +release-name-template: "helm-{{ .Name }}-{{ .Version }}" +package-path: docs/ \ No newline at end of file From 3fd95e60cc5fc3ae0610d0aca2177d8436a65ee1 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:54:04 +0100 Subject: [PATCH 075/124] Try that again --- helm/cr.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/cr.yaml b/helm/cr.yaml index 014803cf1..884c2b46b 100644 --- a/helm/cr.yaml +++ b/helm/cr.yaml @@ -1,2 +1,2 @@ release-name-template: "helm-{{ .Name }}-{{ .Version }}" -package-path: docs/ \ No newline at end of file +pages-index-path: docs/index.yaml \ No newline at end of file From 54b47a98e57cf210b568fa99ae159fd000012eb2 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 6 Jan 2023 11:49:59 -0700 Subject: [PATCH 076/124] Add curl to dendrite docker containers --- Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile b/Dockerfile index ede33e635..6da555c04 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,6 +27,7 @@ RUN --mount=target=. \ # The dendrite base image # FROM alpine:latest AS dendrite-base +RUN apk --update --no-cache add curl LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" LABEL org.opencontainers.image.licenses="Apache-2.0" From 0995dc48224b90432e38fa92345cf5735bca6090 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 6 Jan 2023 12:02:43 -0700 Subject: [PATCH 077/124] Add curl to dendrite-demo-pinecone docker container --- build/docker/Dockerfile.demo-pinecone | 1 + 1 file changed, 1 insertion(+) diff --git a/build/docker/Dockerfile.demo-pinecone b/build/docker/Dockerfile.demo-pinecone index facd1e3af..90f515167 100644 --- a/build/docker/Dockerfile.demo-pinecone +++ b/build/docker/Dockerfile.demo-pinecone @@ -17,6 +17,7 @@ RUN go build -trimpath -o bin/ ./cmd/create-account RUN go build -trimpath -o bin/ ./cmd/generate-keys FROM alpine:latest +RUN apk --update --no-cache add curl LABEL org.opencontainers.image.title="Dendrite (Pinecone demo)" LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" From b0c5af6674465a3384a1b55c84325e7989ce1eb5 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 10 Jan 2023 17:02:38 +0100 Subject: [PATCH 078/124] Fix `/login` issue causing wrong device list updates (#2922) Fixes https://github.com/matrix-org/dendrite/issues/2914 and possibly https://github.com/matrix-org/dendrite/issues/2073? --- clientapi/auth/login_test.go | 5 +- clientapi/auth/password.go | 5 ++ clientapi/routing/login_test.go | 152 ++++++++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 clientapi/routing/login_test.go diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index b79c573aa..044062c42 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -47,7 +48,7 @@ func TestLoginFromJSONReader(t *testing.T) { "password": "herpassword", "device_id": "adevice" }`, - WantUsername: "alice", + WantUsername: "@alice:example.com", WantDeviceID: "adevice", }, { @@ -174,7 +175,7 @@ func (ua *fakeUserInternalAPI) QueryAccountByPassword(ctx context.Context, req * return nil } res.Exists = true - res.Account = &uapi.Account{} + res.Account = &uapi.Account{UserID: userutil.MakeUserID(req.Localpart, req.ServerName)} return nil } diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 4de2b443c..f2b0383ab 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -101,6 +101,8 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } + // If we couldn't find the user by the lower cased localpart, try the provided + // localpart as is. if !res.Exists { err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: localpart, @@ -122,5 +124,8 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } } + // Set the user, so login.Username() can do the right thing + r.Identifier.User = res.Account.UserID + r.User = res.Account.UserID return &r.Login, nil } diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go new file mode 100644 index 000000000..d429d7f8c --- /dev/null +++ b/clientapi/routing/login_test.go @@ -0,0 +1,152 @@ +package routing + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestLogin(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bobUser := &test.User{ID: "@bob:test", AccountType: uapi.AccountTypeUser} + charlie := &test.User{ID: "@Charlie:test", AccountType: uapi.AccountTypeUser} + vhUser := &test.User{ID: "@vhuser:vh1"} + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.ClientAPI.RateLimiting.Enabled = false + // add a vhost + base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + }) + + rsAPI := roomserver.NewInternalAPI(base) + // Needed for /login + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, keyAPI, nil, &base.Cfg.MSCs, nil) + + // Create password + password := util.RandomString(8) + + // create the users + for _, u := range []*test.User{aliceAdmin, bobUser, vhUser, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + if !userRes.AccountCreated { + t.Fatalf("account not created") + } + } + + testCases := []struct { + name string + userID string + wantOK bool + }{ + { + name: "aliceAdmin can login", + userID: aliceAdmin.ID, + wantOK: true, + }, + { + name: "bobUser can login", + userID: bobUser.ID, + wantOK: true, + }, + { + name: "vhuser can login", + userID: vhUser.ID, + wantOK: true, + }, + { + name: "bob with uppercase can login", + userID: "@Bob:test", + wantOK: true, + }, + { + name: "Charlie can login (existing uppercase)", + userID: charlie.ID, + wantOK: true, + }, + { + name: "Charlie can not login with lowercase userID", + userID: strings.ToLower(charlie.ID), + wantOK: false, + }, + } + + ctx := context.Background() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": tc.userID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + + t.Logf("Response: %s", rec.Body.String()) + // get the response + resp := loginResponse{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + // everything OK + if !tc.wantOK && resp.AccessToken == "" { + return + } + if tc.wantOK && resp.AccessToken == "" { + t.Fatalf("expected accessToken after successful login but got none: %+v", resp) + } + + devicesResp := &uapi.QueryDevicesResponse{} + if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: resp.UserID}, devicesResp); err != nil { + t.Fatal(err) + } + for _, dev := range devicesResp.Devices { + // We expect the userID on the device to be the same as resp.UserID + if dev.UserID != resp.UserID { + t.Fatalf("unexpected userID on device: %s", dev.UserID) + } + } + }) + } + }) +} From 7482cd2b47cf19f4da9121461950409ea8f07a12 Mon Sep 17 00:00:00 2001 From: devonh Date: Tue, 10 Jan 2023 18:09:25 +0000 Subject: [PATCH 079/124] Handle DisplayName field in admin user registration endpoint (#2935) `/_synapse/admin/v1/register` has a `displayname` field that we were previously ignoring. This handles that field and adds the displayname to the new user if one was provided. --- clientapi/routing/register.go | 28 +++++-- clientapi/routing/register_secret.go | 13 ++-- clientapi/routing/register_secret_test.go | 2 +- clientapi/routing/register_test.go | 93 +++++++++++++++++++++++ 4 files changed, 123 insertions(+), 13 deletions(-) diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 6087bda0c..ff6a0900e 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -780,7 +780,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr, + req.Context(), userAPI, r.Username, r.ServerName, "", "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) @@ -800,7 +800,7 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr, + req.Context(), userAPI, r.Username, r.ServerName, "", r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) @@ -824,10 +824,10 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username string, serverName gomatrixserverlib.ServerName, + username string, serverName gomatrixserverlib.ServerName, displayName string, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, - displayName, deviceID *string, + deviceDisplayName, deviceID *string, accType userapi.AccountType, ) util.JSONResponse { if username == "" { @@ -887,12 +887,28 @@ func completeRegistration( } } + if displayName != "" { + nameReq := userapi.PerformUpdateDisplayNameRequest{ + Localpart: username, + ServerName: serverName, + DisplayName: displayName, + } + var nameRes userapi.PerformUpdateDisplayNameResponse + err = userAPI.SetDisplayName(ctx, &nameReq, &nameRes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("failed to set display name: " + err.Error()), + } + } + } + var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: username, ServerName: serverName, AccessToken: token, - DeviceDisplayName: displayName, + DeviceDisplayName: deviceDisplayName, DeviceID: deviceID, IPAddr: ipAddr, UserAgent: userAgent, @@ -1077,5 +1093,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.DisplayName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/register_secret.go b/clientapi/routing/register_secret.go index 1a974b77a..f384b604a 100644 --- a/clientapi/routing/register_secret.go +++ b/clientapi/routing/register_secret.go @@ -18,12 +18,13 @@ import ( ) type SharedSecretRegistrationRequest struct { - User string `json:"username"` - Password string `json:"password"` - Nonce string `json:"nonce"` - MacBytes []byte - MacStr string `json:"mac"` - Admin bool `json:"admin"` + User string `json:"username"` + Password string `json:"password"` + Nonce string `json:"nonce"` + MacBytes []byte + MacStr string `json:"mac"` + Admin bool `json:"admin"` + DisplayName string `json:"displayname,omitempty"` } func NewSharedSecretRegistrationRequest(reader io.ReadCloser) (*SharedSecretRegistrationRequest, error) { diff --git a/clientapi/routing/register_secret_test.go b/clientapi/routing/register_secret_test.go index a2ed35853..ca265d237 100644 --- a/clientapi/routing/register_secret_test.go +++ b/clientapi/routing/register_secret_test.go @@ -10,7 +10,7 @@ import ( func TestSharedSecretRegister(t *testing.T) { // these values have come from a local synapse instance to ensure compatibility - jsonStr := []byte(`{"admin":false,"mac":"f1ba8d37123866fd659b40de4bad9b0f8965c565","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice"}`) + jsonStr := []byte(`{"admin":false,"mac":"f1ba8d37123866fd659b40de4bad9b0f8965c565","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) sharedSecret := "dendritetest" req, err := NewSharedSecretRegistrationRequest(io.NopCloser(bytes.NewBuffer(jsonStr))) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index b8fd19e90..bccc1b79b 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -18,6 +18,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "reflect" @@ -35,7 +36,10 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" + "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" ) var ( @@ -570,3 +574,92 @@ func Test_register(t *testing.T) { } }) } + +func TestRegisterUserWithDisplayName(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.Global.ServerName = "server" + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + deviceName, deviceID := "deviceName", "deviceID" + expectedDisplayName := "DisplayName" + response := completeRegistration( + base.Context(), + userAPI, + "user", + "server", + expectedDisplayName, + "password", + "", + "localhost", + "user agent", + "session", + false, + &deviceName, + &deviceID, + api.AccountTypeAdmin, + ) + + assert.Equal(t, http.StatusOK, response.Code) + + req := api.QueryProfileRequest{UserID: "@user:server"} + var res api.QueryProfileResponse + err := userAPI.QueryProfile(base.Context(), &req, &res) + assert.NoError(t, err) + assert.Equal(t, expectedDisplayName, res.DisplayName) + }) +} + +func TestRegisterAdminUsingSharedSecret(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.Global.ServerName = "server" + sharedSecret := "dendritetest" + base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + + expectedDisplayName := "rabbit" + jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) + req, err := NewSharedSecretRegistrationRequest(io.NopCloser(bytes.NewBuffer(jsonStr))) + assert.NoError(t, err) + if err != nil { + t.Fatalf("failed to read request: %s", err) + } + + r := NewSharedSecretRegistration(sharedSecret) + + // force the nonce to be known + r.nonces.Set(req.Nonce, true, cache.DefaultExpiration) + + _, err = r.IsValidMacLogin(req.Nonce, req.User, req.Password, req.Admin, req.MacBytes) + assert.NoError(t, err) + + body := &bytes.Buffer{} + err = json.NewEncoder(body).Encode(req) + assert.NoError(t, err) + ssrr := httptest.NewRequest(http.MethodPost, "/", body) + + response := handleSharedSecretRegistration( + &base.Cfg.ClientAPI, + userAPI, + r, + ssrr, + ) + assert.Equal(t, http.StatusOK, response.Code) + + profilReq := api.QueryProfileRequest{UserID: "@alice:server"} + var profileRes api.QueryProfileResponse + err = userAPI.QueryProfile(base.Context(), &profilReq, &profileRes) + assert.NoError(t, err) + assert.Equal(t, expectedDisplayName, profileRes.DisplayName) + }) +} From 97ebd72b5a731decdf8f67742179e1adc0f9f30d Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 10 Jan 2023 16:26:41 -0700 Subject: [PATCH 080/124] Add FAQs based on commonly asked questions from the community --- docs/FAQ.md | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/docs/FAQ.md b/docs/FAQ.md index 816130515..4047bfffc 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -6,6 +6,12 @@ permalink: /faq # FAQ +## Why does Dendrite exist? + +Dendrite aims to provide a matrix compatible server that has low resource usage compared to [Synapse](https://github.com/matrix-org/synapse). +It also aims to provide more flexibility when scaling either up or down. +Dendrite's code is also very easy to hack on which makes it suitable for experimenting with new matrix features such as peer-to-peer. + ## Is Dendrite stable? Mostly, although there are still bugs and missing features. If you are a confident power user and you are happy to spend some time debugging things when they go wrong, then please try out Dendrite. If you are a community, organisation or business that demands stability and uptime, then Dendrite is not for you yet - please install Synapse instead. @@ -34,6 +40,10 @@ No, Dendrite has a very different database schema to Synapse and the two are not Monolith deployments are always preferred where possible, and at this time, are far better tested than polylith deployments are. The only reason to consider a polylith deployment is if you wish to run different Dendrite components on separate physical machines, but this is an advanced configuration which we don't recommend. +## Can I configure which port Dendrite listens on? + +Yes, use the cli flag `-http-bind-address`. + ## I've installed Dendrite but federation isn't working Check the [Federation Tester](https://federationtester.matrix.org). You need at least: @@ -42,6 +52,10 @@ Check the [Federation Tester](https://federationtester.matrix.org). You need at * A valid TLS certificate for that DNS name * Either DNS SRV records or well-known files +## Whenever I try to connect from Element it says unable to connect to homeserver + +Check that your dendrite instance is running. Otherwise this is most likely due to a reverse proxy misconfiguration. + ## Does Dendrite work with my favourite client? It should do, although we are aware of some minor issues: @@ -49,6 +63,10 @@ It should do, although we are aware of some minor issues: * **Element Android**: registration does not work, but logging in with an existing account does * **Hydrogen**: occasionally sync can fail due to gaps in the `since` parameter, but clearing the cache fixes this +## Is there a public instance of Dendrite I can try out? + +Use [dendrite.matrix.org](https://dendrite.matrix.org) which we officially support. + ## Does Dendrite support Space Summaries? Yes, [Space Summaries](https://github.com/matrix-org/matrix-spec-proposals/pull/2946) were merged into the Matrix Spec as of 2022-01-17 however, they are still treated as an MSC (Matrix Specification Change) in Dendrite. In order to enable Space Summaries in Dendrite, you must add the MSC to the MSC configuration section in the configuration YAML. If the MSC is not enabled, a user will typically see a perpetual loading icon on the summary page. See below for a demonstration of how to add to the Dendrite configuration: @@ -84,10 +102,42 @@ Remember to add the config file(s) to the `app_service_api` section of the confi Yes, you can do this by disabling federation - set `disable_federation` to `true` in the `global` section of the Dendrite configuration file. +## How can I migrate a room in order to change the internal ID? + +This can be done by performing a room upgrade. Use the command `/upgraderoom ` in Element to do this. + +## How do I reset somebody's password on my server? + +Use the admin endpoint [resetpassword](https://matrix-org.github.io/dendrite/administration/adminapi#post-_dendriteadminresetpassworduserid) + ## Should I use PostgreSQL or SQLite for my databases? Please use PostgreSQL wherever possible, especially if you are planning to run a homeserver that caters to more than a couple of users. +## What data needs to be kept if transferring/backing up Dendrite? + +The list of files that need to be stored is: +- matrix-key.pem +- dendrite.yaml +- the postgres or sqlite DB +- the media store +- the search index (although this can be regenerated) + +Note that this list may change / be out of date. We don't officially maintain instructions for migrations like this. + +## How can I prepare enough storage for media caches? + +This might be what you want: [matrix-media-repo](https://github.com/turt2live/matrix-media-repo) +We don't officially support this or any other dedicated media storage solutions. + +## Is there an upgrade guide for Dendrite? + +Run a newer docker image. We don't officially support deployments other than Docker. +Most of the time you should be able to just +- stop +- replace binary +- start + ## Dendrite is using a lot of CPU Generally speaking, you should expect to see some CPU spikes, particularly if you are joining or participating in large rooms. However, constant/sustained high CPU usage is not expected - if you are experiencing that, please join `#dendrite-dev:matrix.org` and let us know what you were doing when the @@ -102,6 +152,10 @@ not expected. Join `#dendrite-dev:matrix.org` and let us know what you were doin ballooned, or file a GitHub issue if you can. If you can take a [memory profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the memory usage is happening. +## Do I need to generate the self-signed certificate if I'm going to use a reverse proxy? + +No, if you already have a proper certificate from some provider, like Let's Encrypt, and use that on your reverse proxy, and the reverse proxy does TLS termination, then you’re good and can use HTTP to the dendrite process. + ## Dendrite is running out of PostgreSQL database connections You may need to revisit the connection limit of your PostgreSQL server and/or make changes to the `max_connections` lines in your Dendrite configuration. Be aware that each Dendrite component opens its own database connections and has its own connection limit, even in monolith mode! From 11a07d855dd7f08fcd386cb778cbdd353ddd5aa4 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 09:52:58 -0700 Subject: [PATCH 081/124] Initial attempt at adding cypress tests to ci --- .github/workflows/schedules.yaml | 35 ++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index d2a1f6e1f..63a60a241 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -73,3 +73,38 @@ jobs: path: | /logs/results.tap /logs/**/*.log* + + element_web: + runs-on: ubuntu-latest + steps: + - uses: tecolicom/actions-use-apt-tools@v1 + with: + # Our test suite includes some screenshot tests with unusual diacritics, which are + # supposed to be covered by STIXGeneral. + tools: fonts-stix + - uses: actions/checkout@v2 + with: + repository: matrix-org/matrix-react-sdk + - uses: actions/setup-node@v3 + with: + cache: 'yarn' + - name: Fetch layered build + run: scripts/ci/layered.sh + - name: Copy config + run: cp element.io/develop/config.json config.json + working-directory: ./element-web + - name: Build + env: + CI_PACKAGE: true + run: yarn build + working-directory: ./element-web + - name: "Run cypress tests" + uses: cypress-io/github-action@v4.1.1 + with: + browser: chrome + start: npx serve -p 8080 ./element-web/webapp + wait-on: 'http://localhost:8080' + env: + PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true + TMPDIR: ${{ runner.temp }} + HOMESERVER: 'dendrite' From 8fef692741f2e228c04b5f986ab65036e89947d2 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 10:10:24 -0700 Subject: [PATCH 082/124] Edit cypress config before running tests --- .github/workflows/schedules.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index 63a60a241..ba05f2083 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -98,6 +98,9 @@ jobs: CI_PACKAGE: true run: yarn build working-directory: ./element-web + - name: Edit Test Config + run: | + sed -i '/HOMESERVER/c\ HOMESERVER: "dendrite",' cypress.config.ts - name: "Run cypress tests" uses: cypress-io/github-action@v4.1.1 with: @@ -107,4 +110,3 @@ jobs: env: PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true TMPDIR: ${{ runner.temp }} - HOMESERVER: 'dendrite' From b297ea7379d6d5b953a810fe2475b549a917cc9a Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 10:40:38 -0700 Subject: [PATCH 083/124] Add cypress cloud recording --- .github/workflows/schedules.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index ba05f2083..45098925f 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -75,6 +75,7 @@ jobs: /logs/**/*.log* element_web: + timeout-minutes: 120 runs-on: ubuntu-latest steps: - uses: tecolicom/actions-use-apt-tools@v1 @@ -107,6 +108,13 @@ jobs: browser: chrome start: npx serve -p 8080 ./element-web/webapp wait-on: 'http://localhost:8080' + record: + true env: + # pass the Dashboard record key as an environment variable + CYPRESS_RECORD_KEY: ${{ secrets.CYPRESS_RECORD_KEY }} + # pass GitHub token to allow accurately detecting a build vs a re-run build + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true TMPDIR: ${{ runner.temp }} From 6ae1dd565c739efb1f847558e4170cdc0cb4085a Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 10:46:52 -0700 Subject: [PATCH 084/124] Revert "Add cypress cloud recording" This reverts commit b297ea7379d6d5b953a810fe2475b549a917cc9a. --- .github/workflows/schedules.yaml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index 45098925f..ba05f2083 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -75,7 +75,6 @@ jobs: /logs/**/*.log* element_web: - timeout-minutes: 120 runs-on: ubuntu-latest steps: - uses: tecolicom/actions-use-apt-tools@v1 @@ -108,13 +107,6 @@ jobs: browser: chrome start: npx serve -p 8080 ./element-web/webapp wait-on: 'http://localhost:8080' - record: - true env: - # pass the Dashboard record key as an environment variable - CYPRESS_RECORD_KEY: ${{ secrets.CYPRESS_RECORD_KEY }} - # pass GitHub token to allow accurately detecting a build vs a re-run build - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true TMPDIR: ${{ runner.temp }} From 25dfbc6ec3991ba04f317cbae4a4dd51bab6013e Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 10:47:37 -0700 Subject: [PATCH 085/124] Extend cypress test timeout in ci --- .github/workflows/schedules.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index ba05f2083..5636c4cf9 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -75,6 +75,7 @@ jobs: /logs/**/*.log* element_web: + timeout-minutes: 120 runs-on: ubuntu-latest steps: - uses: tecolicom/actions-use-apt-tools@v1 From 0491a8e3436bc17535a4c57d26376af83685a97c Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 12 Jan 2023 10:06:03 +0100 Subject: [PATCH 086/124] Fix room summary returning wrong heroes (#2930) This should fix #2910. Probably makes Sytest/Complement a bit upset, since this not using `sort.Strings` anymore. --- syncapi/storage/interface.go | 2 +- .../postgres/current_room_state_table.go | 92 +++++---- syncapi/storage/postgres/memberships_table.go | 27 --- syncapi/storage/shared/storage_sync.go | 58 +++++- .../sqlite3/current_room_state_table.go | 97 +++++++--- syncapi/storage/sqlite3/memberships_table.go | 39 ---- syncapi/storage/storage_test.go | 179 ++++++++++++++++++ syncapi/storage/tables/interface.go | 4 +- syncapi/storage/tables/memberships_test.go | 36 ---- syncapi/streams/stream_pdu.go | 54 +----- 10 files changed, 378 insertions(+), 210 deletions(-) diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 75afbce15..4e22f8a6f 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -45,7 +45,7 @@ type DatabaseTransaction interface { GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) - GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) + GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 48ed20021..3caafa14b 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" @@ -110,6 +111,15 @@ const selectSharedUsersSQL = "" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + ") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');" +const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2` + +const selectRoomHeroes = ` +SELECT state_key FROM syncapi_current_room_state +WHERE type = 'm.room.member' AND room_id = $1 AND membership = ANY($2) AND state_key != $3 +ORDER BY added_at, state_key +LIMIT 5 +` + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -122,6 +132,8 @@ type currentRoomStateStatements struct { selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt selectSharedUsersStmt *sql.Stmt + selectMembershipCountStmt *sql.Stmt + selectRoomHeroesStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -141,40 +153,21 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro return nil, err } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { - return nil, err - } - if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { - return nil, err - } - if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { - return nil, err - } - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertRoomStateStmt, upsertRoomStateSQL}, + {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL}, + {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL}, + {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL}, + {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL}, + {&s.selectCurrentStateStmt, selectCurrentStateSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, + {&s.selectJoinedUsersInRoomStmt, selectJoinedUsersInRoomSQL}, + {&s.selectEventsWithEventIDsStmt, selectEventsWithEventIDsSQL}, + {&s.selectStateEventStmt, selectStateEventSQL}, + {&s.selectSharedUsersStmt, selectSharedUsersSQL}, + {&s.selectMembershipCountStmt, selectMembershipCount}, + {&s.selectRoomHeroesStmt, selectRoomHeroes}, + }.Prepare(db) } // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. @@ -447,3 +440,34 @@ func (s *currentRoomStateStatements) SelectSharedUsers( } return result, rows.Err() } + +func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomHeroesStmt) + rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(memberships), excludeUserID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroesStmt: rows.close() failed") + + var stateKey string + result := make([]string, 0, 5) + for rows.Next() { + if err = rows.Scan(&stateKey); err != nil { + return nil, err + } + result = append(result, stateKey) + } + return result, rows.Err() +} + +func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + return count, nil +} diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index b555e8456..ac44b235f 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -19,10 +19,8 @@ import ( "database/sql" "fmt" - "github.com/lib/pq" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + ") t WHERE t.membership = $3" -const selectHeroesSQL = "" + - "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" - const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" @@ -81,7 +76,6 @@ WHERE ($3::text IS NULL OR t.membership = $3) type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt - selectHeroesStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt selectMembersStmt *sql.Stmt } @@ -95,7 +89,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { return s, sqlutil.StatementList{ {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, - {&s.selectHeroesStmt, selectHeroesSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembersStmt, selectMembersSQL}, }.Prepare(db) @@ -129,26 +122,6 @@ func (s *membershipsStatements) SelectMembershipCount( return } -func (s *membershipsStatements) SelectHeroes( - ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, -) (heroes []string, err error) { - stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt) - var rows *sql.Rows - rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships)) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") - var hero string - for rows.Next() { - if err = rows.Scan(&hero); err != nil { - return - } - heroes = append(heroes, hero) - } - return heroes, rows.Err() -} - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // string as the membership. diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 77afa0290..c6933486c 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/syncapi/types" @@ -92,8 +93,61 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos) } -func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { - return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships) +func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) { + summary := &types.Summary{Heroes: []string{}} + + joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join) + if err != nil { + return summary, err + } + inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite) + if err != nil { + return summary, err + } + summary.InvitedMemberCount = &inviteCount + summary.JoinedMemberCount = &joinCount + + // Get the room name and canonical alias, if any + filter := gomatrixserverlib.DefaultStateFilter() + filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias} + filterRooms := []string{roomID} + + filter.Types = &filterTypes + filter.Rooms = &filterRooms + evs, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, &filter, nil) + if err != nil { + return summary, err + } + + for _, ev := range evs { + switch ev.Type() { + case gomatrixserverlib.MRoomName: + if gjson.GetBytes(ev.Content(), "name").Str != "" { + return summary, nil + } + case gomatrixserverlib.MRoomCanonicalAlias: + if gjson.GetBytes(ev.Content(), "alias").Str != "" { + return summary, nil + } + } + } + + // If there's no room name or canonical alias, get the room heroes, excluding the user + heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite}) + if err != nil { + return summary, err + } + + // "When no joined or invited members are available, this should consist of the banned and left users" + if len(heroes) == 0 { + heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban}) + if err != nil { + return summary, err + } + } + summary.Heroes = heroes + + return summary, nil } func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 7a381f68b..6bc1b267a 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "strings" @@ -95,6 +96,15 @@ const selectSharedUsersSQL = "" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + ") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');" +const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2` + +const selectRoomHeroes = ` +SELECT state_key FROM syncapi_current_room_state +WHERE type = 'm.room.member' AND room_id = $1 AND state_key != $2 AND membership IN ($3) +ORDER BY added_at, state_key +LIMIT 5 +` + type currentRoomStateStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -107,6 +117,8 @@ type currentRoomStateStatements struct { //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic selectStateEventStmt *sql.Stmt //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic + selectMembershipCountStmt *sql.Stmt + //selectRoomHeroes *sql.Stmt - prepared at runtime due to variadic } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { @@ -129,31 +141,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t return nil, err } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return nil, err - } - //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { - // return nil, err - //} - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertRoomStateStmt, upsertRoomStateSQL}, + {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL}, + {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL}, + {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL}, + {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, + {&s.selectStateEventStmt, selectStateEventSQL}, + {&s.selectMembershipCountStmt, selectMembershipCount}, + }.Prepare(db) } // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. @@ -485,3 +482,53 @@ func (s *currentRoomStateStatements) SelectSharedUsers( return result, err } + +func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) { + params := make([]interface{}, len(memberships)+2) + params[0] = roomID + params[1] = excludeUserID + for k, v := range memberships { + params[k+2] = v + } + + query := strings.Replace(selectRoomHeroes, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) + var stmt *sql.Stmt + var err error + if txn != nil { + stmt, err = txn.Prepare(query) + } else { + stmt, err = s.db.Prepare(query) + } + if err != nil { + return []string{}, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRoomHeroes: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroes: rows.close() failed") + + var stateKey string + result := make([]string, 0, 5) + for rows.Next() { + if err = rows.Scan(&stateKey); err != nil { + return nil, err + } + result = append(result, stateKey) + } + return result, rows.Err() +} + +func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + return count, nil +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 7e54fac17..905a1e1a8 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -18,11 +18,9 @@ import ( "context" "database/sql" "fmt" - "strings" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + ") t WHERE t.membership = $3" -const selectHeroesSQL = "" + - "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5" - const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" @@ -99,7 +94,6 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembersStmt, selectMembersSQL}, - // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic }.Prepare(db) } @@ -131,39 +125,6 @@ func (s *membershipsStatements) SelectMembershipCount( return } -func (s *membershipsStatements) SelectHeroes( - ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, -) (heroes []string, err error) { - stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) - stmt, err := s.db.PrepareContext(ctx, stmtSQL) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed") - params := []interface{}{ - roomID, userID, - } - for _, membership := range memberships { - params = append(params, membership) - } - - stmt = sqlutil.TxStmt(txn, stmt) - var rows *sql.Rows - rows, err = stmt.QueryContext(ctx, params...) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") - var hero string - for rows.Next() { - if err = rows.Scan(&hero); err != nil { - return - } - heroes = append(heroes, hero) - } - return heroes, rows.Err() -} - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // string as the membership. diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 5ff185a32..166ddd233 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" ) var ctx = context.Background() @@ -664,3 +665,181 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ return &tok } */ + +func pointer[t any](s t) *t { + return &s +} + +func TestRoomSummary(t *testing.T) { + + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + + // Create some dummy users + moreUsers := []*test.User{} + moreUserIDs := []string{} + for i := 0; i < 10; i++ { + u := test.NewUser(t) + moreUsers = append(moreUsers, u) + moreUserIDs = append(moreUserIDs, u.ID) + } + + testCases := []struct { + name string + wantSummary *types.Summary + additionalEvents func(t *testing.T, room *test.Room) + }{ + { + name: "after initial creation", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{}}, + }, + { + name: "invited user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "invited user, but declined", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "joined user after invitation", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined/invited user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined/invited/left user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "leaving user after joining", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "many users", // heroes ordered by stream id + wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]}, + additionalEvents: func(t *testing.T, room *test.Room) { + for _, x := range moreUsers { + room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(x.ID)) + } + }, + }, + { + name: "canonical alias set", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{ + "alias": "myalias", + }, test.WithStateKey("")) + }, + }, + { + name: "room name set", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{ + "name": "my room name", + }, test.WithStateKey("")) + }, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close, closeBase := MustCreateDatabase(t, dbType) + defer close() + defer closeBase() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + r := test.NewRoom(t, alice) + + if tc.additionalEvents != nil { + tc.additionalEvents(t, r) + } + + // write the room before creating a transaction + MustWriteEvents(t, db, r.Events()) + + transaction, err := db.NewDatabaseTransaction(ctx) + assert.NoError(t, err) + defer transaction.Rollback() + + summary, err := transaction.GetRoomSummary(ctx, r.ID, alice.ID) + assert.NoError(t, err) + assert.Equal(t, tc.wantSummary, summary) + }) + } + }) +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index a0574b257..c02e4ecc5 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -115,6 +115,9 @@ type CurrentRoomState interface { SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) + + SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) + SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (int, error) } // BackwardsExtremities keeps track of backwards extremities for a room. @@ -185,7 +188,6 @@ type Receipts interface { type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) - SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) SelectMemberships( ctx context.Context, txn *sql.Tx, diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go index 0cee7f5a5..df593ae78 100644 --- a/syncapi/storage/tables/memberships_test.go +++ b/syncapi/storage/tables/memberships_test.go @@ -3,8 +3,6 @@ package tables_test import ( "context" "database/sql" - "reflect" - "sort" "testing" "time" @@ -88,43 +86,9 @@ func TestMembershipsTable(t *testing.T) { testUpsert(t, ctx, table, userEvents[0], alice, room) testMembershipCount(t, ctx, table, room) - testHeroes(t, ctx, table, alice, room, users) }) } -func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) { - - // Re-slice and sort the expected users - users = users[1:] - sort.Strings(users) - type testCase struct { - name string - memberships []string - wantHeroes []string - } - - testCases := []testCase{ - {name: "no memberships queried", memberships: []string{}}, - {name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships) - if err != nil { - t.Fatalf("unable to select heroes: %s", err) - } - if gotLen := len(got); gotLen != len(tc.wantHeroes) { - t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen) - } - - if !reflect.DeepEqual(got, tc.wantHeroes) { - t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got) - } - }) - } -} - func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { t.Run("membership counts are correct", func(t *testing.T) { // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index dd7845574..4664276cf 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "sort" "time" "github.com/matrix-org/dendrite/internal/caching" @@ -14,11 +13,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - - "github.com/matrix-org/dendrite/syncapi/notifier" ) // The max number of per-room goroutines to have running. @@ -339,7 +336,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case gomatrixserverlib.Join: jr := types.NewJoinResponse() if hasMembershipChange { - p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition) + jr.Summary, err = snapshot.GetRoomSummary(ctx, delta.RoomID, device.UserID) + if err != nil { + logrus.WithError(err).Warn("failed to get room summary") + } } jr.Timeline.PrevBatch = &prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) @@ -411,45 +411,6 @@ func applyHistoryVisibilityFilter( return events, nil } -func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { - // Work out how many members are in the room. - joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) - invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) - - jr.Summary.JoinedMemberCount = &joinedCount - jr.Summary.InvitedMemberCount = &invitedCount - - fetchStates := []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomName}, - {EventType: gomatrixserverlib.MRoomCanonicalAlias}, - } - // Check if the room has a name or a canonical alias - latestState := &roomserverAPI.QueryLatestEventsAndStateResponse{} - err := p.rsAPI.QueryLatestEventsAndState(ctx, &roomserverAPI.QueryLatestEventsAndStateRequest{StateToFetch: fetchStates, RoomID: roomID}, latestState) - if err != nil { - return - } - // Check if the room has a name or canonical alias, if so, return. - for _, ev := range latestState.StateEvents { - switch ev.Type() { - case gomatrixserverlib.MRoomName: - if gjson.GetBytes(ev.Content(), "name").Str != "" { - return - } - case gomatrixserverlib.MRoomCanonicalAlias: - if gjson.GetBytes(ev.Content(), "alias").Str != "" { - return - } - } - } - heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) - if err != nil { - return - } - sort.Strings(heroes) - jr.Summary.Heroes = heroes -} - func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, snapshot storage.DatabaseTransaction, @@ -493,7 +454,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return } - p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From) + jr.Summary, err = snapshot.GetRoomSummary(ctx, roomID, device.UserID) + if err != nil { + logrus.WithError(err).Warn("failed to get room summary") + } // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: From 477a44faa67eabba0f5d7f632b12fd6bb2d7ec5b Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 12 Jan 2023 09:22:53 -0700 Subject: [PATCH 087/124] Always initialize statistics server map --- federationapi/statistics/statistics.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 2ba99112c..0a44375c6 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -35,18 +35,13 @@ func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistic DB: db, FailuresUntilBlacklist: failuresUntilBlacklist, backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), } } // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { - // If the map hasn't been initialised yet then do that. - if s.servers == nil { - s.mutex.Lock() - s.servers = make(map[gomatrixserverlib.ServerName]*ServerStatistics) - s.mutex.Unlock() - } // Look up if we have statistics for this server already. s.mutex.RLock() server, found := s.servers[serverName] From eeeb3017d662ad6777c1398b325aa98bc36bae94 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 16 Jan 2023 11:52:30 +0000 Subject: [PATCH 088/124] Switch the default config option values for `recaptcha_sitekey_class` and `recaptcha_form_field` (#2939) Attempting to use the [web auth fallback mechanism](https://spec.matrix.org/v1.5/client-server-api/#fallback) for Google ReCAPTCHA with the default setting for `client_api.recaptcha_sitekey_class` of "g-recaptcha-response" results in no captcha being rendered: ![image](https://user-images.githubusercontent.com/1342360/212482321-14980045-6e20-4d59-adaa-59a01ad88367.png) I cross-checked the captcha code between [dendrite.matrix.org's fallback page](https://dendrite.matrix.org/_matrix/client/r0/auth/m.login.recaptcha/fallback/web?session=asdhjaksd) and [matrix-client.matrix.org's one](https://matrix-client.matrix.org/_matrix/client/r0/auth/m.login.recaptcha/fallback/web?session=asdhjaksd) (which both use the same captcha public key) and noticed a discrepancy in the `class` attribute of the div that renders the captcha. [ReCAPTCHA's docs state](https://developers.google.com/recaptcha/docs/v3#automatically_bind_the_challenge_to_a_button) to use "g-recaptcha" as the class for the submit button. I noticed this when user `@parappanon:parappa.party` reported that they were also seeing no captcha being rendered on their Dendrite instance. Changing `client_api.recaptcha_sitekey_class` to "g-recaptcha" caused their captcha to render properly as well. There may have been a change in the class name from ReCAPTCHA v2 to v3? The [docs for v2](https://developers.google.com/recaptcha/docs/display#auto_render) also request one uses "g-recaptcha" though. Thus I propose changing the default setting to unbreak people's recaptcha auth fallback pages. Should fix dendrite.matrix.org as well. --- setup/config/config_clientapi.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 11628b1b0..1deba6bb5 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -85,10 +85,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js" } if c.RecaptchaFormField == "" { - c.RecaptchaFormField = "g-recaptcha" + c.RecaptchaFormField = "g-recaptcha-response" } if c.RecaptchaSitekeyClass == "" { - c.RecaptchaSitekeyClass = "g-recaptcha-response" + c.RecaptchaSitekeyClass = "g-recaptcha" } checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) From 8582c7520abbfca680da9ba16e40a9a92b9fd21c Mon Sep 17 00:00:00 2001 From: Umar Getagazov Date: Tue, 17 Jan 2023 11:07:42 +0300 Subject: [PATCH 089/124] Omit state field from `/messages` response if empty (#2940) The field type is `[ClientEvent]` in the [spec](https://spec.matrix.org/v1.5/client-server-api/#get_matrixclientv3roomsroomidmessages), but right now `null` can also be returned. Omit the field completely if it's empty. Some clients (rightfully) assume it's either not present at all or it's of the right type (see https://github.com/matrix-org/matrix-react-sdk/pull/9913). ### Pull Request Checklist * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * The PR is a simple struct tag fix * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Umar Getagazov ` Signed-off-by: Umar Getagazov --- syncapi/routing/messages.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 0d740ebfc..cafba17c9 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -57,7 +57,7 @@ type messagesResp struct { StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token End string `json:"end,omitempty"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` - State []gomatrixserverlib.ClientEvent `json:"state"` + State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` } // OnIncomingMessagesRequest implements the /messages endpoint from the From 0d0280cf5ff71ec975b17d0f6dadcae7e46574b5 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 17 Jan 2023 10:08:23 +0100 Subject: [PATCH 090/124] `/sync` performance optimizations (#2927) Since #2849 there is no limit for the current state we fetch to calculate history visibility. In large rooms this can cause us to fetch thousands of membership events we don't really care about. This now only gets the state event types and senders in our timeline, which should significantly reduce the amount of events we fetch from the database. Also removes `MaxTopologicalPosition`, as it is an unnecessary DB call, given we use the result in `topological_position < $1` calls. --- federationapi/consumers/roomserver.go | 2 +- syncapi/routing/memberships.go | 19 +- syncapi/routing/messages.go | 31 +-- syncapi/storage/interface.go | 2 - .../output_room_events_topology_table.go | 19 -- syncapi/storage/shared/storage_sync.go | 17 +- .../output_room_events_topology_table.go | 16 -- syncapi/storage/storage_test.go | 88 ++++++- syncapi/storage/tables/interface.go | 2 - syncapi/streams/stream_pdu.go | 35 ++- syncapi/syncapi_test.go | 246 ++++++++++++++++++ 11 files changed, 372 insertions(+), 105 deletions(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 0c1080afa..52b5744a6 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -195,7 +195,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } // If we added new hosts, inform them about our known presence events for this room - if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { membership, _ := ore.Event.Membership() if membership == gomatrixserverlib.Join { s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 3fcc3235c..9ffdf513f 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -16,16 +16,16 @@ package routing import ( "encoding/json" + "math" "net/http" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) type getMembershipResponse struct { @@ -87,19 +87,18 @@ func GetMemberships( if err != nil { return jsonerror.InternalServerError() } + defer db.Rollback() // nolint: errcheck atToken, err := types.NewTopologyTokenFromString(at) if err != nil { + atToken = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} if queryRes.HasBeenInRoom && !queryRes.IsInRoom { // If you have left the room then this will be the members of the room when you left. atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) - } else { - // If you are joined to the room then this will be the current members of the room. - atToken, err = db.MaxTopologicalPosition(req.Context(), roomID) - } - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") - return jsonerror.InternalServerError() + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") + return jsonerror.InternalServerError() + } } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index cafba17c9..4a01ec357 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -17,6 +17,7 @@ package routing import ( "context" "fmt" + "math" "net/http" "sort" "time" @@ -177,10 +178,11 @@ func OnIncomingMessagesRequest( // If "to" isn't provided, it defaults to either the earliest stream // position (if we're going backward) or to the latest one (if we're // going forward). - to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed") - return jsonerror.InternalServerError() + to = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} + if backwardOrdering { + // go 1 earlier than the first event so we correctly fetch the earliest event + // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. + to = types.TopologyToken{} } wasToProvided = false } @@ -577,24 +579,3 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] return events, nil } - -// setToDefault returns the default value for the "to" query parameter of a -// request to /messages if not provided. It defaults to either the earliest -// topological position (if we're going backward) or to the latest one (if we're -// going forward). -// Returns an error if there was an issue with retrieving the latest position -// from the database -func setToDefault( - ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool, - roomID string, -) (to types.TopologyToken, err error) { - if backwardOrdering { - // go 1 earlier than the first event so we correctly fetch the earliest event - // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. - to = types.TopologyToken{} - } else { - to, err = snapshot.MaxTopologicalPosition(ctx, roomID) - } - - return -} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 4e22f8a6f..a4ba82327 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -84,8 +84,6 @@ type DatabaseTransaction interface { EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) - // MaxTopologicalPosition returns the highest topological position for a given room. - MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 6fab900eb..d0e99f267 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -65,14 +65,6 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" - // Select the max topological position for the room, then sort by stream position and take the highest, - // returning both topological and stream positions. -const selectMaxPositionInTopologySQL = "" + - "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + - " WHERE topological_position=(" + - "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + - ") ORDER BY stream_position DESC LIMIT 1" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" @@ -84,7 +76,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -107,9 +98,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { return nil, err } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { return nil, err } @@ -189,10 +177,3 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } - -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( - ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return -} diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index c6933486c..7b07cac5e 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "math" "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" @@ -269,16 +270,6 @@ func (d *DatabaseTransaction) BackwardExtremitiesForRoom( return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID) } -func (d *DatabaseTransaction) MaxTopologicalPosition( - ctx context.Context, roomID string, -) (types.TopologyToken, error) { - depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, err - } - return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil -} - func (d *DatabaseTransaction) EventPositionInTopology( ctx context.Context, eventID string, ) (types.TopologyToken, error) { @@ -297,11 +288,7 @@ func (d *DatabaseTransaction) StreamToTopologicalPosition( case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward return types.TopologyToken{PDUPosition: streamPos}, nil case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward - topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) - } - return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil + return types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, nil case err != nil: // some other error happened return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err) default: diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 81b264988..879456441 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -61,10 +61,6 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" -const selectMaxPositionInTopologySQL = "" + - "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 ORDER BY stream_position DESC" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" @@ -77,7 +73,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -102,9 +97,6 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { return nil, err } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { return nil, err } @@ -182,11 +174,3 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } - -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( - ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return -} diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 166ddd233..e65367d8b 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "math" "reflect" "testing" @@ -199,10 +200,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { _ = MustWriteEvents(t, db, events) WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { - from, err := snapshot.MaxTopologicalPosition(ctx, r.ID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } + from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} t.Logf("max topo pos = %+v", from) // head towards the beginning of time to := types.TopologyToken{} @@ -219,6 +217,88 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { }) } +func TestStreamToTopologicalPosition(t *testing.T) { + alice := test.NewUser(t) + r := test.NewRoom(t, alice) + + testCases := []struct { + name string + roomID string + streamPos types.StreamPosition + backwardOrdering bool + wantToken types.TopologyToken + }{ + { + name: "forward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "forward ordering not found streamPos returns max position", + roomID: r.ID, + streamPos: 100, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + { + name: "backward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "backward ordering not found streamPos returns maxDepth with param pduPosition", + roomID: r.ID, + streamPos: 100, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 5, PDUPosition: 100}, + }, + { + name: "backward non-existent room returns zero token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 0, PDUPosition: 1}, + }, + { + name: "forward non-existent room returns max token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close, closeBase := MustCreateDatabase(t, dbType) + defer close() + defer closeBase() + + txn, err := db.NewDatabaseTransaction(ctx) + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + MustWriteEvents(t, db, r.Events()) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + token, err := txn.StreamToTopologicalPosition(ctx, tc.roomID, tc.streamPos, tc.backwardOrdering) + if err != nil { + t.Fatal(err) + } + if tc.wantToken != token { + t.Fatalf("expected token %q, got %q", tc.wantToken, token) + } + }) + } + + }) +} + /* // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index c02e4ecc5..8366a67dc 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -91,8 +91,6 @@ type Topology interface { SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error) // SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to. SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) - // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. - SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 4664276cf..44013e37c 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -384,19 +384,32 @@ func applyHistoryVisibilityFilter( roomID, userID string, recentEvents []*gomatrixserverlib.HeaderedEvent, ) ([]*gomatrixserverlib.HeaderedEvent, error) { - // We need to make sure we always include the latest states events, if they are in the timeline. - // We grep at least limit * 2 events, to ensure we really get the needed events. - filter := gomatrixserverlib.DefaultStateFilter() - stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) - if err != nil { - // Not a fatal error, we can continue without the stateEvents, - // they are only needed if there are state events in the timeline. - logrus.WithError(err).Warnf("Failed to get current room state for history visibility") + // We need to make sure we always include the latest state events, if they are in the timeline. + alwaysIncludeIDs := make(map[string]struct{}) + var stateTypes []string + var senders []string + for _, ev := range recentEvents { + if ev.StateKey() != nil { + stateTypes = append(stateTypes, ev.Type()) + senders = append(senders, ev.Sender()) + } } - alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents)) - for _, ev := range stateEvents { - alwaysIncludeIDs[ev.EventID()] = struct{}{} + + // Only get the state again if there are state events in the timeline + if len(stateTypes) > 0 { + filter := gomatrixserverlib.DefaultStateFilter() + filter.Types = &stateTypes + filter.Senders = &senders + stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) + if err != nil { + return nil, fmt.Errorf("failed to get current room state for history visibility calculation: %w", err) + } + + for _, ev := range stateEvents { + alwaysIncludeIDs[ev.EventID()] = struct{}{} + } } + startTime := time.Now() events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") if err != nil { diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 483274481..666a872f8 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -521,6 +521,252 @@ func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatr } } +func TestGetMembership(t *testing.T) { + alice := test.NewUser(t) + + aliceDev := userapi.Device{ + ID: "ALICEID", + UserID: alice.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + bob := test.NewUser(t) + bobDev := userapi.Device{ + ID: "BOBID", + UserID: bob.ID, + AccessToken: "notjoinedtoanyrooms", + } + + testCases := []struct { + name string + roomID string + additionalEvents func(t *testing.T, room *test.Room) + request func(t *testing.T, room *test.Room) *http.Request + wantOK bool + wantMemberCount int + useSleep bool // :/ + }{ + { + name: "/members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "/members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + }, + { + name: "Alice leaves before Bob joins, should not be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "Alice leaves after Bob joins, should be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 2, + }, + { + name: "/joined_members - Alice leaves, shouldn't be able to see members ", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: false, + }, + { + name: "'at' specified, returns memberships before Bob joins", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "at": "t2_5", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "'membership=leave' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "membership": "leave", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=join' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "join", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=leave' & 'membership=join' specified, returns correct memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "leave", + "membership": "join", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "non-existent room ID", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", "!notavalidroom:test"), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: false, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + // Use an actual roomserver for this + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + room := test.NewRoom(t, alice) + t.Cleanup(func() { + t.Logf("running cleanup for %s", tc.name) + }) + // inject additional events + if tc.additionalEvents != nil { + tc.additionalEvents(t, room) + } + if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // wait for the events to come down sync + if tc.useSleep { + time.Sleep(time.Millisecond * 100) + } else { + syncUntil(t, base, aliceDev.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) + return gjson.Get(syncBody, path).Exists() + }) + } + + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, tc.request(t, room)) + if w.Code != 200 && tc.wantOK { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + t.Logf("[%s] Resp: %s", tc.name, w.Body.String()) + + // check we got the expected events + if tc.wantOK { + memberCount := len(gjson.GetBytes(w.Body.Bytes(), "chunk").Array()) + if memberCount != tc.wantMemberCount { + t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount) + } + } + }) + } + }) +} + func TestSendToDevice(t *testing.T) { test.WithAllDatabases(t, testSendToDevice) } From b55a7c238fb4b4db9ff4da0a25f0f83316d20f5e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 17 Jan 2023 19:04:02 +0100 Subject: [PATCH 091/124] Version 0.10.9 (#2942) --- CHANGES.md | 20 ++++++++++++++++++++ internal/version.go | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index f5a82cfe2..fa8230659 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,25 @@ # Changelog +## Dendrite 0.10.9 (2023-01-17) + +### Features + +* Stale device lists are now cleaned up on startup, removing entries for users the server doesn't share a room with anymore +* Dendrite now has its own Helm chart +* Guest access is now handled correctly (disallow joins, kick guests on revocation of guest access, as well as over federation) + +### Fixes + +* Push rules have seen several tweaks and fixes, which should, for example, fix notifications for `m.read_receipts` +* Outgoing presence will now correctly be sent to newly joined hosts +* Fixes the `/_dendrite/admin/resetPassword/{userID}` admin endpoint to use the correct variable +* Federated backfilling for medium/large rooms has been fixed +* `/login` causing wrong device list updates has been resolved +* `/sync` should now return the correct room summary heroes +* The default config options for `recaptcha_sitekey_class` and `recaptcha_form_field` are now set correctly +* `/messages` now omits empty `state` to be more spec compliant (contributed by [handlerug](https://github.com/handlerug)) +* `/sync` has been optimised to only query state events for history visibility if they are really needed + ## Dendrite 0.10.8 (2022-11-29) ### Features diff --git a/internal/version.go b/internal/version.go index 685237b9e..ff31dd784 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 8 + VersionPatch = 9 VersionTag = "" // example: "rc1" ) From 67f5c5bc1e837bbdee14d7d3388984ed8960528a Mon Sep 17 00:00:00 2001 From: genofire Date: Wed, 18 Jan 2023 08:45:34 +0100 Subject: [PATCH 092/124] =?UTF-8?q?fix(helm):=20extract=20image=20tag=20to?= =?UTF-8?q?=20value=20(and=20use=20as=20default=20from=20Chart.=E2=80=A6?= =?UTF-8?q?=20(#2934)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit improve image tag handling on the default helm way. with usage of appVersion from: https://github.com/matrix-org/dendrite/blob/0995dc48224b90432e38fa92345cf5735bca6090/helm/dendrite/Chart.yaml#L4 maybe you like to review @S7evinK ? ### Pull Request Checklist * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Geno ` --- helm/dendrite/Chart.yaml | 4 ++-- helm/dendrite/templates/_helpers.tpl | 4 +++- helm/dendrite/templates/deployment.yaml | 4 ++-- helm/dendrite/templates/jobs.yaml | 3 ++- helm/dendrite/templates/service.yaml | 2 +- helm/dendrite/values.yaml | 6 ++++-- 6 files changed, 14 insertions(+), 9 deletions(-) diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml index 15d1e6d19..6e6641c8d 100644 --- a/helm/dendrite/Chart.yaml +++ b/helm/dendrite/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: dendrite -version: "0.10.8" -appVersion: "0.10.8" +version: "0.10.9" +appVersion: "0.10.9" description: Dendrite Matrix Homeserver type: application keywords: diff --git a/helm/dendrite/templates/_helpers.tpl b/helm/dendrite/templates/_helpers.tpl index 291f351bc..026706588 100644 --- a/helm/dendrite/templates/_helpers.tpl +++ b/helm/dendrite/templates/_helpers.tpl @@ -15,9 +15,11 @@ {{- define "image.name" -}} -image: {{ .name }} +{{- with .Values.image -}} +image: {{ .repository }}:{{ .tag | default (printf "v%s" $.Chart.AppVersion) }} imagePullPolicy: {{ .pullPolicy }} {{- end -}} +{{- end -}} {{/* Expand the name of the chart. diff --git a/helm/dendrite/templates/deployment.yaml b/helm/dendrite/templates/deployment.yaml index 629ffe528..b463c7d0b 100644 --- a/helm/dendrite/templates/deployment.yaml +++ b/helm/dendrite/templates/deployment.yaml @@ -45,8 +45,8 @@ spec: persistentVolumeClaim: claimName: {{ default (print ( include "dendrite.fullname" . ) "-search-pvc") $.Values.persistence.search.existingClaim | quote }} containers: - - name: {{ $.Chart.Name }} - {{- include "image.name" $.Values.image | nindent 8 }} + - name: {{ .Chart.Name }} + {{- include "image.name" . | nindent 8 }} args: - '--config' - '/etc/dendrite/dendrite.yaml' diff --git a/helm/dendrite/templates/jobs.yaml b/helm/dendrite/templates/jobs.yaml index 76915694d..c10f358b0 100644 --- a/helm/dendrite/templates/jobs.yaml +++ b/helm/dendrite/templates/jobs.yaml @@ -8,6 +8,7 @@ metadata: name: {{ $name }} labels: app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: Role @@ -80,7 +81,7 @@ spec: name: signing-key readOnly: true - name: generate-key - {{- include "image.name" $.Values.image | nindent 8 }} + {{- include "image.name" . | nindent 8 }} command: - sh - -c diff --git a/helm/dendrite/templates/service.yaml b/helm/dendrite/templates/service.yaml index 365a43f04..3b571df1f 100644 --- a/helm/dendrite/templates/service.yaml +++ b/helm/dendrite/templates/service.yaml @@ -13,5 +13,5 @@ spec: ports: - name: http protocol: TCP - port: 8008 + port: {{ .Values.service.port }} targetPort: 8008 \ No newline at end of file diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index 2c6e80942..87027a886 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -1,8 +1,10 @@ image: # -- Docker repository/image to use - name: "ghcr.io/matrix-org/dendrite-monolith:v0.10.8" + repository: "ghcr.io/matrix-org/dendrite-monolith" # -- Kubernetes pullPolicy pullPolicy: IfNotPresent + # Overrides the image tag whose default is the chart appVersion. + tag: "" # signing key to use @@ -345,4 +347,4 @@ ingress: service: type: ClusterIP - port: 80 + port: 8008 From 738686ae686004c5efa9fe2096502cdc426c6dd8 Mon Sep 17 00:00:00 2001 From: Neil Date: Thu, 19 Jan 2023 20:02:32 +0000 Subject: [PATCH 093/124] Add `/_dendrite/admin/purgeRoom/{roomID}` (#2662) This adds a new admin endpoint `/_dendrite/admin/purgeRoom/{roomID}`. It completely erases all database entries for a given room ID. The roomserver will start by clearing all data for that room and then will generate an output event to notify downstream components (i.e. the sync API and federation API) to do the same. It does not currently clear media and it is currently not implemented for SQLite since it relies on SQL array operations right now. Co-authored-by: Neil Alexander Co-authored-by: Till Faelligen <2353100+S7evinK@users.noreply.github.com> --- clientapi/admin_test.go | 103 ++++++++++- clientapi/routing/admin.go | 32 ++++ clientapi/routing/routing.go | 6 + federationapi/consumers/roomserver.go | 15 +- federationapi/storage/interface.go | 2 + federationapi/storage/shared/storage.go | 15 ++ internal/sqlutil/sql.go | 21 +++ internal/sqlutil/sqlutil_test.go | 51 +++++- roomserver/api/api.go | 1 + roomserver/api/api_trace.go | 10 ++ roomserver/api/output.go | 8 + roomserver/api/perform.go | 8 + roomserver/internal/perform/perform_admin.go | 37 ++++ roomserver/inthttp/client.go | 12 ++ roomserver/inthttp/server.go | 5 + roomserver/roomserver_test.go | 165 ++++++++++++++++++ roomserver/storage/interface.go | 1 + .../storage/postgres/purge_statements.go | 133 ++++++++++++++ roomserver/storage/postgres/rooms_table.go | 14 ++ roomserver/storage/postgres/storage.go | 5 + roomserver/storage/shared/storage.go | 16 ++ .../storage/sqlite3/purge_statements.go | 153 ++++++++++++++++ roomserver/storage/sqlite3/rooms_table.go | 14 ++ .../storage/sqlite3/state_block_table.go | 3 +- .../storage/sqlite3/state_snapshot_table.go | 33 +++- roomserver/storage/sqlite3/storage.go | 6 + roomserver/storage/tables/interface.go | 7 + syncapi/consumers/roomserver.go | 21 +++ syncapi/storage/interface.go | 2 + .../postgres/backwards_extremities_table.go | 27 +-- syncapi/storage/postgres/invites_table.go | 31 ++-- syncapi/storage/postgres/memberships_table.go | 22 ++- .../postgres/notification_data_table.go | 12 ++ .../postgres/output_room_events_table.go | 12 ++ .../output_room_events_topology_table.go | 42 ++--- syncapi/storage/postgres/peeks_table.go | 39 +++-- syncapi/storage/postgres/receipt_table.go | 27 +-- syncapi/storage/shared/storage_consumer.go | 14 -- syncapi/storage/shared/storage_sync.go | 47 +++++ .../sqlite3/backwards_extremities_table.go | 27 +-- syncapi/storage/sqlite3/invites_table.go | 31 ++-- syncapi/storage/sqlite3/memberships_table.go | 12 ++ .../sqlite3/notification_data_table.go | 12 ++ .../sqlite3/output_room_events_table.go | 12 ++ .../output_room_events_topology_table.go | 42 ++--- syncapi/storage/sqlite3/peeks_table.go | 39 +++-- syncapi/storage/sqlite3/receipt_table.go | 27 +-- syncapi/storage/tables/interface.go | 9 + 48 files changed, 1213 insertions(+), 170 deletions(-) create mode 100644 roomserver/storage/postgres/purge_statements.go create mode 100644 roomserver/storage/sqlite3/purge_statements.go diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 0d973f350..c7ca019ff 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -7,9 +7,12 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -41,7 +44,7 @@ func TestAdminResetPassword(t *testing.T) { userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) keyAPI.SetUserAPI(userAPI) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil) + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) // Create the users in the userapi and login accessTokens := map[*test.User]string{ @@ -112,6 +115,7 @@ func TestAdminResetPassword(t *testing.T) { } for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID) if tc.requestOpt != nil { @@ -132,3 +136,100 @@ func TestAdminResetPassword(t *testing.T) { } }) } + +func TestPurgeRoom(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t) + room := test.NewRoom(t, aliceAdmin, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + room.CreateAndInsert(t, aliceAdmin, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + keyAPI.SetUserAPI(userAPI) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) + + // Create the users in the userapi and login + accessTokens := map[*test.User]string{ + aliceAdmin: "", + } + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + } + + testCases := []struct { + name string + roomID string + wantOK bool + }{ + {name: "Can purge existing room", wantOK: true, roomID: room.ID}, + {name: "Can not purge non-existent room", wantOK: false, roomID: "!doesnotexist:localhost"}, + {name: "rejects invalid room ID", wantOK: false, roomID: "@doesnotexist:localhost"}, + } + + for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/purgeRoom/"+tc.roomID) + + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + + rec := httptest.NewRecorder() + base.DendriteAdminMux.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + + }) +} diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index dbd913376..4b4dedfd1 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -1,6 +1,7 @@ package routing import ( + "context" "encoding/json" "fmt" "net/http" @@ -98,6 +99,37 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } } +func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + roomID, ok := vars["roomID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Expecting room ID."), + } + } + res := &roomserverAPI.PerformAdminPurgeRoomResponse{} + if err := rsAPI.PerformAdminPurgeRoom( + context.Background(), + &roomserverAPI.PerformAdminPurgeRoomRequest{ + RoomID: roomID, + }, + res, + ); err != nil { + return util.ErrorResponse(err) + } + if err := res.Error; err != nil { + return err.JSONResponse() + } + return util.JSONResponse{ + Code: 200, + JSON: res, + } +} + func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { if req.Body == nil { return util.JSONResponse{ diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 09c2cd02f..93f6ea901 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -165,6 +165,12 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", + httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminPurgeRoom(req, cfg, device, rsAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 52b5744a6..82a4db3f7 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/queue" @@ -90,8 +91,10 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms msg := msgs[0] // Guaranteed to exist if onMessage is called receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType)) - // Only handle events we care about - if receivedType != api.OutputTypeNewRoomEvent && receivedType != api.OutputTypeNewInboundPeek { + // Only handle events we care about, avoids unneeded unmarshalling + switch receivedType { + case api.OutputTypeNewRoomEvent, api.OutputTypeNewInboundPeek, api.OutputTypePurgeRoom: + default: return true } @@ -126,6 +129,14 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return false } + case api.OutputTypePurgeRoom: + log.WithField("room_id", output.PurgeRoom.RoomID).Warn("Purging room from federation API") + if err := s.db.PurgeRoom(ctx, output.PurgeRoom.RoomID); err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from federation API") + } else { + logrus.WithField("room_id", output.PurgeRoom.RoomID).Warn("Room purged from federation API") + } + default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 276cd9a50..2b4d905fc 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -71,4 +71,6 @@ type Database interface { GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) // DeleteExpiredEDUs cleans up expired EDUs DeleteExpiredEDUs(ctx context.Context) error + + PurgeRoom(ctx context.Context, roomID string) error } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 1e1ea9e17..6cda55725 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -259,3 +259,18 @@ func (d *Database) GetNotaryKeys( }) return sks, err } + +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge joined hosts: %w", err) + } + if err := d.FederationInboundPeeks.DeleteInboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge inbound peeks: %w", err) + } + if err := d.FederationOutboundPeeks.DeleteOutboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge outbound peeks: %w", err) + } + return nil + }) +} diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 19483b268..81c055edd 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -124,6 +124,11 @@ type QueryProvider interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } +// ExecProvider defines the interface for querys used by RunLimitedVariablesExec. +type ExecProvider interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + // SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement // SQLlite can handle. See https://www.sqlite.org/limits.html for more information. const SQLite3MaxVariables = 999 @@ -153,6 +158,22 @@ func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvide return nil } +// RunLimitedVariablesExec split up a query with more variables than the used database can handle in multiple queries. +func RunLimitedVariablesExec(ctx context.Context, query string, qp ExecProvider, variables []interface{}, limit uint) error { + var start int + for start < len(variables) { + n := minOfInts(len(variables)-start, int(limit)) + nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1) + _, err := qp.ExecContext(ctx, nextQuery, variables[start:start+n]...) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("ExecContext returned an error") + return err + } + start = start + n + } + return nil +} + // StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. type StatementList []struct { Statement **sql.Stmt diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go index 79469cddc..c40757893 100644 --- a/internal/sqlutil/sqlutil_test.go +++ b/internal/sqlutil/sqlutil_test.go @@ -3,10 +3,11 @@ package sqlutil import ( "context" "database/sql" + "errors" "reflect" "testing" - sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/DATA-DOG/go-sqlmock" ) func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { @@ -164,6 +165,54 @@ func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { } } +func TestRunLimitedVariablesExec(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + + // Query and expect two queries to be executed + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + variables := []interface{}{ + 1, 2, 3, 4, + } + + query := "DELETE FROM WHERE id IN ($1)" + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables, 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 3 parameters, still queries two times + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:3], 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 2 parameters, queries only once + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:2], 2); err != nil { + t.Fatal(err) + } + + // Test with invalid query (typo) should return an error + mock.ExpectExec(`DELTE FROM`). + WillReturnResult(sqlmock.NewResult(0, 0)). + WillReturnError(errors.New("typo in query")) + + if err = RunLimitedVariablesExec(context.Background(), "DELTE FROM", db, variables[:2], 2); err == nil { + t.Fatal("expected an error, but got none") + } +} + func assertNoError(t *testing.T, err error, msg string) { t.Helper() if err == nil { diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 420ef278a..a8228ae81 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -151,6 +151,7 @@ type ClientRoomserverAPI interface { PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error + PerformAdminPurgeRoom(ctx context.Context, req *PerformAdminPurgeRoomRequest, res *PerformAdminPurgeRoomResponse) error PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index b23263d17..166b651a2 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -137,6 +137,16 @@ func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser( return err } +func (t *RoomserverInternalAPITrace) PerformAdminPurgeRoom( + ctx context.Context, + req *PerformAdminPurgeRoomRequest, + res *PerformAdminPurgeRoomResponse, +) error { + err := t.Impl.PerformAdminPurgeRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformAdminPurgeRoom req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) PerformAdminDownloadState( ctx context.Context, req *PerformAdminDownloadStateRequest, diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 36d0625c7..0c0f52c45 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -55,6 +55,8 @@ const ( OutputTypeNewInboundPeek OutputType = "new_inbound_peek" // OutputTypeRetirePeek indicates that the kafka event is an OutputRetirePeek OutputTypeRetirePeek OutputType = "retire_peek" + // OutputTypePurgeRoom indicates the event is an OutputPurgeRoom + OutputTypePurgeRoom OutputType = "purge_room" ) // An OutputEvent is an entry in the roomserver output kafka log. @@ -78,6 +80,8 @@ type OutputEvent struct { NewInboundPeek *OutputNewInboundPeek `json:"new_inbound_peek,omitempty"` // The content of event with type OutputTypeRetirePeek RetirePeek *OutputRetirePeek `json:"retire_peek,omitempty"` + // The content of the event with type OutputPurgeRoom + PurgeRoom *OutputPurgeRoom `json:"purge_room,omitempty"` } // Type of the OutputNewRoomEvent. @@ -257,3 +261,7 @@ type OutputRetirePeek struct { UserID string DeviceID string } + +type OutputPurgeRoom struct { + RoomID string +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e789b9568..83cb0460a 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -241,6 +241,14 @@ type PerformAdminEvacuateUserResponse struct { Error *PerformError } +type PerformAdminPurgeRoomRequest struct { + RoomID string `json:"room_id"` +} + +type PerformAdminPurgeRoomResponse struct { + Error *PerformError `json:"error,omitempty"` +} + type PerformAdminDownloadStateRequest struct { RoomID string `json:"room_id"` UserID string `json:"user_id"` diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index d42f4e45d..3256162b4 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) type Admin struct { @@ -242,6 +243,42 @@ func (r *Admin) PerformAdminEvacuateUser( return nil } +func (r *Admin) PerformAdminPurgeRoom( + ctx context.Context, + req *api.PerformAdminPurgeRoomRequest, + res *api.PerformAdminPurgeRoomResponse, +) error { + // Validate we actually got a room ID and nothing else + if _, _, err := gomatrixserverlib.SplitID('!', req.RoomID); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Malformed room ID: %s", err), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Purging room from roomserver") + if err := r.DB.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Warn("Failed to purge room from roomserver") + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: err.Error(), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Room purged from roomserver") + + return r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{ + { + Type: api.OutputTypePurgeRoom, + PurgeRoom: &api.OutputPurgeRoom{ + RoomID: req.RoomID, + }, + }, + }) +} + func (r *Admin) PerformAdminDownloadState( ctx context.Context, req *api.PerformAdminDownloadStateRequest, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 8a2e0a03c..556a137ba 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -40,6 +40,7 @@ const ( RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom" RoomserverPerformAdminEvacuateUserPath = "/roomserver/performAdminEvacuateUser" RoomserverPerformAdminDownloadStatePath = "/roomserver/performAdminDownloadState" + RoomserverPerformAdminPurgeRoomPath = "/roomserver/performAdminPurgeRoom" // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" @@ -285,6 +286,17 @@ func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser( ) } +func (h *httpRoomserverInternalAPI) PerformAdminPurgeRoom( + ctx context.Context, + request *api.PerformAdminPurgeRoomRequest, + response *api.PerformAdminPurgeRoomResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformAdminPurgeRoom", h.roomserverURL+RoomserverPerformAdminPurgeRoomPath, + h.httpClient, ctx, request, response, + ) +} + // QueryLatestEventsAndState implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 4d21909b7..f3a51b0b1 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -65,6 +65,11 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", enableMetrics, r.PerformAdminEvacuateUser), ) + internalAPIMux.Handle( + RoomserverPerformAdminPurgeRoomPath, + httputil.MakeInternalRPCAPI("RoomserverPerformAdminPurgeRoom", enableMetrics, r.PerformAdminPurgeRoom), + ) + internalAPIMux.Handle( RoomserverPerformAdminDownloadStatePath, httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", enableMetrics, r.PerformAdminDownloadState), diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 595ceb526..3ec2560d6 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -14,6 +14,10 @@ import ( userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/roomserver" @@ -223,3 +227,164 @@ func Test_QueryLeftUsers(t *testing.T) { }) } + +func TestPurgeRoom(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + inviteEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, db, close := mustCreateDatabase(t, dbType) + defer close() + + jsCtx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsCtx, &base.Cfg.Global.JetStream) + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // some dummy entries to validate after purging + publishResp := &api.PerformPublishResponse{} + if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil { + t.Fatal(err) + } + if publishResp.Error != nil { + t.Fatal(publishResp.Error) + } + + isPublished, err := db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if !isPublished { + t.Fatalf("room should be published before purging") + } + + aliasResp := &api.SetRoomAliasResponse{} + if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", UserID: alice.ID}, aliasResp); err != nil { + t.Fatal(err) + } + // check the alias is actually there + aliasesResp := &api.GetAliasesForRoomIDResponse{} + if err = rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: room.ID}, aliasesResp); err != nil { + t.Fatal(err) + } + wantAliases := 1 + if gotAliases := len(aliasesResp.Aliases); gotAliases != wantAliases { + t.Fatalf("expected %d aliases, got %d", wantAliases, gotAliases) + } + + // validate the room exists before purging + roomInfo, err := db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo == nil { + t.Fatalf("room does not exist") + } + // remember the roomInfo before purging + existingRoomInfo := roomInfo + + // validate there is an invite for bob + nids, err := db.EventStateKeyNIDs(ctx, []string{bob.ID}) + if err != nil { + t.Fatal(err) + } + bobNID, ok := nids[bob.ID] + if !ok { + t.Fatalf("%s does not exist", bob.ID) + } + + _, inviteEventIDs, _, err := db.GetInvitesForUser(ctx, roomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + wantInviteCount := 1 + if inviteCount := len(inviteEventIDs); inviteCount != wantInviteCount { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + if inviteEventIDs[0] != inviteEvent.EventID() { + t.Fatalf("expected invite event ID %s, got %s", inviteEvent.EventID(), inviteEventIDs[0]) + } + + // purge the room from the database + purgeResp := &api.PerformAdminPurgeRoomResponse{} + if err = rsAPI.PerformAdminPurgeRoom(ctx, &api.PerformAdminPurgeRoomRequest{RoomID: room.ID}, purgeResp); err != nil { + t.Fatal(err) + } + + // wait for all consumers to process the purge event + var sum = 1 + timeout := time.Second * 5 + deadline, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for sum > 0 { + if deadline.Err() != nil { + t.Fatalf("test timed out after %s", timeout) + } + sum = 0 + consumerCh := jsCtx.Consumers(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)) + for x := range consumerCh { + sum += x.NumAckPending + } + time.Sleep(time.Millisecond) + } + + roomInfo, err = db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo != nil { + t.Fatalf("room should not exist after purging: %+v", roomInfo) + } + + // validation below + + // There should be no invite left + _, inviteEventIDs, _, err = db.GetInvitesForUser(ctx, existingRoomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + + if inviteCount := len(inviteEventIDs); inviteCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + + // aliases should be deleted + aliases, err := db.GetAliasesForRoomID(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if aliasCount := len(aliases); aliasCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", 0, aliasCount) + } + + // published room should be deleted + isPublished, err = db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if isPublished { + t.Fatalf("room should not be published after purging") + } + }) +} diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 92bc2e66f..e0b9c56b3 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -173,5 +173,6 @@ type Database interface { GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) + PurgeRoom(ctx context.Context, roomID string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/postgres/purge_statements.go b/roomserver/storage/postgres/purge_statements.go new file mode 100644 index 000000000..efba439bd --- /dev/null +++ b/roomserver/storage/postgres/purge_statements.go @@ -0,0 +1,133 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid = ANY(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids && ANY(" + + " SELECT ARRAY_AGG(event_nid) FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id = ANY(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateBlockEntriesSQL = "" + + "DELETE FROM roomserver_state_block WHERE state_block_nid = ANY(" + + " SELECT DISTINCT UNNEST(state_block_nids) FROM roomserver_state_snapshots WHERE room_nid = $1" + + ")" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateBlockEntriesStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt +} + +func PreparePurgeStatements(db *sql.DB) (*purgeStatements, error) { + s := &purgeStatements{} + + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + {&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateBlockEntriesStmt, + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 994399532..c8346733d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -58,6 +58,9 @@ const insertRoomNIDSQL = "" + const selectRoomNIDSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1 FOR UPDATE" + const selectLatestEventNIDsSQL = "" + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" @@ -85,6 +88,7 @@ const bulkSelectRoomNIDsSQL = "" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -106,6 +110,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 23a5f79eb..872084383 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -189,6 +189,10 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db) + if err != nil { + return err + } d.Database = shared.Database{ DB: db, Cache: cache, @@ -206,6 +210,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room MembershipTable: membership, PublishedTable: published, RedactionsTable: redactions, + Purge: purge, } return nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 725cc5bc7..654b078d2 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -43,6 +43,7 @@ type Database struct { MembershipTable tables.Membership PublishedTable tables.Published RedactionsTable tables.Redactions + Purge tables.Purge GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } @@ -1445,6 +1446,21 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +// PurgeRoom removes all information about a given room from the roomserver. +// For large rooms this operation may take a considerable amount of time. +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err := d.RoomsTable.SelectRoomNIDForUpdate(ctx, txn, roomID) + if err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("room %s does not exist", roomID) + } + return fmt.Errorf("failed to lock the room: %w", err) + } + return d.Purge.PurgeRoom(ctx, txn, roomNID, roomID) + }) +} + func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/roomserver/storage/sqlite3/purge_statements.go b/roomserver/storage/sqlite3/purge_statements.go new file mode 100644 index 000000000..c7b4d27a5 --- /dev/null +++ b/roomserver/storage/sqlite3/purge_statements.go @@ -0,0 +1,153 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid IN (" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids IN(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id IN(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt + stateSnapshot *stateSnapshotStatements +} + +func PreparePurgeStatements(db *sql.DB, stateSnapshot *stateSnapshotStatements) (*purgeStatements, error) { + s := &purgeStatements{stateSnapshot: stateSnapshot} + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + //{&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + if err := s.purgeStateBlocks(ctx, txn, roomNID); err != nil { + return err + } + + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} + +func (s *purgeStatements) purgeStateBlocks( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) error { + // Get all stateBlockNIDs + stateBlockNIDs, err := s.stateSnapshot.selectStateBlockNIDsForRoomNID(ctx, txn, roomNID) + if err != nil { + return err + } + params := make([]interface{}, len(stateBlockNIDs)) + seenNIDs := make(map[types.StateBlockNID]struct{}, len(stateBlockNIDs)) + // dedupe NIDs + for k, v := range stateBlockNIDs { + if _, ok := seenNIDs[v]; ok { + continue + } + params[k] = v + seenNIDs[v] = struct{}{} + } + + query := "DELETE FROM roomserver_state_block WHERE state_block_nid IN($1)" + return sqlutil.RunLimitedVariablesExec(ctx, query, txn, params, sqlutil.SQLite3MaxVariables) +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 25b611b3e..7556b3461 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -74,10 +74,14 @@ const bulkSelectRoomIDsSQL = "" + const bulkSelectRoomNIDsSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -105,6 +109,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, }.Prepare(db) } @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 4e67d4da1..ae8181cfa 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -68,7 +67,7 @@ func CreateStateBlockTable(db *sql.DB) error { return err } -func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func PrepareStateBlockTable(db *sql.DB) (*stateBlockStatements, error) { s := &stateBlockStatements{ db: db, } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 73827522c..930ad14dd 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -62,10 +62,14 @@ const bulkSelectStateBlockNIDsSQL = "" + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" +const selectStateBlockNIDsForRoomNID = "" + + "SELECT state_block_nids FROM roomserver_state_snapshots WHERE room_nid = $1" + type stateSnapshotStatements struct { db *sql.DB insertStateStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt + selectStateBlockNIDsStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -73,7 +77,7 @@ func CreateStateSnapshotTable(db *sql.DB) error { return err } -func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) { s := &stateSnapshotStatements{ db: db, } @@ -81,6 +85,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + {&s.selectStateBlockNIDsStmt, selectStateBlockNIDsForRoomNID}, }.Prepare(db) } @@ -146,3 +151,29 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( ) ([]types.EventNID, error) { return nil, tables.OptimisationNotSupportedError } + +func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.StateBlockNID, error) { + var res []types.StateBlockNID + rows, err := sqlutil.TxStmt(txn, s.selectStateBlockNIDsStmt).QueryContext(ctx, roomNID) + if err != nil { + return res, nil + } + defer internal.CloseAndLogIfError(ctx, rows, "selectStateBlockNIDsForRoomNID: rows.close() failed") + + var stateBlockNIDs []types.StateBlockNID + var stateBlockNIDsJSON string + for rows.Next() { + if err = rows.Scan(&stateBlockNIDsJSON); err != nil { + return nil, err + } + if err = json.Unmarshal([]byte(stateBlockNIDsJSON), &stateBlockNIDs); err != nil { + return nil, err + } + + res = append(res, stateBlockNIDs...) + } + + return res, rows.Err() +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 01c3f879c..392edd289 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -197,6 +197,11 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db, stateSnapshot) + if err != nil { + return err + } + d.Database = shared.Database{ DB: db, Cache: cache, @@ -215,6 +220,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room PublishedTable: published, RedactionsTable: redactions, GetRoomUpdaterFn: d.GetRoomUpdater, + Purge: purge, } return nil } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 80fcf72dd..64145f83d 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -73,6 +73,7 @@ type Events interface { type Rooms interface { InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error) SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) + SelectRoomNIDForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error @@ -173,6 +174,12 @@ type Redactions interface { MarkRedactionValidated(ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool) error } +type Purge interface { + PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, + ) error +} + // StrippedEvent represents a stripped event for returning extracted content values. type StrippedEvent struct { RoomID string diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 1b67f5684..21838039a 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -23,6 +23,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -127,6 +128,12 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms s.onRetirePeek(s.ctx, *output.RetirePeek) case api.OutputTypeRedactedEvent: err = s.onRedactEvent(s.ctx, *output.RedactedEvent) + case api.OutputTypePurgeRoom: + err = s.onPurgeRoom(s.ctx, *output.PurgeRoom) + if err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from sync API") + return true // non-fatal, as otherwise we end up in a loop of trying to purge the room + } default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", @@ -473,6 +480,20 @@ func (s *OutputRoomEventConsumer) onRetirePeek( s.notifier.OnRetirePeek(msg.RoomID, msg.UserID, msg.DeviceID, types.StreamingToken{PDUPosition: sp}) } +func (s *OutputRoomEventConsumer) onPurgeRoom( + ctx context.Context, req api.OutputPurgeRoom, +) error { + logrus.WithField("room_id", req.RoomID).Warn("Purging room from sync API") + + if err := s.db.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Error("Failed to purge room from sync API") + return err + } else { + logrus.WithField("room_id", req.RoomID).Warn("Room purged from sync API") + return nil + } +} + func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.HeaderedEvent) (*gomatrixserverlib.HeaderedEvent, error) { if event.StateKey() == nil { return event, nil diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index a4ba82327..a7a127e3a 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -132,6 +132,8 @@ type Database interface { // PurgeRoomState completely purges room state from the sync API. This is done when // receiving an output event that completely resets the state. PurgeRoomState(ctx context.Context, roomID string) error + // PurgeRoom entirely eliminates a room from the sync API, timeline, state and all. + PurgeRoom(ctx context.Context, roomID string) error // UpsertAccountData keeps track of new or updated account data, by saving the type // of the new/updated data, and the user ID and room ID the data is related to (empty) // room ID means the data isn't specific to any room) diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index 8fc92091f..c20d860a7 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -47,10 +47,14 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -59,16 +63,12 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -106,3 +106,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index aada70d5e..151bffa5d 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -62,11 +62,15 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -181,3 +179,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index ac44b235f..47833893a 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -65,18 +65,22 @@ const selectMembershipCountSQL = "" + const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + const selectMembersSQL = ` -SELECT event_id FROM ( - SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC -) t -WHERE ($3::text IS NULL OR t.membership = $3) - AND ($4::text IS NULL OR t.membership <> $4) + SELECT event_id FROM ( + SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC + ) t + WHERE ($3::text IS NULL OR t.membership = $3) + AND ($4::text IS NULL OR t.membership <> $4) ` type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt selectMembersStmt *sql.Stmt } @@ -90,6 +94,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, {&s.selectMembersStmt, selectMembersSQL}, }.Prepare(db) } @@ -139,6 +144,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go index 2c7b24800..7edfd54a6 100644 --- a/syncapi/storage/postgres/notification_data_table.go +++ b/syncapi/storage/postgres/notification_data_table.go @@ -37,6 +37,7 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, }.Prepare(db) } @@ -44,6 +45,7 @@ type notificationDataStatements struct { upsertRoomUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt } const notificationDataSchema = ` @@ -70,6 +72,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { err = sqlutil.TxStmt(txn, r.upsertRoomUnreadCounts).QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) return @@ -106,3 +111,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 3b69b26f6..0075fc8d3 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -176,6 +176,9 @@ const selectContextAfterEventSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id ASC LIMIT $3" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" type outputRoomEventsStatements struct { @@ -193,6 +196,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt selectSearchStmt *sql.Stmt } @@ -230,6 +234,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, {&s.selectSearchStmt, selectSearchSQL}, }.Prepare(db) } @@ -658,6 +663,13 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { return result, rows.Err() } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit) if err != nil { diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index d0e99f267..2382fca5c 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -18,11 +18,12 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -71,6 +72,9 @@ const selectStreamToTopologicalPositionAscSQL = "" + const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt @@ -78,6 +82,7 @@ type outputRoomEventsTopologyStatements struct { selectPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -86,25 +91,15 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // InsertEventInTopology inserts the given event in the room's topology, based @@ -177,3 +172,10 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } + +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go index e20a4882f..64183073d 100644 --- a/syncapi/storage/postgres/peeks_table.go +++ b/syncapi/storage/postgres/peeks_table.go @@ -65,6 +65,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB insertPeekStmt *sql.Stmt @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { @@ -83,25 +87,15 @@ func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { s := &peekStatements{ db: db, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -184,3 +178,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index 327a7a372..0fcbebfcb 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -62,11 +62,15 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { @@ -86,16 +90,12 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { r := &receiptStatements{ db: db, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { @@ -138,3 +138,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index df2338cf8..aeeebb1d2 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -242,20 +242,6 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e return nil } -func (d *Database) PurgeRoomState( - ctx context.Context, roomID string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. - if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { - return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) - } - return nil - }) -} - func (d *Database) WriteEvent( ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 7b07cac5e..8385b95a5 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -649,6 +649,53 @@ func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) return d.Presence.GetMaxPresenceID(ctx, d.txn) } +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.BackwardExtremities.PurgeBackwardExtremities(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge backward extremities: %w", err) + } + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge current room state: %w", err) + } + if err := d.Invites.PurgeInvites(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge invites: %w", err) + } + if err := d.Memberships.PurgeMemberships(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge memberships: %w", err) + } + if err := d.NotificationData.PurgeNotificationData(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge notification data: %w", err) + } + if err := d.OutputEvents.PurgeEvents(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events: %w", err) + } + if err := d.Topology.PurgeEventsTopology(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events topology: %w", err) + } + if err := d.Peeks.PurgePeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge peeks: %w", err) + } + if err := d.Receipts.PurgeReceipts(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge receipts: %w", err) + } + return nil + }) +} + +func (d *Database) PurgeRoomState( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) + } + return nil + }) +} + func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) { id, err := d.Relations.SelectMaxRelationID(ctx, d.txn) return types.StreamPosition(id), err diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 3a5fd6be3..2d8cf2ed2 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -47,11 +47,15 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { db *sql.DB insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -62,16 +66,12 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -109,3 +109,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index e2dbcd5c8..19450099a 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -57,6 +57,9 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -64,6 +67,7 @@ type inviteEventsStatements struct { selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Inv if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -192,3 +190,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 905a1e1a8..2cc46a10a 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -72,6 +72,9 @@ SELECT event_id FROM AND ($4 IS NULL OR t.membership <> $4) ` +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt @@ -79,6 +82,7 @@ type membershipsStatements struct { //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic selectMembershipForUserStmt *sql.Stmt selectMembersStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -94,6 +98,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembersStmt, selectMembersSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, }.Prepare(db) } @@ -142,6 +147,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index 6242898e1..af2b2c074 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -38,6 +38,7 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, // {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime }.Prepare(db) } @@ -47,6 +48,7 @@ type notificationDataStatements struct { streamIDStatements *StreamIDStatements upsertRoomUnreadCounts *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt //selectUserUnreadCountsForRooms *sql.Stmt } @@ -73,6 +75,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { pos, err = r.streamIDStatements.nextNotificationID(ctx, nil) if err != nil { @@ -124,3 +129,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 1aa4bfff7..db708c083 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -120,6 +120,9 @@ const selectContextAfterEventSQL = "" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -130,6 +133,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt //selectSearchStmt *sql.Stmt - prepared at runtime } @@ -163,6 +167,7 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, //{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime }.Prepare(db) } @@ -666,6 +671,13 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [ return } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { params := make([]interface{}, len(types)) for i := range types { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 879456441..dc698de2d 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,10 +18,11 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -67,6 +68,9 @@ const selectStreamToTopologicalPositionAscSQL = "" + const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { db *sql.DB insertEventInTopologyStmt *sql.Stmt @@ -75,6 +79,7 @@ type outputRoomEventsTopologyStatements struct { selectPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -85,25 +90,15 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // insertEventInTopology inserts the given event in the room's topology, based @@ -174,3 +169,10 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } + +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 4ef51b103..5d5200abc 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -64,6 +64,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) { @@ -84,25 +88,15 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks db: db, streamIDStatements: streamID, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -204,3 +198,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index a4a9b4395..ca3d80fb4 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -58,12 +58,16 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB streamIDStatements *StreamIDStatements upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) { @@ -84,16 +88,12 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re db: db, streamIDStatements: streamID, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } // UpsertReceipt creates new user receipts @@ -153,3 +153,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 8366a67dc..145e197cc 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -39,6 +39,7 @@ type Invites interface { // for the room. SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeInvites(ctx context.Context, txn *sql.Tx, roomID string) error } type Peeks interface { @@ -48,6 +49,7 @@ type Peeks interface { SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error) SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgePeeks(ctx context.Context, txn *sql.Tx, roomID string) error } type Events interface { @@ -75,6 +77,8 @@ type Events interface { SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + + PurgeEvents(ctx context.Context, txn *sql.Tx, roomID string) error ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) } @@ -93,6 +97,7 @@ type Topology interface { SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) + PurgeEventsTopology(ctx context.Context, txn *sql.Tx, roomID string) error } type CurrentRoomState interface { @@ -146,6 +151,7 @@ type BackwardsExtremities interface { SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) + PurgeBackwardExtremities(ctx context.Context, txn *sql.Tx, roomID string) error } // SendToDevice tracks send-to-device messages which are sent to individual @@ -181,12 +187,14 @@ type Receipts interface { UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeReceipts(ctx context.Context, txn *sql.Tx, roomID string) error } type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + PurgeMemberships(ctx context.Context, txn *sql.Tx, roomID string) error SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, @@ -198,6 +206,7 @@ type NotificationData interface { UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) + PurgeNotificationData(ctx context.Context, txn *sql.Tx, roomID string) error } type Ignores interface { From ce2bfc3f2e507a012044906af7f25c9dc52873d7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 20 Jan 2023 12:45:56 +0100 Subject: [PATCH 094/124] Make tests more reliable (#2948) When using `testrig.CreateBase` and then using that base for other `NewInternalAPI` calls, we never actually shutdown the components. `testrig.CreateBase` returns a `close` function, which only removes the database, so still running components have issues connecting to the database, since we ripped it out underneath it - which can result in "Disk I/O" or "pq deadlock detected" issues. --- federationapi/federationapi.go | 5 +---- federationapi/federationapi_test.go | 6 +++--- federationapi/routing/routing.go | 17 ++++++++++++----- setup/base/base.go | 2 ++ test/testrig/base.go | 12 ++++++++++-- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 87eb751f5..ce0ce98e9 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -85,10 +85,7 @@ func AddPublicRoutes( } routing.Setup( - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - cfg, + base, rsAPI, f, keyRing, federation, userAPI, keyAPI, mscCfg, servers, producer, diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 68a06a033..7009230cc 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -273,12 +273,12 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { cfg.Global.ServerName = gomatrixserverlib.ServerName("localhost") cfg.Global.PrivateKey = privKey cfg.Global.JetStream.InMemory = true - base := base.NewBaseDendrite(cfg, "Monolith") + b := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) keyRing := &test.NopJSONVerifier{} // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) - baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true) + federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) + baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 0a3ab7a88..04eb3d067 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -32,6 +32,7 @@ import ( keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -49,8 +50,7 @@ import ( // applied: // nolint: gocyclo func Setup( - fedMux, keyMux, wkMux *mux.Router, - cfg *config.FederationAPI, + base *base.BaseDendrite, rsAPI roomserverAPI.FederationRoomserverAPI, fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, @@ -61,9 +61,16 @@ func Setup( servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, ) { - prometheus.MustRegister( - pduCountTotal, eduCountTotal, - ) + fedMux := base.PublicFederationAPIMux + keyMux := base.PublicKeyAPIMux + wkMux := base.PublicWellKnownAPIMux + cfg := &base.Cfg.FederationAPI + + if base.EnableMetrics { + prometheus.MustRegister( + pduCountTotal, eduCountTotal, + ) + } v2keysmux := keyMux.PathPrefix("/v2").Subrouter() v1fedmux := fedMux.PathPrefix("/v1").Subrouter() diff --git a/setup/base/base.go b/setup/base/base.go index d3adbf53f..ff38209fb 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -264,6 +264,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base // Close implements io.Closer func (b *BaseDendrite) Close() error { + b.ProcessContext.ShutdownDendrite() + b.ProcessContext.WaitForShutdown() return b.tracerCloser.Close() } diff --git a/test/testrig/base.go b/test/testrig/base.go index 7bc26a5c5..52e6ef5f1 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -62,7 +62,12 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f MaxIdleConnections: 2, ConnMaxLifetimeSeconds: 60, } - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close + base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() + close() + } case test.DBTypeSQLite: cfg.Defaults(config.DefaultOpts{ Generate: true, @@ -72,7 +77,10 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() { + base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() // cleanup db files. This risks getting out of sync as we add more database strings :( dbFiles := []config.DataSource{ cfg.FederationAPI.Database.ConnectionString, From a2b486091218e761adc0a00ce19ed4b600e489a2 Mon Sep 17 00:00:00 2001 From: Bernhard Feichtinger <43303168+BieHDC@users.noreply.github.com> Date: Fri, 20 Jan 2023 13:13:36 +0100 Subject: [PATCH 095/124] Fix oversight in cmd/generate-config (#2946) The -dir argument was ignored for media_api->base_path. Signed-off-by: `Bernhard Feichtinger <43303168+BieHDC@users.noreply.github.com>` --- cmd/generate-config/main.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 33b18c471..5f75f5e4d 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -54,6 +54,9 @@ func main() { } else { cfg.Global.DatabaseOptions.ConnectionString = uri } + cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) + cfg.Global.JetStream.StoragePath = config.Path(*dirPath) + cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(*dirPath, "searchindex")) cfg.Logging = []config.LogrusHook{ { Type: "file", From caf310fd7976ed3fe8abbbf8cb72d380c7efd3c2 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 20 Jan 2023 15:18:06 +0100 Subject: [PATCH 096/124] AWSY missing federation tests (#2943) In an attempt to fix the missing AWSY tests and to get to 100% server-server compliance. --- are-we-synapse-yet.list | 10 +- go.mod | 2 +- go.sum | 4 + roomserver/internal/input/input_events.go | 107 +++++++++++----------- sytest-blacklist | 38 +------- sytest-whitelist | 14 ++- 6 files changed, 83 insertions(+), 92 deletions(-) diff --git a/are-we-synapse-yet.list b/are-we-synapse-yet.list index 81c0f8049..585374738 100644 --- a/are-we-synapse-yet.list +++ b/are-we-synapse-yet.list @@ -936,4 +936,12 @@ fst Room state after a rejected message event is the same as before fst Room state after a rejected state event is the same as before fpb Federation publicRoom Name/topic keys are correct fed New federated private chats get full presence information (SYN-115) (10 subtests) -dvk Rejects invalid device keys \ No newline at end of file +dvk Rejects invalid device keys +rmv User can create and send/receive messages in a room with version 10 +rmv local user can join room with version 10 +rmv User can invite local user to room with version 10 +rmv remote user can join room with version 10 +rmv User can invite remote user to room with version 10 +rmv Remote user can backfill in a room with version 10 +rmv Can reject invites over federation for rooms with version 10 +rmv Can receive redactions from regular users over federation in room version 10 \ No newline at end of file diff --git a/go.mod b/go.mod index 2d7174150..a86dd2cb8 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab + github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index b12f65eab..e5cd67bed 100644 --- a/go.sum +++ b/go.sum @@ -350,6 +350,10 @@ github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45 h1:zGrmcm2M4F4f+zk5JXAkw3oHa/zXhOh5XVGBdl7GdPo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 h1:P7me2oCmksST9B4+1I1nA+XrnDQwIqAWmy6ntQrXwc8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 4179fc1ef..67edb3217 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -24,6 +24,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/tidwall/gjson" "github.com/matrix-org/gomatrixserverlib" @@ -40,7 +41,6 @@ import ( "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -166,6 +166,7 @@ func (r *Inputer) processRoomEvent( missingPrev = !input.HasState && len(missingPrevIDs) > 0 } + // If we have missing events (auth or prev), we build a list of servers to ask if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), @@ -200,59 +201,8 @@ func (r *Inputer) processRoomEvent( } } - // First of all, check that the auth events of the event are known. - // If they aren't then we will ask the federation API for them. isRejected := false - authEvents := gomatrixserverlib.NewAuthEvents(nil) - knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.fetchAuthEvents: %w", err) - } - - // Check if the event is allowed by its auth events. If it isn't then - // we consider the event to be "rejected" — it will still be persisted. var rejectionErr error - if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { - isRejected = true - logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) - } - - // Accumulate the auth event NIDs. - authEventIDs := event.AuthEventIDs() - authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) - for _, authEventID := range authEventIDs { - if _, ok := knownEvents[authEventID]; !ok { - // Unknown auth events only really matter if the event actually failed - // auth. If it passed auth then we can assume that everything that was - // known was sufficient, even if extraneous auth events were specified - // but weren't found. - if isRejected { - if event.StateKey() != nil { - return fmt.Errorf( - "missing auth event %s for state event %s (type %q, state key %q)", - authEventID, event.EventID(), event.Type(), *event.StateKey(), - ) - } else { - return fmt.Errorf( - "missing auth event %s for timeline event %s (type %q)", - authEventID, event.EventID(), event.Type(), - ) - } - } - } else { - authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) - } - } - - var softfail bool - if input.Kind == api.KindNew { - // Check that the event passes authentication checks based on the - // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) - if err != nil { - logger.WithError(err).Warn("Error authing soft-failed event") - } - } // At this point we are checking whether we know all of the prev events, and // if we know the state before the prev events. This is necessary before we @@ -314,6 +264,59 @@ func (r *Inputer) processRoomEvent( } } + // Check that the auth events of the event are known. + // If they aren't then we will ask the federation API for them. + authEvents := gomatrixserverlib.NewAuthEvents(nil) + knownEvents := map[string]*types.Event{} + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return fmt.Errorf("r.fetchAuthEvents: %w", err) + } + + // Check if the event is allowed by its auth events. If it isn't then + // we consider the event to be "rejected" — it will still be persisted. + if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + isRejected = true + rejectionErr = err + logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) + } + + // Accumulate the auth event NIDs. + authEventIDs := event.AuthEventIDs() + authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) + for _, authEventID := range authEventIDs { + if _, ok := knownEvents[authEventID]; !ok { + // Unknown auth events only really matter if the event actually failed + // auth. If it passed auth then we can assume that everything that was + // known was sufficient, even if extraneous auth events were specified + // but weren't found. + if isRejected { + if event.StateKey() != nil { + return fmt.Errorf( + "missing auth event %s for state event %s (type %q, state key %q)", + authEventID, event.EventID(), event.Type(), *event.StateKey(), + ) + } else { + return fmt.Errorf( + "missing auth event %s for timeline event %s (type %q)", + authEventID, event.EventID(), event.Type(), + ) + } + } + } else { + authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) + } + } + + var softfail bool + if input.Kind == api.KindNew { + // Check that the event passes authentication checks based on the + // current room state. + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + if err != nil { + logger.WithError(err).Warn("Error authing soft-failed event") + } + } + // Get the state before the event so that we can work out if the event was // allowed at the time, and also to get the history visibility. We won't // bother doing this if the event was already rejected as it just ends up diff --git a/sytest-blacklist b/sytest-blacklist index 99cfbabc8..bb0ee368f 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,54 +1,18 @@ -# Relies on a rejected PL event which will never be accepted into the DAG - -# Caused by - -Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state - -# We don't implement lazy membership loading yet - +# Blacklisted due to https://github.com/matrix-org/matrix-spec/issues/942 The only membership state included in a gapped incremental sync is for senders in the timeline -# Blacklisted out of flakiness after #1479 - -Invited user can reject local invite after originator leaves -Invited user can reject invite for empty room -If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes - -# Blacklisted due to flakiness - -Forgotten room messages cannot be paginated - -# Blacklisted due to flakiness after #1774 - -Local device key changes get to remote servers with correct prev_id - -# we don't support groups - -Remove group category -Remove group role - # Flakey - AS-ghosted users can use rooms themselves AS-ghosted users can use rooms via AS Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server -# More flakey - -Guest users can join guest_access rooms - # This will fail in HTTP API mode, so blacklisted for now - If a device list update goes missing, the server resyncs on the next one # Might be a bug in the test because leaves do appear :-( - Leaves are present in non-gapped incremental syncs -# Below test was passing for the wrong reason, failing correctly since #2858 -New federated private chats get full presence information (SYN-115) - # We don't have any state to calculate m.room.guest_access when accepting invites Guest users can accept invites to private rooms over federation \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 215889a49..1f6ecc29e 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -766,4 +766,16 @@ remote user has tags copied to the new room Local and remote users' homeservers remove a room from their public directory on upgrade Guest users denied access over federation if guest access prohibited Guest users are kicked from guest_access rooms on revocation of guest_access -Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file +Guest users are kicked from guest_access rooms on revocation of guest_access over federation +User can create and send/receive messages in a room with version 10 +local user can join room with version 10 +User can invite local user to room with version 10 +remote user can join room with version 10 +User can invite remote user to room with version 10 +Remote user can backfill in a room with version 10 +Can reject invites over federation for rooms with version 10 +Can receive redactions from regular users over federation in room version 10 +New federated private chats get full presence information (SYN-115) +/state returns M_NOT_FOUND for an outlier +/state_ids returns M_NOT_FOUND for an outlier +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state \ No newline at end of file From 25cb65acdbb6702e84a6bcb6245d6d23d90c2359 Mon Sep 17 00:00:00 2001 From: Catalan Lover <48515417+FSG-Cat@users.noreply.github.com> Date: Fri, 20 Jan 2023 15:41:29 +0100 Subject: [PATCH 097/124] Change Default Room version to 10 (#2933) This PR implements [MSC3904](https://github.com/matrix-org/matrix-spec-proposals/pull/3904). This PR is almost identical to #2781 but this PR is also filed well technically 1 day before the MSC passes FCP but well everyone knows this MSC is expected to have passed FCP on monday so im refiling this change today on saturday as i was doing prep work for monday. I assume that this PR wont be counted as clogging the queue since by the next time i expect to be a work day for this project this PR will be implementing an FCP passed disposition merge MSC. Also as for the lack of tests i belive that this simple change does not need to pass new tests due to that these tests are expected to already have been passed by the successful use of Dendrite with Room version 10 already. ### Pull Request Checklist * [X] I have added tests for PR _or_ I have justified why this PR doesn't need tests. * [X] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) Signed-off-by: Catalan Lover Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com> Co-authored-by: kegsay --- roomserver/version/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roomserver/version/version.go b/roomserver/version/version.go index 729d00a80..c40d8e0f7 100644 --- a/roomserver/version/version.go +++ b/roomserver/version/version.go @@ -23,7 +23,7 @@ import ( // DefaultRoomVersion contains the room version that will, by // default, be used to create new rooms on this server. func DefaultRoomVersion() gomatrixserverlib.RoomVersion { - return gomatrixserverlib.RoomVersionV9 + return gomatrixserverlib.RoomVersionV10 } // RoomVersions returns a map of all known room versions to this From 430932f0f161dd836c98082ff97b57beedec02e6 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 20 Jan 2023 16:20:01 +0100 Subject: [PATCH 098/124] Version 0.11.0 (#2949) --- CHANGES.md | 14 ++++++++++++++ helm/dendrite/Chart.yaml | 4 ++-- helm/dendrite/README.md | 7 ++++--- helm/dendrite/values.yaml | 2 +- internal/version.go | 4 ++-- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index fa8230659..e1f7affb5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,19 @@ # Changelog +## Dendrite 0.11.0 (2023-01-20) + +The last three missing federation API Sytests have been fixed - bringing us to 100% server-server Synapse parity, with client-server parity at 93% 🎉 + +### Features + +* Added `/_dendrite/admin/purgeRoom/{roomID}` to clean up the database +* The default room version was updated to 10 (contributed by [FSG-Cat](https://github.com/FSG-Cat)) + +### Fixes + +* An oversight in the `create-config` binary, which now correctly sets the media path if specified (contributed by [BieHDC](https://github.com/BieHDC)) +* The Helm chart now uses the `$.Chart.AppVersion` as the default image version to pull, with the possibility to override it (contributed by [genofire](https://github.com/genofire)) + ## Dendrite 0.10.9 (2023-01-17) ### Features diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml index 6e6641c8d..174fc5496 100644 --- a/helm/dendrite/Chart.yaml +++ b/helm/dendrite/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: dendrite -version: "0.10.9" -appVersion: "0.10.9" +version: "0.11.0" +appVersion: "0.11.0" description: Dendrite Matrix Homeserver type: application keywords: diff --git a/helm/dendrite/README.md b/helm/dendrite/README.md index cb850d655..6a1658429 100644 --- a/helm/dendrite/README.md +++ b/helm/dendrite/README.md @@ -1,6 +1,6 @@ # dendrite -![Version: 0.10.8](https://img.shields.io/badge/Version-0.10.8-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.10.8](https://img.shields.io/badge/AppVersion-0.10.8-informational?style=flat-square) +![Version: 0.11.0](https://img.shields.io/badge/Version-0.11.0-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.11.0](https://img.shields.io/badge/AppVersion-0.11.0-informational?style=flat-square) Dendrite Matrix Homeserver Status: **NOT PRODUCTION READY** @@ -41,8 +41,9 @@ Create a folder `appservices` and place your configurations in there. The confi | Key | Type | Default | Description | |-----|------|---------|-------------| -| image.name | string | `"ghcr.io/matrix-org/dendrite-monolith:v0.10.8"` | Docker repository/image to use | +| image.repository | string | `"ghcr.io/matrix-org/dendrite-monolith"` | Docker repository/image to use | | image.pullPolicy | string | `"IfNotPresent"` | Kubernetes pullPolicy | +| image.tag | string | `""` | Overrides the image tag whose default is the chart appVersion. | | signing_key.create | bool | `true` | Create a new signing key, if not exists | | signing_key.existingSecret | string | `""` | Use an existing secret | | resources | object | sets some sane default values | Default resource requests/limits. | @@ -144,4 +145,4 @@ Create a folder `appservices` and place your configurations in there. The confi | ingress.annotations | object | `{}` | Extra, custom annotations | | ingress.tls | list | `[]` | | | service.type | string | `"ClusterIP"` | | -| service.port | int | `80` | | +| service.port | int | `8008` | | diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index 87027a886..848241ab6 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -3,7 +3,7 @@ image: repository: "ghcr.io/matrix-org/dendrite-monolith" # -- Kubernetes pullPolicy pullPolicy: IfNotPresent - # Overrides the image tag whose default is the chart appVersion. + # -- Overrides the image tag whose default is the chart appVersion. tag: "" diff --git a/internal/version.go b/internal/version.go index ff31dd784..fbe4a01b0 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 10 - VersionPatch = 9 + VersionMinor = 11 + VersionPatch = 0 VersionTag = "" // example: "rc1" ) From 48fa869fa3578741d1d5775d30f24f6b097ab995 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 23 Jan 2023 13:17:15 +0100 Subject: [PATCH 099/124] Use `t.TempDir` for SQLite databases, so tests don't rip out each others databases (#2950) This should hopefully finally fix issues about `disk I/O error` as seen [here](https://gitlab.alpinelinux.org/alpine/aports/-/jobs/955030/raw) Hopefully this will also fix `SSL accept attempt failed` issues by disabling HTTP keep alives when generating a config for CI. --- cmd/generate-config/main.go | 1 + internal/log.go | 2 ++ internal/log_unix.go | 2 ++ setup/jetstream/helpers.go | 5 +++++ sytest-blacklist | 1 + sytest-whitelist | 7 ++++++- test/db.go | 12 +++++------- test/testrig/base.go | 37 ++++++++++++++----------------------- 8 files changed, 36 insertions(+), 31 deletions(-) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 5f75f5e4d..56a145653 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -70,6 +70,7 @@ func main() { cfg.AppServiceAPI.DisableTLSValidation = true cfg.ClientAPI.RateLimiting.Enabled = false cfg.FederationAPI.DisableTLSValidation = false + cfg.FederationAPI.DisableHTTPKeepalives = true // don't hit matrix.org when running tests!!! cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{} cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) diff --git a/internal/log.go b/internal/log.go index d7e852c81..da6e20418 100644 --- a/internal/log.go +++ b/internal/log.go @@ -24,6 +24,7 @@ import ( "path/filepath" "runtime" "strings" + "sync" "github.com/matrix-org/util" @@ -37,6 +38,7 @@ import ( // this unfortunately results in us adding the same hook multiple times. // This map ensures we only ever add one level hook. var stdLevelLogAdded = make(map[logrus.Level]bool) +var levelLogAddedMu = &sync.Mutex{} type utcFormatter struct { logrus.Formatter diff --git a/internal/log_unix.go b/internal/log_unix.go index b38e7c2e8..8f34c320d 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -85,6 +85,8 @@ func checkSyslogHookParams(params map[string]interface{}) { } func setupStdLogHook(level logrus.Level) { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() if stdLevelLogAdded[level] { return } diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index c1ce9583f..533652160 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -77,6 +77,11 @@ func JetStreamConsumer( // The consumer was deleted so stop. return } else { + // Unfortunately, there's no ErrServerShutdown or similar, so we need to compare the string + if err.Error() == "nats: Server Shutdown" { + logrus.WithContext(ctx).Warn("nats server shutting down") + return + } // Something else went wrong, so we'll panic. sentry.CaptureException(err) logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) diff --git a/sytest-blacklist b/sytest-blacklist index bb0ee368f..49a3cc870 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -7,6 +7,7 @@ AS-ghosted users can use rooms via AS Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server +If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes # This will fail in HTTP API mode, so blacklisted for now If a device list update goes missing, the server resyncs on the next one diff --git a/sytest-whitelist b/sytest-whitelist index 1f6ecc29e..c61e0bc3c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -778,4 +778,9 @@ Can receive redactions from regular users over federation in room version 10 New federated private chats get full presence information (SYN-115) /state returns M_NOT_FOUND for an outlier /state_ids returns M_NOT_FOUND for an outlier -Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state \ No newline at end of file +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state +Invited user can reject invite for empty room +Invited user can reject local invite after originator leaves +Guest users can join guest_access rooms +Forgotten room messages cannot be paginated +Local device key changes get to remote servers with correct prev_id \ No newline at end of file diff --git a/test/db.go b/test/db.go index 17f637e18..54ded6adb 100644 --- a/test/db.go +++ b/test/db.go @@ -22,6 +22,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "testing" "github.com/lib/pq" @@ -103,13 +104,10 @@ func currentUser() string { // TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { - // this will be made in the current working directory which namespaces concurrent package runs correctly - dbname := "dendrite_test.db" + // this will be made in the t.TempDir, which is unique per test + dbname := filepath.Join(t.TempDir(), "dendrite_test.db") return fmt.Sprintf("file:%s", dbname), func() { - err := os.Remove(dbname) - if err != nil { - t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err) - } + t.Cleanup(func() {}) // removes the t.TempDir } } @@ -176,7 +174,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) { for dbName, dbType := range dbs { dbt := dbType t.Run(dbName, func(tt *testing.T) { - //tt.Parallel() + tt.Parallel() testFn(tt, dbt) }) } diff --git a/test/testrig/base.go b/test/testrig/base.go index 52e6ef5f1..9773da223 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -15,18 +15,14 @@ package testrig import ( - "errors" "fmt" - "io/fs" - "os" - "strings" + "path/filepath" "testing" - "github.com/nats-io/nats.go" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/nats-io/nats.go" ) func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) { @@ -77,27 +73,22 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) + + // Use a temp dir provided by go for tests, this will be cleanup by a call to t.CleanUp() + tempDir := t.TempDir() + cfg.FederationAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "federationapi.db")) + cfg.KeyServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "keyserver.db")) + cfg.MSCs.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mscs.db")) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mediaapi.db")) + cfg.RoomServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "roomserver.db")) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "syncapi.db")) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db")) + base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) return base, func() { base.ShutdownDendrite() base.WaitForShutdown() - // cleanup db files. This risks getting out of sync as we add more database strings :( - dbFiles := []config.DataSource{ - cfg.FederationAPI.Database.ConnectionString, - cfg.KeyServer.Database.ConnectionString, - cfg.MSCs.Database.ConnectionString, - cfg.MediaAPI.Database.ConnectionString, - cfg.RoomServer.Database.ConnectionString, - cfg.SyncAPI.Database.ConnectionString, - cfg.UserAPI.AccountDatabase.ConnectionString, - } - for _, fileURI := range dbFiles { - path := strings.TrimPrefix(string(fileURI), "file:") - err := os.Remove(path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - t.Fatalf("failed to cleanup sqlite db '%s': %s", fileURI, err) - } - } + t.Cleanup(func() {}) // removes t.TempDir, where all database files are created } default: t.Fatalf("unknown db type: %v", dbType) From 5b73592f5a4dddf64184fcbe33f4c1835c656480 Mon Sep 17 00:00:00 2001 From: devonh Date: Mon, 23 Jan 2023 17:55:12 +0000 Subject: [PATCH 100/124] Initial Store & Forward Implementation (#2917) This adds store & forward relays into dendrite for p2p. A few things have changed: - new relay api serves new http endpoints for s&f federation - updated outbound federation queueing which will attempt to forward using s&f if appropriate - database entries to track s&f relays for other nodes --- .gitignore | 1 + build/gobind-pinecone/monolith.go | 201 ++++- build/gobind-pinecone/monolith_test.go | 198 +++++ cmd/dendrite-demo-pinecone/ARCHITECTURE.md | 59 ++ cmd/dendrite-demo-pinecone/README.md | 39 + cmd/dendrite-demo-pinecone/main.go | 125 ++- federationapi/api/api.go | 25 +- federationapi/federationapi.go | 5 +- federationapi/internal/api.go | 9 +- .../internal/federationclient_test.go | 202 +++++ federationapi/internal/perform.go | 73 +- federationapi/internal/perform_test.go | 190 ++++ federationapi/inthttp/client.go | 12 + federationapi/queue/destinationqueue.go | 80 +- federationapi/queue/queue.go | 25 +- federationapi/queue/queue_test.go | 436 ++++------ federationapi/routing/profile_test.go | 94 ++ federationapi/routing/query_test.go | 94 ++ federationapi/routing/routing.go | 14 +- federationapi/routing/send.go | 339 +------- federationapi/routing/send_test.go | 605 ++----------- federationapi/statistics/statistics.go | 186 +++- federationapi/statistics/statistics_test.go | 58 +- federationapi/storage/interface.go | 47 +- .../storage/postgres/assumed_offline_table.go | 107 +++ .../storage/postgres/relay_servers_table.go | 137 +++ federationapi/storage/postgres/storage.go | 10 + .../storage/shared/receipt/receipt.go | 42 + federationapi/storage/shared/storage.go | 184 +++- federationapi/storage/shared/storage_edus.go | 29 +- federationapi/storage/shared/storage_pdus.go | 27 +- .../storage/sqlite3/assumed_offline_table.go | 107 +++ .../storage/sqlite3/relay_servers_table.go | 148 ++++ federationapi/storage/sqlite3/storage.go | 13 +- federationapi/storage/storage_test.go | 103 ++- federationapi/storage/tables/interface.go | 27 + .../tables/relay_servers_table_test.go | 224 +++++ go.mod | 20 +- go.sum | 42 +- internal/log.go | 2 + internal/log_unix.go | 4 +- internal/transactionrequest.go | 356 ++++++++ internal/transactionrequest_test.go | 820 ++++++++++++++++++ mediaapi/routing/routing.go | 26 +- relayapi/api/api.go | 56 ++ relayapi/internal/api.go | 53 ++ relayapi/internal/perform.go | 141 +++ relayapi/internal/perform_test.go | 121 +++ relayapi/relayapi.go | 74 ++ relayapi/relayapi_test.go | 154 ++++ relayapi/routing/relaytxn.go | 74 ++ relayapi/routing/relaytxn_test.go | 220 +++++ relayapi/routing/routing.go | 123 +++ relayapi/routing/sendrelay.go | 77 ++ relayapi/routing/sendrelay_test.go | 209 +++++ relayapi/storage/interface.go | 47 + .../postgres/relay_queue_json_table.go | 113 +++ .../storage/postgres/relay_queue_table.go | 156 ++++ relayapi/storage/postgres/storage.go | 64 ++ relayapi/storage/shared/storage.go | 170 ++++ .../storage/sqlite3/relay_queue_json_table.go | 137 +++ relayapi/storage/sqlite3/relay_queue_table.go | 168 ++++ relayapi/storage/sqlite3/storage.go | 64 ++ relayapi/storage/storage.go | 46 + relayapi/storage/tables/interface.go | 66 ++ .../tables/relay_queue_json_table_test.go | 173 ++++ .../storage/tables/relay_queue_table_test.go | 229 +++++ setup/base/base.go | 6 + setup/config/config.go | 5 +- setup/config/config_federationapi.go | 7 + setup/config/config_relayapi.go | 52 ++ setup/config/config_test.go | 29 +- setup/monolith.go | 7 + test/db.go | 1 - test/memory_federation_db.go | 488 +++++++++++ test/memory_relay_db.go | 140 +++ test/testrig/base.go | 4 +- 77 files changed, 7646 insertions(+), 1373 deletions(-) create mode 100644 build/gobind-pinecone/monolith_test.go create mode 100644 cmd/dendrite-demo-pinecone/ARCHITECTURE.md create mode 100644 federationapi/internal/federationclient_test.go create mode 100644 federationapi/internal/perform_test.go create mode 100644 federationapi/routing/profile_test.go create mode 100644 federationapi/routing/query_test.go create mode 100644 federationapi/storage/postgres/assumed_offline_table.go create mode 100644 federationapi/storage/postgres/relay_servers_table.go create mode 100644 federationapi/storage/shared/receipt/receipt.go create mode 100644 federationapi/storage/sqlite3/assumed_offline_table.go create mode 100644 federationapi/storage/sqlite3/relay_servers_table.go create mode 100644 federationapi/storage/tables/relay_servers_table_test.go create mode 100644 internal/transactionrequest.go create mode 100644 internal/transactionrequest_test.go create mode 100644 relayapi/api/api.go create mode 100644 relayapi/internal/api.go create mode 100644 relayapi/internal/perform.go create mode 100644 relayapi/internal/perform_test.go create mode 100644 relayapi/relayapi.go create mode 100644 relayapi/relayapi_test.go create mode 100644 relayapi/routing/relaytxn.go create mode 100644 relayapi/routing/relaytxn_test.go create mode 100644 relayapi/routing/routing.go create mode 100644 relayapi/routing/sendrelay.go create mode 100644 relayapi/routing/sendrelay_test.go create mode 100644 relayapi/storage/interface.go create mode 100644 relayapi/storage/postgres/relay_queue_json_table.go create mode 100644 relayapi/storage/postgres/relay_queue_table.go create mode 100644 relayapi/storage/postgres/storage.go create mode 100644 relayapi/storage/shared/storage.go create mode 100644 relayapi/storage/sqlite3/relay_queue_json_table.go create mode 100644 relayapi/storage/sqlite3/relay_queue_table.go create mode 100644 relayapi/storage/sqlite3/storage.go create mode 100644 relayapi/storage/storage.go create mode 100644 relayapi/storage/tables/interface.go create mode 100644 relayapi/storage/tables/relay_queue_json_table_test.go create mode 100644 relayapi/storage/tables/relay_queue_table_test.go create mode 100644 setup/config/config_relayapi.go create mode 100644 test/memory_federation_db.go create mode 100644 test/memory_relay_db.go diff --git a/.gitignore b/.gitignore index e4f0112c4..fe5e82797 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,7 @@ dendrite.yaml # Database files *.db +*.db-journal # Log files *.log* diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index b8f8111d2..ff61ea6c8 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -41,13 +41,16 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/relayapi" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" userapiAPI "github.com/matrix-org/dendrite/userapi/api" @@ -67,24 +70,27 @@ import ( ) const ( - PeerTypeRemote = pineconeRouter.PeerTypeRemote - PeerTypeMulticast = pineconeRouter.PeerTypeMulticast - PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth - PeerTypeBonjour = pineconeRouter.PeerTypeBonjour + PeerTypeRemote = pineconeRouter.PeerTypeRemote + PeerTypeMulticast = pineconeRouter.PeerTypeMulticast + PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth + PeerTypeBonjour = pineconeRouter.PeerTypeBonjour + relayServerRetryInterval = time.Second * 30 ) type DendriteMonolith struct { - logger logrus.Logger - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - PineconeManager *pineconeConnections.ConnectionManager - StorageDirectory string - CacheDirectory string - listener net.Listener - httpServer *http.Server - processContext *process.ProcessContext - userAPI userapiAPI.UserInternalAPI + logger logrus.Logger + baseDendrite *base.BaseDendrite + PineconeRouter *pineconeRouter.Router + PineconeMulticast *pineconeMulticast.Multicast + PineconeQUIC *pineconeSessions.Sessions + PineconeManager *pineconeConnections.ConnectionManager + StorageDirectory string + CacheDirectory string + listener net.Listener + httpServer *http.Server + userAPI userapiAPI.UserInternalAPI + federationAPI api.FederationInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool } func (m *DendriteMonolith) PublicKey() string { @@ -326,6 +332,7 @@ func (m *DendriteMonolith) Start() { cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(m.StorageDirectory, prefix))) cfg.MediaAPI.BasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(m.StorageDirectory, prefix))) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true @@ -335,9 +342,9 @@ func (m *DendriteMonolith) Start() { panic(err) } - base := base.NewBaseDendrite(cfg, "Monolith") + base := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) + m.baseDendrite = base base.ConfigureAdminEndpoints() - defer base.Close() // nolint: errcheck federation := conn.CreateFederationClient(base, m.PineconeQUIC) @@ -346,11 +353,11 @@ func (m *DendriteMonolith) Start() { rsAPI := roomserver.NewInternalAPI(base) - fsAPI := federationapi.NewInternalAPI( + m.federationAPI = federationapi.NewInternalAPI( base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, m.federationAPI, rsAPI) m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) @@ -358,10 +365,24 @@ func (m *DendriteMonolith) Start() { // The underlying roomserver implementation needs to be able to call the fedsender. // This is different to rsAPI which can be the http client which doesn't need this dependency - rsAPI.SetFederationAPI(fsAPI, keyRing) + rsAPI.SetFederationAPI(m.federationAPI, keyRing) userProvider := users.NewPineconeUserProvider(m.PineconeRouter, m.PineconeQUIC, m.userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, fsAPI, federation) + roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, m.federationAPI, federation) + + js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &base.Cfg.FederationAPI, + UserAPI: m.userAPI, + } + relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer) monolith := setup.Monolith{ Config: base.Cfg, @@ -370,10 +391,11 @@ func (m *DendriteMonolith) Start() { KeyRing: keyRing, AppserviceAPI: asAPI, - FederationAPI: fsAPI, + FederationAPI: m.federationAPI, RoomserverAPI: rsAPI, UserAPI: m.userAPI, KeyAPI: keyAPI, + RelayAPI: relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } @@ -411,8 +433,6 @@ func (m *DendriteMonolith) Start() { Handler: h2c.NewHandler(pMux, h2s), } - m.processContext = base.ProcessContext - go func() { m.logger.Info("Listening on ", cfg.Global.ServerName) @@ -420,7 +440,7 @@ func (m *DendriteMonolith) Start() { case net.ErrClosed, http.ErrServerClosed: m.logger.Info("Stopped listening on ", cfg.Global.ServerName) default: - m.logger.Fatal(err) + m.logger.Error("Stopped listening on ", cfg.Global.ServerName) } }() go func() { @@ -430,33 +450,44 @@ func (m *DendriteMonolith) Start() { case net.ErrClosed, http.ErrServerClosed: m.logger.Info("Stopped listening on ", cfg.Global.ServerName) default: - m.logger.Fatal(err) + m.logger.Error("Stopped listening on ", cfg.Global.ServerName) } }() go func(ch <-chan pineconeEvents.Event) { eLog := logrus.WithField("pinecone", "events") + stopRelayServerSync := make(chan bool) + + relayRetriever := RelayServerRetriever{ + Context: context.Background(), + ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), + FederationAPI: m.federationAPI, + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + RelayAPI: monolith.RelayAPI, + running: *atomic.NewBool(false), + } + relayRetriever.InitializeRelayServers(eLog) for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: + if !relayRetriever.running.Load() { + go relayRetriever.SyncRelayServers(stopRelayServerSync) + } case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: + if relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { + stopRelayServerSync <- true + } case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) + // eLog.Info("Broadcast received from: ", e.PeerID) req := &api.PerformWakeupServersRequest{ ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, } res := &api.PerformWakeupServersResponse{} - if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + if err := m.federationAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } - case pineconeEvents.BandwidthReport: default: } } @@ -464,12 +495,106 @@ func (m *DendriteMonolith) Start() { } func (m *DendriteMonolith) Stop() { - m.processContext.ShutdownDendrite() + m.baseDendrite.Close() + m.baseDendrite.WaitForShutdown() _ = m.listener.Close() m.PineconeMulticast.Stop() _ = m.PineconeQUIC.Close() _ = m.PineconeRouter.Close() - m.processContext.WaitForComponentsToFinish() +} + +type RelayServerRetriever struct { + Context context.Context + ServerName gomatrixserverlib.ServerName + FederationAPI api.FederationInternalAPI + RelayAPI relayServerAPI.RelayInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool + queriedServersMutex sync.Mutex + running atomic.Bool +} + +func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} + response := api.P2PQueryRelayServersResponse{} + err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) + if err != nil { + eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + } + for _, server := range response.RelayServers { + m.relayServersQueried[server] = false + } + + eLog.Infof("Registered relay servers: %v", response.RelayServers) +} + +func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { + defer m.running.Store(false) + + t := time.NewTimer(relayServerRetryInterval) + for { + relayServersToQuery := []gomatrixserverlib.ServerName{} + func() { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + for server, complete := range m.relayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } + } + }() + if len(relayServersToQuery) == 0 { + // All relay servers have been synced. + return + } + m.queryRelayServers(relayServersToQuery) + t.Reset(relayServerRetryInterval) + + select { + case <-stop: + if !t.Stop() { + <-t.C + } + return + case <-t.C: + } + } +} + +func (m *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + + result := map[gomatrixserverlib.ServerName]bool{} + for server, queried := range m.relayServersQueried { + result[server] = queried + } + return result +} + +func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("querying relay servers for any available transactions") + for _, server := range relayServers { + userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) + if err != nil { + return + } + err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) + if err == nil { + func() { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + m.relayServersQueried[server] = true + }() + // TODO : What happens if your relay receives new messages after this point? + // Should you continue to check with them, or should they try and contact you? + // They could send a "new_async_events" message your way maybe? + // Then you could mark them as needing to be queried again. + // What if you miss this message? + // Maybe you should try querying them again after a certain period of time as a backup? + } else { + logrus.Errorf("Failed querying relay server: %s", err.Error()) + } + } } const MaxFrameSize = types.MaxFrameSize diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go new file mode 100644 index 000000000..edcf22bbe --- /dev/null +++ b/build/gobind-pinecone/monolith_test.go @@ -0,0 +1,198 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gobind + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/api" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "gotest.tools/v3/poll" +) + +var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + +type TestNetConn struct { + net.Conn + shouldFail bool +} + +func (t *TestNetConn) Read(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + n := copy(b, TestBuf) + return n, nil + } +} + +func (t *TestNetConn) Write(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + return len(b), nil + } +} + +func (t *TestNetConn) Close() error { + if t.shouldFail { + return fmt.Errorf("Failed") + } else { + return nil + } +} + +func TestConduitStoresPort(t *testing.T) { + conduit := Conduit{port: 7} + assert.Equal(t, 7, conduit.Port()) +} + +func TestConduitRead(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + b := make([]byte, len(TestBuf)) + bytes, err := conduit.Read(b) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) + assert.Equal(t, TestBuf, b) +} + +func TestConduitReadCopy(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + result, err := conduit.ReadCopy() + assert.NoError(t, err) + assert.Equal(t, TestBuf, result) +} + +func TestConduitWrite(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + bytes, err := conduit.Write(TestBuf) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) +} + +func TestConduitClose(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + assert.True(t, conduit.closed.Load()) +} + +func TestConduitReadClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + b := make([]byte, len(TestBuf)) + _, err = conduit.Read(b) + assert.Error(t, err) +} + +func TestConduitReadCopyClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.ReadCopy() + assert.Error(t, err) +} + +func TestConduitWriteClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.Write(TestBuf) + assert.Error(t, err) +} + +func TestConduitReadCopyFails(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{shouldFail: true}} + _, err := conduit.ReadCopy() + assert.Error(t, err) +} + +var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"} + +type FakeFedAPI struct { + api.FederationInternalAPI +} + +func (f *FakeFedAPI) P2PQueryRelayServers(ctx context.Context, req *api.P2PQueryRelayServersRequest, res *api.P2PQueryRelayServersResponse) error { + res.RelayServers = testRelayServers + return nil +} + +type FakeRelayAPI struct { + relayServerAPI.RelayInternalAPI +} + +func (r *FakeRelayAPI) PerformRelayServerSync(ctx context.Context, userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error { + return nil +} + +func TestRelayRetrieverInitialization(t *testing.T) { + retriever := RelayServerRetriever{ + Context: context.Background(), + ServerName: "server", + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + FederationAPI: &FakeFedAPI{}, + RelayAPI: &FakeRelayAPI{}, + } + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) +} + +func TestRelayRetrieverSync(t *testing.T) { + retriever := RelayServerRetriever{ + Context: context.Background(), + ServerName: "server", + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + FederationAPI: &FakeFedAPI{}, + RelayAPI: &FakeRelayAPI{}, + } + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) + + stopRelayServerSync := make(chan bool) + go retriever.SyncRelayServers(stopRelayServerSync) + + check := func(log poll.LogT) poll.Result { + relayServers := retriever.GetQueriedServerStatus() + for _, queried := range relayServers { + if !queried { + return poll.Continue("waiting for all servers to be queried") + } + } + + stopRelayServerSync <- true + return poll.Success() + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestMonolithStarts(t *testing.T) { + monolith := DendriteMonolith{} + monolith.Start() + monolith.PublicKey() + monolith.Stop() +} diff --git a/cmd/dendrite-demo-pinecone/ARCHITECTURE.md b/cmd/dendrite-demo-pinecone/ARCHITECTURE.md new file mode 100644 index 000000000..1b0941053 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/ARCHITECTURE.md @@ -0,0 +1,59 @@ +## Relay Server Architecture + +Relay Servers function similar to the way physical mail drop boxes do. +A node can have many associated relay servers. Matrix events can be sent to them instead of to the destination node, and the destination node will eventually retrieve them from the relay server. +Nodes that want to send events to an offline node need to know what relay servers are associated with their intended destination. +Currently this is manually configured in the dendrite database. In the future this information could be configurable in the app and shared automatically via other means. + +Currently events are sent as complete Matrix Transactions. +Transactions include a list of PDUs, (which contain, among other things, lists of authorization events, previous events, and signatures) a list of EDUs, and other information about the transaction. +There is no additional information sent along with the transaction other than what is typically added to them during Matrix federation today. +In the future this will probably need to change in order to handle more complex room state resolution during p2p usage. + +### Relay Server Architecture + +``` + 0 +--------------------+ + +----------------------------------------+ | P2P Node A | + | Relay Server | | +--------+ | + | | | | Client | | + | +--------------------+ | | +--------+ | + | | Relay Server API | | | | | + | | | | | V | + | .--------. 2 | +-------------+ | | 1 | +------------+ | + | |`--------`| <----- | Forwarder | <------------- | Homeserver | | + | | Database | | +-------------+ | | | +------------+ | + | `----------` | | | +--------------------+ + | ^ | | | + | | 4 | +-------------+ | | + | `------------ | Retriever | <------. +--------------------+ + | | +-------------+ | | | | P2P Node B | + | | | | | | +--------+ | + | +--------------------+ | | | | Client | | + | | | | +--------+ | + +----------------------------------------+ | | | | + | | V | + 3 | | +------------+ | + `------ | Homeserver | | + | +------------+ | + +--------------------+ +``` + +- 0: This relay server is currently only acting on behalf of `P2P Node B`. It will only receive, and later forward events that are destined for `P2P Node B`. +- 1: When `P2P Node A` fails sending directly to `P2P Node B` (after a configurable number of attempts), it checks for any known relay servers associated with `P2P Node B` and sends to all of them. + - If sending to any of the relay servers succeeds, that transaction is considered to be successfully sent. +- 2: The relay server `forwarder` stores the transaction json in its database and marks it as destined for `P2P Node B`. +- 3: When `P2P Node B` comes online, it queries all its relay servers for any missed messages. +- 4: The relay server `retriever` will look in its database for any transactions that are destined for `P2P Node B` and returns them one at a time. + +For now, it is important that we don’t design out a hybrid approach of having both sender-side and recipient-side relay servers. +Both approaches make sense and determining which makes for a better experience depends on the use case. + +#### Sender-Side Relay Servers + +If we are running around truly ad-hoc, and I don't know when or where you will be able to pick up messages, then having a sender designated server makes sense to give things the best chance at making their way to the destination. +But in order to achieve this, you are either relying on p2p presence broadcasts for the relay to know when to try forwarding (which means you are in a pretty small network), or the relay just keeps on periodically attempting to forward to the destination which will lead to a lot of extra traffic on the network. + +#### Recipient-Side Relay Servers + +If we have agreed to some static relay server before going off and doing other things, or if we are talking about more global p2p federation, then having a recipient designated relay server can cut down on redundant traffic since it will sit there idle until the recipient pulls events from it. diff --git a/cmd/dendrite-demo-pinecone/README.md b/cmd/dendrite-demo-pinecone/README.md index d6dd95905..5cacd0924 100644 --- a/cmd/dendrite-demo-pinecone/README.md +++ b/cmd/dendrite-demo-pinecone/README.md @@ -24,3 +24,42 @@ Then point your favourite Matrix client to the homeserver URL`http://localhost: If your peering connection is operational then you should see a `Connected TCP:` line in the log output. If not then try a different peer. Once logged in, you should be able to open the room directory or join a room by its ID. + +## Store & Forward Relays + +To test out the store & forward relay functionality, you need a minimum of 3 instances. +One instance will act as the relay, and the other two instances will be the users trying to communicate. +Then you can send messages between the two nodes and watch as the relay is used if the receiving node is offline. + +### Launching the Nodes + +Relay Server: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir relay/ -listen "[::]:49000" +``` + +Node 1: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir node-1/ -peer "[::]:49000" -port 8007 +``` + +Node 2: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir node-2/ -peer "[::]:49000" -port 8009 +``` + +### Database Setup + +At the moment, the database must be manually configured. +For both `Node 1` and `Node 2` add the following entries to their respective `relay_server` table in the federationapi database: +``` +server_name: {node_1_public_key}, relay_server_name: {relay_public_key} +server_name: {node_2_public_key}, relay_server_name: {relay_public_key} +``` + +After editing the database you will need to relaunch the nodes for the changes to be picked up by dendrite. + +### Testing + +Now you can run two separate instances of element and connect them to `Node 1` and `Node 2`. +You can shutdown one of the nodes and continue sending messages. If you wait long enough, the message will be sent to the relay server. (you can see this in the log output of the relay server) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 3f627b41d..a813c37a2 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -38,16 +38,21 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/relayapi" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" + "go.uber.org/atomic" pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" @@ -66,6 +71,8 @@ var ( instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") ) +const relayServerRetryInterval = time.Second * 30 + // nolint:gocyclo func main() { flag.Parse() @@ -139,6 +146,7 @@ func main() { cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName))) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName))) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) cfg.ClientAPI.RegistrationDisabled = false @@ -224,6 +232,20 @@ func main() { userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation) roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation) + js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &base.Cfg.FederationAPI, + UserAPI: userAPI, + } + relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer) + monolith := setup.Monolith{ Config: base.Cfg, Client: conn.CreateClient(base, pQUIC), @@ -235,6 +257,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, KeyAPI: keyAPI, + RelayAPI: relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } @@ -305,27 +328,38 @@ func main() { go func(ch <-chan pineconeEvents.Event) { eLog := logrus.WithField("pinecone", "events") + relayServerSyncRunning := atomic.NewBool(false) + stopRelayServerSync := make(chan bool) + + m := RelayServerRetriever{ + Context: context.Background(), + ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()), + FederationAPI: fsAPI, + RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + RelayAPI: monolith.RelayAPI, + } + m.InitializeRelayServers(eLog) for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: + if !relayServerSyncRunning.Load() { + go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning) + } case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: + if relayServerSyncRunning.Load() && pRouter.TotalPeerCount() == 0 { + stopRelayServerSync <- true + } case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) + // eLog.Info("Broadcast received from: ", e.PeerID) req := &api.PerformWakeupServersRequest{ ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, } res := &api.PerformWakeupServersResponse{} if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } - case pineconeEvents.BandwidthReport: default: } } @@ -333,3 +367,78 @@ func main() { base.WaitForShutdown() } + +type RelayServerRetriever struct { + Context context.Context + ServerName gomatrixserverlib.ServerName + FederationAPI api.FederationInternalAPI + RelayServersQueried map[gomatrixserverlib.ServerName]bool + RelayAPI relayServerAPI.RelayInternalAPI +} + +func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} + response := api.P2PQueryRelayServersResponse{} + err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) + if err != nil { + eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + } + for _, server := range response.RelayServers { + m.RelayServersQueried[server] = false + } + + eLog.Infof("Registered relay servers: %v", response.RelayServers) +} + +func (m *RelayServerRetriever) syncRelayServers(stop <-chan bool, running atomic.Bool) { + defer running.Store(false) + + t := time.NewTimer(relayServerRetryInterval) + for { + relayServersToQuery := []gomatrixserverlib.ServerName{} + for server, complete := range m.RelayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } + } + if len(relayServersToQuery) == 0 { + // All relay servers have been synced. + return + } + m.queryRelayServers(relayServersToQuery) + t.Reset(relayServerRetryInterval) + + select { + case <-stop: + // We have been asked to stop syncing, drain the timer and return. + if !t.Stop() { + <-t.C + } + return + case <-t.C: + // The timer has expired. Continue to the next loop iteration. + } + } +} + +func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("querying relay servers for any available transactions") + for _, server := range relayServers { + userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) + if err != nil { + return + } + err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) + if err == nil { + m.RelayServersQueried[server] = true + // TODO : What happens if your relay receives new messages after this point? + // Should you continue to check with them, or should they try and contact you? + // They could send a "new_async_events" message your way maybe? + // Then you could mark them as needing to be queried again. + // What if you miss this message? + // Maybe you should try querying them again after a certain period of time as a backup? + } else { + logrus.Errorf("Failed querying relay server: %s", err.Error()) + } + } +} diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 50d0339e4..417b08521 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -18,6 +18,7 @@ type FederationInternalAPI interface { gomatrixserverlib.KeyDatabase ClientFederationAPI RoomserverFederationAPI + P2PFederationAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) @@ -30,7 +31,6 @@ type FederationInternalAPI interface { request *PerformBroadcastEDURequest, response *PerformBroadcastEDUResponse, ) error - PerformWakeupServers( ctx context.Context, request *PerformWakeupServersRequest, @@ -71,6 +71,15 @@ type RoomserverFederationAPI interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationAPI interface { + // Relay Server sync api used in the pinecone demos. + P2PQueryRelayServers( + ctx context.Context, + request *P2PQueryRelayServersRequest, + response *P2PQueryRelayServersResponse, + ) error +} + // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError @@ -82,6 +91,7 @@ type KeyserverFederationAPI interface { // an interface for gmsl.FederationClient - contains functions called by federationapi only. type FederationClient interface { + P2PFederationClient gomatrixserverlib.KeyClient SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) @@ -110,6 +120,11 @@ type FederationClient interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationClient interface { + P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) + P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetRelayTransaction, err error) +} + // FederationClientError is returned from FederationClient methods in the event of a problem. type FederationClientError struct { Err string @@ -233,3 +248,11 @@ type InputPublicKeysRequest struct { type InputPublicKeysResponse struct { } + +type P2PQueryRelayServersRequest struct { + Server gomatrixserverlib.ServerName +} + +type P2PQueryRelayServersResponse struct { + RelayServers []gomatrixserverlib.ServerName +} diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index ce0ce98e9..ed9a545d6 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -113,7 +113,10 @@ func NewInternalAPI( _ = federationDB.RemoveAllServersFromBlacklist() } - stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) + stats := statistics.NewStatistics( + federationDB, + cfg.FederationMaxRetries+1, + cfg.P2PFederationRetriesUntilAssumedOffline+1) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 14056eafc..99773a750 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -109,13 +109,14 @@ func NewFederationInternalAPI( func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) { stats := a.statistics.ForServer(s) - until, blacklisted := stats.BackoffInfo() - if blacklisted { + if stats.Blacklisted() { return stats, &api.FederationClientError{ Blacklisted: true, } } + now := time.Now() + until := stats.BackoffInfo() if until != nil && now.Before(*until) { return stats, &api.FederationClientError{ RetryAfter: time.Until(*until), @@ -163,7 +164,7 @@ func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( RetryAfter: retryAfter, } } - stats.Success() + stats.Success(statistics.SendDirect) return res, nil } @@ -171,7 +172,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted( s gomatrixserverlib.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats := a.statistics.ForServer(s) - if _, blacklisted := stats.BackoffInfo(); blacklisted { + if blacklisted := stats.Blacklisted(); blacklisted { return stats, &api.FederationClientError{ Err: fmt.Sprintf("server %q is blacklisted", s), Blacklisted: true, diff --git a/federationapi/internal/federationclient_test.go b/federationapi/internal/federationclient_test.go new file mode 100644 index 000000000..49137e2d8 --- /dev/null +++ b/federationapi/internal/federationclient_test.go @@ -0,0 +1,202 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 +) + +func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) { + t.queryKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespQueryKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespQueryKeys{}, nil +} + +func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) { + t.claimKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespClaimKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespClaimKeys{}, nil +} + +func TestFederationClientQueryKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysFailure(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{shouldFail: true} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientClaimKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.claimKeysCalled) +} + +func TestFederationClientClaimKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.claimKeysCalled) +} diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index d86d07e03..552942f28 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/consumers" + "github.com/matrix-org/dendrite/federationapi/statistics" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -24,6 +25,10 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( request *api.PerformDirectoryLookupRequest, response *api.PerformDirectoryLookupResponse, ) (err error) { + if !r.shouldAttemptDirectFederation(request.ServerName) { + return fmt.Errorf("relay servers have no meaningful response for directory lookup.") + } + dir, err := r.federation.LookupRoomAlias( ctx, r.cfg.Matrix.ServerName, @@ -36,7 +41,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( } response.RoomID = dir.RoomID response.ServerNames = dir.Servers - r.statistics.ForServer(request.ServerName).Success() + r.statistics.ForServer(request.ServerName).Success(statistics.SendDirect) return nil } @@ -144,6 +149,10 @@ func (r *FederationInternalAPI) performJoinUsingServer( supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for join.") + } + _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) if err != nil { return err @@ -164,7 +173,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.MakeJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" @@ -219,7 +228,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.SendJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // If the remote server returned an event in the "event" key of // the send_join request then we should use that instead. It may @@ -407,6 +416,10 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( serverName gomatrixserverlib.ServerName, supportedVersions []gomatrixserverlib.RoomVersion, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for outbound peek.") + } + // create a unique ID for this peek. // for now we just use the room ID again. In future, if we ever // support concurrent peeks to the same room with different filters @@ -446,7 +459,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.Peek: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Work out if we support the room version that has been supplied in // the peek response. @@ -516,6 +529,10 @@ func (r *FederationInternalAPI) PerformLeave( // Try each server that we were provided until we land on one that // successfully completes the make-leave send-leave dance. for _, serverName := range request.ServerNames { + if !r.shouldAttemptDirectFederation(serverName) { + continue + } + // Try to perform a make_leave using the information supplied in the // request. respMakeLeave, err := r.federation.MakeLeave( @@ -585,7 +602,7 @@ func (r *FederationInternalAPI) PerformLeave( continue } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) return nil } @@ -616,6 +633,12 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } + // TODO (devon): This should be allowed via a relay. Currently only transactions + // can be sent to relays. Would need to extend relays to handle invites. + if !r.shouldAttemptDirectFederation(destination) { + return fmt.Errorf("relay servers have no meaningful response for invite.") + } + logrus.WithFields(logrus.Fields{ "event_id": request.Event.EventID(), "user_id": *request.Event.StateKey(), @@ -682,12 +705,8 @@ func (r *FederationInternalAPI) PerformWakeupServers( func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { for _, srv := range destinations { - // Check the statistics cache for the blacklist status to prevent hitting - // the database unnecessarily. - if r.queues.IsServerBlacklisted(srv) { - _ = r.db.RemoveServerFromBlacklist(srv) - } - r.queues.RetryServer(srv) + wasBlacklisted := r.statistics.ForServer(srv).MarkServerAlive() + r.queues.RetryServer(srv, wasBlacklisted) } } @@ -719,7 +738,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { return fmt.Errorf("auth chain response is missing m.room.create event") } -func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion { +func setDefaultRoomVersionFromJoinEvent( + joinEvent gomatrixserverlib.EventBuilder, +) gomatrixserverlib.RoomVersion { // if auth events are not event references we know it must be v3+ // we have to do these shenanigans to satisfy sytest, specifically for: // "Outbound federation rejects m.room.create events with an unknown room version" @@ -802,3 +823,31 @@ func federatedAuthProvider( return returning, nil } } + +// P2PQueryRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PQueryRelayServers( + ctx context.Context, + request *api.P2PQueryRelayServersRequest, + response *api.P2PQueryRelayServersResponse, +) error { + logrus.Infof("Getting relay servers for: %s", request.Server) + relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server) + if err != nil { + return err + } + + response.RelayServers = relayServers + return nil +} + +func (r *FederationInternalAPI) shouldAttemptDirectFederation( + destination gomatrixserverlib.ServerName, +) bool { + var shouldRelay bool + stats := r.statistics.ForServer(destination) + if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 { + shouldRelay = true + } + + return !shouldRelay +} diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go new file mode 100644 index 000000000..e8e0d00a3 --- /dev/null +++ b/federationapi/internal/perform_test.go @@ -0,0 +1,190 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + api.FederationClient + queryKeysCalled bool + claimKeysCalled bool + shouldFail bool +} + +func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return gomatrixserverlib.RespDirectory{}, nil +} + +func TestPerformWakeupServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.AddServerToBlacklist(server) + testDB.SetServerAssumedOffline(context.Background(), server) + blacklisted, err := testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.True(t, blacklisted) + offline, err := testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.True(t, offline) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{server}, + } + res := api.PerformWakeupServersResponse{} + err = fedAPI.PerformWakeupServers(context.Background(), &req, &res) + assert.NoError(t, err) + + blacklisted, err = testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.False(t, blacklisted) + offline, err = testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.False(t, offline) +} + +func TestQueryRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PQueryRelayServersRequest{ + Server: server, + } + res := api.P2PQueryRelayServersResponse{} + err = fedAPI.P2PQueryRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + assert.Equal(t, len(relayServers), len(res.RelayServers)) +} + +func TestPerformDirectoryLookup(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: "server", + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.NoError(t, err) +} + +func TestPerformDirectoryLookupRelaying(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.SetServerAssumedOffline(context.Background(), server) + testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"}) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: server, + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: server, + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.Error(t, err) +} diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 6eefdc7cd..6130a567d 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -24,6 +24,7 @@ const ( FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" + FederationAPIQueryRelayServers = "/federationapi/queryRelayServers" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -510,3 +511,14 @@ func (h *httpFederationInternalAPI) QueryPublicKeys( h.httpClient, ctx, request, response, ) } + +func (h *httpFederationInternalAPI) P2PQueryRelayServers( + ctx context.Context, + request *api.P2PQueryRelayServersRequest, + response *api.P2PQueryRelayServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryRelayServers", h.federationAPIURL+FederationAPIQueryRelayServers, + h.httpClient, ctx, request, response, + ) +} diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a4a87fe99..51350916d 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -29,7 +29,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -70,7 +70,7 @@ type destinationQueue struct { // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return @@ -84,8 +84,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re oq.pendingMutex.Lock() if len(oq.pendingPDUs) < maxPDUsInMemory { oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: receipt, + pdu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re // sendEDU adds the EDU event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return @@ -115,8 +115,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.pendingMutex.Lock() if len(oq.pendingEDUs) < maxEDUsInMemory { oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: receipt, + edu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -210,10 +210,10 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotPDUs := map[string]struct{}{} gotEDUs := map[string]struct{}{} for _, pdu := range oq.pendingPDUs { - gotPDUs[pdu.receipt.String()] = struct{}{} + gotPDUs[pdu.dbReceipt.String()] = struct{}{} } for _, edu := range oq.pendingEDUs { - gotEDUs[edu.receipt.String()] = struct{}{} + gotEDUs[edu.dbReceipt.String()] = struct{}{} } overflowed := false @@ -371,7 +371,7 @@ func (oq *destinationQueue) backgroundSend() { // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. - terr := oq.nextTransaction(toSendPDUs, toSendEDUs) + terr, sendMethod := oq.nextTransaction(toSendPDUs, toSendEDUs) if terr != nil { // We failed to send the transaction. Mark it as a failure. _, blacklisted := oq.statistics.Failure() @@ -388,18 +388,19 @@ func (oq *destinationQueue) backgroundSend() { return } } else { - oq.handleTransactionSuccess(pduCount, eduCount) + oq.handleTransactionSuccess(pduCount, eduCount, sendMethod) } } } // nextTransaction creates a new transaction from the pending event // queue and sends it. -// Returns an error if the transaction wasn't sent. +// Returns an error if the transaction wasn't sent. And whether the success +// was to a relay server or not. func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) error { +) (err error, sendMethod statistics.SendMethod) { // Create the transaction. t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) @@ -407,7 +408,37 @@ func (oq *destinationQueue) nextTransaction( // Try to send the transaction to the destination server. ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) defer cancel() - _, err := oq.client.SendTransaction(ctx, t) + + relayServers := oq.statistics.KnownRelayServers() + if oq.statistics.AssumedOffline() && len(relayServers) > 0 { + sendMethod = statistics.SendViaRelay + relaySuccess := false + logrus.Infof("Sending to relay servers: %v", relayServers) + // TODO : how to pass through actual userID here?!?!?!?! + userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) + if userErr != nil { + return userErr, sendMethod + } + + // Attempt sending to each known relay server. + for _, relayServer := range relayServers { + _, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer) + if relayErr != nil { + err = relayErr + } else { + // If sending to one of the relay servers succeeds, consider the send successful. + relaySuccess = true + } + } + + // Clear the error if sending to any of the relay servers succeeded. + if relaySuccess { + err = nil + } + } else { + sendMethod = statistics.SendDirect + _, err = oq.client.SendTransaction(ctx, t) + } switch errResponse := err.(type) { case nil: // Clean up the transaction in the database. @@ -427,7 +458,7 @@ func (oq *destinationQueue) nextTransaction( oq.transactionIDMutex.Lock() oq.transactionID = "" oq.transactionIDMutex.Unlock() - return nil + return nil, sendMethod case gomatrix.HTTPError: // Report that we failed to send the transaction and we // will retry again, subject to backoff. @@ -437,13 +468,13 @@ func (oq *destinationQueue) nextTransaction( // to a 400-ish error code := errResponse.Code logrus.Debug("Transaction failed with HTTP", code) - return err + return err, sendMethod default: logrus.WithFields(logrus.Fields{ "destination": oq.destination, logrus.ErrorKey: err, }).Debugf("Failed to send transaction %q", t.TransactionID) - return err + return err, sendMethod } } @@ -453,7 +484,7 @@ func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) createTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { +) (gomatrixserverlib.Transaction, []*receipt.Receipt, []*receipt.Receipt) { // If there's no projected transaction ID then generate one. If // the transaction succeeds then we'll set it back to "" so that // we generate a new one next time. If it fails, we'll preserve @@ -474,8 +505,8 @@ func (oq *destinationQueue) createTransaction( t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.TransactionID = oq.transactionID - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt + var pduReceipts []*receipt.Receipt + var eduReceipts []*receipt.Receipt // Go through PDUs that we retrieved from the database, if any, // and add them into the transaction. @@ -487,7 +518,7 @@ func (oq *destinationQueue) createTransaction( // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) + pduReceipts = append(pduReceipts, pdu.dbReceipt) } // Do the same for pending EDUS in the queue. @@ -497,7 +528,7 @@ func (oq *destinationQueue) createTransaction( continue } t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) + eduReceipts = append(eduReceipts, edu.dbReceipt) } return t, pduReceipts, eduReceipts @@ -530,10 +561,11 @@ func (oq *destinationQueue) blacklistDestination() { // handleTransactionSuccess updates the cached event queues as well as the success and // backoff information for this server. -func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) { +func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int, sendMethod statistics.SendMethod) { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. - oq.statistics.Success() + + oq.statistics.Success(sendMethod) oq.pendingMutex.Lock() defer oq.pendingMutex.Unlock() diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 75b1b36be..5d6b8d44c 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -30,7 +30,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -138,13 +138,13 @@ func NewOutgoingQueues( } type queuedPDU struct { - receipt *shared.Receipt - pdu *gomatrixserverlib.HeaderedEvent + dbReceipt *receipt.Receipt + pdu *gomatrixserverlib.HeaderedEvent } type queuedEDU struct { - receipt *shared.Receipt - edu *gomatrixserverlib.EDU + dbReceipt *receipt.Receipt + edu *gomatrixserverlib.EDU } func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { @@ -374,24 +374,13 @@ func (oqs *OutgoingQueues) SendEDU( return nil } -// IsServerBlacklisted returns whether or not the provided server is currently -// blacklisted. -func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool { - return oqs.statistics.ForServer(srv).Blacklisted() -} - // RetryServer attempts to resend events to the given server if we had given up. -func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { +func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName, wasBlacklisted bool) { if oqs.disabled { return } - serverStatistics := oqs.statistics.ForServer(srv) - forceWakeup := serverStatistics.Blacklisted() - serverStatistics.RemoveBlacklist() - serverStatistics.ClearBackoff() - if queue := oqs.getQueue(srv); queue != nil { - queue.wakeQueueIfEventsPending(forceWakeup) + queue.wakeQueueIfEventsPending(wasBlacklisted) } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index c317edc21..36e2ccbc2 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - "sync" "testing" "time" @@ -26,13 +25,11 @@ import ( "gotest.tools/v3/poll" "github.com/matrix-org/gomatrixserverlib" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" @@ -57,7 +54,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } else { // Fake Database - db := createDatabase() + db := test.NewInMemoryFederationDatabase() b := struct { ProcessContext *process.ProcessContext }{ProcessContext: process.NewProcessContext()} @@ -65,220 +62,6 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } -func createDatabase() storage.Database { - return &fakeDatabase{ - pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), - pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), - associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - } -} - -type fakeDatabase struct { - storage.Database - dbMutex sync.Mutex - pendingPDUServers map[gomatrixserverlib.ServerName]struct{} - pendingEDUServers map[gomatrixserverlib.ServerName]struct{} - blacklistedServers map[gomatrixserverlib.ServerName]struct{} - pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent - pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU - associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} - associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} -} - -var nidMutex sync.Mutex -var nid = int64(0) - -func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal([]byte(js), &event); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingPDUs[&receipt] = &event - return &receipt, nil - } - - var edu gomatrixserverlib.EDU - if err := json.Unmarshal([]byte(js), &edu); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingEDUs[&receipt] = &edu - return &receipt, nil - } - - return nil, errors.New("Failed to determine type of json to store") -} - -func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - pduCount := 0 - pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) - if receipts, ok := d.associatedPDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingPDUs[receipt]; ok { - pdus[receipt] = event - pduCount++ - if pduCount == limit { - break - } - } - } - } - return pdus, nil -} - -func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - eduCount := 0 - edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) - if receipts, ok := d.associatedEDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingEDUs[receipt]; ok { - edus[receipt] = event - eduCount++ - if eduCount == limit { - break - } - } - } - } - return edus, nil -} - -func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingPDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedPDUs[destination]; !ok { - d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedPDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("PDU doesn't exist") - } -} - -func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingEDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedEDUs[destination]; !ok { - d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedEDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("EDU doesn't exist") - } -} - -func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if pdus, ok := d.associatedPDUs[serverName]; ok { - for _, receipt := range receipts { - delete(pdus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if edus, ok := d.associatedEDUs[serverName]; ok { - for _, receipt := range receipts { - delete(edus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingPDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingEDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers[serverName] = struct{}{} - return nil -} - -func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - delete(d.blacklistedServers, serverName) - return nil -} - -func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) - return nil -} - -func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - isBlacklisted := false - if _, ok := d.blacklistedServers[serverName]; ok { - isBlacklisted = true - } - - return isBlacklisted, nil -} - type stubFederationRoomServerAPI struct { rsapi.FederationRoomserverAPI } @@ -290,8 +73,10 @@ func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Cont type stubFederationClient struct { api.FederationClient - shouldTxSucceed bool - txCount atomic.Uint32 + shouldTxSucceed bool + shouldTxRelaySucceed bool + txCount atomic.Uint32 + txRelayCount atomic.Uint32 } func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { @@ -304,6 +89,16 @@ func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixse return gomatrixserverlib.RespSend{}, result } +func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) { + var result error + if !f.shouldTxRelaySucceed { + result = fmt.Errorf("relay transaction failed") + } + + f.txRelayCount.Add(1) + return gomatrixserverlib.EmptyResp{}, result +} + func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { t.Helper() content := `{"type":"m.room.message"}` @@ -319,15 +114,18 @@ func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} } -func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { +func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, shouldTxRelaySucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) fc := &stubFederationClient{ - shouldTxSucceed: shouldTxSucceed, - txCount: *atomic.NewUint32(0), + shouldTxSucceed: shouldTxSucceed, + shouldTxRelaySucceed: shouldTxRelaySucceed, + txCount: *atomic.NewUint32(0), + txRelayCount: *atomic.NewUint32(0), } rs := &stubFederationRoomServerAPI{} - stats := statistics.NewStatistics(db, failuresUntilBlacklist) + + stats := statistics.NewStatistics(db, failuresUntilBlacklist, failuresUntilAssumedOffline) signingInfo := []*gomatrixserverlib.SigningIdentity{ { KeyID: "ed21019:auto", @@ -344,7 +142,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -373,7 +171,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -402,7 +200,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -432,7 +230,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -462,7 +260,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -513,7 +311,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -564,7 +362,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -596,7 +394,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -628,7 +426,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -662,7 +460,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -696,7 +494,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -730,8 +528,8 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -747,7 +545,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -781,8 +579,8 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -801,7 +599,7 @@ func TestSendPDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -845,7 +643,7 @@ func TestSendEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -889,7 +687,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -940,7 +738,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -978,7 +776,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { destination := gomatrixserverlib.ServerName("remotehost") destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. defer close() defer func() { @@ -1023,8 +821,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) assert.NoError(t, dbErrPDU) @@ -1038,3 +836,147 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) }) } + +func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} + +func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go new file mode 100644 index 000000000..763656081 --- /dev/null +++ b/federationapi/routing/profile_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeUserAPI struct { + userAPI.FederationUserAPI +} + +func (u *fakeUserAPI) QueryProfile(ctx context.Context, req *userAPI.QueryProfileRequest, res *userAPI.QueryProfileResponse) error { + return nil +} + +func TestHandleQueryProfile(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin))) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go new file mode 100644 index 000000000..21f35bf0c --- /dev/null +++ b/federationapi/routing/query_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedclient "github.com/matrix-org/dendrite/federationapi/api" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeFedClient struct { + fedclient.FederationClient +} + +func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return +} + +func TestHandleQueryDirectory(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server")) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 04eb3d067..5eb30c6ec 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -41,6 +41,12 @@ import ( "github.com/sirupsen/logrus" ) +const ( + SendRouteName = "Send" + QueryDirectoryRouteName = "QueryDirectory" + QueryProfileRouteName = "QueryProfile" +) + // Setup registers HTTP handlers with the given ServeMux. // The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly // path unescape twice (once from the router, once from MakeFedAPI). We need to have this enabled @@ -68,7 +74,7 @@ func Setup( if base.EnableMetrics { prometheus.MustRegister( - pduCountTotal, eduCountTotal, + internal.PDUCountTotal, internal.EDUCountTotal, ) } @@ -138,7 +144,7 @@ func Setup( cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer, ) }, - )).Methods(http.MethodPut, http.MethodOptions) + )).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName) v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -248,7 +254,7 @@ func Setup( httpReq, federation, cfg, rsAPI, fsAPI, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryDirectoryRouteName) v1fedmux.Handle("/query/profile", MakeFedAPI( "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -257,7 +263,7 @@ func Setup( httpReq, userAPI, cfg, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryProfileRouteName) v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index a146d85bd..67b513c90 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -17,26 +17,20 @@ package routing import ( "context" "encoding/json" - "fmt" "net/http" "sync" "time" - "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" - "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" - syncTypes "github.com/matrix-org/dendrite/syncapi/types" ) const ( @@ -56,26 +50,6 @@ const ( MetricsWorkMissingPrevEvents = "missing_prev_events" ) -var ( - pduCountTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_pdus", - Help: "Number of incoming PDUs from remote servers with labels for success", - }, - []string{"status"}, // 'success' or 'total' - ) - eduCountTotal = prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_edus", - Help: "Number of incoming EDUs from remote servers", - }, - ) -) - var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse // Send implements /_matrix/federation/v1/send/{txnID} @@ -123,18 +97,6 @@ func Send( defer close(ch) defer inFlightTxnsPerOrigin.Delete(index) - t := txnReq{ - rsAPI: rsAPI, - keys: keys, - ourServerName: cfg.Matrix.ServerName, - federation: federation, - servers: servers, - keyAPI: keyAPI, - roomsMu: mu, - producer: producer, - inboundPresenceEnabled: cfg.Matrix.Presence.EnableInbound, - } - var txnEvents struct { PDUs []json.RawMessage `json:"pdus"` EDUs []gomatrixserverlib.EDU `json:"edus"` @@ -155,16 +117,23 @@ func Send( } } - // TODO: Really we should have a function to convert FederationRequest to txnReq - t.PDUs = txnEvents.PDUs - t.EDUs = txnEvents.EDUs - t.Origin = request.Origin() - t.TransactionID = txnID - t.Destination = cfg.Matrix.ServerName + t := internal.NewTxnReq( + rsAPI, + keyAPI, + cfg.Matrix.ServerName, + keys, + mu, + producer, + cfg.Matrix.Presence.EnableInbound, + txnEvents.PDUs, + txnEvents.EDUs, + request.Origin(), + txnID, + cfg.Matrix.ServerName) util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction(httpReq.Context()) + resp, jsonErr := t.ProcessTransaction(httpReq.Context()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -181,283 +150,3 @@ func Send( ch <- res return res } - -type txnReq struct { - gomatrixserverlib.Transaction - rsAPI api.FederationRoomserverAPI - keyAPI keyapi.FederationKeyAPI - ourServerName gomatrixserverlib.ServerName - keys gomatrixserverlib.JSONVerifier - federation txnFederationClient - roomsMu *internal.MutexByRoom - servers federationAPI.ServersInRoomProvider - producer *producers.SyncAPIProducer - inboundPresenceEnabled bool -} - -// A subset of FederationClient functionality that txn requires. Useful for testing. -type txnFederationClient interface { - LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, - ) - LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - t.processEDUs(ctx) - }() - - results := make(map[string]gomatrixserverlib.PDUResult) - roomVersions := make(map[string]gomatrixserverlib.RoomVersion) - getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { - if v, ok := roomVersions[roomID]; ok { - return v - } - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) - return "" - } - roomVersions[roomID] = verRes.RoomVersion - return verRes.RoomVersion - } - - for _, pdu := range t.PDUs { - pduCountTotal.WithLabelValues("total").Inc() - var header struct { - RoomID string `json:"room_id"` - } - if err := json.Unmarshal(pdu, &header); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") - // We don't know the event ID at this point so we can't return the - // failure in the PDU results - continue - } - roomVersion := getRoomVersion(header.RoomID) - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) - if err != nil { - if _, ok := err.(gomatrixserverlib.BadJSONError); ok { - // Room version 6 states that homeservers should strictly enforce canonical JSON - // on PDUs. - // - // This enforces that the entire transaction is rejected if a single bad PDU is - // sent. It is unclear if this is the correct behaviour or not. - // - // See https://github.com/matrix-org/synapse/issues/7543 - return nil, &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("PDU contains bad JSON"), - } - } - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) - continue - } - if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { - continue - } - if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: "Forbidden by server ACLs", - } - continue - } - if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - // pass the event to the roomserver which will do auth checks - // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently - // discarded by the caller of this function - if err = api.SendEvents( - ctx, - t.rsAPI, - api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - event.Headered(roomVersion), - }, - t.Destination, - t.Origin, - api.DoNotSendToOtherServers, - nil, - true, - ); err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - results[event.EventID()] = gomatrixserverlib.PDUResult{} - pduCountTotal.WithLabelValues("success").Inc() - } - - wg.Wait() - return &gomatrixserverlib.RespSend{PDUs: results}, nil -} - -// nolint:gocyclo -func (t *txnReq) processEDUs(ctx context.Context) { - for _, e := range t.EDUs { - eduCountTotal.Inc() - switch e.Type { - case gomatrixserverlib.MTyping: - // https://matrix.org/docs/spec/server_server/latest#typing-notifications - var typingPayload struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` - Typing bool `json:"typing"` - } - if err := json.Unmarshal(e.Content, &typingPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") - } - case gomatrixserverlib.MDirectToDevice: - // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema - var directPayload gomatrixserverlib.ToDeviceMessage - if err := json.Unmarshal(e.Content, &directPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - for userID, byUser := range directPayload.Messages { - for deviceID, message := range byUser { - // TODO: check that the user and the device actually exist here - if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": directPayload.Sender, - "user_id": userID, - "device_id": deviceID, - }).Error("Failed to send send-to-device event to JetStream") - } - } - } - case gomatrixserverlib.MDeviceListUpdate: - if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") - } - case gomatrixserverlib.MReceipt: - // https://matrix.org/docs/spec/server_server/r0.1.4#receipts - payload := map[string]types.FederationReceiptMRead{} - - if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") - continue - } - - for roomID, receipt := range payload { - for userID, mread := range receipt.User { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") - continue - } - if t.Origin != domain { - util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) - continue - } - if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": t.Origin, - "user_id": userID, - "room_id": roomID, - "events": mread.EventIDs, - }).Error("Failed to send receipt event to JetStream") - continue - } - } - } - case types.MSigningKeyUpdate: - if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - logrus.WithError(err).Errorf("Failed to process signing key update") - } - case gomatrixserverlib.MPresence: - if t.inboundPresenceEnabled { - if err := t.processPresence(ctx, e); err != nil { - logrus.WithError(err).Errorf("Failed to process presence update") - } - } - default: - util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") - } - } -} - -// processPresence handles m.receipt events -func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { - payload := types.Presence{} - if err := json.Unmarshal(e.Content, &payload); err != nil { - return err - } - for _, content := range payload.Push { - if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - presence, ok := syncTypes.PresenceFromString(content.Presence) - if !ok { - continue - } - if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { - return err - } - } - return nil -} - -// processReceiptEvent sends receipt events to JetStream -func (t *txnReq) processReceiptEvent(ctx context.Context, - userID, roomID, receiptType string, - timestamp gomatrixserverlib.Timestamp, - eventIDs []string, -) error { - if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { - return nil - } else if serverName == t.ourServerName { - return nil - } else if serverName != t.Origin { - return nil - } - // store every event - for _, eventID := range eventIDs { - if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { - return fmt.Errorf("unable to set receipt event: %w", err) - } - } - - return nil -} diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index b8bfe0221..d7feee0e5 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -1,552 +1,87 @@ -package routing +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test import ( - "context" + "encoding/hex" "encoding/json" - "fmt" + "net/http/httptest" "testing" - "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/api" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") - testDestination = gomatrixserverlib.ServerName("white.orchard") + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") ) -var ( - testRoomVersion = gomatrixserverlib.RoomVersionV1 - testData = []json.RawMessage{ - []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), - // messages - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), - } - testEvents = []*gomatrixserverlib.HeaderedEvent{} - testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) -) +type sendContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} -func init() { - for _, j := range testData { - e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) +func TestHandleSend(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedapi := fedAPI.NewInternalAPI(base, nil, nil, nil, nil, true) + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, nil, nil, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234") + content := sendContent{} + err := req.SetContent(content) if err != nil { - panic("cannot load test data: " + err.Error()) + t.Fatalf("Error: %s", err.Error()) } - h := e.Headered(testRoomVersion) - testEvents = append(testEvents, h) - if e.StateKey() != nil { - testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: e.Type(), - StateKey: *e.StateKey(), - }] = h + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) } - } + vars := map[string]string{"txnID": "1234"} + w := httptest.NewRecorder() + httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + assert.Equal(t, 200, res.StatusCode) + }) } - -type testRoomserverAPI struct { - api.RoomserverInternalAPITrace - inputRoomEvents []api.InputRoomEvent - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse -} - -func (t *testRoomserverAPI) InputRoomEvents( - ctx context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) error { - t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) - for _, ire := range request.InputRoomEvents { - fmt.Println("InputRoomEvents: ", ire.Event.EventID()) - } - return nil -} - -// Query the latest events and state for a room from the room server. -func (t *testRoomserverAPI) QueryLatestEventsAndState( - ctx context.Context, - request *api.QueryLatestEventsAndStateRequest, - response *api.QueryLatestEventsAndStateResponse, -) error { - r := t.queryLatestEventsAndState(request) - response.RoomExists = r.RoomExists - response.RoomVersion = testRoomVersion - response.LatestEvents = r.LatestEvents - response.StateEvents = r.StateEvents - response.Depth = r.Depth - return nil -} - -// Query the state after a list of events in a room from the room server. -func (t *testRoomserverAPI) QueryStateAfterEvents( - ctx context.Context, - request *api.QueryStateAfterEventsRequest, - response *api.QueryStateAfterEventsResponse, -) error { - response.RoomVersion = testRoomVersion - res := t.queryStateAfterEvents(request) - response.PrevEventsExist = res.PrevEventsExist - response.RoomExists = res.RoomExists - response.StateEvents = res.StateEvents - return nil -} - -// Query a list of events by event ID. -func (t *testRoomserverAPI) QueryEventsByID( - ctx context.Context, - request *api.QueryEventsByIDRequest, - response *api.QueryEventsByIDResponse, -) error { - res := t.queryEventsByID(request) - response.Events = res.Events - return nil -} - -// Query if a server is joined to a room -func (t *testRoomserverAPI) QueryServerJoinedToRoom( - ctx context.Context, - request *api.QueryServerJoinedToRoomRequest, - response *api.QueryServerJoinedToRoomResponse, -) error { - response.RoomExists = true - response.IsInRoom = true - return nil -} - -// Asks for the room version for a given room. -func (t *testRoomserverAPI) QueryRoomVersionForRoom( - ctx context.Context, - request *api.QueryRoomVersionForRoomRequest, - response *api.QueryRoomVersionForRoomResponse, -) error { - response.RoomVersion = testRoomVersion - return nil -} - -func (t *testRoomserverAPI) QueryServerBannedFromRoom( - ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, -) error { - res.Banned = false - return nil -} - -type txnFedClient struct { - state map[string]gomatrixserverlib.RespState // event_id to response - stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response - getEvent map[string]gomatrixserverlib.Transaction // event_id to response - getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, -) { - fmt.Println("testFederationClient.LookupState", eventID) - r, ok := c.state[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { - fmt.Println("testFederationClient.LookupStateIDs", eventID) - r, ok := c.stateIDs[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { - fmt.Println("testFederationClient.GetEvent", eventID) - r, ok := c.getEvent[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { - return c.getMissingEvents(missing) -} - -func mustCreateTransaction(rsAPI api.FederationRoomserverAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { - t := &txnReq{ - rsAPI: rsAPI, - keys: &test.NopJSONVerifier{}, - federation: fedClient, - roomsMu: internal.NewMutexByRoom(), - } - t.PDUs = pdus - t.Origin = testOrigin - t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - t.Destination = testDestination - return t -} - -func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) { - res, err := txn.processTransaction(context.Background()) - if err != nil { - t.Errorf("txn.processTransaction returned an error: %v", err) - return - } - if len(res.PDUs) != len(txn.PDUs) { - t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) - return - } -NextPDU: - for eventID, result := range res.PDUs { - if result.Error == "" { - continue - } - for _, eventIDWantError := range pdusWithErrors { - if eventID == eventIDWantError { - break NextPDU - } - } - t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) - } -} - -/* -func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []*gomatrixserverlib.HeaderedEvent) { -NextTuple: - for _, t := range tuples { - for _, o := range omitTuples { - if t == o { - break NextTuple - } - } - h, ok := testStateEvents[t] - if ok { - result = append(result, h) - } - } - return -} -*/ - -func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { - for _, g := range got { - fmt.Println("GOT ", g.Event.EventID()) - } - if len(got) != len(want) { - t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) - return - } - for i := range got { - if got[i].Event.EventID() != want[i].EventID() { - t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) - } - } -} - -// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on -// to the roomserver. It's the most basic test possible. -func TestBasicTransaction(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver -// as it does the auth check. -func TestTransactionFailAuthChecks(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, []string{}) - // expect message to be sent to the roomserver - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, -// we request them from /get_missing_events. It works by setting PrevEventsExist=false in the roomserver query response, -// resulting in a call to /get_missing_events which returns the missing prev event. Both events should be processed in -// topological order and sent to the roomserver. -/* -func TestTransactionFetchMissingPrevEvents(t *testing.T) { - haveEvent := testEvents[len(testEvents)-3] - prevEvent := testEvents[len(testEvents)-2] - inputEvent := testEvents[len(testEvents)-1] - - var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions - rsAPI = &testRoomserverAPI{ - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - res := api.QueryEventsByIDResponse{} - for _, ev := range testEvents { - for _, id := range req.EventIDs { - if ev.EventID() == id { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: true, - StateEvents: testEvents[:5], - } - }, - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - missingPrevEvent := []string{"missing_prev_event"} - if len(req.PrevEventIDs) == 1 { - switch req.PrevEventIDs[0] { - case haveEvent.EventID(): - missingPrevEvent = []string{} - case prevEvent.EventID(): - // we only have this event if we've been send prevEvent - if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { - missingPrevEvent = []string{} - } - } - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: haveEvent.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - haveEvent.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, nil), - } - }, - } - - cli := &txnFedClient{ - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, haveEvent.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{inputEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, inputEvent.EventID()) - } - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - prevEvent.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - inputEvent.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) -} - -// The purpose of this test is to check that when there are missing prev_events and we still haven't been able to fill -// in the hole with /get_missing_events that the state BEFORE the events we want to persist is fetched via /state_ids -// and /event. It works by setting PrevEventsExist=false in the roomserver query response, resulting in -// a call to /get_missing_events which returns 1 out of the 2 events it needs to fill in the gap. Synapse and Dendrite -// both give up after 1x /get_missing_events call, relying on requesting the state AFTER the missing event in order to -// continue. The DAG looks something like: -// FE GME TXN -// A ---> B ---> C ---> D -// TXN=event in the txn, GME=response to /get_missing_events, FE=roomserver's forward extremity. Should result in: -// - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event. -// - state resolution is done to check C is allowed. -// This results in B being sent as an outlier FIRST, then C,D. -func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { - eventA := testEvents[len(testEvents)-5] - // this is also len(testEvents)-4 - eventB := testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }] - eventC := testEvents[len(testEvents)-3] - eventD := testEvents[len(testEvents)-2] - fmt.Println("a:", eventA.EventID()) - fmt.Println("b:", eventB.EventID()) - fmt.Println("c:", eventC.EventID()) - fmt.Println("d:", eventD.EventID()) - var rsAPI *testRoomserverAPI - rsAPI = &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }, - } - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - omitTuples = nil // include event B now - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - var stateEvents []*gomatrixserverlib.HeaderedEvent - if prevEventExists { - stateEvents = fromStateTuples(req.StateToFetch, omitTuples) - } - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: prevEventExists, - RoomExists: true, - StateEvents: stateEvents, - } - }, - - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - - var missingPrevEvent []string - if !prevEventExists { - missingPrevEvent = []string{"test"} - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}, - } - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: eventA.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - eventA.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, omitTuples), - } - }, - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - var res api.QueryEventsByIDResponse - fmt.Println("queryEventsByID ", req.EventIDs) - for _, wantEventID := range req.EventIDs { - for _, ev := range testStateEvents { - // roomserver is missing the power levels event unless it's been sent to us recently as an outlier - if wantEventID == eventB.EventID() { - fmt.Println("Asked for pl event") - for _, inEv := range rsAPI.inputRoomEvents { - fmt.Println("recv ", inEv.Event.EventID()) - if inEv.Event.EventID() == wantEventID { - res.Events = append(res.Events, inEv.Event) - break - } - } - continue - } - if ev.EventID() == wantEventID { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - } - // /state_ids for event B returns every state event but B (it's the state before) - var authEventIDs []string - var stateEventIDs []string - for _, ev := range testStateEvents { - if ev.EventID() == eventB.EventID() { - continue - } - // state res checks what auth events you give it, and this isn't a valid auth event - if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility { - authEventIDs = append(authEventIDs, ev.EventID()) - } - stateEventIDs = append(stateEventIDs, ev.EventID()) - } - cli := &txnFedClient{ - stateIDs: map[string]gomatrixserverlib.RespStateIDs{ - eventB.EventID(): { - StateEventIDs: stateEventIDs, - AuthEventIDs: authEventIDs, - }, - }, - // /event for event B returns it - getEvent: map[string]gomatrixserverlib.Transaction{ - eventB.EventID(): { - PDUs: []json.RawMessage{ - eventB.JSON(), - }, - }, - }, - // /get_missing_events should be done exactly once - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{eventA.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, eventA.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{eventD.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, eventD.EventID()) - } - // just return event C, not event B so /state_ids logic kicks in as there will STILL be missing prev_events - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - eventC.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - eventD.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) -} -*/ diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 0a44375c6..866c09336 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -1,6 +1,7 @@ package statistics import ( + "context" "math" "math/rand" "sync" @@ -28,14 +29,24 @@ type Statistics struct { // just blacklist the host altogether? The backoff is exponential, // so the max time here to attempt is 2**failures seconds. FailuresUntilBlacklist uint32 + + // How many times should we tolerate consecutive failures before we + // mark the destination as offline. At this point we should attempt + // to send messages to the user's async relay servers if we know them. + FailuresUntilAssumedOffline uint32 } -func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics { +func NewStatistics( + db storage.Database, + failuresUntilBlacklist uint32, + failuresUntilAssumedOffline uint32, +) Statistics { return Statistics{ - DB: db, - FailuresUntilBlacklist: failuresUntilBlacklist, - backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), - servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + FailuresUntilAssumedOffline: failuresUntilAssumedOffline, + backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), } } @@ -50,8 +61,9 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS if !found { s.mutex.Lock() server = &ServerStatistics{ - statistics: s, - serverName: serverName, + statistics: s, + serverName: serverName, + knownRelayServers: []gomatrixserverlib.ServerName{}, } s.servers[serverName] = server s.mutex.Unlock() @@ -61,24 +73,49 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS } else { server.blacklisted.Store(blacklisted) } + assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName) + } else { + server.assumedOffline.Store(assumedOffline) + } + + knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName) + } else { + server.relayMutex.Lock() + server.knownRelayServers = knownRelayServers + server.relayMutex.Unlock() + } } return server } +type SendMethod uint8 + +const ( + SendDirect SendMethod = iota + SendViaRelay +) + // ServerStatistics contains information about our interactions with a // remote federated host, e.g. how many times we were successful, how // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - successCounter atomic.Uint32 // how many times have we succeeded? - backoffNotifier func() // notifies destination queue when backoff completes - notifierMutex sync.Mutex + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + assumedOffline atomic.Bool // is the node assumed to be offline + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes + notifierMutex sync.Mutex + knownRelayServers []gomatrixserverlib.ServerName + relayMutex sync.Mutex } const maxJitterMultiplier = 1.4 @@ -113,13 +150,19 @@ func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) { // attempt, which increases the sent counter and resets the idle and // failure counters. If a host was blacklisted at this point then // we will unblacklist it. -func (s *ServerStatistics) Success() { +// `relay` specifies whether the success was to the actual destination +// or one of their relay servers. +func (s *ServerStatistics) Success(method SendMethod) { s.cancel() s.backoffCount.Store(0) - s.successCounter.Inc() - if s.statistics.DB != nil { - if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + // NOTE : Sending to the final destination vs. a relay server has + // slightly different semantics. + if method == SendDirect { + s.successCounter.Inc() + if s.blacklisted.Load() && s.statistics.DB != nil { + if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + } } } } @@ -139,7 +182,18 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { // start a goroutine which will wait out the backoff and // unset the backoffStarted flag when done. if s.backoffStarted.CompareAndSwap(false, true) { - if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist { + backoffCount := s.backoffCount.Inc() + + if backoffCount >= s.statistics.FailuresUntilAssumedOffline { + s.assumedOffline.CompareAndSwap(false, true) + if s.statistics.DB != nil { + if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName) + } + } + } + + if backoffCount >= s.statistics.FailuresUntilBlacklist { s.blacklisted.Store(true) if s.statistics.DB != nil { if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { @@ -157,13 +211,21 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { s.backoffUntil.Store(until) s.statistics.backoffMutex.Lock() - defer s.statistics.backoffMutex.Unlock() s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished) + s.statistics.backoffMutex.Unlock() } return s.backoffUntil.Load().(time.Time), false } +// MarkServerAlive removes the assumed offline and blacklisted statuses from this server. +// Returns whether the server was blacklisted before this point. +func (s *ServerStatistics) MarkServerAlive() bool { + s.removeAssumedOffline() + wasBlacklisted := s.removeBlacklist() + return wasBlacklisted +} + // ClearBackoff stops the backoff timer for this destination if it is running // and removes the timer from the backoffTimers map. func (s *ServerStatistics) ClearBackoff() { @@ -191,13 +253,13 @@ func (s *ServerStatistics) backoffFinished() { } // BackoffInfo returns information about the current or previous backoff. -// Returns the last backoffUntil time and whether the server is currently blacklisted or not. -func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) { +// Returns the last backoffUntil time. +func (s *ServerStatistics) BackoffInfo() *time.Time { until, ok := s.backoffUntil.Load().(time.Time) if ok { - return &until, s.blacklisted.Load() + return &until } - return nil, s.blacklisted.Load() + return nil } // Blacklisted returns true if the server is blacklisted and false @@ -206,10 +268,33 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } -// RemoveBlacklist removes the blacklisted status from the server. -func (s *ServerStatistics) RemoveBlacklist() { +// AssumedOffline returns true if the server is assumed offline and false +// otherwise. +func (s *ServerStatistics) AssumedOffline() bool { + return s.assumedOffline.Load() +} + +// removeBlacklist removes the blacklisted status from the server. +// Returns whether the server was blacklisted. +func (s *ServerStatistics) removeBlacklist() bool { + var wasBlacklisted bool + + if s.Blacklisted() { + wasBlacklisted = true + _ = s.statistics.DB.RemoveServerFromBlacklist(s.serverName) + } s.cancel() s.backoffCount.Store(0) + + return wasBlacklisted +} + +// removeAssumedOffline removes the assumed offline status from the server. +func (s *ServerStatistics) removeAssumedOffline() { + if s.AssumedOffline() { + _ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName) + } + s.assumedOffline.Store(false) } // SuccessCount returns the number of successful requests. This is @@ -217,3 +302,46 @@ func (s *ServerStatistics) RemoveBlacklist() { func (s *ServerStatistics) SuccessCount() uint32 { return s.successCounter.Load() } + +// KnownRelayServers returns the list of relay servers associated with this +// server. +func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName { + s.relayMutex.Lock() + defer s.relayMutex.Unlock() + return s.knownRelayServers +} + +func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) { + seenSet := make(map[gomatrixserverlib.ServerName]bool) + uniqueList := []gomatrixserverlib.ServerName{} + for _, srv := range relayServers { + if seenSet[srv] { + continue + } + seenSet[srv] = true + uniqueList = append(uniqueList, srv) + } + + err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList) + if err != nil { + logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList) + return + } + + for _, newServer := range uniqueList { + alreadyKnown := false + knownRelayServers := s.KnownRelayServers() + for _, srv := range knownRelayServers { + if srv == newServer { + alreadyKnown = true + } + } + if !alreadyKnown { + { + s.relayMutex.Lock() + s.knownRelayServers = append(s.knownRelayServers, newServer) + s.relayMutex.Unlock() + } + } + } +} diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 6aa997f44..183b9aa0c 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -4,17 +4,26 @@ import ( "math" "testing" "time" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 ) func TestBackoff(t *testing.T) { - stats := NewStatistics(nil, 7) + stats := NewStatistics(nil, FailuresUntilBlacklist, FailuresUntilAssumedOffline) server := ServerStatistics{ statistics: &stats, serverName: "test.com", } // Start by checking that counting successes works. - server.Success() + server.Success(SendDirect) if successes := server.SuccessCount(); successes != 1 { t.Fatalf("Expected success count 1, got %d", successes) } @@ -31,9 +40,8 @@ func TestBackoff(t *testing.T) { // side effects since a backoff is already in progress. If it does // then we'll fail. until, blacklisted := server.Failure() - - // Get the duration. - _, blacklist := server.BackoffInfo() + blacklist := server.Blacklisted() + assumedOffline := server.AssumedOffline() duration := time.Until(until) // Unset the backoff, or otherwise our next call will think that @@ -41,16 +49,43 @@ func TestBackoff(t *testing.T) { server.cancel() server.backoffStarted.Store(false) + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assuming the destination was offline but didn't", i) + } + } + + // Check if we should be assumed offline by now. + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assumed offline but didn't", i) + } else { + t.Logf("Backoff %d is assumed offline as expected", i) + } + } else { + if assumedOffline { + t.Fatalf("Backoff %d should not have resulted in assumed offline but did", i) + } else { + t.Logf("Backoff %d is not assumed offline as expected", i) + } + } + // Check if we should be blacklisted by now. if i >= stats.FailuresUntilBlacklist { if !blacklist { t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) } else if blacklist != blacklisted { - t.Fatalf("BackoffInfo and Failure returned different blacklist values") + t.Fatalf("Blacklisted and Failure returned different blacklist values") } else { t.Logf("Backoff %d is blacklisted as expected", i) continue } + } else { + if blacklist { + t.Fatalf("Backoff %d should not have resulted in blacklist but did", i) + } else { + t.Logf("Backoff %d is not blacklisted as expected", i) + } } // Check if the duration is what we expect. @@ -69,3 +104,14 @@ func TestBackoff(t *testing.T) { } } } + +func TestRelayServersListing(t *testing.T) { + stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline) + server := ServerStatistics{statistics: &stats} + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers := server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers = server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) +} diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 2b4d905fc..4f5300af1 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -20,11 +20,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" ) type Database interface { + P2PDatabase gomatrixserverlib.KeyDatabase UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) @@ -34,16 +35,16 @@ type Database interface { // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) - StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) + StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error) - GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) - GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) + GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) + GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error - AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error - CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error + CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error + CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) @@ -54,6 +55,18 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) + // Adds the server to the list of assumed offline servers. + // If the server already exists in the table, nothing happens and returns success. + SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Removes the server from the list of assumed offline servers. + // If the server doesn't exist in the table, nothing happens and returns success. + RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Purges all entries from the assumed offline table. + RemoveAllServersAssumedOffline(ctx context.Context) error + // Gets whether the provided server is present in the table. + // If it is present, returns true. If not, returns false. + IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error) + AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) @@ -74,3 +87,21 @@ type Database interface { PurgeRoom(ctx context.Context, roomID string) error } + +type P2PDatabase interface { + // Stores the given list of servers as relay servers for the provided destination server. + // Providing duplicates will only lead to a single entry and won't lead to an error. + P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Get the list of relay servers associated with the provided destination server. + // If no entry exists in the table, an empty list is returned and does not result in an error. + P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + + // Deletes any entries for the provided destination server that match the provided relayServers list. + // If any of the provided servers don't match an entry, nothing happens and no error is returned. + P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Deletes all entries for the provided destination server. + // If the destination server doesn't exist in the table, nothing happens and no error is returned. + P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error +} diff --git a/federationapi/storage/postgres/assumed_offline_table.go b/federationapi/storage/postgres/assumed_offline_table.go new file mode 100644 index 000000000..5695d2e54 --- /dev/null +++ b/federationapi/storage/postgres/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "TRUNCATE federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/postgres/relay_servers_table.go b/federationapi/storage/postgres/relay_servers_table.go new file mode 100644 index 000000000..f7267978f --- /dev/null +++ b/federationapi/storage/postgres/relay_servers_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + deleteRelayServersStmt *sql.Stmt + deleteAllRelayServersStmt *sql.Stmt +} + +func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteRelayServersStmt, deleteRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers)) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index fe84e932e..b81f128e7 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -62,6 +62,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewPostgresAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewPostgresRelayServersTable(d.db) + if err != nil { + return nil, err + } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err @@ -104,6 +112,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationInboundPeeks: inboundPeeks, FederationOutboundPeeks: outboundPeeks, NotaryServerKeysJSON: notaryJSON, diff --git a/federationapi/storage/shared/receipt/receipt.go b/federationapi/storage/shared/receipt/receipt.go new file mode 100644 index 000000000..b347269c1 --- /dev/null +++ b/federationapi/storage/shared/receipt/receipt.go @@ -0,0 +1,42 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// A Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. +// We don't actually export the NIDs but we need the caller to be able +// to pass them back so that we can clean up if the transaction sends +// successfully. + +package receipt + +import "fmt" + +// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry +// in some database table. +// The internal nid value cannot be modified after a Receipt has been created. +// This guarantees a receipt will always refer to the same table entry that it was created +// to represent. +type Receipt struct { + nid int64 +} + +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + +func (r *Receipt) GetNID() int64 { + return r.nid +} + +func (r *Receipt) String() string { + return fmt.Sprintf("%d", r.nid) +} diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 6cda55725..6769637bc 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal/caching" @@ -37,6 +38,8 @@ type Database struct { FederationQueueJSON tables.FederationQueueJSON FederationJoinedHosts tables.FederationJoinedHosts FederationBlacklist tables.FederationBlacklist + FederationAssumedOffline tables.FederationAssumedOffline + FederationRelayServers tables.FederationRelayServers FederationOutboundPeeks tables.FederationOutboundPeeks FederationInboundPeeks tables.FederationInboundPeeks NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON @@ -44,22 +47,6 @@ type Database struct { ServerSigningKeys tables.FederationServerSigningKeys } -// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. -// We don't actually export the NIDs but we need the caller to be able -// to pass them back so that we can clean up if the transaction sends -// successfully. -type Receipt struct { - nid int64 -} - -func NewReceipt(nid int64) Receipt { - return Receipt{nid: nid} -} - -func (r *Receipt) String() string { - return fmt.Sprintf("%d", r.nid) -} - // UpdateRoom updates the joined hosts for a room and returns what the joined // hosts were before the update, or nil if this was a duplicate message. // This is called when we receive a message from kafka, so we pass in @@ -113,11 +100,18 @@ func (d *Database) GetJoinedHosts( // GetAllJoinedHosts returns the currently joined hosts for // all rooms known to the federation sender. // Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetJoinedHostsForRooms( + ctx context.Context, + roomIDs []string, + excludeSelf, + excludeBlacklisted bool, +) ([]gomatrixserverlib.ServerName, error) { servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err @@ -139,7 +133,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, // metadata entries. func (d *Database) StoreJSON( ctx context.Context, js string, -) (*Receipt, error) { +) (*receipt.Receipt, error) { var nid int64 var err error _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -149,18 +143,21 @@ func (d *Database) StoreJSON( if err != nil { return nil, fmt.Errorf("d.insertQueueJSON: %w", err) } - return &Receipt{ - nid: nid, - }, nil + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil } -func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) }) } -func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) }) @@ -172,51 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error { }) } -func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } -func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveAllServersAssumedOffline( + ctx context.Context, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn) + }) +} + +func (d *Database) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName) +} + +func (d *Database) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName) +} + +func (d *Database) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PRemoveAllRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName) + }) +} + +func (d *Database) AddOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { +func (d *Database) GetOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID, + peekID string, +) (*types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { +func (d *Database) GetOutboundPeeks( + ctx context.Context, + roomID string, +) ([]types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID) } -func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) AddInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { +func (d *Database) GetInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, +) (*types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { +func (d *Database) GetInboundPeeks( + ctx context.Context, + roomID string, +) ([]types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID) } -func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { +func (d *Database) UpdateNotaryKeys( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + serverKeys gomatrixserverlib.ServerKeys, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { validUntil := serverKeys.ValidUntilTS // Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid. @@ -251,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv } func (d *Database) GetNotaryKeys( - ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID, + ctx context.Context, + serverName gomatrixserverlib.ServerName, + optKeyIDs []gomatrixserverlib.KeyID, ) (sks []gomatrixserverlib.ServerKeys, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs) diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index be8355f31..cff1ade6f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -22,6 +22,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{ func (d *Database) AssociateEDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, ) error { @@ -62,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations( var err error for destination := range destinations { err = d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ) } return err @@ -81,10 +82,10 @@ func (d *Database) GetPendingEDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - edus map[*Receipt]*gomatrixserverlib.EDU, + edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error, ) { - edus = make(map[*Receipt]*gomatrixserverlib.EDU) + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) if err != nil { @@ -94,7 +95,8 @@ func (d *Database) GetPendingEDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok { - edus[&Receipt{nid}] = edu + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = edu } else { retrieve = append(retrieve, nid) } @@ -110,7 +112,8 @@ func (d *Database) GetPendingEDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - edus[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = &event d.Cache.StoreFederationQueuedEDU(nid, &event) } @@ -124,7 +127,7 @@ func (d *Database) GetPendingEDUs( func (d *Database) CleanEDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -132,7 +135,7 @@ func (d *Database) CleanEDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index da4cb979d..854e00553 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -30,17 +31,17 @@ import ( func (d *Database) AssociatePDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error for destination := range destinations { err = d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - "", // transaction ID - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table ) } return err @@ -54,7 +55,7 @@ func (d *Database) GetPendingPDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - events map[*Receipt]*gomatrixserverlib.HeaderedEvent, + events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error, ) { // Strictly speaking this doesn't need to be using the writer @@ -62,7 +63,7 @@ func (d *Database) GetPendingPDUs( // a guarantee of transactional isolation, it's actually useful // to know in SQLite mode that nothing else is trying to modify // the database. - events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent) + events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit) if err != nil { @@ -72,7 +73,8 @@ func (d *Database) GetPendingPDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok { - events[&Receipt{nid}] = event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = event } else { retrieve = append(retrieve, nid) } @@ -88,7 +90,8 @@ func (d *Database) GetPendingPDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - events[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = &event d.Cache.StoreFederationQueuedPDU(nid, &event) } @@ -103,7 +106,7 @@ func (d *Database) GetPendingPDUs( func (d *Database) CleanPDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -111,7 +114,7 @@ func (d *Database) CleanPDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/sqlite3/assumed_offline_table.go b/federationapi/storage/sqlite3/assumed_offline_table.go new file mode 100644 index 000000000..ff2afb4da --- /dev/null +++ b/federationapi/storage/sqlite3/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/sqlite3/relay_servers_table.go b/federationapi/storage/sqlite3/relay_servers_table.go new file mode 100644 index 000000000..27c3cca2c --- /dev/null +++ b/federationapi/storage/sqlite3/relay_servers_table.go @@ -0,0 +1,148 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + // deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic + deleteAllRelayServersStmt *sql.Stmt +} + +func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1) + deleteStmt, err := s.db.Prepare(deleteSQL) + if err != nil { + return err + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + params := make([]interface{}, len(relayServers)+1) + params[0] = serverName + for i, v := range relayServers { + params[i+1] = v + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index d13b5defc..1e7e41a2c 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -1,5 +1,4 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -61,6 +60,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewSQLiteRelayServersTable(d.db) + if err != nil { + return nil, err + } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err @@ -103,6 +110,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationOutboundPeeks: outboundPeeks, FederationInboundPeeks: inboundPeeks, NotaryServerKeysJSON: notaryKeys, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 5b57d40d4..1d2a13e81 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -6,14 +6,13 @@ import ( "testing" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/stretchr/testify/assert" - "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { @@ -246,3 +245,99 @@ func TestInboundPeeking(t *testing.T) { assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } + +func TestServersAssumedOffline(t *testing.T) { + server1 := gomatrixserverlib.ServerName("server1") + server2 := gomatrixserverlib.ServerName("server2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + // Set server1 & server2 as assumed offline. + err := db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + err = db.SetServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + + // Ensure both servers are assumed offline. + isOffline, err := db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Set server1 as not assumed offline. + err = db.RemoveServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Re-set server1 as assumed offline. + err = db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure server1 is assumed offline. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + + err = db.RemoveAllServersAssumedOffline(context.Background()) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.False(t, isOffline) + }) +} + +func TestRelayServersStored(t *testing.T) { + server := gomatrixserverlib.ServerName("server") + relayServer1 := gomatrixserverlib.ServerName("relayserver1") + relayServer2 := gomatrixserverlib.ServerName("relayserver2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + + err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + + err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + assert.Equal(t, relayServer2, relayServers[1]) + + err = db.P2PRemoveAllRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 2b36edb46..762504e45 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -49,6 +49,19 @@ type FederationQueueJSON interface { SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) } +type FederationQueueTransactions interface { + InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +type FederationTransactionJSON interface { + InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} + type FederationJoinedHosts interface { InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error @@ -66,6 +79,20 @@ type FederationBlacklist interface { DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } +type FederationAssumedOffline interface { + InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) + DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error +} + +type FederationRelayServers interface { + InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error +} + type FederationOutboundPeeks interface { InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go new file mode 100644 index 000000000..b41211551 --- /dev/null +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -0,0 +1,224 @@ +package tables_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + server1 = "server1" + server2 = "server2" + server3 = "server3" + server4 = "server4" +) + +type RelayServersDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.FederationRelayServers +} + +func mustCreateRelayServersTable( + t *testing.T, + dbType test.DBType, +) (database RelayServersDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.FederationRelayServers + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayServersTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayServersTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayServersDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func Equal(a, b []gomatrixserverlib.ServerName) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func TestShouldInsertRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldInsertRelayServersWithDuplicates(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2} + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + // Insert the same list again, this shouldn't fail and should have no effect. + err = db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldGetRelayServersUnknownDestination(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + + // Query relay servers for a destination that doesn't exist in the table. + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, []gomatrixserverlib.ServerName{}) { + t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers) + } + }) +} + +func TestShouldDeleteCorrectRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + relayServers1 := []gomatrixserverlib.ServerName{server2, server3} + relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4} + + err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error()) + } + + expectedRelayServers := []gomatrixserverlib.ServerName{server3} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldDeleteAllRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteAllRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + + expectedRelayServers1 := []gomatrixserverlib.ServerName{} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers1) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} diff --git a/go.mod b/go.mod index a86dd2cb8..871e94eb3 100644 --- a/go.mod +++ b/go.mod @@ -22,9 +22,9 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 - github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 - github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f + github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb + github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.15 github.com/nats-io/nats-server/v2 v2.9.8 github.com/nats-io/nats.go v1.20.0 @@ -37,17 +37,17 @@ require ( github.com/prometheus/client_golang v1.13.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.1 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.1.0 + golang.org/x/crypto v0.5.0 golang.org/x/image v0.1.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e - golang.org/x/net v0.1.0 - golang.org/x/term v0.1.0 + golang.org/x/net v0.5.0 + golang.org/x/term v0.4.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 @@ -119,12 +119,12 @@ require ( github.com/prometheus/procfs v0.8.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect go.etcd.io/bbolt v1.3.6 // indirect golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 // indirect golang.org/x/mod v0.6.0 // indirect - golang.org/x/sys v0.1.0 // indirect - golang.org/x/text v0.4.0 // indirect + golang.org/x/sys v0.4.0 // indirect + golang.org/x/text v0.6.0 // indirect golang.org/x/time v0.1.0 // indirect golang.org/x/tools v0.2.0 // indirect google.golang.org/protobuf v1.28.1 // indirect diff --git a/go.sum b/go.sum index e5cd67bed..1ca3e8a80 100644 --- a/go.sum +++ b/go.sum @@ -348,16 +348,12 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45 h1:zGrmcm2M4F4f+zk5JXAkw3oHa/zXhOh5XVGBdl7GdPo= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 h1:P7me2oCmksST9B4+1I1nA+XrnDQwIqAWmy6ntQrXwc8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f h1:niRWEVkeeekpjxwnMhKn8PD0PUloDsNXP8W+Ez/co/M= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb h1:2L+ltfNKab56FoBBqAvbBLjoAbxwwoZie+B8d+Mp3JI= +github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -494,12 +490,13 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= @@ -543,8 +540,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -625,8 +622,8 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= +golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -701,12 +698,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.4.0 h1:O7UWfv5+A2qiuulQk30kVinPoMtoIPeVaKLEgLpVkvg= +golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -714,8 +711,9 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/log.go b/internal/log.go index da6e20418..9e8656c5b 100644 --- a/internal/log.go +++ b/internal/log.go @@ -101,6 +101,8 @@ func SetupPprof() { // SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded. func SetupStdLogging() { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() logrus.SetReportCaller(true) logrus.SetFormatter(&utcFormatter{ &logrus.TextFormatter{ diff --git a/internal/log_unix.go b/internal/log_unix.go index 8f34c320d..859427041 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -32,6 +32,8 @@ import ( // If something fails here it means that the logging was improperly configured, // so we just exit with the error func SetupHookLogging(hooks []config.LogrusHook, componentName string) { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() for _, hook := range hooks { // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) @@ -85,8 +87,6 @@ func checkSyslogHookParams(params map[string]interface{}) { } func setupStdLogHook(level logrus.Level) { - levelLogAddedMu.Lock() - defer levelLogAddedMu.Unlock() if stdLevelLogAdded[level] { return } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go new file mode 100644 index 000000000..95673fc14 --- /dev/null +++ b/internal/transactionrequest.go @@ -0,0 +1,356 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/federationapi/types" + keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/roomserver/api" + syncTypes "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" +) + +var ( + PDUCountTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_pdus", + Help: "Number of incoming PDUs from remote servers with labels for success", + }, + []string{"status"}, // 'success' or 'total' + ) + EDUCountTotal = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_edus", + Help: "Number of incoming EDUs from remote servers", + }, + ) +) + +type TxnReq struct { + gomatrixserverlib.Transaction + rsAPI api.FederationRoomserverAPI + keyAPI keyapi.FederationKeyAPI + ourServerName gomatrixserverlib.ServerName + keys gomatrixserverlib.JSONVerifier + roomsMu *MutexByRoom + producer *producers.SyncAPIProducer + inboundPresenceEnabled bool +} + +func NewTxnReq( + rsAPI api.FederationRoomserverAPI, + keyAPI keyapi.FederationKeyAPI, + ourServerName gomatrixserverlib.ServerName, + keys gomatrixserverlib.JSONVerifier, + roomsMu *MutexByRoom, + producer *producers.SyncAPIProducer, + inboundPresenceEnabled bool, + pdus []json.RawMessage, + edus []gomatrixserverlib.EDU, + origin gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, + destination gomatrixserverlib.ServerName, +) TxnReq { + t := TxnReq{ + rsAPI: rsAPI, + keyAPI: keyAPI, + ourServerName: ourServerName, + keys: keys, + roomsMu: roomsMu, + producer: producer, + inboundPresenceEnabled: inboundPresenceEnabled, + } + + t.PDUs = pdus + t.EDUs = edus + t.Origin = origin + t.TransactionID = transactionID + t.Destination = destination + + return t +} + +func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if t.producer != nil { + t.processEDUs(ctx) + } + }() + + results := make(map[string]gomatrixserverlib.PDUResult) + roomVersions := make(map[string]gomatrixserverlib.RoomVersion) + getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { + if v, ok := roomVersions[roomID]; ok { + return v + } + verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} + verRes := api.QueryRoomVersionForRoomResponse{} + if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) + return "" + } + roomVersions[roomID] = verRes.RoomVersion + return verRes.RoomVersion + } + + for _, pdu := range t.PDUs { + PDUCountTotal.WithLabelValues("total").Inc() + var header struct { + RoomID string `json:"room_id"` + } + if err := json.Unmarshal(pdu, &header); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") + // We don't know the event ID at this point so we can't return the + // failure in the PDU results + continue + } + roomVersion := getRoomVersion(header.RoomID) + event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + if err != nil { + if _, ok := err.(gomatrixserverlib.BadJSONError); ok { + // Room version 6 states that homeservers should strictly enforce canonical JSON + // on PDUs. + // + // This enforces that the entire transaction is rejected if a single bad PDU is + // sent. It is unclear if this is the correct behaviour or not. + // + // See https://github.com/matrix-org/synapse/issues/7543 + return nil, &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("PDU contains bad JSON"), + } + } + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + continue + } + if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + continue + } + if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: "Forbidden by server ACLs", + } + continue + } + if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + // pass the event to the roomserver which will do auth checks + // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently + // discarded by the caller of this function + if err = api.SendEvents( + ctx, + t.rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + event.Headered(roomVersion), + }, + t.Destination, + t.Origin, + api.DoNotSendToOtherServers, + nil, + true, + ); err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + results[event.EventID()] = gomatrixserverlib.PDUResult{} + PDUCountTotal.WithLabelValues("success").Inc() + } + + wg.Wait() + return &gomatrixserverlib.RespSend{PDUs: results}, nil +} + +// nolint:gocyclo +func (t *TxnReq) processEDUs(ctx context.Context) { + for _, e := range t.EDUs { + EDUCountTotal.Inc() + switch e.Type { + case gomatrixserverlib.MTyping: + // https://matrix.org/docs/spec/server_server/latest#typing-notifications + var typingPayload struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + Typing bool `json:"typing"` + } + if err := json.Unmarshal(e.Content, &typingPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") + } + case gomatrixserverlib.MDirectToDevice: + // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema + var directPayload gomatrixserverlib.ToDeviceMessage + if err := json.Unmarshal(e.Content, &directPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + for userID, byUser := range directPayload.Messages { + for deviceID, message := range byUser { + // TODO: check that the user and the device actually exist here + if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": directPayload.Sender, + "user_id": userID, + "device_id": deviceID, + }).Error("Failed to send send-to-device event to JetStream") + } + } + } + case gomatrixserverlib.MDeviceListUpdate: + if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") + } + case gomatrixserverlib.MReceipt: + // https://matrix.org/docs/spec/server_server/r0.1.4#receipts + payload := map[string]types.FederationReceiptMRead{} + + if err := json.Unmarshal(e.Content, &payload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") + continue + } + + for roomID, receipt := range payload { + for userID, mread := range receipt.User { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") + continue + } + if t.Origin != domain { + util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + continue + } + if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": t.Origin, + "user_id": userID, + "room_id": roomID, + "events": mread.EventIDs, + }).Error("Failed to send receipt event to JetStream") + continue + } + } + } + case types.MSigningKeyUpdate: + if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + logrus.WithError(err).Errorf("Failed to process signing key update") + } + case gomatrixserverlib.MPresence: + if t.inboundPresenceEnabled { + if err := t.processPresence(ctx, e); err != nil { + logrus.WithError(err).Errorf("Failed to process presence update") + } + } + default: + util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") + } + } +} + +// processPresence handles m.receipt events +func (t *TxnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { + payload := types.Presence{} + if err := json.Unmarshal(e.Content, &payload); err != nil { + return err + } + for _, content := range payload.Push { + if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + presence, ok := syncTypes.PresenceFromString(content.Presence) + if !ok { + continue + } + if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { + return err + } + } + return nil +} + +// processReceiptEvent sends receipt events to JetStream +func (t *TxnReq) processReceiptEvent(ctx context.Context, + userID, roomID, receiptType string, + timestamp gomatrixserverlib.Timestamp, + eventIDs []string, +) error { + if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { + return nil + } else if serverName == t.ourServerName { + return nil + } else if serverName != t.Origin { + return nil + } + // store every event + for _, eventID := range eventIDs { + if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { + return fmt.Errorf("unable to set receipt event: %w", err) + } + } + + return nil +} diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go new file mode 100644 index 000000000..dd1bd3502 --- /dev/null +++ b/internal/transactionrequest_test.go @@ -0,0 +1,820 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/producers" + keyAPI "github.com/matrix-org/dendrite/keyserver/api" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "gotest.tools/v3/poll" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testDestination = gomatrixserverlib.ServerName("white.orchard") +) + +var ( + invalidSignatures = json.RawMessage(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localishhost","sender":"@userid:localhost","signatures":{"localhost":{"ed2559:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiaQiWAQ"}},"type":"m.room.member"}`) + testData = []json.RawMessage{ + []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), + // messages + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), + } + testEvent = []byte(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`) + testRoomVersion = gomatrixserverlib.RoomVersionV1 + testEvents = []*gomatrixserverlib.HeaderedEvent{} + testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) +) + +type FakeRsAPI struct { + rsAPI.RoomserverInternalAPI + shouldFailQuery bool + bannedFromRoom bool + shouldEventsFail bool +} + +func (r *FakeRsAPI) QueryRoomVersionForRoom( + ctx context.Context, + req *rsAPI.QueryRoomVersionForRoomRequest, + res *rsAPI.QueryRoomVersionForRoomResponse, +) error { + if r.shouldFailQuery { + return fmt.Errorf("Failure") + } + res.RoomVersion = gomatrixserverlib.RoomVersionV10 + return nil +} + +func (r *FakeRsAPI) QueryServerBannedFromRoom( + ctx context.Context, + req *rsAPI.QueryServerBannedFromRoomRequest, + res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + if r.bannedFromRoom { + res.Banned = true + } else { + res.Banned = false + } + return nil +} + +func (r *FakeRsAPI) InputRoomEvents( + ctx context.Context, + req *rsAPI.InputRoomEventsRequest, + res *rsAPI.InputRoomEventsResponse, +) error { + if r.shouldEventsFail { + return fmt.Errorf("Failure") + } + return nil +} + +func TestEmptyTransactionRequest(t *testing.T) { + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", nil, nil, nil, false, []json.RawMessage{}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDU(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUs(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, append(testData, testEvent), []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestBadPDU(t *testing.T) { + pdu := json.RawMessage("{\"room_id\":\"asdf\"}") + pdu2 := json.RawMessage("\"roomid\":\"asdf\"") + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{pdu, pdu2, testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUQueryFailure(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldFailQuery: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDUBannedFromRoom(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{bannedFromRoom: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUInvalidSignature(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{invalidSignatures}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUSendFail(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldEventsFail: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func createTransactionWithEDU(ctx *process.ProcessContext, edus []gomatrixserverlib.EDU) (TxnReq, nats.JetStreamContext, *config.Dendrite) { + cfg := &config.Dendrite{} + cfg.Defaults(config.DefaultOpts{ + Generate: true, + Monolithic: true, + }) + cfg.Global.JetStream.InMemory = true + natsInstance := &jetstream.NATSInstance{} + js, _ := natsInstance.Prepare(ctx, &cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &cfg.FederationAPI, + UserAPI: nil, + } + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, producer, true, []json.RawMessage{}, edus, "kaer.morhen", "", "ourserver") + return txn, js, cfg +} + +func TestProcessTransactionRequestEDUTyping(t *testing.T) { + var err error + roomID := "!roomid:kaer.morhen" + userID := "@userid:kaer.morhen" + typing := true + edu := gomatrixserverlib.EDU{Type: "m.typing"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": roomID, + "user_id": userID, + "typing": typing, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.typing"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + room := msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, room) + user := msg.Header.Get(jetstream.UserID) + assert.Equal(t, userID, user) + typ, parseErr := strconv.ParseBool(msg.Header.Get("typing")) + if parseErr != nil { + return true + } + assert.Equal(t, typing, typ) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + cfg.Global.JetStream.Durable("TestTypingConsumer"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUToDevice(t *testing.T) { + var err error + sender := "@userid:kaer.morhen" + messageID := "$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg" + msgType := "m.dendrite.test" + edu := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "sender": sender, + "type": msgType, + "message_id": messageID, + "messages": map[string]interface{}{ + "@alice:example.org": map[string]interface{}{ + "IWHQUZUIAH": map[string]interface{}{ + "algorithm": "m.megolm.v1.aes-sha2", + "room_id": "!Cuyf34gef24t:localhost", + "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ", + "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY...", + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputSendToDeviceEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, sender, output.Sender) + assert.Equal(t, msgType, output.Type) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + cfg.Global.JetStream.Durable("TestToDevice"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUDeviceListUpdate(t *testing.T) { + var err error + deviceID := "QBUAZIFURK" + userID := "@john:example.com" + edu := gomatrixserverlib.EDU{Type: "m.device_list_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "device_display_name": "Mobile", + "device_id": deviceID, + "key": "value", + "keys": map[string]interface{}{ + "algorithms": []string{ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", + }, + "device_id": "JLAFKJWSCS", + "keys": map[string]interface{}{ + "curve25519:JLAFKJWSCS": "3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI", + "ed25519:JLAFKJWSCS": "lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI", + }, + "signatures": map[string]interface{}{ + "@alice:example.com": map[string]interface{}{ + "ed25519:JLAFKJWSCS": "dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA", + }, + }, + "user_id": "@alice:example.com", + }, + "prev_id": []int{ + 5, + }, + "stream_id": 6, + "user_id": userID, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.device_list_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output gomatrixserverlib.DeviceListUpdateEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, userID, output.UserID) + assert.Equal(t, deviceID, output.DeviceID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + cfg.Global.JetStream.Durable("TestDeviceListUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUReceipt(t *testing.T) { + var err error + roomID := "!some_room:example.org" + edu := gomatrixserverlib.EDU{Type: "m.receipt"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:kaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.receipt"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badUser := gomatrixserverlib.EDU{Type: "m.receipt"} + if badUser.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "johnkaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badDomain := gomatrixserverlib.EDU{Type: "m.receipt"} + if badDomain.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:bad.domain": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + edus := []gomatrixserverlib.EDU{badEDU, badUser, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputReceiptEvent + output.RoomID = msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, output.RoomID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + cfg.Global.JetStream.Durable("TestReceipt"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUSigningKeyUpdate(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output keyAPI.CrossSigningKeyUpdate + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + cfg.Global.JetStream.Durable("TestSigningKeyUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUPresence(t *testing.T) { + var err error + userID := "@john:kaer.morhen" + presence := "online" + edu := gomatrixserverlib.EDU{Type: "m.presence"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "push": []map[string]interface{}{{ + "currently_active": true, + "last_active_ago": 5000, + "presence": presence, + "status_msg": "Making cupcakes", + "user_id": userID, + }}, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.presence"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + userIDRes := msg.Header.Get(jetstream.UserID) + presenceRes := msg.Header.Get("presence") + assert.Equal(t, userID, userIDRes) + assert.Equal(t, presence, presenceRes) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + cfg.Global.JetStream.Durable("TestPresence"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUUnhandled(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.unhandled"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, _, _ := createTransactionWithEDU(ctx, []gomatrixserverlib.EDU{edu}) + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func init() { + for _, j := range testData { + e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) + if err != nil { + panic("cannot load test data: " + err.Error()) + } + h := e.Headered(testRoomVersion) + testEvents = append(testEvents, h) + if e.StateKey() != nil { + testStateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = h + } + } +} + +type testRoomserverAPI struct { + rsAPI.RoomserverInternalAPITrace + inputRoomEvents []rsAPI.InputRoomEvent + queryStateAfterEvents func(*rsAPI.QueryStateAfterEventsRequest) rsAPI.QueryStateAfterEventsResponse + queryEventsByID func(req *rsAPI.QueryEventsByIDRequest) rsAPI.QueryEventsByIDResponse + queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse +} + +func (t *testRoomserverAPI) InputRoomEvents( + ctx context.Context, + request *rsAPI.InputRoomEventsRequest, + response *rsAPI.InputRoomEventsResponse, +) error { + t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) + for _, ire := range request.InputRoomEvents { + fmt.Println("InputRoomEvents: ", ire.Event.EventID()) + } + return nil +} + +// Query the latest events and state for a room from the room server. +func (t *testRoomserverAPI) QueryLatestEventsAndState( + ctx context.Context, + request *rsAPI.QueryLatestEventsAndStateRequest, + response *rsAPI.QueryLatestEventsAndStateResponse, +) error { + r := t.queryLatestEventsAndState(request) + response.RoomExists = r.RoomExists + response.RoomVersion = testRoomVersion + response.LatestEvents = r.LatestEvents + response.StateEvents = r.StateEvents + response.Depth = r.Depth + return nil +} + +// Query the state after a list of events in a room from the room server. +func (t *testRoomserverAPI) QueryStateAfterEvents( + ctx context.Context, + request *rsAPI.QueryStateAfterEventsRequest, + response *rsAPI.QueryStateAfterEventsResponse, +) error { + response.RoomVersion = testRoomVersion + res := t.queryStateAfterEvents(request) + response.PrevEventsExist = res.PrevEventsExist + response.RoomExists = res.RoomExists + response.StateEvents = res.StateEvents + return nil +} + +// Query a list of events by event ID. +func (t *testRoomserverAPI) QueryEventsByID( + ctx context.Context, + request *rsAPI.QueryEventsByIDRequest, + response *rsAPI.QueryEventsByIDResponse, +) error { + res := t.queryEventsByID(request) + response.Events = res.Events + return nil +} + +// Query if a server is joined to a room +func (t *testRoomserverAPI) QueryServerJoinedToRoom( + ctx context.Context, + request *rsAPI.QueryServerJoinedToRoomRequest, + response *rsAPI.QueryServerJoinedToRoomResponse, +) error { + response.RoomExists = true + response.IsInRoom = true + return nil +} + +// Asks for the room version for a given room. +func (t *testRoomserverAPI) QueryRoomVersionForRoom( + ctx context.Context, + request *rsAPI.QueryRoomVersionForRoomRequest, + response *rsAPI.QueryRoomVersionForRoomResponse, +) error { + response.RoomVersion = testRoomVersion + return nil +} + +func (t *testRoomserverAPI) QueryServerBannedFromRoom( + ctx context.Context, req *rsAPI.QueryServerBannedFromRoomRequest, res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + res.Banned = false + return nil +} + +func mustCreateTransaction(rsAPI rsAPI.FederationRoomserverAPI, pdus []json.RawMessage) *TxnReq { + t := NewTxnReq( + rsAPI, + nil, + "", + &test.NopJSONVerifier{}, + NewMutexByRoom(), + nil, + false, + pdus, + nil, + testOrigin, + gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())), + testDestination) + t.PDUs = pdus + t.Origin = testOrigin + t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + t.Destination = testDestination + return &t +} + +func mustProcessTransaction(t *testing.T, txn *TxnReq, pdusWithErrors []string) { + res, err := txn.ProcessTransaction(context.Background()) + if err != nil { + t.Errorf("txn.processTransaction returned an error: %v", err) + return + } + if len(res.PDUs) != len(txn.PDUs) { + t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) + return + } +NextPDU: + for eventID, result := range res.PDUs { + if result.Error == "" { + continue + } + for _, eventIDWantError := range pdusWithErrors { + if eventID == eventIDWantError { + break NextPDU + } + } + t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) + } +} + +func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { + for _, g := range got { + fmt.Println("GOT ", g.Event.EventID()) + } + if len(got) != len(want) { + t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) + return + } + for i := range got { + if got[i].Event.EventID() != want[i].EventID() { + t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) + } + } +} + +// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on +// to the roomserver. It's the most basic test possible. +func TestBasicTransaction(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} + +// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver +// as it does the auth check. +func TestTransactionFailAuthChecks(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, []string{}) + // expect message to be sent to the roomserver + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 9dcfa955f..50af2f884 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -108,13 +108,16 @@ func makeDownloadAPI( activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) http.HandlerFunc { - counterVec := promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: name, - Help: "Total number of media_api requests for either thumbnails or full downloads", - }, - []string{"code"}, - ) + var counterVec *prometheus.CounterVec + if cfg.Matrix.Metrics.Enabled { + counterVec = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: name, + Help: "Total number of media_api requests for either thumbnails or full downloads", + }, + []string{"code"}, + ) + } httpHandler := func(w http.ResponseWriter, req *http.Request) { req = util.RequestWithLogging(req) @@ -166,5 +169,12 @@ func makeDownloadAPI( vars["downloadName"], ) } - return promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + + var handlerFunc http.HandlerFunc + if counterVec != nil { + handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + } else { + handlerFunc = http.HandlerFunc(httpHandler) + } + return handlerFunc } diff --git a/relayapi/api/api.go b/relayapi/api/api.go new file mode 100644 index 000000000..9db393225 --- /dev/null +++ b/relayapi/api/api.go @@ -0,0 +1,56 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayInternalAPI is used to query information from the relay server. +type RelayInternalAPI interface { + RelayServerAPI + + // Retrieve from external relay server all transactions stored for us and process them. + PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, + ) error +} + +// RelayServerAPI exposes the store & query transaction functionality of a relay server. +type RelayServerAPI interface { + // Store transactions for forwarding to the destination at a later time. + PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, + ) error + + // Obtain the oldest stored transaction for the specified userID. + QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, + ) (QueryRelayTransactionsResponse, error) +} + +type QueryRelayTransactionsResponse struct { + Transaction gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id"` + EntriesQueued bool `json:"entries_queued"` +} diff --git a/relayapi/internal/api.go b/relayapi/internal/api.go new file mode 100644 index 000000000..3ff8c2add --- /dev/null +++ b/relayapi/internal/api.go @@ -0,0 +1,53 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type RelayInternalAPI struct { + db storage.Database + fedClient fedAPI.FederationClient + rsAPI rsAPI.RoomserverInternalAPI + keyRing *gomatrixserverlib.KeyRing + producer *producers.SyncAPIProducer + presenceEnabledInbound bool + serverName gomatrixserverlib.ServerName +} + +func NewRelayInternalAPI( + db storage.Database, + fedClient fedAPI.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, + presenceEnabledInbound bool, + serverName gomatrixserverlib.ServerName, +) *RelayInternalAPI { + return &RelayInternalAPI{ + db: db, + fedClient: fedClient, + rsAPI: rsAPI, + keyRing: keyRing, + producer: producer, + presenceEnabledInbound: presenceEnabledInbound, + serverName: serverName, + } +} diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go new file mode 100644 index 000000000..594299334 --- /dev/null +++ b/relayapi/internal/perform.go @@ -0,0 +1,141 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// PerformRelayServerSync implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, +) error { + // Providing a default RelayEntry (EntryID = 0) is done to ask the relay if there are any + // transactions available for this node. + prevEntry := gomatrixserverlib.RelayEntry{} + asyncResponse, err := r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Txn) + + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + for asyncResponse.EntriesQueued { + // There are still more entries available for this node from the relay. + logrus.Infof("Retrieving next entry from relay, previous: %v", prevEntry) + asyncResponse, err = r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Txn) + } + + return nil +} + +// PerformStoreTransaction implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, +) error { + logrus.Warnf("Storing transaction for %v", userID) + receipt, err := r.db.StoreTransaction(ctx, transaction) + if err != nil { + logrus.Errorf("db.StoreTransaction: %s", err.Error()) + return err + } + err = r.db.AssociateTransactionWithDestinations( + ctx, + map[gomatrixserverlib.UserID]struct{}{ + userID: {}, + }, + transaction.TransactionID, + receipt) + + return err +} + +// QueryTransactions implements api.RelayInternalAPI +func (r *RelayInternalAPI) QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, +) (api.QueryRelayTransactionsResponse, error) { + logrus.Infof("QueryTransactions for %s", userID.Raw()) + if previousEntry.EntryID > 0 { + logrus.Infof("Cleaning previous entry (%v) from db for %s", + previousEntry.EntryID, + userID.Raw(), + ) + prevReceipt := receipt.NewReceipt(previousEntry.EntryID) + err := r.db.CleanTransactions(ctx, userID, []*receipt.Receipt{&prevReceipt}) + if err != nil { + logrus.Errorf("db.CleanTransactions: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + } + + transaction, receipt, err := r.db.GetTransaction(ctx, userID) + if err != nil { + logrus.Errorf("db.GetTransaction: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + + response := api.QueryRelayTransactionsResponse{} + if transaction != nil && receipt != nil { + logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, userID.Raw()) + response.Transaction = *transaction + response.EntryID = receipt.GetNID() + response.EntriesQueued = true + } else { + logrus.Infof("No more entries in the queue for %s", userID.Raw()) + response.EntryID = 0 + response.EntriesQueued = false + } + + return response, nil +} + +func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) { + logrus.Warn("Processing transaction from relay server") + mu := internal.NewMutexByRoom() + t := internal.NewTxnReq( + r.rsAPI, + nil, + r.serverName, + r.keyRing, + mu, + r.producer, + r.presenceEnabledInbound, + txn.PDUs, + txn.EDUs, + txn.Origin, + txn.TransactionID, + txn.Destination) + + t.ProcessTransaction(context.TODO()) +} diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go new file mode 100644 index 000000000..fb71b7d0e --- /dev/null +++ b/relayapi/internal/perform_test.go @@ -0,0 +1,121 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + fedAPI.FederationClient + shouldFail bool + queryCount uint + queueDepth uint +} + +func (f *testFedClient) P2PGetTransactionFromRelay( + ctx context.Context, + u gomatrixserverlib.UserID, + prev gomatrixserverlib.RelayEntry, + relayServer gomatrixserverlib.ServerName, +) (res gomatrixserverlib.RespGetRelayTransaction, err error) { + f.queryCount++ + if f.shouldFail { + return res, fmt.Errorf("Error") + } + + res = gomatrixserverlib.RespGetRelayTransaction{ + Txn: gomatrixserverlib.Transaction{}, + EntryID: 0, + } + if f.queueDepth > 0 { + res.EntriesQueued = true + } else { + res.EntriesQueued = false + } + f.queueDepth -= 1 + + return +} + +func TestPerformRelayServerSync(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) +} + +func TestPerformRelayServerSyncFedError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{shouldFail: true} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.Error(t, err) +} + +func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{queueDepth: 2} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) + assert.Equal(t, uint(3), fedClient.queryCount) +} diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go new file mode 100644 index 000000000..f9f9d4ff9 --- /dev/null +++ b/relayapi/relayapi.go @@ -0,0 +1,74 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi + +import ( + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. +func AddPublicRoutes( + base *base.BaseDendrite, + keyRing gomatrixserverlib.JSONVerifier, + relayAPI api.RelayInternalAPI, +) { + fedCfg := &base.Cfg.FederationAPI + + relay, ok := relayAPI.(*internal.RelayInternalAPI) + if !ok { + panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " + + "RelayInternalAPI. This is a programming error.") + } + + routing.Setup( + base.PublicFederationAPIMux, + fedCfg, + relay, + keyRing, + ) +} + +func NewRelayInternalAPI( + base *base.BaseDendrite, + fedClient *gomatrixserverlib.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, +) api.RelayInternalAPI { + cfg := &base.Cfg.RelayAPI + + relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) + if err != nil { + logrus.WithError(err).Panic("failed to connect to relay db") + } + + return internal.NewRelayInternalAPI( + relayDB, + fedClient, + rsAPI, + keyRing, + producer, + base.Cfg.Global.Presence.EnableInbound, + base.Cfg.Global.ServerName, + ) +} diff --git a/relayapi/relayapi_test.go b/relayapi/relayapi_test.go new file mode 100644 index 000000000..dfa06811d --- /dev/null +++ b/relayapi/relayapi_test.go @@ -0,0 +1,154 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi_test + +import ( + "crypto/ed25519" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/relayapi" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func TestCreateNewRelayInternalAPI(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + assert.NotNil(t, relayAPI) + }) +} + +func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + if dbType == test.DBTypeSQLite { + base.Cfg.RelayAPI.Database.ConnectionString = "file:" + } else { + base.Cfg.RelayAPI.Database.ConnectionString = "test" + } + defer close() + + assert.Panics(t, func() { + relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + }) + }) +} + +func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + assert.Panics(t, func() { + relayapi.AddPublicRoutes(base, nil, nil) + }) + }) +} + +func createGetRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID) + content := gomatrixserverlib.RelayEntry{EntryID: 0} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +type sendRelayContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} + +func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnID string, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID) + content := sendRelayContent{} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID, "txnID": txnID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +func TestCreateRelayPublicRoutes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + assert.NotNil(t, relayAPI) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + relayapi.AddPublicRoutes(base, keyRing, relayAPI) + + testCases := []struct { + name string + req *http.Request + wantCode int + wantJoinedRooms []string + }{ + { + name: "relay_txn invalid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"), + wantCode: 400, + }, + { + name: "relay_txn valid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), + wantCode: 200, + }, + { + name: "send_relay invalid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"), + wantCode: 400, + }, + { + name: "send_relay valid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), + wantCode: 200, + }, + } + + for _, tc := range testCases { + w := httptest.NewRecorder() + base.PublicFederationAPIMux.ServeHTTP(w, tc.req) + if w.Code != tc.wantCode { + t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) + } + } + }) +} diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go new file mode 100644 index 000000000..1b11b0ecd --- /dev/null +++ b/relayapi/routing/relaytxn.go @@ -0,0 +1,74 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type RelayTransactionResponse struct { + Transaction gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id,omitempty"` + EntriesQueued bool `json:"entries_queued"` +} + +// GetTransactionFromRelay implements /_matrix/federation/v1/relay_txn/{userID} +// This endpoint can be extracted into a separate relay server service. +func GetTransactionFromRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + logrus.Infof("Handling relay_txn for %s", userID.Raw()) + + previousEntry := gomatrixserverlib.RelayEntry{} + if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("invalid json provided"), + } + } + if previousEntry.EntryID < 0 { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("Invalid entry id provided. Must be >= 0."), + } + } + logrus.Infof("Previous entry provided: %v", previousEntry.EntryID) + + response, err := relayAPI.QueryTransactions(httpReq.Context(), userID, previousEntry) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: RelayTransactionResponse{ + Transaction: response.Transaction, + EntryID: response.EntryID, + EntriesQueued: response.EntriesQueued, + }, + } +} diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go new file mode 100644 index 000000000..a47fdb198 --- /dev/null +++ b/relayapi/routing/relaytxn_test.go @@ -0,0 +1,220 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func createQuery( + userID gomatrixserverlib.UserID, + prevEntry gomatrixserverlib.RelayEntry, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/relay_txn/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("GET", userID.Domain(), "relay", path) + request.SetContent(prevEntry) + + return request +} + +func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetInvalidPrevEntryFails(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestGetReturnsSavedTransaction(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetReturnsMultipleSavedTransactions(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + transaction2 := createTransaction() + receipt2, err := db.StoreTransaction(context.Background(), transaction2) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction2.TransactionID, + receipt2) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction2, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go new file mode 100644 index 000000000..6df0cdc5f --- /dev/null +++ b/relayapi/routing/routing.go @@ -0,0 +1,123 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "fmt" + "net/http" + "time" + + "github.com/getsentry/sentry-go" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/httputil" + relayInternal "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// Setup registers HTTP handlers with the given ServeMux. +// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly +// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled +// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz] +// +// Due to Setup being used to call many other functions, a gocyclo nolint is +// applied: +// nolint: gocyclo +func Setup( + fedMux *mux.Router, + cfg *config.FederationAPI, + relayAPI *relayInternal.RelayInternalAPI, + keys gomatrixserverlib.JSONVerifier, +) { + v1fedmux := fedMux.PathPrefix("/v1").Subrouter() + + v1fedmux.Handle("/send_relay/{txnID}/{userID}", MakeRelayAPI( + "send_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return SendTransactionToRelay( + httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]), + *userID, + ) + }, + )).Methods(http.MethodPut, http.MethodOptions) + + v1fedmux.Handle("/relay_txn/{userID}", MakeRelayAPI( + "get_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return GetTransactionFromRelay(httpReq, request, relayAPI, *userID) + }, + )).Methods(http.MethodGet, http.MethodOptions) +} + +// MakeRelayAPI makes an http.Handler that checks matrix relay authentication. +func MakeRelayAPI( + metricsName string, serverName gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, + keyRing gomatrixserverlib.JSONVerifier, + f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + req, time.Now(), serverName, isLocalServerName, keyRing, + ) + if fedReq == nil { + return errResp + } + // add the user to Sentry, if enabled + hub := sentry.GetHubFromContext(req.Context()) + if hub != nil { + hub.Scope().SetTag("origin", string(fedReq.Origin())) + hub.Scope().SetTag("uri", fedReq.RequestURI()) + } + defer func() { + if r := recover(); r != nil { + if hub != nil { + hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) + } + // re-panic to return the 500 + panic(r) + } + }() + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + } + + jsonRes := f(req, fedReq, vars) + // do not log 4xx as errors as they are client fails, not server fails + if hub != nil && jsonRes.Code >= 500 { + hub.Scope().SetExtra("response", jsonRes) + hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) + } + return jsonRes + } + return httputil.MakeExternalAPI(metricsName, h) +} diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go new file mode 100644 index 000000000..a7027f293 --- /dev/null +++ b/relayapi/routing/sendrelay.go @@ -0,0 +1,77 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// SendTransactionToRelay implements PUT /_matrix/federation/v1/relay_txn/{txnID}/{userID} +// This endpoint can be extracted into a separate relay server service. +func SendTransactionToRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + txnID gomatrixserverlib.TransactionID, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + var txnEvents struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` + } + + if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { + logrus.Info("The request body could not be decoded into valid JSON." + err.Error()) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()), + } + } + + // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. + // https://matrix.org/docs/spec/server_server/latest#transactions + if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + } + } + + t := gomatrixserverlib.Transaction{} + t.PDUs = txnEvents.PDUs + t.EDUs = txnEvents.EDUs + t.Origin = fedReq.Origin() + t.TransactionID = txnID + t.Destination = userID.Domain() + + util.GetLogger(httpReq.Context()).Warnf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, fedReq.Origin(), len(t.PDUs), len(t.EDUs)) + + err := relayAPI.PerformStoreTransaction(httpReq.Context(), t, userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("could not store the transaction for forwarding"), + } + } + + return util.JSONResponse{Code: 200} +} diff --git a/relayapi/routing/sendrelay_test.go b/relayapi/routing/sendrelay_test.go new file mode 100644 index 000000000..d9ed75002 --- /dev/null +++ b/relayapi/routing/sendrelay_test.go @@ -0,0 +1,209 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func createTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + return txn +} + +func createFederationRequest( + userID gomatrixserverlib.UserID, + txnID gomatrixserverlib.TransactionID, + origin gomatrixserverlib.ServerName, + destination gomatrixserverlib.ServerName, + content interface{}, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path) + request.SetContent(content) + + return request +} + +func TestForwardEmptyReturnsOk(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.Equal(t, 200, response.Code) +} + +func TestForwardBadJSONReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field bool `json:"pdus"` + } + content := BadData{ + Field: false, + } + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyPDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []json.RawMessage `json:"pdus"` + } + content := BadData{ + Field: []json.RawMessage{}, + } + for i := 0; i < 51; i++ { + content.Field = append(content.Field, []byte{}) + } + assert.Greater(t, len(content.Field), 50) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyEDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []gomatrixserverlib.EDU `json:"edus"` + } + content := BadData{ + Field: []gomatrixserverlib.EDU{}, + } + for i := 0; i < 101; i++ { + content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}) + } + assert.Greater(t, len(content.Field), 100) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestUniqueTransactionStoredInDatabase(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay( + httpReq, &request, relayAPI, txn.TransactionID, *userID) + transaction, _, err := db.GetTransaction(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction") + + transactionCount, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction count") + + assert.Equal(t, 200, response.Code) + assert.Equal(t, int64(1), transactionCount) + assert.Equal(t, txn.TransactionID, transaction.TransactionID) +} diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go new file mode 100644 index 000000000..f5f9a06e5 --- /dev/null +++ b/relayapi/storage/interface.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database interface { + // Adds a new transaction to the queue json table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error) + + // Adds a new transaction_id: server_name mapping with associated json table nid to the queue + // entry table for each provided destination. + AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error + + // Removes every server_name: receipt pair provided from the queue entries table. + // Will then remove every entry for each receipt provided from the queue json table. + // If any of the entries don't exist in either table, nothing will happen for that entry and + // an error will not be generated. + CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error + + // Gets the oldest transaction for the provided server_name. + // If no transactions exist, returns nil and no error. + GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) + + // Gets the number of transactions being stored for the provided server_name. + // If the server doesn't exist in the database then 0 is returned with no error. + GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) +} diff --git a/relayapi/storage/postgres/relay_queue_json_table.go b/relayapi/storage/postgres/relay_queue_json_table.go new file mode 100644 index 000000000..74410fc88 --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_json_table.go @@ -0,0 +1,113 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid BIGSERIAL, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + + " RETURNING json_nid" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid = ANY($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + deleteJSONStmt *sql.Stmt + selectJSONStmt *sql.Stmt +} + +func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + {&s.deleteJSONStmt, deleteQueueJSONSQL}, + {&s.selectJSONStmt, selectQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (int64, error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + var lastid int64 + if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil { + return 0, err + } + return lastid, nil +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) + _, err := stmt.ExecContext(ctx, pq.Int64Array(nids)) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, s.selectJSONStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, err + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/postgres/relay_queue_table.go b/relayapi/storage/postgres/relay_queue_table.go new file mode 100644 index 000000000..e97cf8cc0 --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_table.go @@ -0,0 +1,156 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The destination server that we will send the event to. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + deleteQueueEntriesStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt +} + +func NewPostgresRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.deleteQueueEntriesStmt, deleteQueueEntriesSQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go new file mode 100644 index 000000000..1042beba7 --- /dev/null +++ b/relayapi/storage/postgres/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the relayapi +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + return nil, err + } + queue, err := NewPostgresRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewPostgresRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go new file mode 100644 index 000000000..0993707bf --- /dev/null +++ b/relayapi/storage/shared/storage.go @@ -0,0 +1,170 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + IsLocalServerName func(gomatrixserverlib.ServerName) bool + Cache caching.FederationCache + Writer sqlutil.Writer + RelayQueue tables.RelayQueue + RelayQueueJSON tables.RelayQueueJSON +} + +func (d *Database) StoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, +) (*receipt.Receipt, error) { + var err error + jsonTransaction, err := json.Marshal(transaction) + if err != nil { + return nil, fmt.Errorf("failed to marshal: %w", err) + } + + var nid int64 + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(jsonTransaction)) + return err + }) + if err != nil { + return nil, fmt.Errorf("d.insertQueueJSON: %w", err) + } + + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil +} + +func (d *Database) AssociateTransactionWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.UserID]struct{}, + transactionID gomatrixserverlib.TransactionID, + dbReceipt *receipt.Receipt, +) error { + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var lastErr error + for destination := range destinations { + destination := destination + err := d.RelayQueue.InsertQueueEntry( + ctx, + txn, + transactionID, + destination.Domain(), + dbReceipt.GetNID(), + ) + if err != nil { + lastErr = fmt.Errorf("d.insertQueueEntry: %w", err) + } + } + return lastErr + }) + + return err +} + +func (d *Database) CleanTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + receipts []*receipt.Receipt, +) error { + nids := make([]int64, len(receipts)) + for i, dbReceipt := range receipts { + nids[i] = dbReceipt.GetNID() + } + + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + deleteEntryErr := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids) + // TODO : If there are still queue entries for any of these nids for other destinations + // then we shouldn't delete the json entries. + // But this can't happen with the current api design. + // There will only ever be one server entry for each nid since each call to send_relay + // only accepts a single server name and inside there we create a new json entry. + // So for multiple destinations we would call send_relay multiple times and have multiple + // json entries of the same transaction. + // + // TLDR; this works as expected right now but can easily be optimised in the future. + deleteJSONErr := d.RelayQueueJSON.DeleteQueueJSON(ctx, txn, nids) + + if deleteEntryErr != nil { + return fmt.Errorf("d.deleteQueueEntries: %w", deleteEntryErr) + } + if deleteJSONErr != nil { + return fmt.Errorf("d.deleteQueueJSON: %w", deleteJSONErr) + } + return nil + }) + + return err +} + +func (d *Database) GetTransaction( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) { + entriesRequested := 1 + nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), entriesRequested) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err) + } + if len(nids) == 0 { + return nil, nil, nil + } + firstNID := nids[0] + + txns := map[int64][]byte{} + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids) + return err + }) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err) + } + + transaction := &gomatrixserverlib.Transaction{} + if _, ok := txns[firstNID]; !ok { + return nil, nil, fmt.Errorf("Failed retrieving json blob for transaction: %d", firstNID) + } + + err = json.Unmarshal(txns[firstNID], transaction) + if err != nil { + return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err) + } + + newReceipt := receipt.NewReceipt(firstNID) + return transaction, &newReceipt, nil +} + +func (d *Database) GetTransactionCount( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (int64, error) { + count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain()) + if err != nil { + return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err) + } + return count, nil +} diff --git a/relayapi/storage/sqlite3/relay_queue_json_table.go b/relayapi/storage/sqlite3/relay_queue_json_table.go new file mode 100644 index 000000000..502da3b00 --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_json_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid IN ($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic + //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (lastid int64, err error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("res.LastInsertId: %w", err) + } + return +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(nids)) + for k, v := range nids { + iNIDs[k] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(jsonNIDs)) + for k, v := range jsonNIDs { + iNIDs[k] = v + } + + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, selectStmt) + rows, err := stmt.QueryContext(ctx, iNIDs...) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/sqlite3/relay_queue_table.go b/relayapi/storage/sqlite3/relay_queue_table.go new file mode 100644 index 000000000..49c6b4de5 --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_table.go @@ -0,0 +1,168 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The domain part of the user ID the m.room.member event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt + // deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go new file mode 100644 index 000000000..3ed4ab046 --- /dev/null +++ b/relayapi/storage/sqlite3/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the federation sender +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + return nil, err + } + queue, err := NewSQLiteRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/storage.go b/relayapi/storage/storage.go new file mode 100644 index 000000000..16ecbcfb7 --- /dev/null +++ b/relayapi/storage/storage.go @@ -0,0 +1,46 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !wasm +// +build !wasm + +package storage + +import ( + "fmt" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (Database, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/relayapi/storage/tables/interface.go b/relayapi/storage/tables/interface.go new file mode 100644 index 000000000..9056a5678 --- /dev/null +++ b/relayapi/storage/tables/interface.go @@ -0,0 +1,66 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayQueue table contains a mapping of server name to transaction id and the corresponding nid. +// These are the transactions being stored for the given destination server. +// The nids correspond to entries in the RelayQueueJSON table. +type RelayQueue interface { + // Adds a new transaction_id: server_name mapping with associated json table nid to the table. + // Will ensure only one transaction id is present for each server_name: nid mapping. + // Adding duplicates will silently do nothing. + InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + + // Removes multiple entries from the table corresponding the the list of nids provided. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + + // Get a list of nids associated with the provided server name. + // Returns up to `limit` nids. The entries are returned oldest first. + // Will return an empty list if no matches were found. + SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + + // Get the number of entries in the table associated with the provided server name. + // If there are no matching rows, a count of 0 is returned with err set to nil. + SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +// RelayQueueJSON table contains a map of nid to the raw transaction json. +type RelayQueueJSON interface { + // Adds a new transaction to the table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + + // Removes multiple nids from the table. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + + // Get the transaction json corresponding to the provided nids. + // Will return a partial result containing any matching nid from the table. + // Will return an empty map if no matches were found. + // It is the caller's responsibility to deal with the results appropriately. + // return: map indexed by nid of each matching transaction json. + SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} diff --git a/relayapi/storage/tables/relay_queue_json_table_test.go b/relayapi/storage/tables/relay_queue_json_table_test.go new file mode 100644 index 000000000..efa3363e5 --- /dev/null +++ b/relayapi/storage/tables/relay_queue_json_table_test.go @@ -0,0 +1,173 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func mustCreateTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + + return txn +} + +type RelayQueueJSONDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueueJSON +} + +func mustCreateQueueJSONTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueJSONDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueueJSON + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueJSONTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueJSONDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + _, err = db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + var storedJSON map[int64][]byte + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 1, len(storedJSON)) + + var storedTx gomatrixserverlib.Transaction + json.Unmarshal(storedJSON[1], &storedTx) + + assert.Equal(t, transaction, storedTx) + }) +} + +func TestShouldDeleteTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + storedJSON := map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + storedJSON = map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 0, len(storedJSON)) + }) +} diff --git a/relayapi/storage/tables/relay_queue_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go new file mode 100644 index 000000000..99f9922c0 --- /dev/null +++ b/relayapi/storage/tables/relay_queue_table_test.go @@ -0,0 +1,229 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type RelayQueueDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueue +} + +func mustCreateQueueTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueue + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, nid, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + }) +} + +func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(2) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName = gomatrixserverlib.ServerName("domain") + oldestNID := int64(1) + err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 1) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + + retrievedNids, err = db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, nid, retrievedNids[1]) + assert.Equal(t, 2, len(retrievedNids)) + }) +} + +func TestShouldDeleteQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(0), count) + }) +} + +func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano())) + serverName2 := gomatrixserverlib.ServerName("domain2") + nid2 := int64(2) + transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano())) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + + count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + }) +} diff --git a/setup/base/base.go b/setup/base/base.go index ff38209fb..de8f81517 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -595,6 +595,12 @@ func (b *BaseDendrite) WaitForShutdown() { logrus.Warnf("failed to flush all Sentry events!") } } + if b.Fulltext != nil { + err := b.Fulltext.Close() + if err != nil { + logrus.Warnf("failed to close full text search!") + } + } logrus.Warnf("Dendrite is exiting now") } diff --git a/setup/config/config.go b/setup/config/config.go index 41d2b6674..2b38cd512 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -62,6 +62,7 @@ type Dendrite struct { RoomServer RoomServer `yaml:"room_server"` SyncAPI SyncAPI `yaml:"sync_api"` UserAPI UserAPI `yaml:"user_api"` + RelayAPI RelayAPI `yaml:"relay_api"` MSCs MSCs `yaml:"mscs"` @@ -349,6 +350,7 @@ func (c *Dendrite) Defaults(opts DefaultOpts) { c.SyncAPI.Defaults(opts) c.UserAPI.Defaults(opts) c.AppServiceAPI.Defaults(opts) + c.RelayAPI.Defaults(opts) c.MSCs.Defaults(opts) c.Wiring() } @@ -361,7 +363,7 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { &c.Global, &c.ClientAPI, &c.FederationAPI, &c.KeyServer, &c.MediaAPI, &c.RoomServer, &c.SyncAPI, &c.UserAPI, - &c.AppServiceAPI, &c.MSCs, + &c.AppServiceAPI, &c.RelayAPI, &c.MSCs, } { c.Verify(configErrs, isMonolith) } @@ -377,6 +379,7 @@ func (c *Dendrite) Wiring() { c.SyncAPI.Matrix = &c.Global c.UserAPI.Matrix = &c.Global c.AppServiceAPI.Matrix = &c.Global + c.RelayAPI.Matrix = &c.Global c.MSCs.Matrix = &c.Global c.ClientAPI.Derived = &c.Derived diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 0f853865f..6c198018d 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -18,6 +18,12 @@ type FederationAPI struct { // The default value is 16 if not specified, which is circa 18 hours. FederationMaxRetries uint32 `yaml:"send_max_retries"` + // P2P Feature: How many consecutive failures that we should tolerate when + // sending federation requests to a specific server until we should assume they + // are offline. If we assume they are offline then we will attempt to send + // messages to their relay server if we know of one that is appropriate. + P2PFederationRetriesUntilAssumedOffline uint32 `yaml:"p2p_retries_until_assumed_offline"` + // FederationDisableTLSValidation disables the validation of X.509 TLS certs // on remote federation endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` @@ -43,6 +49,7 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) { c.Database.Defaults(10) } c.FederationMaxRetries = 16 + c.P2PFederationRetriesUntilAssumedOffline = 2 c.DisableTLSValidation = false c.DisableHTTPKeepalives = false if opts.Generate { diff --git a/setup/config/config_relayapi.go b/setup/config/config_relayapi.go new file mode 100644 index 000000000..5a6b093d4 --- /dev/null +++ b/setup/config/config_relayapi.go @@ -0,0 +1,52 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +type RelayAPI struct { + Matrix *Global `yaml:"-"` + + InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` + ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` + + // The database stores information used by the relay queue to + // forward transactions to remote servers. + Database DatabaseOptions `yaml:"database,omitempty"` +} + +func (c *RelayAPI) Defaults(opts DefaultOpts) { + if !opts.Monolithic { + c.InternalAPI.Listen = "http://localhost:7775" + c.InternalAPI.Connect = "http://localhost:7775" + c.ExternalAPI.Listen = "http://[::]:8075" + c.Database.Defaults(10) + } + if opts.Generate { + if !opts.Monolithic { + c.Database.ConnectionString = "file:relayapi.db" + } + } +} + +func (c *RelayAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { + if isMonolith { // polylith required configs below + return + } + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString)) + } + checkURL(configErrs, "relay_api.external_api.listen", string(c.ExternalAPI.Listen)) + checkURL(configErrs, "relay_api.internal_api.listen", string(c.InternalAPI.Listen)) + checkURL(configErrs, "relay_api.internal_api.connect", string(c.InternalAPI.Connect)) +} diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 3408bf46d..ffbf4c3c5 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -20,11 +20,12 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) func TestLoadConfigRelative(t *testing.T) { - _, err := loadConfig("/my/config/dir", []byte(testConfig), + cfg, err := loadConfig("/my/config/dir", []byte(testConfig), mockReadFile{ "/my/config/dir/matrix_key.pem": testKey, "/my/config/dir/tls_cert.pem": testCert, @@ -34,6 +35,15 @@ func TestLoadConfigRelative(t *testing.T) { if err != nil { t.Error("failed to load config:", err) } + + configErrors := &ConfigErrors{} + cfg.Verify(configErrors, false) + if len(*configErrors) > 0 { + for _, err := range *configErrors { + logrus.Errorf("Configuration error: %s", err) + } + t.Error("configuration verification failed") + } } const testConfig = ` @@ -68,6 +78,8 @@ global: display_name: "Server alerts" avatar: "" room_name: "Server Alerts" + jetstream: + addresses: ["test"] app_service_api: internal_api: listen: http://localhost:7777 @@ -84,7 +96,7 @@ client_api: connect: http://localhost:7771 external_api: listen: http://[::]:8071 - registration_disabled: false + registration_disabled: true registration_shared_secret: "" enable_registration_captcha: false recaptcha_public_key: "" @@ -112,6 +124,8 @@ federation_api: connect: http://localhost:7772 external_api: listen: http://[::]:8072 + database: + connection_string: file:federationapi.db key_server: internal_api: listen: http://localhost:7779 @@ -194,6 +208,17 @@ user_api: max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 +relay_api: + internal_api: + listen: http://localhost:7775 + connect: http://localhost:7775 + external_api: + listen: http://[::]:8075 + database: + connection_string: file:relayapi.db +mscs: + database: + connection_string: file:mscs.db tracing: enabled: false jaeger: diff --git a/setup/monolith.go b/setup/monolith.go index 41a897024..5bbe4019e 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -23,6 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal/transactions" keyAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/mediaapi" + "github.com/matrix-org/dendrite/relayapi" + relayAPI "github.com/matrix-org/dendrite/relayapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" @@ -44,6 +46,7 @@ type Monolith struct { RoomserverAPI roomserverAPI.RoomserverInternalAPI UserAPI userapi.UserInternalAPI KeyAPI keyAPI.KeyInternalAPI + RelayAPI relayAPI.RelayInternalAPI // Optional ExtPublicRoomsProvider api.ExtraPublicRoomsProvider @@ -71,4 +74,8 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { syncapi.AddPublicRoutes( base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, ) + + if m.RelayAPI != nil { + relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI) + } } diff --git a/test/db.go b/test/db.go index 54ded6adb..d2f405d49 100644 --- a/test/db.go +++ b/test/db.go @@ -101,7 +101,6 @@ func currentUser() string { // Returns the connection string to use and a close function which must be called when the test finishes. // Calling this function twice will return the same database, which will have data from previous tests // unless close() is called. -// TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { // this will be made in the t.TempDir, which is unique per test diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go new file mode 100644 index 000000000..cc9e1e8fd --- /dev/null +++ b/test/memory_federation_db.go @@ -0,0 +1,488 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "encoding/json" + "errors" + "sync" + "time" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/federationapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var nidMutex sync.Mutex +var nid = int64(0) + +type InMemoryFederationDatabase struct { + dbMutex sync.Mutex + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + assumedOffline map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*receipt.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName +} + +func NewInMemoryFederationDatabase() *InMemoryFederationDatabase { + return &InMemoryFederationDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*receipt.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName), + } +} + +func (d *InMemoryFederationDatabase) StoreJSON( + ctx context.Context, + js string, +) (*receipt.Receipt, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingPDUs[&newReceipt] = &event + return &newReceipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingEDUs[&newReceipt] = &edu + return &newReceipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *InMemoryFederationDatabase) GetPendingPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + pduCount := 0 + pdus = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingPDUs[dbReceipt]; ok { + pdus[dbReceipt] = event + pduCount++ + if pduCount == limit { + break + } + } + } + } + return pdus, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + eduCount := 0 + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingEDUs[dbReceipt]; ok { + edus[dbReceipt] = event + eduCount++ + if eduCount == limit { + break + } + } + } + } + return edus, nil +} + +func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingPDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedPDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, + eduType string, + expireEDUTypes map[string]time.Duration, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingEDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedEDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) CleanPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(pdus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) CleanEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(edus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersFromBlacklist() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +func (d *InMemoryFederationDatabase) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.assumedOffline, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine( + ctx context.Context, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + assumedOffline := false + if _, ok := d.assumedOffline[serverName]; ok { + assumedOffline = true + } + + return assumedOffline, nil +} + +func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + knownRelayServers := []gomatrixserverlib.ServerName{} + if relayServers, ok := d.relayServers[serverName]; ok { + knownRelayServers = relayServers + } + + return knownRelayServers, nil +} + +func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + alreadyKnown := false + for _, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + alreadyKnown = true + } + } + if !alreadyKnown { + d.relayServers[serverName] = append(d.relayServers[serverName], relayServer) + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + +func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) FetcherName() string { + return "" +} + +func (d *InMemoryFederationDatabase) StoreKeys(ctx context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error { + return nil +} + +func (d *InMemoryFederationDatabase) UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { + return nil +} + +func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) DeleteExpiredEDUs(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) PurgeRoom(ctx context.Context, roomID string) error { + return nil +} diff --git a/test/memory_relay_db.go b/test/memory_relay_db.go new file mode 100644 index 000000000..db93919df --- /dev/null +++ b/test/memory_relay_db.go @@ -0,0 +1,140 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + + "github.com/matrix-org/gomatrixserverlib" +) + +type InMemoryRelayDatabase struct { + nid int64 + nidMutex sync.Mutex + transactions map[int64]json.RawMessage + associations map[gomatrixserverlib.ServerName][]int64 +} + +func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { + return &InMemoryRelayDatabase{ + nid: 1, + nidMutex: sync.Mutex{}, + transactions: make(map[int64]json.RawMessage), + associations: make(map[gomatrixserverlib.ServerName][]int64), + } +} + +func (d *InMemoryRelayDatabase) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + if _, ok := d.associations[serverName]; !ok { + d.associations[serverName] = []int64{} + } + d.associations[serverName] = append(d.associations[serverName], nid) + return nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + for _, nid := range jsonNIDs { + for index, associatedNID := range d.associations[serverName] { + if associatedNID == nid { + d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) + } + } + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + results := []int64{} + resultCount := limit + if limit > len(d.associations[serverName]) { + resultCount = len(d.associations[serverName]) + } + if resultCount > 0 { + for i := 0; i < resultCount; i++ { + results = append(results, d.associations[serverName][i]) + } + } + + return results, nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return int64(len(d.associations[serverName])), nil +} + +func (d *InMemoryRelayDatabase) InsertQueueJSON( + ctx context.Context, + txn *sql.Tx, + json string, +) (int64, error) { + d.nidMutex.Lock() + defer d.nidMutex.Unlock() + + nid := d.nid + d.transactions[nid] = []byte(json) + d.nid++ + + return nid, nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueJSON( + ctx context.Context, + txn *sql.Tx, + nids []int64, +) error { + for _, nid := range nids { + delete(d.transactions, nid) + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueJSON( + ctx context.Context, + txn *sql.Tx, + jsonNIDs []int64, +) (map[int64][]byte, error) { + result := make(map[int64][]byte) + for _, nid := range jsonNIDs { + if transaction, ok := d.transactions[nid]; ok { + result[nid] = transaction + } + } + + return result, nil +} diff --git a/test/testrig/base.go b/test/testrig/base.go index 9773da223..dfc0d8aaf 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -67,9 +67,10 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f case test.DBTypeSQLite: cfg.Defaults(config.DefaultOpts{ Generate: true, - Monolithic: false, // because we need a database per component + Monolithic: true, }) cfg.Global.ServerName = "test" + // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) @@ -83,6 +84,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f cfg.RoomServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "roomserver.db")) cfg.SyncAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "syncapi.db")) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db")) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "relayapi.db")) base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) return base, func() { From ace44458b25768099f7b86663f2bb45ddf0d39c9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 26 Jan 2023 08:25:39 +0100 Subject: [PATCH 101/124] Bump commonmarker from 0.23.6 to 0.23.7 in /docs (#2952) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [commonmarker](https://github.com/gjtorikian/commonmarker) from 0.23.6 to 0.23.7.
Release notes

Sourced from commonmarker's releases.

v0.23.7

What's Changed

Full Changelog: https://github.com/gjtorikian/commonmarker/compare/v0.23.6...v0.23.7

v0.23.7.pre1

What's Changed

Full Changelog: https://github.com/gjtorikian/commonmarker/compare/v0.23.6...v0.23.7.pre1

Changelog

Sourced from commonmarker's changelog.

Changelog

v1.0.0.pre6 (2023-01-09)

Full Changelog

Closed issues:

  • Cargo.lock prevents Ruby 3.2.0 from installing commonmarker v1.0.0.pre4 #211

Merged pull requests:

  • always use rb_sys (don't use Ruby's emerging cargo tooling where available) #213 (kivikakk)

v1.0.0.pre5 (2023-01-08)

Full Changelog

Merged pull requests:

v1.0.0.pre4 (2022-12-28)

Full Changelog

Closed issues:

  • Will the cmark-gfm branch continue to be maintained for awhile? #207

Merged pull requests:

v1.0.0.pre3 (2022-11-30)

Full Changelog

Closed issues:

  • Code block incorrectly parsed in commonmarker 1.0.0.pre #202

Merged pull requests:

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=commonmarker&package-manager=bundler&previous-version=0.23.6&new-version=0.23.7)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/Gemfile.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 509a8cbcf..5d79365f8 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -14,7 +14,7 @@ GEM execjs coffee-script-source (1.11.1) colorator (1.1.0) - commonmarker (0.23.6) + commonmarker (0.23.7) concurrent-ruby (1.1.10) dnsruby (1.61.9) simpleidn (~> 0.1) From 80738cc2a05941b91ad715bbf6831b636dbab08d Mon Sep 17 00:00:00 2001 From: Lukas <32310269+LukasLJL@users.noreply.github.com> Date: Thu, 26 Jan 2023 16:25:17 +0100 Subject: [PATCH 102/124] Added Landing Page (#2885) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I have added/copied a landing page like Synpase does. Recently I have installed Dendrite and was wondering why it´s not working. After some troubleshooting I figured out there is no landing page like synpase has, so the Server was running just fine. Hopefuly this PR can fix this problem and may help other users who run into this issue. I have not written any unit tests, because it´s just a simple landing page with a redirect to a static site. ### Pull Request Checklist * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Lukas Huida` Co-authored-by: Till Faelligen <2353100+S7evinK@users.noreply.github.com> --- internal/httputil/paths.go | 1 + setup/base/base.go | 28 +++++++++++++++ setup/base/base_test.go | 57 ++++++++++++++++++++++++++++++ setup/base/static/index.gotmpl | 63 ++++++++++++++++++++++++++++++++++ 4 files changed, 149 insertions(+) create mode 100644 setup/base/base_test.go create mode 100644 setup/base/static/index.gotmpl diff --git a/internal/httputil/paths.go b/internal/httputil/paths.go index 12cf59eb4..62eff0415 100644 --- a/internal/httputil/paths.go +++ b/internal/httputil/paths.go @@ -19,6 +19,7 @@ const ( PublicFederationPathPrefix = "/_matrix/federation/" PublicKeyPathPrefix = "/_matrix/key/" PublicMediaPathPrefix = "/_matrix/media/" + PublicStaticPath = "/_matrix/static/" PublicWellKnownPrefix = "/.well-known/matrix/" InternalPathPrefix = "/api/" DendriteAdminPathPrefix = "/_dendrite/" diff --git a/setup/base/base.go b/setup/base/base.go index de8f81517..6ea68119d 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -15,11 +15,14 @@ package base import ( + "bytes" "context" "crypto/tls" "database/sql" + "embed" "encoding/json" "fmt" + "html/template" "io" "net" "net/http" @@ -65,6 +68,9 @@ import ( userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" ) +//go:embed static/*.gotmpl +var staticContent embed.FS + // BaseDendrite is a base for creating new instances of dendrite. It parses // command line flags and config, and exposes methods for creating various // resources. All errors are handled by logging then exiting, so all methods @@ -79,6 +85,7 @@ type BaseDendrite struct { PublicKeyAPIMux *mux.Router PublicMediaAPIMux *mux.Router PublicWellKnownAPIMux *mux.Router + PublicStaticMux *mux.Router InternalAPIMux *mux.Router DendriteAdminMux *mux.Router SynapseAdminMux *mux.Router @@ -250,6 +257,7 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base PublicKeyAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicKeyPathPrefix).Subrouter().UseEncodedPath(), PublicMediaAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicMediaPathPrefix).Subrouter().UseEncodedPath(), PublicWellKnownAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicWellKnownPrefix).Subrouter().UseEncodedPath(), + PublicStaticMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicStaticPath).Subrouter().UseEncodedPath(), InternalAPIMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.InternalPathPrefix).Subrouter().UseEncodedPath(), DendriteAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.DendriteAdminPathPrefix).Subrouter().UseEncodedPath(), SynapseAdminMux: mux.NewRouter().SkipClean(true).PathPrefix(httputil.SynapseAdminPathPrefix).Subrouter().UseEncodedPath(), @@ -405,6 +413,7 @@ func (b *BaseDendrite) configureHTTPErrors() { for _, router := range []*mux.Router{ b.PublicMediaAPIMux, b.DendriteAdminMux, b.SynapseAdminMux, b.PublicWellKnownAPIMux, + b.PublicStaticMux, } { router.NotFoundHandler = notFoundCORSHandler router.MethodNotAllowedHandler = notAllowedCORSHandler @@ -478,6 +487,11 @@ func (b *BaseDendrite) SetupAndServeHTTP( b.configureHTTPErrors() + //Redirect for Landing Page + externalRouter.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, httputil.PublicStaticPath, http.StatusFound) + }) + internalRouter.PathPrefix(httputil.InternalPathPrefix).Handler(b.InternalAPIMux) if b.Cfg.Global.Metrics.Enabled { internalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth)) @@ -485,6 +499,19 @@ func (b *BaseDendrite) SetupAndServeHTTP( b.ConfigureAdminEndpoints() + // Parse and execute the landing page template + tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) + landingPage := &bytes.Buffer{} + if err := tmpl.ExecuteTemplate(landingPage, "index.gotmpl", map[string]string{ + "Version": internal.VersionString(), + }); err != nil { + logrus.WithError(err).Fatal("failed to execute landing page template") + } + + b.PublicStaticMux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { + _, _ = w.Write(landingPage.Bytes()) + }) + var clientHandler http.Handler clientHandler = b.PublicClientAPIMux if b.Cfg.Global.Sentry.Enabled { @@ -510,6 +537,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( externalRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(b.SynapseAdminMux) externalRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(b.PublicMediaAPIMux) externalRouter.PathPrefix(httputil.PublicWellKnownPrefix).Handler(b.PublicWellKnownAPIMux) + externalRouter.PathPrefix(httputil.PublicStaticPath).Handler(b.PublicStaticMux) b.startupLock.Unlock() if internalAddr != NoListener && internalAddr != externalAddr { diff --git a/setup/base/base_test.go b/setup/base/base_test.go new file mode 100644 index 000000000..61cb530a9 --- /dev/null +++ b/setup/base/base_test.go @@ -0,0 +1,57 @@ +package base_test + +import ( + "bytes" + "embed" + "html/template" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/stretchr/testify/assert" +) + +//go:embed static/*.gotmpl +var staticContent embed.FS + +func TestLandingPage(t *testing.T) { + // generate the expected result + tmpl := template.Must(template.ParseFS(staticContent, "static/*.gotmpl")) + expectedRes := &bytes.Buffer{} + err := tmpl.ExecuteTemplate(expectedRes, "index.gotmpl", map[string]string{ + "Version": internal.VersionString(), + }) + assert.NoError(t, err) + + b, _, _ := testrig.Base(nil) + defer b.Close() + + // hack: create a server and close it immediately, just to get a random port assigned + s := httptest.NewServer(nil) + s.Close() + + // start base with the listener and wait for it to be started + go b.SetupAndServeHTTP("", config.HTTPAddress(s.URL), nil, nil) + time.Sleep(time.Millisecond * 10) + + // When hitting /, we should be redirected to /_matrix/static, which should contain the landing page + req, err := http.NewRequest(http.MethodGet, s.URL, nil) + assert.NoError(t, err) + + // do the request + resp, err := s.Client().Do(req) + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // read the response + buf := &bytes.Buffer{} + _, err = buf.ReadFrom(resp.Body) + assert.NoError(t, err) + + // Using .String() for user friendly output + assert.Equal(t, expectedRes.String(), buf.String(), "response mismatch") +} diff --git a/setup/base/static/index.gotmpl b/setup/base/static/index.gotmpl new file mode 100644 index 000000000..b3c5576eb --- /dev/null +++ b/setup/base/static/index.gotmpl @@ -0,0 +1,63 @@ + + + + Dendrite is running + + + + +

It works! Dendrite {{ .Version }} is running

+

Your Dendrite server is listening on this port and is ready for messages.

+

To use this server you'll need a Matrix client. +

+

Welcome to the Matrix universe :)

+
+

+ + + matrix.org + + +

+ + From 24a865aeb76dbff1a8a95d549e19070ff102e441 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 26 Jan 2023 10:15:42 -0700 Subject: [PATCH 103/124] Move relay arch into relayapi and add docs for new endpoints --- .../ARCHITECTURE.md | 77 ++++++++++++++++++- 1 file changed, 76 insertions(+), 1 deletion(-) rename {cmd/dendrite-demo-pinecone => relayapi}/ARCHITECTURE.md (57%) diff --git a/cmd/dendrite-demo-pinecone/ARCHITECTURE.md b/relayapi/ARCHITECTURE.md similarity index 57% rename from cmd/dendrite-demo-pinecone/ARCHITECTURE.md rename to relayapi/ARCHITECTURE.md index 1b0941053..f4cf529b4 100644 --- a/cmd/dendrite-demo-pinecone/ARCHITECTURE.md +++ b/relayapi/ARCHITECTURE.md @@ -10,7 +10,7 @@ Transactions include a list of PDUs, (which contain, among other things, lists o There is no additional information sent along with the transaction other than what is typically added to them during Matrix federation today. In the future this will probably need to change in order to handle more complex room state resolution during p2p usage. -### Relay Server Architecture +### Design ``` 0 +--------------------+ @@ -57,3 +57,78 @@ But in order to achieve this, you are either relying on p2p presence broadcasts #### Recipient-Side Relay Servers If we have agreed to some static relay server before going off and doing other things, or if we are talking about more global p2p federation, then having a recipient designated relay server can cut down on redundant traffic since it will sit there idle until the recipient pulls events from it. + +### API + +Relay servers make use of 2 new matrix federation endpoints. +These are: +- PUT /_matrix/federation/v1/send_relay/{txnID}/{userID} +- GET /_matrix/federation/v1/relay_txn/{userID} + +#### Send_Relay + +The `send_relay` endpoint is used to send events to a relay server that are destined for some other node. Servers can send events to this endpoint if they wish for the relay server to store & forward events for them when they go offline. + +##### Request + +###### Request Parameters + +| Name | Type | Description | +|--------|--------|-----------------------------------------------------| +| txnID | string | **Required:** The transaction ID. | +| userID | string | **Required:** The destination for this transaciton. | + +###### Request Body + +| Name | Type | Description | +|--------|--------|----------------------------------------| +| pdus | [PDU] | **Required:** List of pdus. Max 50. | +| edus | [EDU] | List of edus. May be omitted. Max 100. | + +##### Responses + +| Code | Reason | +|--------|--------------------------------------------------| +| 200 | Successfully stored transaction for forwarding. | +| 400 | Invalid userID. | +| 400 | Invalid request body. | +| 400 | Too many pdus or edus. | +| 500 | Server failed processing transaction. | + +#### Relay_Txn + +The `relay_txn` endpoint is used to get events from a relay server that are destined for you. Servers can send events to this endpoint if they wish for the relay server to store & forward events for them when they go offline. + +##### Request + +**This needs to be changed to prevent nodes from obtaining transactions not destined for them. Possibly by adding a signature field to the request.** + +###### Request Parameters + +| Name | Type | Description | +|--------|--------|----------------------------------------------------------------| +| userID | string | **Required:** The user ID that events are being requested for. | + +###### Request Body + +| Name | Type | Description | +|----------|--------|----------------------------------------| +| entry_id | int64 | **Required:** The id of the previous transaction received from the relay. Provided in the previous response to this endpoint. | + +##### Responses + +| Code | Reason | +|--------|--------------------------------------------------| +| 200 | Successfully stored transaction for forwarding. | +| 400 | Invalid userID. | +| 400 | Invalid request body. | +| 400 | Invalid previous entry. Must be >= 0 | +| 500 | Server failed processing transaction. | + +###### 200 Response Body + +| Name | Type | Description | +|----------------|--------------|--------------------------------------------------------------------------| +| transaction | Transaction | **Required:** A matrix transaction. | +| entry_id | int64 | An ID associated with this transaction. | +| entries_queued | bool | **Required:** Whether or not there are more events stored for this user. | From 2debabf0f09bb6e55063bbaa00dfb77090789abc Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 26 Jan 2023 10:58:44 -0700 Subject: [PATCH 104/124] Bump bleve to v2.3.6 --- go.mod | 28 ++++++++-------- go.sum | 100 ++++++++++++++++----------------------------------------- 2 files changed, 42 insertions(+), 86 deletions(-) diff --git a/go.mod b/go.mod index 871e94eb3..8d5eafcef 100644 --- a/go.mod +++ b/go.mod @@ -6,7 +6,7 @@ require ( github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/MFAshby/stdemuxerhook v1.0.0 github.com/Masterminds/semver/v3 v3.1.1 - github.com/blevesearch/bleve/v2 v2.3.4 + github.com/blevesearch/bleve/v2 v2.3.6 github.com/codeclysm/extract v2.2.0+incompatible github.com/dgraph-io/ristretto v0.1.1 github.com/docker/docker v20.10.19+incompatible @@ -58,24 +58,24 @@ require ( require ( github.com/HdrHistogram/hdrhistogram-go v1.1.2 // indirect github.com/Microsoft/go-winio v0.5.2 // indirect - github.com/RoaringBitmap/roaring v1.2.1 // indirect + github.com/RoaringBitmap/roaring v1.2.3 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/bits-and-blooms/bitset v1.3.3 // indirect - github.com/blevesearch/bleve_index_api v1.0.3 // indirect - github.com/blevesearch/geo v0.1.14 // indirect + github.com/bits-and-blooms/bitset v1.5.0 // indirect + github.com/blevesearch/bleve_index_api v1.0.5 // indirect + github.com/blevesearch/geo v0.1.17 // indirect github.com/blevesearch/go-porterstemmer v1.0.3 // indirect github.com/blevesearch/gtreap v0.1.1 // indirect github.com/blevesearch/mmap-go v1.0.4 // indirect - github.com/blevesearch/scorch_segment_api/v2 v2.1.2 // indirect - github.com/blevesearch/segment v0.9.0 // indirect + github.com/blevesearch/scorch_segment_api/v2 v2.1.4 // indirect + github.com/blevesearch/segment v0.9.1 // indirect github.com/blevesearch/snowballstem v0.9.0 // indirect - github.com/blevesearch/upsidedown_store_api v1.0.1 // indirect - github.com/blevesearch/vellum v1.0.8 // indirect - github.com/blevesearch/zapx/v11 v11.3.5 // indirect - github.com/blevesearch/zapx/v12 v12.3.5 // indirect - github.com/blevesearch/zapx/v13 v13.3.5 // indirect - github.com/blevesearch/zapx/v14 v14.3.5 // indirect - github.com/blevesearch/zapx/v15 v15.3.5 // indirect + github.com/blevesearch/upsidedown_store_api v1.0.2 // indirect + github.com/blevesearch/vellum v1.0.9 // indirect + github.com/blevesearch/zapx/v11 v11.3.7 // indirect + github.com/blevesearch/zapx/v12 v12.3.7 // indirect + github.com/blevesearch/zapx/v13 v13.3.7 // indirect + github.com/blevesearch/zapx/v14 v14.3.7 // indirect + github.com/blevesearch/zapx/v15 v15.3.8 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/docker/distribution v2.8.1+incompatible // indirect diff --git a/go.sum b/go.sum index 1ca3e8a80..d57960561 100644 --- a/go.sum +++ b/go.sum @@ -50,9 +50,8 @@ github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0 github.com/Microsoft/go-winio v0.5.2 h1:a9IhgEQBCUEk6QCdml9CiJGhAws+YwffDHEMp1VMrpA= github.com/Microsoft/go-winio v0.5.2/go.mod h1:WpS1mjBmmwHBEWmogvA2mj8546UReBk4v8QkMxJ6pZY= github.com/RoaringBitmap/roaring v0.4.7/go.mod h1:8khRDP4HmeXns4xIj9oGrKSz7XTQiJx2zgh7AcNke4w= -github.com/RoaringBitmap/roaring v0.9.4/go.mod h1:icnadbWcNyfEHlYdr+tDlOTih1Bf/h+rzPpv4sbomAA= -github.com/RoaringBitmap/roaring v1.2.1 h1:58/LJlg/81wfEHd5L9qsHduznOIhyv4qb1yWcSvVq9A= -github.com/RoaringBitmap/roaring v1.2.1/go.mod h1:icnadbWcNyfEHlYdr+tDlOTih1Bf/h+rzPpv4sbomAA= +github.com/RoaringBitmap/roaring v1.2.3 h1:yqreLINqIrX22ErkKI0vY47/ivtJr6n+kMhVOVmhWBY= +github.com/RoaringBitmap/roaring v1.2.3/go.mod h1:plvDsJQpxOC5bw8LRteu/MLWHsHez/3y6cubLI4/1yE= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= @@ -70,51 +69,45 @@ github.com/anacrolix/missinggo v1.2.1 h1:0IE3TqX5y5D0IxeMwTyIgqdDew4QrzcXaaEnJQy github.com/anacrolix/missinggo v1.2.1/go.mod h1:J5cMhif8jPmFoC3+Uvob3OXXNIhOUikzMt+uUjeM21Y= github.com/anacrolix/missinggo/perf v1.0.0/go.mod h1:ljAFWkBuzkO12MQclXzZrosP5urunoLS0Cbvb4V0uMQ= github.com/anacrolix/tagflag v0.0.0-20180109131632-2146c8d41bf0/go.mod h1:1m2U/K6ZT+JZG0+bdMK6qauP49QT4wE5pmhJXOKKCHw= -github.com/armon/consul-api v0.0.0-20180202201655-eb2c6b5be1b6/go.mod h1:grANhF5doyWs3UAsr3K4I6qtAmlQcZDesFNEHPZAzj8= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bits-and-blooms/bitset v1.2.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= -github.com/bits-and-blooms/bitset v1.3.3 h1:R1XWiopGiXf66xygsiLpzLo67xEYvMkHw3w+rCOSAwg= -github.com/bits-and-blooms/bitset v1.3.3/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= -github.com/blevesearch/bleve/v2 v2.3.4 h1:SSb7/cwGzo85LWX1jchIsXM8ZiNNMX3shT5lROM63ew= -github.com/blevesearch/bleve/v2 v2.3.4/go.mod h1:Ot0zYum8XQRfPcwhae8bZmNyYubynsoMjVvl1jPqL30= -github.com/blevesearch/bleve_index_api v1.0.3 h1:DDSWaPXOZZJ2BB73ZTWjKxydAugjwywcqU+91AAqcAg= -github.com/blevesearch/bleve_index_api v1.0.3/go.mod h1:fiwKS0xLEm+gBRgv5mumf0dhgFr2mDgZah1pqv1c1M4= -github.com/blevesearch/geo v0.1.13/go.mod h1:cRIvqCdk3cgMhGeHNNe6yPzb+w56otxbfo1FBJfR2Pc= -github.com/blevesearch/geo v0.1.14 h1:TTDpJN6l9ck/cUYbXSn4aCElNls0Whe44rcQKsB7EfU= -github.com/blevesearch/geo v0.1.14/go.mod h1:cRIvqCdk3cgMhGeHNNe6yPzb+w56otxbfo1FBJfR2Pc= -github.com/blevesearch/go-metrics v0.0.0-20201227073835-cf1acfcdf475/go.mod h1:9eJDeqxJ3E7WnLebQUlPD7ZjSce7AnDb9vjGmMCbD0A= +github.com/bits-and-blooms/bitset v1.5.0 h1:NpE8frKRLGHIcEzkR+gZhiioW1+WbYV6fKwD6ZIpQT8= +github.com/bits-and-blooms/bitset v1.5.0/go.mod h1:gIdJ4wp64HaoK2YrL1Q5/N7Y16edYb8uY+O0FJTyyDA= +github.com/blevesearch/bleve/v2 v2.3.6 h1:NlntUHcV5CSWIhpugx4d/BRMGCiaoI8ZZXrXlahzNq4= +github.com/blevesearch/bleve/v2 v2.3.6/go.mod h1:JM2legf1cKVkdV8Ehu7msKIOKC0McSw0Q16Fmv9vsW4= +github.com/blevesearch/bleve_index_api v1.0.5 h1:Lc986kpC4Z0/n1g3gg8ul7H+lxgOQPcXb9SxvQGu+tw= +github.com/blevesearch/bleve_index_api v1.0.5/go.mod h1:YXMDwaXFFXwncRS8UobWs7nvo0DmusriM1nztTlj1ms= +github.com/blevesearch/geo v0.1.17 h1:AguzI6/5mHXapzB0gE9IKWo+wWPHZmXZoscHcjFgAFA= +github.com/blevesearch/geo v0.1.17/go.mod h1:uRMGWG0HJYfWfFJpK3zTdnnr1K+ksZTuWKhXeSokfnM= github.com/blevesearch/go-porterstemmer v1.0.3 h1:GtmsqID0aZdCSNiY8SkuPJ12pD4jI+DdXTAn4YRcHCo= github.com/blevesearch/go-porterstemmer v1.0.3/go.mod h1:angGc5Ht+k2xhJdZi511LtmxuEf0OVpvUUNrwmM1P7M= -github.com/blevesearch/goleveldb v1.0.1/go.mod h1:WrU8ltZbIp0wAoig/MHbrPCXSOLpe79nz5lv5nqfYrQ= github.com/blevesearch/gtreap v0.1.1 h1:2JWigFrzDMR+42WGIN/V2p0cUvn4UP3C4Q5nmaZGW8Y= github.com/blevesearch/gtreap v0.1.1/go.mod h1:QaQyDRAT51sotthUWAH4Sj08awFSSWzgYICSZ3w0tYk= -github.com/blevesearch/mmap-go v1.0.2/go.mod h1:ol2qBqYaOUsGdm7aRMRrYGgPvnwLe6Y+7LMvAB5IbSA= github.com/blevesearch/mmap-go v1.0.4 h1:OVhDhT5B/M1HNPpYPBKIEJaD0F3Si+CrEKULGCDPWmc= github.com/blevesearch/mmap-go v1.0.4/go.mod h1:EWmEAOmdAS9z/pi/+Toxu99DnsbhG1TIxUoRmJw/pSs= -github.com/blevesearch/scorch_segment_api/v2 v2.1.2 h1:TAte9VZLWda5WAVlZTTZ+GCzEHqGJb4iB2aiZSA6Iv8= -github.com/blevesearch/scorch_segment_api/v2 v2.1.2/go.mod h1:rvoQXZGq8drq7vXbNeyiRzdEOwZkjkiYGf1822i6CRA= -github.com/blevesearch/segment v0.9.0 h1:5lG7yBCx98or7gK2cHMKPukPZ/31Kag7nONpoBt22Ac= -github.com/blevesearch/segment v0.9.0/go.mod h1:9PfHYUdQCgHktBgvtUOF4x+pc4/l8rdH0u5spnW85UQ= -github.com/blevesearch/snowball v0.6.1/go.mod h1:ZF0IBg5vgpeoUhnMza2v0A/z8m1cWPlwhke08LpNusg= +github.com/blevesearch/scorch_segment_api/v2 v2.1.4 h1:LmGmo5twU3gV+natJbKmOktS9eMhokPGKWuR+jX84vk= +github.com/blevesearch/scorch_segment_api/v2 v2.1.4/go.mod h1:PgVnbbg/t1UkgezPDu8EHLi1BHQ17xUwsFdU6NnOYS0= +github.com/blevesearch/segment v0.9.1 h1:+dThDy+Lvgj5JMxhmOVlgFfkUtZV2kw49xax4+jTfSU= +github.com/blevesearch/segment v0.9.1/go.mod h1:zN21iLm7+GnBHWTao9I+Au/7MBiL8pPFtJBJTsk6kQw= github.com/blevesearch/snowballstem v0.9.0 h1:lMQ189YspGP6sXvZQ4WZ+MLawfV8wOmPoD/iWeNXm8s= github.com/blevesearch/snowballstem v0.9.0/go.mod h1:PivSj3JMc8WuaFkTSRDW2SlrulNWPl4ABg1tC/hlgLs= -github.com/blevesearch/upsidedown_store_api v1.0.1 h1:1SYRwyoFLwG3sj0ed89RLtM15amfX2pXlYbFOnF8zNU= -github.com/blevesearch/upsidedown_store_api v1.0.1/go.mod h1:MQDVGpHZrpe3Uy26zJBf/a8h0FZY6xJbthIMm8myH2Q= -github.com/blevesearch/vellum v1.0.8 h1:iMGh4lfxza4BnWO/UJTMPlI3HsK9YawjPv+TteVa9ck= -github.com/blevesearch/vellum v1.0.8/go.mod h1:+cpRi/tqq49xUYSQN2P7A5zNSNrS+MscLeeaZ3J46UA= -github.com/blevesearch/zapx/v11 v11.3.5 h1:eBQWQ7huA+mzm0sAGnZDwgGGli7S45EO+N+ObFWssbI= -github.com/blevesearch/zapx/v11 v11.3.5/go.mod h1:5UdIa/HRMdeRCiLQOyFESsnqBGiip7vQmYReA9toevU= -github.com/blevesearch/zapx/v12 v12.3.5 h1:5pX2hU+R1aZihT7ac1dNWh1n4wqkIM9pZzWp0ANED9s= -github.com/blevesearch/zapx/v12 v12.3.5/go.mod h1:ANcthYRZQycpbRut/6ArF5gP5HxQyJqiFcuJCBju/ss= -github.com/blevesearch/zapx/v13 v13.3.5 h1:eJ3gbD+Nu8p36/O6lhfdvWQ4pxsGYSuTOBrLLPVWJ74= -github.com/blevesearch/zapx/v13 v13.3.5/go.mod h1:FV+dRnScFgKnRDIp08RQL4JhVXt1x2HE3AOzqYa6fjo= -github.com/blevesearch/zapx/v14 v14.3.5 h1:hEvVjZaagFCvOUJrlFQ6/Z6Jjy0opM3g7TMEo58TwP4= -github.com/blevesearch/zapx/v14 v14.3.5/go.mod h1:954A/eKFb+pg/ncIYWLWCKY+mIjReM9FGTGIO2Wu1cU= -github.com/blevesearch/zapx/v15 v15.3.5 h1:NVD0qq8vRk66ImJn1KloXT5ckqPDUZT7VbVJs9jKlac= -github.com/blevesearch/zapx/v15 v15.3.5/go.mod h1:QMUh2hXCaYIWFKPYGavq/Iga2zbHWZ9DZAa9uFbWyvg= +github.com/blevesearch/upsidedown_store_api v1.0.2 h1:U53Q6YoWEARVLd1OYNc9kvhBMGZzVrdmaozG2MfoB+A= +github.com/blevesearch/upsidedown_store_api v1.0.2/go.mod h1:M01mh3Gpfy56Ps/UXHjEO/knbqyQ1Oamg8If49gRwrQ= +github.com/blevesearch/vellum v1.0.9 h1:PL+NWVk3dDGPCV0hoDu9XLLJgqU4E5s/dOeEJByQ2uQ= +github.com/blevesearch/vellum v1.0.9/go.mod h1:ul1oT0FhSMDIExNjIxHqJoGpVrBpKCdgDQNxfqgJt7k= +github.com/blevesearch/zapx/v11 v11.3.7 h1:Y6yIAF/DVPiqZUA/jNgSLXmqewfzwHzuwfKyfdG+Xaw= +github.com/blevesearch/zapx/v11 v11.3.7/go.mod h1:Xk9Z69AoAWIOvWudNDMlxJDqSYGf90LS0EfnaAIvXCA= +github.com/blevesearch/zapx/v12 v12.3.7 h1:DfQ6rsmZfEK4PzzJJRXjiM6AObG02+HWvprlXQ1Y7eI= +github.com/blevesearch/zapx/v12 v12.3.7/go.mod h1:SgEtYIBGvM0mgIBn2/tQE/5SdrPXaJUaT/kVqpAPxm0= +github.com/blevesearch/zapx/v13 v13.3.7 h1:igIQg5eKmjw168I7av0Vtwedf7kHnQro/M+ubM4d2l8= +github.com/blevesearch/zapx/v13 v13.3.7/go.mod h1:yyrB4kJ0OT75UPZwT/zS+Ru0/jYKorCOOSY5dBzAy+s= +github.com/blevesearch/zapx/v14 v14.3.7 h1:gfe+fbWslDWP/evHLtp/GOvmNM3sw1BbqD7LhycBX20= +github.com/blevesearch/zapx/v14 v14.3.7/go.mod h1:9J/RbOkqZ1KSjmkOes03AkETX7hrXT0sFMpWH4ewC4w= +github.com/blevesearch/zapx/v15 v15.3.8 h1:q4uMngBHzL1IIhRc8AJUEkj6dGOE3u1l3phLu7hq8uk= +github.com/blevesearch/zapx/v15 v15.3.8/go.mod h1:m7Y6m8soYUvS7MjN9eKlz1xrLCcmqfFadmu7GhWIrLY= github.com/bradfitz/iter v0.0.0-20140124041915-454541ec3da2/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20190303215204-33e6a9893b0c/go.mod h1:PyRFw1Lt2wKX4ZVSQ2mk+PeDa1rxyObEDlApuIsUKuo= github.com/bradfitz/iter v0.0.0-20191230175014-e8f45d346db8 h1:GKTyiRCL6zVf5wWaqKnf+7Qs6GbEPfd4iMOitWzXJx8= @@ -130,12 +123,6 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/codeclysm/extract v2.2.0+incompatible h1:q3wyckoA30bhUSiwdQezMqVhwd8+WGE64/GL//LtUhI= github.com/codeclysm/extract v2.2.0+incompatible/go.mod h1:2nhFMPHiU9At61hz+12bfrlpXSUrOnK+wR+KlGO4Uks= -github.com/coreos/etcd v3.3.10+incompatible/go.mod h1:uF7uidLiAD3TWHmW31ZFd/JWoc32PjwdhPthX9715RE= -github.com/coreos/go-etcd v2.0.0+incompatible/go.mod h1:Jez6KQU2B/sWsbdaef3ED8NzMklzPG4d5KIOhIy30Tk= -github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= -github.com/couchbase/ghistogram v0.1.0/go.mod h1:s1Jhy76zqfEecpNWJfWUiKZookAFaiGOEoyzgHt9i7k= -github.com/couchbase/moss v0.2.0/go.mod h1:9MaHIaRuy9pvLPUJxB8sh8OrLfyDczECVL37grCIubs= -github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -164,7 +151,6 @@ github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7 github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/frankban/quicktest v1.0.0/go.mod h1:R98jIehRai+d1/3Hv2//jOVCTJhW1VBavT6B6CuGq2k= github.com/frankban/quicktest v1.14.3 h1:FJKSZTDHjyhriyC81FLQ0LY93eSai0ZyR/ZIkd3ZUKE= -github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/getsentry/sentry-go v0.14.0 h1:rlOBkuFZRKKdUnKO+0U3JclRDQKlRu5vVQtkWSQvC70= github.com/getsentry/sentry-go v0.14.0/go.mod h1:RZPJKSw+adu8PBNygiri/A98FqVr2HtRckJk9XVxJ9I= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= @@ -241,7 +227,6 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= -github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gologme/log v1.3.0 h1:l781G4dE+pbigClDSDzSaaYKtiueHCILUa/qSDsmHAo= @@ -261,7 +246,6 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -289,15 +273,11 @@ github.com/h2non/filetype v1.1.3/go.mod h1:319b3zT68BvV+WRj7cwy856M2ehB3HqNOt6sy github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huandu/xstrings v1.0.0 h1:pO2K/gKgKaat5LdpAhxhluX2GPQMaI3W5FUz/I/UnWk= github.com/huandu/xstrings v1.0.0/go.mod h1:4qWG/gcEcfX4z/mBDHJ++3ReCw9ibxbsNJbcucJdbSo= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= -github.com/json-iterator/go v0.0.0-20171115153421-f7279a603ede/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.10/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= @@ -337,7 +317,6 @@ github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lucas-clemente/quic-go v0.30.0 h1:nwLW0h8ahVQ5EPTIM7uhl/stHqQDea15oRlYKZmw2O0= github.com/lucas-clemente/quic-go v0.30.0/go.mod h1:ssOrRsOmdxa768Wr78vnh2B8JozgLsMzG/g+0qEC7uk= -github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= github.com/marten-seemann/qtls-go1-19 v0.1.1 h1:mnbxeq3oEyQxQXwI4ReCgW9DPoPR94sNlqWoDZnjRIE= @@ -365,8 +344,6 @@ github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182aff github.com/miekg/dns v1.1.50 h1:DQUfb9uc6smULcREF09Uc+/Gd46YWqJd5DbpPE9xkcA= github.com/minio/highwayhash v1.0.2 h1:Aak5U0nElisjDCfPSG79Tgzkn2gl66NxOMspRrKnA/g= github.com/minio/highwayhash v1.0.2/go.mod h1:BQskDq+xkJ12lmlUUi7U0M5Swg3EWR+dLTk+kldvVxY= -github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= -github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae h1:O4SWKdcHVCvYqyDV+9CJA1fcDN2L11Bule0iFy3YlAI= github.com/moby/term v0.0.0-20220808134915-39b0c02b01ae/go.mod h1:E2VnQOmVuvZB6UYnnDB0qG5Nq/1tD9acaOpo6xmt0Kw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -400,11 +377,8 @@ github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 h1:Dmx8g2747UTVPzSkmohk84S3g/uWqd6+f4SSLPhLcfA= github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79/go.mod h1:E26fwEtRNigBfFfHDWsklmo0T7Ixbg0XXgck+Hq4O9k= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= -github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo/v2 v2.3.0 h1:kUMoxMoQG3ogk/QWyKh3zibV7BKZ+xBpWil1cTylVqc= github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= -github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= github.com/onsi/gomega v1.22.1 h1:pY8O4lBfsHKZHM/6nrxkhVPUznOlIu3quZcKP/M20KI= github.com/onsi/gomega v1.22.1/go.mod h1:x6n7VNe4hw0vkyYUM4mjIXx3JbLiPaBPNgB7PRQ1tuM= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -415,8 +389,6 @@ github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+ github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pelletier/go-toml v1.2.0 h1:T5zMGML61Wp+FlcbWjRDT7yAxhJNAiPPLOFECq181zc= -github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= github.com/pelletier/go-toml/v2 v2.0.5 h1:ipoSadvV8oGUjnUbMub59IDPPwfxF694nG/jwbMiyQg= github.com/philhofer/fwd v1.0.0/go.mod h1:gk3iGcWd9+svBvR0sR+KPcfE+RNWozjowpeBVG3ZVNU= github.com/pingcap/errors v0.11.4 h1:lFuQV/oaUMGcD2tqt+01ROSmJs75VG1ToEOkZIZ4nE4= @@ -458,7 +430,6 @@ github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa/go.mod h1:qq github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/russross/blackfriday v1.5.2/go.mod h1:JO/DiYxRf+HjHt06OyowR9PTA263kcR/rfWxYHBV53g= github.com/ryszard/goskiplist v0.0.0-20150312221310-2dfbae5fcf46/go.mod h1:uAQ5PCi+MFsC7HjREoAz1BU+Mq60+05gifQSsHSDG/8= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -468,12 +439,7 @@ github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0 github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v0.0.0-20181108003508-044398e4856c/go.mod h1:XDJAKZRPZ1CvBcN2aX5YOUTYGHki24fSF0Iv48Ibg0s= -github.com/spf13/afero v1.1.2/go.mod h1:j4pytiNVoe2o6bmDsKpLACNPDBIoEAkihy7loJ1B0CQ= -github.com/spf13/cast v1.3.0/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= -github.com/spf13/cobra v0.0.5/go.mod h1:3K3wKZymM7VvHMDS9+Akkh4K60UwM26emMESw8tLCHU= -github.com/spf13/jwalterweatherman v1.0.0/go.mod h1:cQK4TGJAtQXfYWX+Ddv3mKDzgVb68N+wFjFa4jdeBTo= github.com/spf13/pflag v1.0.3/go.mod h1:DYY7MBk1bdzusC3SYhjObp+wFpr4gzcvqqNjLnInEg4= -github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DMA2s= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= @@ -506,11 +472,9 @@ github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVK github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= -github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0= github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4= -github.com/xordataexchange/crypt v0.0.3-0.20170626215501-b2862e3d0a77/go.mod h1:aYKd//L2LvnjZzWKhF00oedf4jCCReLcmhLdhm1A27Q= github.com/yggdrasil-network/yggdrasil-go v0.4.6 h1:GALUDV9QPz/5FVkbazpkTc9EABHufA556JwUJZr41j4= github.com/yggdrasil-network/yggdrasil-go v0.4.6/go.mod h1:PBMoAOvQjA9geNEeGyMXA9QgCS6Bu+9V+1VkWM84wpw= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -519,7 +483,6 @@ github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9de github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= @@ -531,7 +494,6 @@ go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= -golang.org/x/crypto v0.0.0-20181203042331-505ab145d0a9/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -588,7 +550,6 @@ golang.org/x/mod v0.6.0 h1:b9gGHsz9/HhJ3HF5DHQytPpuwocVTChQJK3AvoLRD5I= golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -645,10 +606,7 @@ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181205085412-a5c9d58dba9a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20181221143128-b4a75ba826a6/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190130150945-aca44879d564/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -861,13 +819,11 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/h2non/bimg.v1 v1.1.9 h1:wZIUbeOnwr37Ta4aofhIv8OI8v4ujpjXC9mXnAGpQjM= gopkg.in/h2non/bimg.v1 v1.1.9/go.mod h1:PgsZL7dLwUbsGm1NYps320GxGgvQNTnecMCZqxV11So= gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= gopkg.in/macaroon.v2 v2.1.0 h1:HZcsjBCzq9t0eBPMKqTN/uSN6JOm78ZJ2INbqcBQOUI= gopkg.in/macaroon.v2 v2.1.0/go.mod h1:OUb+TQP/OP0WOerC2Jp/3CwhIKyIa9kQjuc7H24e6/o= -gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= From 63df85db6d5bc528a784dc52e550fc64385c5f67 Mon Sep 17 00:00:00 2001 From: devonh Date: Sat, 28 Jan 2023 23:27:53 +0000 Subject: [PATCH 105/124] Relay integration to pinecone demos (#2955) This extends the dendrite monolith for pinecone to integrate the s&f features into the mobile apps. Also makes a few tweaks to federation queueing/statistics to make some edge cases more robust. --- build/gobind-pinecone/monolith.go | 287 +++++++++++++++++++----- build/gobind-pinecone/monolith_test.go | 124 ++++++++++ federationapi/api/api.go | 32 ++- federationapi/internal/perform.go | 30 +++ federationapi/internal/perform_test.go | 41 ++++ federationapi/inthttp/client.go | 24 ++ federationapi/queue/destinationqueue.go | 67 +++--- federationapi/queue/queue_test.go | 4 +- federationapi/statistics/statistics.go | 2 + setup/config/config_federationapi.go | 2 +- test/memory_federation_db.go | 31 ++- 11 files changed, 559 insertions(+), 85 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index ff61ea6c8..5e8e5875c 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -78,19 +78,19 @@ const ( ) type DendriteMonolith struct { - logger logrus.Logger - baseDendrite *base.BaseDendrite - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - PineconeManager *pineconeConnections.ConnectionManager - StorageDirectory string - CacheDirectory string - listener net.Listener - httpServer *http.Server - userAPI userapiAPI.UserInternalAPI - federationAPI api.FederationInternalAPI - relayServersQueried map[gomatrixserverlib.ServerName]bool + logger logrus.Logger + baseDendrite *base.BaseDendrite + PineconeRouter *pineconeRouter.Router + PineconeMulticast *pineconeMulticast.Multicast + PineconeQUIC *pineconeSessions.Sessions + PineconeManager *pineconeConnections.ConnectionManager + StorageDirectory string + CacheDirectory string + listener net.Listener + httpServer *http.Server + userAPI userapiAPI.UserInternalAPI + federationAPI api.FederationInternalAPI + relayRetriever RelayServerRetriever } func (m *DendriteMonolith) PublicKey() string { @@ -167,6 +167,152 @@ func (m *DendriteMonolith) SetStaticPeer(uri string) { } } +func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) { + var nodeKey gomatrixserverlib.ServerName + if userID, err := gomatrixserverlib.NewUserID(nodeID, false); err == nil { + hexKey, decodeErr := hex.DecodeString(string(userID.Domain())) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("UserID domain is not a valid ed25519 public key: %v", userID.Domain()) + } else { + nodeKey = userID.Domain() + } + } else { + hexKey, decodeErr := hex.DecodeString(nodeID) + if decodeErr != nil || len(hexKey) != ed25519.PublicKeySize { + return "", fmt.Errorf("Relay server uri is not a valid ed25519 public key: %v", nodeID) + } else { + nodeKey = gomatrixserverlib.ServerName(nodeID) + } + } + + return nodeKey, nil +} + +func updateNodeRelayServers( + node gomatrixserverlib.ServerName, + relays []gomatrixserverlib.ServerName, + ctx context.Context, + fedAPI api.FederationInternalAPI, +) { + // Get the current relay list + request := api.P2PQueryRelayServersRequest{Server: node} + response := api.P2PQueryRelayServersResponse{} + err := fedAPI.P2PQueryRelayServers(ctx, &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of relay servers for %s: %s", node, err.Error()) + } + + // Remove old, non-matching relays + var serversToRemove []gomatrixserverlib.ServerName + for _, existingServer := range response.RelayServers { + shouldRemove := true + for _, newServer := range relays { + if newServer == existingServer { + shouldRemove = false + break + } + } + + if shouldRemove { + serversToRemove = append(serversToRemove, existingServer) + } + } + removeRequest := api.P2PRemoveRelayServersRequest{ + Server: node, + RelayServers: serversToRemove, + } + removeResponse := api.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(ctx, &removeRequest, &removeResponse) + if err != nil { + logrus.Warnf("Failed removing old relay servers for %s: %s", node, err.Error()) + } + + // Add new relays + addRequest := api.P2PAddRelayServersRequest{ + Server: node, + RelayServers: relays, + } + addResponse := api.P2PAddRelayServersResponse{} + err = fedAPI.P2PAddRelayServers(ctx, &addRequest, &addResponse) + if err != nil { + logrus.Warnf("Failed adding relay servers for %s: %s", node, err.Error()) + } +} + +func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { + relays := []gomatrixserverlib.ServerName{} + for _, uri := range strings.Split(uris, ",") { + uri = strings.TrimSpace(uri) + if len(uri) == 0 { + continue + } + + nodeKey, err := getServerKeyFromString(uri) + if err != nil { + logrus.Errorf(err.Error()) + continue + } + relays = append(relays, nodeKey) + } + + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return + } + + if string(nodeKey) == m.PublicKey() { + logrus.Infof("Setting own relay servers to: %v", relays) + m.relayRetriever.SetRelayServers(relays) + } else { + updateNodeRelayServers( + gomatrixserverlib.ServerName(nodeKey), + relays, + m.baseDendrite.Context(), + m.federationAPI, + ) + } +} + +func (m *DendriteMonolith) GetRelayServers(nodeID string) string { + nodeKey, err := getServerKeyFromString(nodeID) + if err != nil { + logrus.Errorf(err.Error()) + return "" + } + + relaysString := "" + if string(nodeKey) == m.PublicKey() { + relays := m.relayRetriever.GetRelayServers() + + for i, relay := range relays { + if i != 0 { + // Append a comma to the previous entry if there is one. + relaysString += "," + } + relaysString += string(relay) + } + } else { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(nodeKey)} + response := api.P2PQueryRelayServersResponse{} + err := m.federationAPI.P2PQueryRelayServers(m.baseDendrite.Context(), &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + return "" + } + + for i, relay := range response.RelayServers { + if i != 0 { + // Append a comma to the previous entry if there is one. + relaysString += "," + } + relaysString += string(relay) + } + } + + return relaysString +} + func (m *DendriteMonolith) DisconnectType(peertype int) { for _, p := range m.PineconeRouter.Peers() { if int(peertype) == p.PeerType { @@ -454,28 +600,28 @@ func (m *DendriteMonolith) Start() { } }() - go func(ch <-chan pineconeEvents.Event) { - eLog := logrus.WithField("pinecone", "events") - stopRelayServerSync := make(chan bool) + stopRelayServerSync := make(chan bool) - relayRetriever := RelayServerRetriever{ - Context: context.Background(), - ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), - FederationAPI: m.federationAPI, - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - RelayAPI: monolith.RelayAPI, - running: *atomic.NewBool(false), - } - relayRetriever.InitializeRelayServers(eLog) + eLog := logrus.WithField("pinecone", "events") + m.relayRetriever = RelayServerRetriever{ + Context: context.Background(), + ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), + FederationAPI: m.federationAPI, + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + RelayAPI: monolith.RelayAPI, + running: *atomic.NewBool(false), + quit: stopRelayServerSync, + } + m.relayRetriever.InitializeRelayServers(eLog) + + go func(ch <-chan pineconeEvents.Event) { for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: - if !relayRetriever.running.Load() { - go relayRetriever.SyncRelayServers(stopRelayServerSync) - } + m.relayRetriever.StartSync() case pineconeEvents.PeerRemoved: - if relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { + if m.relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { stopRelayServerSync <- true } case pineconeEvents.BroadcastReceived: @@ -495,7 +641,7 @@ func (m *DendriteMonolith) Start() { } func (m *DendriteMonolith) Stop() { - m.baseDendrite.Close() + _ = m.baseDendrite.Close() m.baseDendrite.WaitForShutdown() _ = m.listener.Close() m.PineconeMulticast.Stop() @@ -511,32 +657,68 @@ type RelayServerRetriever struct { relayServersQueried map[gomatrixserverlib.ServerName]bool queriedServersMutex sync.Mutex running atomic.Bool + quit <-chan bool } -func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { - request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} +func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(r.ServerName)} response := api.P2PQueryRelayServersResponse{} - err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) + err := r.FederationAPI.P2PQueryRelayServers(r.Context, &request, &response) if err != nil { eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) } + + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() for _, server := range response.RelayServers { - m.relayServersQueried[server] = false + r.relayServersQueried[server] = false } eLog.Infof("Registered relay servers: %v", response.RelayServers) } -func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { - defer m.running.Store(false) +func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { + updateNodeRelayServers(r.ServerName, servers, r.Context, r.FederationAPI) + + // Replace list of servers to sync with and mark them all as unsynced. + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) + for _, server := range servers { + r.relayServersQueried[server] = false + } + + r.StartSync() +} + +func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + relayServers := []gomatrixserverlib.ServerName{} + for server := range r.relayServersQueried { + relayServers = append(relayServers, server) + } + + return relayServers +} + +func (r *RelayServerRetriever) StartSync() { + if !r.running.Load() { + logrus.Info("Starting relay server sync") + go r.SyncRelayServers(r.quit) + } +} + +func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { + defer r.running.Store(false) t := time.NewTimer(relayServerRetryInterval) for { relayServersToQuery := []gomatrixserverlib.ServerName{} func() { - m.queriedServersMutex.Lock() - defer m.queriedServersMutex.Unlock() - for server, complete := range m.relayServersQueried { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + for server, complete := range r.relayServersQueried { if !complete { relayServersToQuery = append(relayServersToQuery, server) } @@ -544,9 +726,10 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { }() if len(relayServersToQuery) == 0 { // All relay servers have been synced. + logrus.Info("Finished syncing with all known relays") return } - m.queryRelayServers(relayServersToQuery) + r.queryRelayServers(relayServersToQuery) t.Reset(relayServerRetryInterval) select { @@ -560,30 +743,32 @@ func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { } } -func (m *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { - m.queriedServersMutex.Lock() - defer m.queriedServersMutex.Unlock() +func (r *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() result := map[gomatrixserverlib.ServerName]bool{} - for server, queried := range m.relayServersQueried { + for server, queried := range r.relayServersQueried { result[server] = queried } return result } -func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { - logrus.Info("querying relay servers for any available transactions") +func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("Querying relay servers for any available transactions") for _, server := range relayServers { - userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) + userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.ServerName), false) if err != nil { return } - err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) + + logrus.Infof("Syncing with relay: %s", string(server)) + err = r.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) if err == nil { func() { - m.queriedServersMutex.Lock() - defer m.queriedServersMutex.Unlock() - m.relayServersQueried[server] = true + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried[server] = true }() // TODO : What happens if your relay receives new messages after this point? // Should you continue to check with them, or should they try and contact you? diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index edcf22bbe..3c8873e09 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "net" + "strings" "testing" "time" @@ -196,3 +197,126 @@ func TestMonolithStarts(t *testing.T) { monolith.PublicKey() monolith.Stop() } + +func TestMonolithSetRelayServers(t *testing.T) { + testCases := []struct { + name string + nodeID string + relays string + expectedRelays string + expectSelf bool + }{ + { + name: "assorted valid, invalid, empty & self keys", + nodeID: "@valid:abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectSelf: true, + }, + { + name: "invalid node key", + nodeID: "@invalid:notakey", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "", + expectSelf: false, + }, + { + name: "node is self", + nodeID: "self", + relays: "@valid:123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd,@invalid:notakey,,", + expectedRelays: "123456123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectSelf: false, + }, + } + + for _, tc := range testCases { + monolith := DendriteMonolith{} + monolith.Start() + + inputRelays := tc.relays + expectedRelays := tc.expectedRelays + if tc.expectSelf { + inputRelays += "," + monolith.PublicKey() + expectedRelays += "," + monolith.PublicKey() + } + nodeID := tc.nodeID + if nodeID == "self" { + nodeID = monolith.PublicKey() + } + + monolith.SetRelayServers(nodeID, inputRelays) + relays := monolith.GetRelayServers(nodeID) + monolith.Stop() + + if !containSameKeys(strings.Split(relays, ","), strings.Split(expectedRelays, ",")) { + t.Fatalf("%s: expected %s got %s", tc.name, expectedRelays, relays) + } + } +} + +func containSameKeys(expected []string, actual []string) bool { + if len(expected) != len(actual) { + return false + } + + for _, expectedKey := range expected { + hasMatch := false + for _, actualKey := range actual { + if actualKey == expectedKey { + hasMatch = true + } + } + + if !hasMatch { + return false + } + } + + return true +} + +func TestParseServerKey(t *testing.T) { + testCases := []struct { + name string + serverKey string + expectedErr bool + expectedKey gomatrixserverlib.ServerName + }{ + { + name: "valid userid as key", + serverKey: "@valid:abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectedErr: false, + expectedKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + }, + { + name: "valid key", + serverKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + expectedErr: false, + expectedKey: "abcdef123456abcdef123456abcdef123456abcdef123456abcdef123456abcd", + }, + { + name: "invalid userid key", + serverKey: "@invalid:notakey", + expectedErr: true, + expectedKey: "", + }, + { + name: "invalid key", + serverKey: "@invalid:notakey", + expectedErr: true, + expectedKey: "", + }, + } + + for _, tc := range testCases { + key, err := getServerKeyFromString(tc.serverKey) + if tc.expectedErr && err == nil { + t.Fatalf("%s: expected an error", tc.name) + } else if !tc.expectedErr && err != nil { + t.Fatalf("%s: didn't expect an error: %s", tc.name, err.Error()) + } + if tc.expectedKey != key { + t.Fatalf("%s: keys not equal. expected: %s got: %s", tc.name, tc.expectedKey, key) + } + } +} diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 417b08521..e4c0b2714 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -72,12 +72,26 @@ type RoomserverFederationAPI interface { } type P2PFederationAPI interface { - // Relay Server sync api used in the pinecone demos. + // Get the relay servers associated for the given server. P2PQueryRelayServers( ctx context.Context, request *P2PQueryRelayServersRequest, response *P2PQueryRelayServersResponse, ) error + + // Add relay server associations to the given server. + P2PAddRelayServers( + ctx context.Context, + request *P2PAddRelayServersRequest, + response *P2PAddRelayServersResponse, + ) error + + // Remove relay server associations from the given server. + P2PRemoveRelayServers( + ctx context.Context, + request *P2PRemoveRelayServersRequest, + response *P2PRemoveRelayServersResponse, + ) error } // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver @@ -256,3 +270,19 @@ type P2PQueryRelayServersRequest struct { type P2PQueryRelayServersResponse struct { RelayServers []gomatrixserverlib.ServerName } + +type P2PAddRelayServersRequest struct { + Server gomatrixserverlib.ServerName + RelayServers []gomatrixserverlib.ServerName +} + +type P2PAddRelayServersResponse struct { +} + +type P2PRemoveRelayServersRequest struct { + Server gomatrixserverlib.ServerName + RelayServers []gomatrixserverlib.ServerName +} + +type P2PRemoveRelayServersResponse struct { +} diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 552942f28..b9684f767 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -840,6 +840,36 @@ func (r *FederationInternalAPI) P2PQueryRelayServers( return nil } +// P2PAddRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PAddRelayServers( + ctx context.Context, + request *api.P2PAddRelayServersRequest, + response *api.P2PAddRelayServersResponse, +) error { + logrus.Infof("Adding relay servers for: %s", request.Server) + err := r.db.P2PAddRelayServersForServer(ctx, request.Server, request.RelayServers) + if err != nil { + return err + } + + return nil +} + +// P2PRemoveRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PRemoveRelayServers( + ctx context.Context, + request *api.P2PRemoveRelayServersRequest, + response *api.P2PRemoveRelayServersResponse, +) error { + logrus.Infof("Adding relay servers for: %s", request.Server) + err := r.db.P2PRemoveRelayServersForServer(ctx, request.Server, request.RelayServers) + if err != nil { + return err + } + + return nil +} + func (r *FederationInternalAPI) shouldAttemptDirectFederation( destination gomatrixserverlib.ServerName, ) bool { diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go index e8e0d00a3..e6e366f99 100644 --- a/federationapi/internal/perform_test.go +++ b/federationapi/internal/perform_test.go @@ -123,6 +123,47 @@ func TestQueryRelayServers(t *testing.T) { assert.Equal(t, len(relayServers), len(res.RelayServers)) } +func TestRemoveRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PRemoveRelayServersRequest{ + Server: server, + RelayServers: []gomatrixserverlib.ServerName{"relay1"}, + } + res := api.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + finalRelays, err := testDB.P2PGetRelayServersForServer(context.Background(), server) + assert.NoError(t, err) + assert.Equal(t, 1, len(finalRelays)) + assert.Equal(t, gomatrixserverlib.ServerName("relay2"), finalRelays[0]) +} + func TestPerformDirectoryLookup(t *testing.T) { testDB := test.NewInMemoryFederationDatabase() diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 6130a567d..00e069d1e 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -25,6 +25,8 @@ const ( FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" FederationAPIQueryRelayServers = "/federationapi/queryRelayServers" + FederationAPIAddRelayServers = "/federationapi/addRelayServers" + FederationAPIRemoveRelayServers = "/federationapi/removeRelayServers" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -522,3 +524,25 @@ func (h *httpFederationInternalAPI) P2PQueryRelayServers( h.httpClient, ctx, request, response, ) } + +func (h *httpFederationInternalAPI) P2PAddRelayServers( + ctx context.Context, + request *api.P2PAddRelayServersRequest, + response *api.P2PAddRelayServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "AddRelayServers", h.federationAPIURL+FederationAPIAddRelayServers, + h.httpClient, ctx, request, response, + ) +} + +func (h *httpFederationInternalAPI) P2PRemoveRelayServers( + ctx context.Context, + request *api.P2PRemoveRelayServersRequest, + response *api.P2PRemoveRelayServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "RemoveRelayServers", h.federationAPIURL+FederationAPIRemoveRelayServers, + h.httpClient, ctx, request, response, + ) +} diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 51350916d..12e6db9fa 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -410,34 +410,49 @@ func (oq *destinationQueue) nextTransaction( defer cancel() relayServers := oq.statistics.KnownRelayServers() - if oq.statistics.AssumedOffline() && len(relayServers) > 0 { - sendMethod = statistics.SendViaRelay - relaySuccess := false - logrus.Infof("Sending to relay servers: %v", relayServers) - // TODO : how to pass through actual userID here?!?!?!?! - userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) - if userErr != nil { - return userErr, sendMethod - } - - // Attempt sending to each known relay server. - for _, relayServer := range relayServers { - _, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer) - if relayErr != nil { - err = relayErr - } else { - // If sending to one of the relay servers succeeds, consider the send successful. - relaySuccess = true - } - } - - // Clear the error if sending to any of the relay servers succeeded. - if relaySuccess { - err = nil - } - } else { + hasRelayServers := len(relayServers) > 0 + shouldSendToRelays := oq.statistics.AssumedOffline() && hasRelayServers + if !shouldSendToRelays { sendMethod = statistics.SendDirect _, err = oq.client.SendTransaction(ctx, t) + } else { + // Try sending directly to the destination first in case they came back online. + sendMethod = statistics.SendDirect + _, err = oq.client.SendTransaction(ctx, t) + if err != nil { + // The destination is still offline, try sending to relays. + sendMethod = statistics.SendViaRelay + relaySuccess := false + logrus.Infof("Sending %q to relay servers: %v", t.TransactionID, relayServers) + // TODO : how to pass through actual userID here?!?!?!?! + userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) + if userErr != nil { + return userErr, sendMethod + } + + // Attempt sending to each known relay server. + for _, relayServer := range relayServers { + _, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer) + if relayErr != nil { + err = relayErr + } else { + // If sending to one of the relay servers succeeds, consider the send successful. + relaySuccess = true + + // TODO : what about if the dest comes back online but can't see their relay? + // How do I sync with the dest in that case? + // Should change the database to have a "relay success" flag on events and if + // I see the node back online, maybe directly send through the backlog of events + // with "relay success"... could lead to duplicate events, but only those that + // I sent. And will lead to a much more consistent experience. + } + } + + // Clear the error if sending to any of the relay servers succeeded. + if relaySuccess { + err = nil + } + } } switch errResponse := err.(type) { case nil: diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 36e2ccbc2..bccfb3428 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -923,7 +923,7 @@ func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { assert.NoError(t, err) check := func(log poll.LogT) poll.Result { - if fc.txCount.Load() == 1 { + if fc.txCount.Load() >= 1 { if fc.txRelayCount.Load() == 1 { data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -962,7 +962,7 @@ func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { assert.NoError(t, err) check := func(log poll.LogT) poll.Result { - if fc.txCount.Load() == 1 { + if fc.txCount.Load() >= 1 { if fc.txRelayCount.Load() == 1 { data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 866c09336..e29e3b140 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -164,6 +164,8 @@ func (s *ServerStatistics) Success(method SendMethod) { logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) } } + + s.removeAssumedOffline() } } diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 6c198018d..cd7d90562 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -49,7 +49,7 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) { c.Database.Defaults(10) } c.FederationMaxRetries = 16 - c.P2PFederationRetriesUntilAssumedOffline = 2 + c.P2PFederationRetriesUntilAssumedOffline = 1 c.DisableTLSValidation = false c.DisableHTTPKeepalives = false if opts.Generate { diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go index cc9e1e8fd..de0dc54eb 100644 --- a/test/memory_federation_db.go +++ b/test/memory_federation_db.go @@ -399,6 +399,33 @@ func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( return nil } +func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + for i, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + d.relayServers[serverName] = append( + d.relayServers[serverName][:i], + d.relayServers[serverName][i+1:]..., + ) + break + } + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { return nil, nil } @@ -431,10 +458,6 @@ func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context. return nil } -func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { - return nil -} - func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { return nil } From 0f998e3af3be06a1f0626de8cb74413c5da310f4 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Sun, 29 Jan 2023 12:26:16 -0700 Subject: [PATCH 106/124] Add pinecone demo toggle for dis/enabling relaying for other nodes --- build/gobind-pinecone/monolith.go | 13 ++++++-- cmd/dendrite-demo-pinecone/main.go | 2 +- relayapi/api/api.go | 6 ++++ relayapi/internal/api.go | 6 ++++ relayapi/internal/perform.go | 14 ++++++++ relayapi/internal/perform_test.go | 6 ++-- relayapi/relayapi.go | 2 ++ relayapi/relayapi_test.go | 52 ++++++++++++++++++++++++++---- relayapi/routing/relaytxn.go | 4 +-- relayapi/routing/relaytxn_test.go | 8 ++--- relayapi/routing/routing.go | 17 ++++++++++ relayapi/routing/sendrelay.go | 4 ++- relayapi/routing/sendrelay_test.go | 10 +++--- 13 files changed, 119 insertions(+), 25 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 5e8e5875c..a219621b1 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -90,6 +90,7 @@ type DendriteMonolith struct { httpServer *http.Server userAPI userapiAPI.UserInternalAPI federationAPI api.FederationInternalAPI + relayAPI relayServerAPI.RelayInternalAPI relayRetriever RelayServerRetriever } @@ -313,6 +314,14 @@ func (m *DendriteMonolith) GetRelayServers(nodeID string) string { return relaysString } +func (m *DendriteMonolith) RelayingEnabled() bool { + return m.relayAPI.RelayingEnabled() +} + +func (m *DendriteMonolith) SetRelayingEnabled(enabled bool) { + m.relayAPI.SetRelayingEnabled(enabled) +} + func (m *DendriteMonolith) DisconnectType(peertype int) { for _, p := range m.PineconeRouter.Peers() { if int(peertype) == p.PeerType { @@ -528,7 +537,7 @@ func (m *DendriteMonolith) Start() { Config: &base.Cfg.FederationAPI, UserAPI: m.userAPI, } - relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer) + m.relayAPI = relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer, false) monolith := setup.Monolith{ Config: base.Cfg, @@ -541,7 +550,7 @@ func (m *DendriteMonolith) Start() { RoomserverAPI: rsAPI, UserAPI: m.userAPI, KeyAPI: keyAPI, - RelayAPI: relayAPI, + RelayAPI: m.relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index a813c37a2..83655b69c 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -244,7 +244,7 @@ func main() { Config: &base.Cfg.FederationAPI, UserAPI: userAPI, } - relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer) + relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer, true) monolith := setup.Monolith{ Config: base.Cfg, diff --git a/relayapi/api/api.go b/relayapi/api/api.go index 9db393225..9b4b62e58 100644 --- a/relayapi/api/api.go +++ b/relayapi/api/api.go @@ -30,6 +30,12 @@ type RelayInternalAPI interface { userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName, ) error + + // Tells the relayapi whether or not it should act as a relay server for external servers. + SetRelayingEnabled(bool) + + // Obtain whether the relayapi is currently configured to act as a relay server for external servers. + RelayingEnabled() bool } // RelayServerAPI exposes the store & query transaction functionality of a relay server. diff --git a/relayapi/internal/api.go b/relayapi/internal/api.go index 3ff8c2add..3a5762fb8 100644 --- a/relayapi/internal/api.go +++ b/relayapi/internal/api.go @@ -15,6 +15,8 @@ package internal import ( + "sync" + fedAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/relayapi/storage" @@ -30,6 +32,8 @@ type RelayInternalAPI struct { producer *producers.SyncAPIProducer presenceEnabledInbound bool serverName gomatrixserverlib.ServerName + relayingEnabledMutex sync.Mutex + relayingEnabled bool } func NewRelayInternalAPI( @@ -40,6 +44,7 @@ func NewRelayInternalAPI( producer *producers.SyncAPIProducer, presenceEnabledInbound bool, serverName gomatrixserverlib.ServerName, + relayingEnabled bool, ) *RelayInternalAPI { return &RelayInternalAPI{ db: db, @@ -49,5 +54,6 @@ func NewRelayInternalAPI( producer: producer, presenceEnabledInbound: presenceEnabledInbound, serverName: serverName, + relayingEnabled: relayingEnabled, } } diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go index 594299334..c00d0d0ed 100644 --- a/relayapi/internal/perform.go +++ b/relayapi/internal/perform.go @@ -24,6 +24,20 @@ import ( "github.com/sirupsen/logrus" ) +// SetRelayingEnabled implements api.RelayInternalAPI +func (r *RelayInternalAPI) SetRelayingEnabled(enabled bool) { + r.relayingEnabledMutex.Lock() + defer r.relayingEnabledMutex.Unlock() + r.relayingEnabled = enabled +} + +// RelayingEnabled implements api.RelayInternalAPI +func (r *RelayInternalAPI) RelayingEnabled() bool { + r.relayingEnabledMutex.Lock() + defer r.relayingEnabledMutex.Unlock() + return r.relayingEnabled +} + // PerformRelayServerSync implements api.RelayInternalAPI func (r *RelayInternalAPI) PerformRelayServerSync( ctx context.Context, diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go index fb71b7d0e..5673b1994 100644 --- a/relayapi/internal/perform_test.go +++ b/relayapi/internal/perform_test.go @@ -72,7 +72,7 @@ func TestPerformRelayServerSync(t *testing.T) { fedClient := &testFedClient{} relayAPI := NewRelayInternalAPI( - &db, fedClient, nil, nil, nil, false, "", + &db, fedClient, nil, nil, nil, false, "", true, ) err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) @@ -92,7 +92,7 @@ func TestPerformRelayServerSyncFedError(t *testing.T) { fedClient := &testFedClient{shouldFail: true} relayAPI := NewRelayInternalAPI( - &db, fedClient, nil, nil, nil, false, "", + &db, fedClient, nil, nil, nil, false, "", true, ) err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) @@ -112,7 +112,7 @@ func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) { fedClient := &testFedClient{queueDepth: 2} relayAPI := NewRelayInternalAPI( - &db, fedClient, nil, nil, nil, false, "", + &db, fedClient, nil, nil, nil, false, "", true, ) err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go index f9f9d4ff9..200a1814a 100644 --- a/relayapi/relayapi.go +++ b/relayapi/relayapi.go @@ -54,6 +54,7 @@ func NewRelayInternalAPI( rsAPI rsAPI.RoomserverInternalAPI, keyRing *gomatrixserverlib.KeyRing, producer *producers.SyncAPIProducer, + relayingEnabled bool, ) api.RelayInternalAPI { cfg := &base.Cfg.RelayAPI @@ -70,5 +71,6 @@ func NewRelayInternalAPI( producer, base.Cfg.Global.Presence.EnableInbound, base.Cfg.Global.ServerName, + relayingEnabled, ) } diff --git a/relayapi/relayapi_test.go b/relayapi/relayapi_test.go index dfa06811d..f1b3262aa 100644 --- a/relayapi/relayapi_test.go +++ b/relayapi/relayapi_test.go @@ -36,7 +36,7 @@ func TestCreateNewRelayInternalAPI(t *testing.T) { base, close := testrig.CreateBaseDendrite(t, dbType) defer close() - relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true) assert.NotNil(t, relayAPI) }) } @@ -52,7 +52,7 @@ func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) { defer close() assert.Panics(t, func() { - relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true) }) }) } @@ -108,7 +108,7 @@ func TestCreateRelayPublicRoutes(t *testing.T) { base, close := testrig.CreateBaseDendrite(t, dbType) defer close() - relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, true) assert.NotNil(t, relayAPI) serverKeyAPI := &signing.YggdrasilKeys{} @@ -116,10 +116,9 @@ func TestCreateRelayPublicRoutes(t *testing.T) { relayapi.AddPublicRoutes(base, keyRing, relayAPI) testCases := []struct { - name string - req *http.Request - wantCode int - wantJoinedRooms []string + name string + req *http.Request + wantCode int }{ { name: "relay_txn invalid user id", @@ -152,3 +151,42 @@ func TestCreateRelayPublicRoutes(t *testing.T) { } }) } + +func TestDisableRelayPublicRoutes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil, false) + assert.NotNil(t, relayAPI) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + relayapi.AddPublicRoutes(base, keyRing, relayAPI) + + testCases := []struct { + name string + req *http.Request + wantCode int + }{ + { + name: "relay_txn valid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), + wantCode: 404, + }, + { + name: "send_relay valid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), + wantCode: 404, + }, + } + + for _, tc := range testCases { + w := httptest.NewRecorder() + base.PublicFederationAPIMux.ServeHTTP(w, tc.req) + if w.Code != tc.wantCode { + t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) + } + } + }) +} diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go index 1b11b0ecd..76241a33b 100644 --- a/relayapi/routing/relaytxn.go +++ b/relayapi/routing/relaytxn.go @@ -31,7 +31,7 @@ type RelayTransactionResponse struct { EntriesQueued bool `json:"entries_queued"` } -// GetTransactionFromRelay implements /_matrix/federation/v1/relay_txn/{userID} +// GetTransactionFromRelay implements GET /_matrix/federation/v1/relay_txn/{userID} // This endpoint can be extracted into a separate relay server service. func GetTransactionFromRelay( httpReq *http.Request, @@ -39,7 +39,7 @@ func GetTransactionFromRelay( relayAPI api.RelayInternalAPI, userID gomatrixserverlib.UserID, ) util.JSONResponse { - logrus.Infof("Handling relay_txn for %s", userID.Raw()) + logrus.Infof("Processing relay_txn for %s", userID.Raw()) previousEntry := gomatrixserverlib.RelayEntry{} if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go index a47fdb198..3cb4c8adc 100644 --- a/relayapi/routing/relaytxn_test.go +++ b/relayapi/routing/relaytxn_test.go @@ -57,7 +57,7 @@ func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { assert.NoError(t, err, "Failed to store transaction") relayAPI := internal.NewRelayInternalAPI( - &db, nil, nil, nil, nil, false, "", + &db, nil, nil, nil, nil, false, "", true, ) request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) @@ -90,7 +90,7 @@ func TestGetInvalidPrevEntryFails(t *testing.T) { assert.NoError(t, err, "Failed to store transaction") relayAPI := internal.NewRelayInternalAPI( - &db, nil, nil, nil, nil, false, "", + &db, nil, nil, nil, nil, false, "", true, ) request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) @@ -123,7 +123,7 @@ func TestGetReturnsSavedTransaction(t *testing.T) { assert.NoError(t, err, "Failed to associate transaction with user") relayAPI := internal.NewRelayInternalAPI( - &db, nil, nil, nil, nil, false, "", + &db, nil, nil, nil, nil, false, "", true, ) request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) @@ -186,7 +186,7 @@ func TestGetReturnsMultipleSavedTransactions(t *testing.T) { assert.NoError(t, err, "Failed to associate transaction with user") relayAPI := internal.NewRelayInternalAPI( - &db, nil, nil, nil, nil, false, "", + &db, nil, nil, nil, nil, false, "", true, ) request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go index 6df0cdc5f..c908e950a 100644 --- a/relayapi/routing/routing.go +++ b/relayapi/routing/routing.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) // Setup registers HTTP handlers with the given ServeMux. @@ -48,6 +49,14 @@ func Setup( v1fedmux.Handle("/send_relay/{txnID}/{userID}", MakeRelayAPI( "send_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + logrus.Infof("Handling send_relay from: %s", request.Origin()) + if !relayAPI.RelayingEnabled() { + logrus.Warnf("Dropping send_relay from: %s", request.Origin()) + return util.JSONResponse{ + Code: http.StatusNotFound, + } + } + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) if err != nil { return util.JSONResponse{ @@ -65,6 +74,14 @@ func Setup( v1fedmux.Handle("/relay_txn/{userID}", MakeRelayAPI( "get_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + logrus.Infof("Handling relay_txn from: %s", request.Origin()) + if !relayAPI.RelayingEnabled() { + logrus.Warnf("Dropping relay_txn from: %s", request.Origin()) + return util.JSONResponse{ + Code: http.StatusNotFound, + } + } + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) if err != nil { return util.JSONResponse{ diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go index a7027f293..85a722c8a 100644 --- a/relayapi/routing/sendrelay.go +++ b/relayapi/routing/sendrelay.go @@ -25,7 +25,7 @@ import ( "github.com/sirupsen/logrus" ) -// SendTransactionToRelay implements PUT /_matrix/federation/v1/relay_txn/{txnID}/{userID} +// SendTransactionToRelay implements PUT /_matrix/federation/v1/send_relay/{txnID}/{userID} // This endpoint can be extracted into a separate relay server service. func SendTransactionToRelay( httpReq *http.Request, @@ -34,6 +34,8 @@ func SendTransactionToRelay( txnID gomatrixserverlib.TransactionID, userID gomatrixserverlib.UserID, ) util.JSONResponse { + logrus.Infof("Processing send_relay for %s", userID.Raw()) + var txnEvents struct { PDUs []json.RawMessage `json:"pdus"` EDUs []gomatrixserverlib.EDU `json:"edus"` diff --git a/relayapi/routing/sendrelay_test.go b/relayapi/routing/sendrelay_test.go index d9ed75002..66594c47c 100644 --- a/relayapi/routing/sendrelay_test.go +++ b/relayapi/routing/sendrelay_test.go @@ -72,7 +72,7 @@ func TestForwardEmptyReturnsOk(t *testing.T) { request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) relayAPI := internal.NewRelayInternalAPI( - &db, nil, nil, nil, nil, false, "", + &db, nil, nil, nil, nil, false, "", true, ) response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) @@ -101,7 +101,7 @@ func TestForwardBadJSONReturnsError(t *testing.T) { request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) relayAPI := internal.NewRelayInternalAPI( - &db, nil, nil, nil, nil, false, "", + &db, nil, nil, nil, nil, false, "", true, ) response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) @@ -135,7 +135,7 @@ func TestForwardTooManyPDUsReturnsError(t *testing.T) { request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) relayAPI := internal.NewRelayInternalAPI( - &db, nil, nil, nil, nil, false, "", + &db, nil, nil, nil, nil, false, "", true, ) response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) @@ -169,7 +169,7 @@ func TestForwardTooManyEDUsReturnsError(t *testing.T) { request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) relayAPI := internal.NewRelayInternalAPI( - &db, nil, nil, nil, nil, false, "", + &db, nil, nil, nil, nil, false, "", true, ) response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) @@ -192,7 +192,7 @@ func TestUniqueTransactionStoredInDatabase(t *testing.T) { request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) relayAPI := internal.NewRelayInternalAPI( - &db, nil, nil, nil, nil, false, "", + &db, nil, nil, nil, nil, false, "", true, ) response := routing.SendTransactionToRelay( From f98003c030432074e0c411cd70e09bacae820940 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Sun, 29 Jan 2023 18:13:39 -0700 Subject: [PATCH 107/124] Add cmd line option to pinecone demo for enabling relaying --- cmd/dendrite-demo-pinecone/main.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 83655b69c..3f63ab654 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -64,11 +64,12 @@ import ( ) var ( - instanceName = flag.String("name", "dendrite-p2p-pinecone", "the name of this P2P demo instance") - instancePort = flag.Int("port", 8008, "the port that the client API will listen on") - instancePeer = flag.String("peer", "", "the static Pinecone peers to connect to, comma separated-list") - instanceListen = flag.String("listen", ":0", "the port Pinecone peers can connect to") - instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") + instanceName = flag.String("name", "dendrite-p2p-pinecone", "the name of this P2P demo instance") + instancePort = flag.Int("port", 8008, "the port that the client API will listen on") + instancePeer = flag.String("peer", "", "the static Pinecone peers to connect to, comma separated-list") + instanceListen = flag.String("listen", ":0", "the port Pinecone peers can connect to") + instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") + instanceRelayingEnabled = flag.Bool("relay", false, "whether to enable store & forward relaying for other nodes") ) const relayServerRetryInterval = time.Second * 30 @@ -244,7 +245,8 @@ func main() { Config: &base.Cfg.FederationAPI, UserAPI: userAPI, } - relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer, true) + relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer, *instanceRelayingEnabled) + logrus.Infof("Relaying enabled: %v", relayAPI.RelayingEnabled()) monolith := setup.Monolith{ Config: base.Cfg, From 7b3334778f44e98e50bf6b0deb5a04a7a89a03de Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 31 Jan 2023 12:31:57 -0700 Subject: [PATCH 108/124] Use gmsl relay_txn response type --- relayapi/internal/perform.go | 4 ++-- relayapi/internal/perform_test.go | 4 ++-- relayapi/routing/relaytxn.go | 10 ++-------- relayapi/routing/relaytxn_test.go | 12 ++++++------ 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go index c00d0d0ed..62c7d446e 100644 --- a/relayapi/internal/perform.go +++ b/relayapi/internal/perform.go @@ -52,7 +52,7 @@ func (r *RelayInternalAPI) PerformRelayServerSync( logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) return err } - r.processTransaction(&asyncResponse.Txn) + r.processTransaction(&asyncResponse.Transaction) prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} for asyncResponse.EntriesQueued { @@ -64,7 +64,7 @@ func (r *RelayInternalAPI) PerformRelayServerSync( logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) return err } - r.processTransaction(&asyncResponse.Txn) + r.processTransaction(&asyncResponse.Transaction) } return nil diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go index 5673b1994..278706a3e 100644 --- a/relayapi/internal/perform_test.go +++ b/relayapi/internal/perform_test.go @@ -46,8 +46,8 @@ func (f *testFedClient) P2PGetTransactionFromRelay( } res = gomatrixserverlib.RespGetRelayTransaction{ - Txn: gomatrixserverlib.Transaction{}, - EntryID: 0, + Transaction: gomatrixserverlib.Transaction{}, + EntryID: 0, } if f.queueDepth > 0 { res.EntriesQueued = true diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go index 76241a33b..63b42ec7d 100644 --- a/relayapi/routing/relaytxn.go +++ b/relayapi/routing/relaytxn.go @@ -25,12 +25,6 @@ import ( "github.com/sirupsen/logrus" ) -type RelayTransactionResponse struct { - Transaction gomatrixserverlib.Transaction `json:"transaction"` - EntryID int64 `json:"entry_id,omitempty"` - EntriesQueued bool `json:"entries_queued"` -} - // GetTransactionFromRelay implements GET /_matrix/federation/v1/relay_txn/{userID} // This endpoint can be extracted into a separate relay server service. func GetTransactionFromRelay( @@ -41,7 +35,7 @@ func GetTransactionFromRelay( ) util.JSONResponse { logrus.Infof("Processing relay_txn for %s", userID.Raw()) - previousEntry := gomatrixserverlib.RelayEntry{} + var previousEntry gomatrixserverlib.RelayEntry if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { return util.JSONResponse{ Code: http.StatusInternalServerError, @@ -65,7 +59,7 @@ func GetTransactionFromRelay( return util.JSONResponse{ Code: http.StatusOK, - JSON: RelayTransactionResponse{ + JSON: gomatrixserverlib.RespGetRelayTransaction{ Transaction: response.Transaction, EntryID: response.EntryID, EntriesQueued: response.EntriesQueued, diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go index 3cb4c8adc..4c099a642 100644 --- a/relayapi/routing/relaytxn_test.go +++ b/relayapi/routing/relaytxn_test.go @@ -64,7 +64,7 @@ func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse := response.JSON.(routing.RelayTransactionResponse) + jsonResponse := response.JSON.(gomatrixserverlib.RespGetRelayTransaction) assert.Equal(t, false, jsonResponse.EntriesQueued) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) @@ -130,7 +130,7 @@ func TestGetReturnsSavedTransaction(t *testing.T) { response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse := response.JSON.(routing.RelayTransactionResponse) + jsonResponse := response.JSON.(gomatrixserverlib.RespGetRelayTransaction) assert.True(t, jsonResponse.EntriesQueued) assert.Equal(t, transaction, jsonResponse.Transaction) @@ -139,7 +139,7 @@ func TestGetReturnsSavedTransaction(t *testing.T) { response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse = response.JSON.(routing.RelayTransactionResponse) + jsonResponse = response.JSON.(gomatrixserverlib.RespGetRelayTransaction) assert.False(t, jsonResponse.EntriesQueued) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) @@ -193,7 +193,7 @@ func TestGetReturnsMultipleSavedTransactions(t *testing.T) { response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse := response.JSON.(routing.RelayTransactionResponse) + jsonResponse := response.JSON.(gomatrixserverlib.RespGetRelayTransaction) assert.True(t, jsonResponse.EntriesQueued) assert.Equal(t, transaction, jsonResponse.Transaction) @@ -201,7 +201,7 @@ func TestGetReturnsMultipleSavedTransactions(t *testing.T) { response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse = response.JSON.(routing.RelayTransactionResponse) + jsonResponse = response.JSON.(gomatrixserverlib.RespGetRelayTransaction) assert.True(t, jsonResponse.EntriesQueued) assert.Equal(t, transaction2, jsonResponse.Transaction) @@ -210,7 +210,7 @@ func TestGetReturnsMultipleSavedTransactions(t *testing.T) { response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) assert.Equal(t, http.StatusOK, response.Code) - jsonResponse = response.JSON.(routing.RelayTransactionResponse) + jsonResponse = response.JSON.(gomatrixserverlib.RespGetRelayTransaction) assert.False(t, jsonResponse.EntriesQueued) assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) From b935da6c33a94375329bf17a0abb13f74a24c8b2 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 31 Jan 2023 12:32:36 -0700 Subject: [PATCH 109/124] Use new gmsl RelayEvents type for send_relay request body --- relayapi/routing/sendrelay.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go index 85a722c8a..ce744cb49 100644 --- a/relayapi/routing/sendrelay.go +++ b/relayapi/routing/sendrelay.go @@ -36,11 +36,7 @@ func SendTransactionToRelay( ) util.JSONResponse { logrus.Infof("Processing send_relay for %s", userID.Raw()) - var txnEvents struct { - PDUs []json.RawMessage `json:"pdus"` - EDUs []gomatrixserverlib.EDU `json:"edus"` - } - + var txnEvents gomatrixserverlib.RelayEvents if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { logrus.Info("The request body could not be decoded into valid JSON." + err.Error()) return util.JSONResponse{ From 4af88ff0e6fe364eb8f98c64aebd26b059b17196 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 31 Jan 2023 12:49:47 -0700 Subject: [PATCH 110/124] Update gmsl dependency --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 8d5eafcef..25f022707 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f + github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1 github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index d57960561..38122e63c 100644 --- a/go.sum +++ b/go.sum @@ -327,8 +327,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f h1:niRWEVkeeekpjxwnMhKn8PD0PUloDsNXP8W+Ez/co/M= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1 h1:JSw0nmjMrgBmoM2aQsa78LTpI5BnuD9+vOiEQ4Qo0qw= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230131183213-122f1e0e3fa1/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb h1:2L+ltfNKab56FoBBqAvbBLjoAbxwwoZie+B8d+Mp3JI= github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= From 4738fe656fc0975b5faa4fcdf6c11a776b9563a0 Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 1 Feb 2023 16:32:31 +0000 Subject: [PATCH 111/124] Roomserver published pkey migration (#2960) Adds a missed migration to update the primary key on the roomserver_published table in postgres. Primary key was changed in #2836. --- ...0230131091021_published_appservice_pkey.go | 32 +++++++++++++++++++ .../storage/postgres/published_table.go | 14 +++++--- 2 files changed, 42 insertions(+), 4 deletions(-) create mode 100644 roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go diff --git a/roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go b/roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go new file mode 100644 index 000000000..add66446b --- /dev/null +++ b/roomserver/storage/postgres/deltas/20230131091021_published_appservice_pkey.go @@ -0,0 +1,32 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpPulishedAppservicePrimaryKey(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE roomserver_published RENAME CONSTRAINT roomserver_published_pkey TO roomserver_published_pkeyold; +CREATE UNIQUE INDEX roomserver_published_pkey ON roomserver_published (room_id, appservice_id, network_id); +ALTER TABLE roomserver_published DROP CONSTRAINT roomserver_published_pkeyold; +ALTER TABLE roomserver_published ADD PRIMARY KEY USING INDEX roomserver_published_pkey;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 61caccb0e..eca81d81f 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -65,10 +65,16 @@ func CreatePublishedTable(db *sql.DB) error { return err } m := sqlutil.NewMigrator(db) - m.AddMigrations(sqlutil.Migration{ - Version: "roomserver: published appservice", - Up: deltas.UpPulishedAppservice, - }) + m.AddMigrations([]sqlutil.Migration{ + { + Version: "roomserver: published appservice", + Up: deltas.UpPulishedAppservice, + }, + { + Version: "roomserver: published appservice pkey", + Up: deltas.UpPulishedAppservicePrimaryKey, + }, + }...) return m.Up(context.Background()) } From be43b9c0eae5fd45e38f954c23cfadbdfa77cb32 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 31 Jan 2023 15:51:08 -0700 Subject: [PATCH 112/124] Refactor common relay sync struct to remove duplication --- build/gobind-pinecone/monolith.go | 218 ++-------------- build/gobind-pinecone/monolith_test.go | 69 ----- cmd/dendrite-demo-pinecone/main.go | 112 ++------- cmd/dendrite-demo-pinecone/relay/retriever.go | 237 ++++++++++++++++++ .../relay/retriever_test.go | 99 ++++++++ 5 files changed, 366 insertions(+), 369 deletions(-) create mode 100644 cmd/dendrite-demo-pinecone/relay/retriever.go create mode 100644 cmd/dendrite-demo-pinecone/relay/retriever_test.go diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index a219621b1..f44eed89b 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -36,6 +36,7 @@ import ( "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" @@ -70,11 +71,10 @@ import ( ) const ( - PeerTypeRemote = pineconeRouter.PeerTypeRemote - PeerTypeMulticast = pineconeRouter.PeerTypeMulticast - PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth - PeerTypeBonjour = pineconeRouter.PeerTypeBonjour - relayServerRetryInterval = time.Second * 30 + PeerTypeRemote = pineconeRouter.PeerTypeRemote + PeerTypeMulticast = pineconeRouter.PeerTypeMulticast + PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth + PeerTypeBonjour = pineconeRouter.PeerTypeBonjour ) type DendriteMonolith struct { @@ -91,7 +91,7 @@ type DendriteMonolith struct { userAPI userapiAPI.UserInternalAPI federationAPI api.FederationInternalAPI relayAPI relayServerAPI.RelayInternalAPI - relayRetriever RelayServerRetriever + relayRetriever relay.RelayServerRetriever } func (m *DendriteMonolith) PublicKey() string { @@ -189,57 +189,6 @@ func getServerKeyFromString(nodeID string) (gomatrixserverlib.ServerName, error) return nodeKey, nil } -func updateNodeRelayServers( - node gomatrixserverlib.ServerName, - relays []gomatrixserverlib.ServerName, - ctx context.Context, - fedAPI api.FederationInternalAPI, -) { - // Get the current relay list - request := api.P2PQueryRelayServersRequest{Server: node} - response := api.P2PQueryRelayServersResponse{} - err := fedAPI.P2PQueryRelayServers(ctx, &request, &response) - if err != nil { - logrus.Warnf("Failed obtaining list of relay servers for %s: %s", node, err.Error()) - } - - // Remove old, non-matching relays - var serversToRemove []gomatrixserverlib.ServerName - for _, existingServer := range response.RelayServers { - shouldRemove := true - for _, newServer := range relays { - if newServer == existingServer { - shouldRemove = false - break - } - } - - if shouldRemove { - serversToRemove = append(serversToRemove, existingServer) - } - } - removeRequest := api.P2PRemoveRelayServersRequest{ - Server: node, - RelayServers: serversToRemove, - } - removeResponse := api.P2PRemoveRelayServersResponse{} - err = fedAPI.P2PRemoveRelayServers(ctx, &removeRequest, &removeResponse) - if err != nil { - logrus.Warnf("Failed removing old relay servers for %s: %s", node, err.Error()) - } - - // Add new relays - addRequest := api.P2PAddRelayServersRequest{ - Server: node, - RelayServers: relays, - } - addResponse := api.P2PAddRelayServersResponse{} - err = fedAPI.P2PAddRelayServers(ctx, &addRequest, &addResponse) - if err != nil { - logrus.Warnf("Failed adding relay servers for %s: %s", node, err.Error()) - } -} - func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { relays := []gomatrixserverlib.ServerName{} for _, uri := range strings.Split(uris, ",") { @@ -266,7 +215,7 @@ func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { logrus.Infof("Setting own relay servers to: %v", relays) m.relayRetriever.SetRelayServers(relays) } else { - updateNodeRelayServers( + relay.UpdateNodeRelayServers( gomatrixserverlib.ServerName(nodeKey), relays, m.baseDendrite.Context(), @@ -610,27 +559,23 @@ func (m *DendriteMonolith) Start() { }() stopRelayServerSync := make(chan bool) - eLog := logrus.WithField("pinecone", "events") - m.relayRetriever = RelayServerRetriever{ - Context: context.Background(), - ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), - FederationAPI: m.federationAPI, - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - RelayAPI: monolith.RelayAPI, - running: *atomic.NewBool(false), - quit: stopRelayServerSync, - } + m.relayRetriever = relay.NewRelayServerRetriever( + context.Background(), + gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), + m.federationAPI, + monolith.RelayAPI, + stopRelayServerSync, + ) m.relayRetriever.InitializeRelayServers(eLog) go func(ch <-chan pineconeEvents.Event) { - for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: m.relayRetriever.StartSync() case pineconeEvents.PeerRemoved: - if m.relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { + if m.relayRetriever.IsRunning() && m.PineconeRouter.TotalPeerCount() == 0 { stopRelayServerSync <- true } case pineconeEvents.BroadcastReceived: @@ -658,139 +603,6 @@ func (m *DendriteMonolith) Stop() { _ = m.PineconeRouter.Close() } -type RelayServerRetriever struct { - Context context.Context - ServerName gomatrixserverlib.ServerName - FederationAPI api.FederationInternalAPI - RelayAPI relayServerAPI.RelayInternalAPI - relayServersQueried map[gomatrixserverlib.ServerName]bool - queriedServersMutex sync.Mutex - running atomic.Bool - quit <-chan bool -} - -func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { - request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(r.ServerName)} - response := api.P2PQueryRelayServersResponse{} - err := r.FederationAPI.P2PQueryRelayServers(r.Context, &request, &response) - if err != nil { - eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) - } - - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - for _, server := range response.RelayServers { - r.relayServersQueried[server] = false - } - - eLog.Infof("Registered relay servers: %v", response.RelayServers) -} - -func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { - updateNodeRelayServers(r.ServerName, servers, r.Context, r.FederationAPI) - - // Replace list of servers to sync with and mark them all as unsynced. - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) - for _, server := range servers { - r.relayServersQueried[server] = false - } - - r.StartSync() -} - -func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName { - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - relayServers := []gomatrixserverlib.ServerName{} - for server := range r.relayServersQueried { - relayServers = append(relayServers, server) - } - - return relayServers -} - -func (r *RelayServerRetriever) StartSync() { - if !r.running.Load() { - logrus.Info("Starting relay server sync") - go r.SyncRelayServers(r.quit) - } -} - -func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { - defer r.running.Store(false) - - t := time.NewTimer(relayServerRetryInterval) - for { - relayServersToQuery := []gomatrixserverlib.ServerName{} - func() { - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - for server, complete := range r.relayServersQueried { - if !complete { - relayServersToQuery = append(relayServersToQuery, server) - } - } - }() - if len(relayServersToQuery) == 0 { - // All relay servers have been synced. - logrus.Info("Finished syncing with all known relays") - return - } - r.queryRelayServers(relayServersToQuery) - t.Reset(relayServerRetryInterval) - - select { - case <-stop: - if !t.Stop() { - <-t.C - } - return - case <-t.C: - } - } -} - -func (r *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - - result := map[gomatrixserverlib.ServerName]bool{} - for server, queried := range r.relayServersQueried { - result[server] = queried - } - return result -} - -func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { - logrus.Info("Querying relay servers for any available transactions") - for _, server := range relayServers { - userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.ServerName), false) - if err != nil { - return - } - - logrus.Infof("Syncing with relay: %s", string(server)) - err = r.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) - if err == nil { - func() { - r.queriedServersMutex.Lock() - defer r.queriedServersMutex.Unlock() - r.relayServersQueried[server] = true - }() - // TODO : What happens if your relay receives new messages after this point? - // Should you continue to check with them, or should they try and contact you? - // They could send a "new_async_events" message your way maybe? - // Then you could mark them as needing to be queried again. - // What if you miss this message? - // Maybe you should try querying them again after a certain period of time as a backup? - } else { - logrus.Errorf("Failed querying relay server: %s", err.Error()) - } - } -} - const MaxFrameSize = types.MaxFrameSize type Conduit struct { diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index 3c8873e09..f6bf2ef03 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -15,19 +15,13 @@ package gobind import ( - "context" "fmt" "net" "strings" "testing" - "time" - "github.com/matrix-org/dendrite/federationapi/api" - relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/gomatrixserverlib" - "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "gotest.tools/v3/poll" ) var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} @@ -128,69 +122,6 @@ func TestConduitReadCopyFails(t *testing.T) { assert.Error(t, err) } -var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"} - -type FakeFedAPI struct { - api.FederationInternalAPI -} - -func (f *FakeFedAPI) P2PQueryRelayServers(ctx context.Context, req *api.P2PQueryRelayServersRequest, res *api.P2PQueryRelayServersResponse) error { - res.RelayServers = testRelayServers - return nil -} - -type FakeRelayAPI struct { - relayServerAPI.RelayInternalAPI -} - -func (r *FakeRelayAPI) PerformRelayServerSync(ctx context.Context, userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error { - return nil -} - -func TestRelayRetrieverInitialization(t *testing.T) { - retriever := RelayServerRetriever{ - Context: context.Background(), - ServerName: "server", - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - FederationAPI: &FakeFedAPI{}, - RelayAPI: &FakeRelayAPI{}, - } - - retriever.InitializeRelayServers(logrus.WithField("test", "relay")) - relayServers := retriever.GetQueriedServerStatus() - assert.Equal(t, 2, len(relayServers)) -} - -func TestRelayRetrieverSync(t *testing.T) { - retriever := RelayServerRetriever{ - Context: context.Background(), - ServerName: "server", - relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - FederationAPI: &FakeFedAPI{}, - RelayAPI: &FakeRelayAPI{}, - } - - retriever.InitializeRelayServers(logrus.WithField("test", "relay")) - relayServers := retriever.GetQueriedServerStatus() - assert.Equal(t, 2, len(relayServers)) - - stopRelayServerSync := make(chan bool) - go retriever.SyncRelayServers(stopRelayServerSync) - - check := func(log poll.LogT) poll.Result { - relayServers := retriever.GetQueriedServerStatus() - for _, queried := range relayServers { - if !queried { - return poll.Continue("waiting for all servers to be queried") - } - } - - stopRelayServerSync <- true - return poll.Success() - } - poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) -} - func TestMonolithStarts(t *testing.T) { monolith := DendriteMonolith{} monolith.Start() diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 3f63ab654..abffbb67a 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -33,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" @@ -43,7 +44,6 @@ import ( "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/relayapi" - relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" @@ -52,7 +52,6 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" - "go.uber.org/atomic" pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" @@ -72,8 +71,6 @@ var ( instanceRelayingEnabled = flag.Bool("relay", false, "whether to enable store & forward relaying for other nodes") ) -const relayServerRetryInterval = time.Second * 30 - // nolint:gocyclo func main() { flag.Parse() @@ -328,28 +325,24 @@ func main() { logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) }() + stopRelayServerSync := make(chan bool) + eLog := logrus.WithField("pinecone", "events") + relayRetriever := relay.NewRelayServerRetriever( + context.Background(), + gomatrixserverlib.ServerName(pRouter.PublicKey().String()), + monolith.FederationAPI, + monolith.RelayAPI, + stopRelayServerSync, + ) + relayRetriever.InitializeRelayServers(eLog) + go func(ch <-chan pineconeEvents.Event) { - eLog := logrus.WithField("pinecone", "events") - relayServerSyncRunning := atomic.NewBool(false) - stopRelayServerSync := make(chan bool) - - m := RelayServerRetriever{ - Context: context.Background(), - ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()), - FederationAPI: fsAPI, - RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), - RelayAPI: monolith.RelayAPI, - } - m.InitializeRelayServers(eLog) - for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: - if !relayServerSyncRunning.Load() { - go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning) - } + relayRetriever.StartSync() case pineconeEvents.PeerRemoved: - if relayServerSyncRunning.Load() && pRouter.TotalPeerCount() == 0 { + if relayRetriever.IsRunning() && pRouter.TotalPeerCount() == 0 { stopRelayServerSync <- true } case pineconeEvents.BroadcastReceived: @@ -359,7 +352,7 @@ func main() { ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, } res := &api.PerformWakeupServersResponse{} - if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + if err := monolith.FederationAPI.PerformWakeupServers(base.Context(), req, res); err != nil { eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } default: @@ -369,78 +362,3 @@ func main() { base.WaitForShutdown() } - -type RelayServerRetriever struct { - Context context.Context - ServerName gomatrixserverlib.ServerName - FederationAPI api.FederationInternalAPI - RelayServersQueried map[gomatrixserverlib.ServerName]bool - RelayAPI relayServerAPI.RelayInternalAPI -} - -func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { - request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} - response := api.P2PQueryRelayServersResponse{} - err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) - if err != nil { - eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) - } - for _, server := range response.RelayServers { - m.RelayServersQueried[server] = false - } - - eLog.Infof("Registered relay servers: %v", response.RelayServers) -} - -func (m *RelayServerRetriever) syncRelayServers(stop <-chan bool, running atomic.Bool) { - defer running.Store(false) - - t := time.NewTimer(relayServerRetryInterval) - for { - relayServersToQuery := []gomatrixserverlib.ServerName{} - for server, complete := range m.RelayServersQueried { - if !complete { - relayServersToQuery = append(relayServersToQuery, server) - } - } - if len(relayServersToQuery) == 0 { - // All relay servers have been synced. - return - } - m.queryRelayServers(relayServersToQuery) - t.Reset(relayServerRetryInterval) - - select { - case <-stop: - // We have been asked to stop syncing, drain the timer and return. - if !t.Stop() { - <-t.C - } - return - case <-t.C: - // The timer has expired. Continue to the next loop iteration. - } - } -} - -func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { - logrus.Info("querying relay servers for any available transactions") - for _, server := range relayServers { - userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) - if err != nil { - return - } - err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) - if err == nil { - m.RelayServersQueried[server] = true - // TODO : What happens if your relay receives new messages after this point? - // Should you continue to check with them, or should they try and contact you? - // They could send a "new_async_events" message your way maybe? - // Then you could mark them as needing to be queried again. - // What if you miss this message? - // Maybe you should try querying them again after a certain period of time as a backup? - } else { - logrus.Errorf("Failed querying relay server: %s", err.Error()) - } - } -} diff --git a/cmd/dendrite-demo-pinecone/relay/retriever.go b/cmd/dendrite-demo-pinecone/relay/retriever.go new file mode 100644 index 000000000..1b5c617ef --- /dev/null +++ b/cmd/dendrite-demo-pinecone/relay/retriever.go @@ -0,0 +1,237 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relay + +import ( + "context" + "sync" + "time" + + federationAPI "github.com/matrix-org/dendrite/federationapi/api" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "go.uber.org/atomic" +) + +const ( + relayServerRetryInterval = time.Second * 30 +) + +type RelayServerRetriever struct { + ctx context.Context + serverName gomatrixserverlib.ServerName + federationAPI federationAPI.FederationInternalAPI + relayAPI relayServerAPI.RelayInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool + queriedServersMutex sync.Mutex + running atomic.Bool + quit <-chan bool +} + +func NewRelayServerRetriever( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + federationAPI federationAPI.FederationInternalAPI, + relayAPI relayServerAPI.RelayInternalAPI, + quit <-chan bool, +) RelayServerRetriever { + return RelayServerRetriever{ + ctx: ctx, + serverName: serverName, + federationAPI: federationAPI, + relayAPI: relayAPI, + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + running: *atomic.NewBool(false), + quit: quit, + } +} + +func (r *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := federationAPI.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(r.serverName)} + response := federationAPI.P2PQueryRelayServersResponse{} + err := r.federationAPI.P2PQueryRelayServers(r.ctx, &request, &response) + if err != nil { + eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + } + + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + for _, server := range response.RelayServers { + r.relayServersQueried[server] = false + } + + eLog.Infof("Registered relay servers: %v", response.RelayServers) +} + +func (r *RelayServerRetriever) SetRelayServers(servers []gomatrixserverlib.ServerName) { + UpdateNodeRelayServers(r.serverName, servers, r.ctx, r.federationAPI) + + // Replace list of servers to sync with and mark them all as unsynced. + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried = make(map[gomatrixserverlib.ServerName]bool) + for _, server := range servers { + r.relayServersQueried[server] = false + } + + r.StartSync() +} + +func (r *RelayServerRetriever) GetRelayServers() []gomatrixserverlib.ServerName { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + relayServers := []gomatrixserverlib.ServerName{} + for server := range r.relayServersQueried { + relayServers = append(relayServers, server) + } + + return relayServers +} + +func (r *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + + result := map[gomatrixserverlib.ServerName]bool{} + for server, queried := range r.relayServersQueried { + result[server] = queried + } + return result +} + +func (r *RelayServerRetriever) StartSync() { + if !r.running.Load() { + logrus.Info("Starting relay server sync") + go r.SyncRelayServers(r.quit) + } +} + +func (r *RelayServerRetriever) IsRunning() bool { + return r.running.Load() +} + +func (r *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { + defer r.running.Store(false) + + t := time.NewTimer(relayServerRetryInterval) + for { + relayServersToQuery := []gomatrixserverlib.ServerName{} + func() { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + for server, complete := range r.relayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } + } + }() + if len(relayServersToQuery) == 0 { + // All relay servers have been synced. + logrus.Info("Finished syncing with all known relays") + return + } + r.queryRelayServers(relayServersToQuery) + t.Reset(relayServerRetryInterval) + + select { + case <-stop: + if !t.Stop() { + <-t.C + } + return + case <-t.C: + } + } +} + +func (r *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("Querying relay servers for any available transactions") + for _, server := range relayServers { + userID, err := gomatrixserverlib.NewUserID("@user:"+string(r.serverName), false) + if err != nil { + return + } + + logrus.Infof("Syncing with relay: %s", string(server)) + err = r.relayAPI.PerformRelayServerSync(context.Background(), *userID, server) + if err == nil { + func() { + r.queriedServersMutex.Lock() + defer r.queriedServersMutex.Unlock() + r.relayServersQueried[server] = true + }() + // TODO : What happens if your relay receives new messages after this point? + // Should you continue to check with them, or should they try and contact you? + // They could send a "new_async_events" message your way maybe? + // Then you could mark them as needing to be queried again. + // What if you miss this message? + // Maybe you should try querying them again after a certain period of time as a backup? + } else { + logrus.Errorf("Failed querying relay server: %s", err.Error()) + } + } +} + +func UpdateNodeRelayServers( + node gomatrixserverlib.ServerName, + relays []gomatrixserverlib.ServerName, + ctx context.Context, + fedAPI federationAPI.FederationInternalAPI, +) { + // Get the current relay list + request := federationAPI.P2PQueryRelayServersRequest{Server: node} + response := federationAPI.P2PQueryRelayServersResponse{} + err := fedAPI.P2PQueryRelayServers(ctx, &request, &response) + if err != nil { + logrus.Warnf("Failed obtaining list of relay servers for %s: %s", node, err.Error()) + } + + // Remove old, non-matching relays + var serversToRemove []gomatrixserverlib.ServerName + for _, existingServer := range response.RelayServers { + shouldRemove := true + for _, newServer := range relays { + if newServer == existingServer { + shouldRemove = false + break + } + } + + if shouldRemove { + serversToRemove = append(serversToRemove, existingServer) + } + } + removeRequest := federationAPI.P2PRemoveRelayServersRequest{ + Server: node, + RelayServers: serversToRemove, + } + removeResponse := federationAPI.P2PRemoveRelayServersResponse{} + err = fedAPI.P2PRemoveRelayServers(ctx, &removeRequest, &removeResponse) + if err != nil { + logrus.Warnf("Failed removing old relay servers for %s: %s", node, err.Error()) + } + + // Add new relays + addRequest := federationAPI.P2PAddRelayServersRequest{ + Server: node, + RelayServers: relays, + } + addResponse := federationAPI.P2PAddRelayServersResponse{} + err = fedAPI.P2PAddRelayServers(ctx, &addRequest, &addResponse) + if err != nil { + logrus.Warnf("Failed adding relay servers for %s: %s", node, err.Error()) + } +} diff --git a/cmd/dendrite-demo-pinecone/relay/retriever_test.go b/cmd/dendrite-demo-pinecone/relay/retriever_test.go new file mode 100644 index 000000000..8f86a3770 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/relay/retriever_test.go @@ -0,0 +1,99 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relay + +import ( + "context" + "testing" + "time" + + federationAPI "github.com/matrix-org/dendrite/federationapi/api" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "gotest.tools/v3/poll" +) + +var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"} + +type FakeFedAPI struct { + federationAPI.FederationInternalAPI +} + +func (f *FakeFedAPI) P2PQueryRelayServers( + ctx context.Context, + req *federationAPI.P2PQueryRelayServersRequest, + res *federationAPI.P2PQueryRelayServersResponse, +) error { + res.RelayServers = testRelayServers + return nil +} + +type FakeRelayAPI struct { + relayServerAPI.RelayInternalAPI +} + +func (r *FakeRelayAPI) PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, +) error { + return nil +} + +func TestRelayRetrieverInitialization(t *testing.T) { + retriever := NewRelayServerRetriever( + context.Background(), + "server", + &FakeFedAPI{}, + &FakeRelayAPI{}, + make(<-chan bool), + ) + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) +} + +func TestRelayRetrieverSync(t *testing.T) { + retriever := NewRelayServerRetriever( + context.Background(), + "server", + &FakeFedAPI{}, + &FakeRelayAPI{}, + make(<-chan bool), + ) + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) + + stopRelayServerSync := make(chan bool) + go retriever.SyncRelayServers(stopRelayServerSync) + + check := func(log poll.LogT) poll.Result { + relayServers := retriever.GetQueriedServerStatus() + for _, queried := range relayServers { + if !queried { + return poll.Continue("waiting for all servers to be queried") + } + } + + stopRelayServerSync <- true + return poll.Success() + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} From 529feb07eec0b52eee66bc1f4ce6985b83485e73 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 31 Jan 2023 16:10:48 -0700 Subject: [PATCH 113/124] Refactor conduit type from pinecone demo into its own package --- build/gobind-pinecone/monolith.go | 71 ++-------- build/gobind-pinecone/monolith_test.go | 101 --------------- cmd/dendrite-demo-pinecone/conduit/conduit.go | 84 ++++++++++++ .../conduit/conduit_test.go | 121 ++++++++++++++++++ 4 files changed, 214 insertions(+), 163 deletions(-) create mode 100644 cmd/dendrite-demo-pinecone/conduit/conduit.go create mode 100644 cmd/dendrite-demo-pinecone/conduit/conduit_test.go diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index f44eed89b..848ac79fe 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -21,20 +21,17 @@ import ( "crypto/tls" "encoding/hex" "fmt" - "io" "net" "net/http" "os" "path/filepath" "strings" - "sync" "time" - "go.uber.org/atomic" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conduit" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" @@ -291,16 +288,14 @@ func (m *DendriteMonolith) DisconnectPort(port int) { m.PineconeRouter.Disconnect(types.SwitchPortID(port), nil) } -func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error) { +func (m *DendriteMonolith) Conduit(zone string, peertype int) (*conduit.Conduit, error) { l, r := net.Pipe() - conduit := &Conduit{conn: r, port: 0} + newConduit := conduit.NewConduit(r, 0) go func() { - conduit.portMutex.Lock() - defer conduit.portMutex.Unlock() - logrus.Errorf("Attempting authenticated connect") + var port types.SwitchPortID var err error - if conduit.port, err = m.PineconeRouter.Connect( + if port, err = m.PineconeRouter.Connect( l, pineconeRouter.ConnectionZone(zone), pineconeRouter.ConnectionPeerType(peertype), @@ -308,12 +303,13 @@ func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error) logrus.Errorf("Authenticated connect failed: %s", err) _ = l.Close() _ = r.Close() - _ = conduit.Close() + _ = newConduit.Close() return } - logrus.Infof("Authenticated connect succeeded (port %d)", conduit.port) + newConduit.SetPort(port) + logrus.Infof("Authenticated connect succeeded (port %d)", newConduit.Port()) }() - return conduit, nil + return &newConduit, nil } func (m *DendriteMonolith) RegisterUser(localpart, password string) (string, error) { @@ -602,52 +598,3 @@ func (m *DendriteMonolith) Stop() { _ = m.PineconeQUIC.Close() _ = m.PineconeRouter.Close() } - -const MaxFrameSize = types.MaxFrameSize - -type Conduit struct { - closed atomic.Bool - conn net.Conn - port types.SwitchPortID - portMutex sync.Mutex -} - -func (c *Conduit) Port() int { - c.portMutex.Lock() - defer c.portMutex.Unlock() - return int(c.port) -} - -func (c *Conduit) Read(b []byte) (int, error) { - if c.closed.Load() { - return 0, io.EOF - } - return c.conn.Read(b) -} - -func (c *Conduit) ReadCopy() ([]byte, error) { - if c.closed.Load() { - return nil, io.EOF - } - var buf [65535 * 2]byte - n, err := c.conn.Read(buf[:]) - if err != nil { - return nil, err - } - return buf[:n], nil -} - -func (c *Conduit) Write(b []byte) (int, error) { - if c.closed.Load() { - return 0, io.EOF - } - return c.conn.Write(b) -} - -func (c *Conduit) Close() error { - if c.closed.Load() { - return io.ErrClosedPipe - } - c.closed.Store(true) - return c.conn.Close() -} diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go index f6bf2ef03..434e07ef2 100644 --- a/build/gobind-pinecone/monolith_test.go +++ b/build/gobind-pinecone/monolith_test.go @@ -15,113 +15,12 @@ package gobind import ( - "fmt" - "net" "strings" "testing" "github.com/matrix-org/gomatrixserverlib" - "github.com/stretchr/testify/assert" ) -var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} - -type TestNetConn struct { - net.Conn - shouldFail bool -} - -func (t *TestNetConn) Read(b []byte) (int, error) { - if t.shouldFail { - return 0, fmt.Errorf("Failed") - } else { - n := copy(b, TestBuf) - return n, nil - } -} - -func (t *TestNetConn) Write(b []byte) (int, error) { - if t.shouldFail { - return 0, fmt.Errorf("Failed") - } else { - return len(b), nil - } -} - -func (t *TestNetConn) Close() error { - if t.shouldFail { - return fmt.Errorf("Failed") - } else { - return nil - } -} - -func TestConduitStoresPort(t *testing.T) { - conduit := Conduit{port: 7} - assert.Equal(t, 7, conduit.Port()) -} - -func TestConduitRead(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - b := make([]byte, len(TestBuf)) - bytes, err := conduit.Read(b) - assert.NoError(t, err) - assert.Equal(t, len(TestBuf), bytes) - assert.Equal(t, TestBuf, b) -} - -func TestConduitReadCopy(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - result, err := conduit.ReadCopy() - assert.NoError(t, err) - assert.Equal(t, TestBuf, result) -} - -func TestConduitWrite(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - bytes, err := conduit.Write(TestBuf) - assert.NoError(t, err) - assert.Equal(t, len(TestBuf), bytes) -} - -func TestConduitClose(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - err := conduit.Close() - assert.NoError(t, err) - assert.True(t, conduit.closed.Load()) -} - -func TestConduitReadClosed(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - err := conduit.Close() - assert.NoError(t, err) - b := make([]byte, len(TestBuf)) - _, err = conduit.Read(b) - assert.Error(t, err) -} - -func TestConduitReadCopyClosed(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - err := conduit.Close() - assert.NoError(t, err) - _, err = conduit.ReadCopy() - assert.Error(t, err) -} - -func TestConduitWriteClosed(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{}} - err := conduit.Close() - assert.NoError(t, err) - _, err = conduit.Write(TestBuf) - assert.Error(t, err) -} - -func TestConduitReadCopyFails(t *testing.T) { - conduit := Conduit{conn: &TestNetConn{shouldFail: true}} - _, err := conduit.ReadCopy() - assert.Error(t, err) -} - func TestMonolithStarts(t *testing.T) { monolith := DendriteMonolith{} monolith.Start() diff --git a/cmd/dendrite-demo-pinecone/conduit/conduit.go b/cmd/dendrite-demo-pinecone/conduit/conduit.go new file mode 100644 index 000000000..be139c19c --- /dev/null +++ b/cmd/dendrite-demo-pinecone/conduit/conduit.go @@ -0,0 +1,84 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conduit + +import ( + "io" + "net" + "sync" + + "github.com/matrix-org/pinecone/types" + "go.uber.org/atomic" +) + +type Conduit struct { + closed atomic.Bool + conn net.Conn + portMutex sync.Mutex + port types.SwitchPortID +} + +func NewConduit(conn net.Conn, port int) Conduit { + return Conduit{ + conn: conn, + port: types.SwitchPortID(port), + } +} + +func (c *Conduit) Port() int { + c.portMutex.Lock() + defer c.portMutex.Unlock() + return int(c.port) +} + +func (c *Conduit) SetPort(port types.SwitchPortID) { + c.portMutex.Lock() + defer c.portMutex.Unlock() + c.port = port +} + +func (c *Conduit) Read(b []byte) (int, error) { + if c.closed.Load() { + return 0, io.EOF + } + return c.conn.Read(b) +} + +func (c *Conduit) ReadCopy() ([]byte, error) { + if c.closed.Load() { + return nil, io.EOF + } + var buf [65535 * 2]byte + n, err := c.conn.Read(buf[:]) + if err != nil { + return nil, err + } + return buf[:n], nil +} + +func (c *Conduit) Write(b []byte) (int, error) { + if c.closed.Load() { + return 0, io.EOF + } + return c.conn.Write(b) +} + +func (c *Conduit) Close() error { + if c.closed.Load() { + return io.ErrClosedPipe + } + c.closed.Store(true) + return c.conn.Close() +} diff --git a/cmd/dendrite-demo-pinecone/conduit/conduit_test.go b/cmd/dendrite-demo-pinecone/conduit/conduit_test.go new file mode 100644 index 000000000..d8cd3133f --- /dev/null +++ b/cmd/dendrite-demo-pinecone/conduit/conduit_test.go @@ -0,0 +1,121 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package conduit + +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + +type TestNetConn struct { + net.Conn + shouldFail bool +} + +func (t *TestNetConn) Read(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + n := copy(b, TestBuf) + return n, nil + } +} + +func (t *TestNetConn) Write(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + return len(b), nil + } +} + +func (t *TestNetConn) Close() error { + if t.shouldFail { + return fmt.Errorf("Failed") + } else { + return nil + } +} + +func TestConduitStoresPort(t *testing.T) { + conduit := Conduit{port: 7} + assert.Equal(t, 7, conduit.Port()) +} + +func TestConduitRead(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + b := make([]byte, len(TestBuf)) + bytes, err := conduit.Read(b) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) + assert.Equal(t, TestBuf, b) +} + +func TestConduitReadCopy(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + result, err := conduit.ReadCopy() + assert.NoError(t, err) + assert.Equal(t, TestBuf, result) +} + +func TestConduitWrite(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + bytes, err := conduit.Write(TestBuf) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) +} + +func TestConduitClose(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + assert.True(t, conduit.closed.Load()) +} + +func TestConduitReadClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + b := make([]byte, len(TestBuf)) + _, err = conduit.Read(b) + assert.Error(t, err) +} + +func TestConduitReadCopyClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.ReadCopy() + assert.Error(t, err) +} + +func TestConduitWriteClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.Write(TestBuf) + assert.Error(t, err) +} + +func TestConduitReadCopyFails(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{shouldFail: true}} + _, err := conduit.ReadCopy() + assert.Error(t, err) +} From 2f8377e94b5b32f4957c1f6c64acdd8bd2880e23 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 31 Jan 2023 16:13:59 -0700 Subject: [PATCH 114/124] Remove nolint: gocyclo from relayapi routing setup --- relayapi/routing/routing.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go index c908e950a..8ee0743ef 100644 --- a/relayapi/routing/routing.go +++ b/relayapi/routing/routing.go @@ -34,10 +34,6 @@ import ( // The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly // path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled // so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz] -// -// Due to Setup being used to call many other functions, a gocyclo nolint is -// applied: -// nolint: gocyclo func Setup( fedMux *mux.Router, cfg *config.FederationAPI, From d4f64f91ca5afc16bfd55f64b59310f13473f016 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 31 Jan 2023 16:27:51 -0700 Subject: [PATCH 115/124] Refactor pinecone demo to remove duplicate key setup --- build/gobind-pinecone/monolith.go | 42 ++--------------- cmd/dendrite-demo-pinecone/keys/keys.go | 63 +++++++++++++++++++++++++ cmd/dendrite-demo-pinecone/main.go | 39 ++------------- 3 files changed, 69 insertions(+), 75 deletions(-) create mode 100644 cmd/dendrite-demo-pinecone/keys/keys.go diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 848ac79fe..dd4838737 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -23,7 +23,6 @@ import ( "fmt" "net" "net/http" - "os" "path/filepath" "strings" "time" @@ -33,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conduit" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/keys" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" @@ -49,7 +49,6 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" userapiAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -351,45 +350,10 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e return loginRes.Device.AccessToken, nil } -// nolint:gocyclo func (m *DendriteMonolith) Start() { - var sk ed25519.PrivateKey - var pk ed25519.PublicKey - keyfile := filepath.Join(m.StorageDirectory, "p2p.pem") - if _, err := os.Stat(keyfile); os.IsNotExist(err) { - oldkeyfile := filepath.Join(m.StorageDirectory, "p2p.key") - if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) { - if err = test.NewMatrixKey(keyfile); err != nil { - panic("failed to generate a new PEM key: " + err.Error()) - } - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } else { - if sk, err = os.ReadFile(oldkeyfile); err != nil { - panic("failed to read the old private key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - if err = test.SaveMatrixKey(keyfile, sk); err != nil { - panic("failed to convert the private key to PEM format: " + err.Error()) - } - } - } else { - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } - - pk = sk.Public().(ed25519.PublicKey) + oldKeyfile := filepath.Join(m.StorageDirectory, "p2p.key") + sk, pk := keys.GetOrCreateKey(keyfile, oldKeyfile) var err error m.listener, err = net.Listen("tcp", "localhost:65432") diff --git a/cmd/dendrite-demo-pinecone/keys/keys.go b/cmd/dendrite-demo-pinecone/keys/keys.go new file mode 100644 index 000000000..db84d8ec9 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/keys/keys.go @@ -0,0 +1,63 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package keys + +import ( + "crypto/ed25519" + "os" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" +) + +func GetOrCreateKey(keyfile string, oldKeyfile string) (ed25519.PrivateKey, ed25519.PublicKey) { + var sk ed25519.PrivateKey + var pk ed25519.PublicKey + + if _, err := os.Stat(keyfile); os.IsNotExist(err) { + if _, err = os.Stat(oldKeyfile); os.IsNotExist(err) { + if err = test.NewMatrixKey(keyfile); err != nil { + panic("failed to generate a new PEM key: " + err.Error()) + } + if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { + panic("failed to load PEM key: " + err.Error()) + } + if len(sk) != ed25519.PrivateKeySize { + panic("the private key is not long enough") + } + } else { + if sk, err = os.ReadFile(oldKeyfile); err != nil { + panic("failed to read the old private key: " + err.Error()) + } + if len(sk) != ed25519.PrivateKeySize { + panic("the private key is not long enough") + } + if err = test.SaveMatrixKey(keyfile, sk); err != nil { + panic("failed to convert the private key to PEM format: " + err.Error()) + } + } + } else { + if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { + panic("failed to load PEM key: " + err.Error()) + } + if len(sk) != ed25519.PrivateKeySize { + panic("the private key is not long enough") + } + } + + pk = sk.Public().(ed25519.PublicKey) + + return sk, pk +} diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index abffbb67a..9ee792e4a 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -33,6 +33,7 @@ import ( "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/keys" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" @@ -49,7 +50,6 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" @@ -71,7 +71,6 @@ var ( instanceRelayingEnabled = flag.Bool("relay", false, "whether to enable store & forward relaying for other nodes") ) -// nolint:gocyclo func main() { flag.Parse() internal.SetupPprof() @@ -97,40 +96,8 @@ func main() { pk = sk.Public().(ed25519.PublicKey) } else { keyfile := filepath.Join(*instanceDir, *instanceName) + ".pem" - if _, err := os.Stat(keyfile); os.IsNotExist(err) { - oldkeyfile := *instanceName + ".key" - if _, err = os.Stat(oldkeyfile); os.IsNotExist(err) { - if err = test.NewMatrixKey(keyfile); err != nil { - panic("failed to generate a new PEM key: " + err.Error()) - } - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } else { - if sk, err = os.ReadFile(oldkeyfile); err != nil { - panic("failed to read the old private key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - if err := test.SaveMatrixKey(keyfile, sk); err != nil { - panic("failed to convert the private key to PEM format: " + err.Error()) - } - } - } else { - var err error - if _, sk, err = config.LoadMatrixKey(keyfile, os.ReadFile); err != nil { - panic("failed to load PEM key: " + err.Error()) - } - if len(sk) != ed25519.PrivateKeySize { - panic("the private key is not long enough") - } - } - - pk = sk.Public().(ed25519.PublicKey) + oldKeyfile := *instanceName + ".key" + sk, pk = keys.GetOrCreateKey(keyfile, oldKeyfile) cfg.Defaults(config.DefaultOpts{ Generate: true, From dbc2869cbd36b069f4a0745557320c5fe1daf0f7 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 31 Jan 2023 17:35:21 -0700 Subject: [PATCH 116/124] Refactor pinecone demo to remove duplicate pinecone setup --- build/gobind-pinecone/monolith.go | 125 ++++++------------ cmd/dendrite-demo-pinecone/main.go | 86 ++++-------- .../{keys => monolith}/keys.go | 2 +- .../monolith/monolith.go | 80 +++++++++++ 4 files changed, 153 insertions(+), 140 deletions(-) rename cmd/dendrite-demo-pinecone/{keys => monolith}/keys.go (99%) create mode 100644 cmd/dendrite-demo-pinecone/monolith/monolith.go diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index dd4838737..cca92151b 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -32,7 +32,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conduit" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/keys" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/monolith" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" @@ -47,21 +47,18 @@ import ( "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/userapi" userapiAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/pinecone/types" "github.com/sirupsen/logrus" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" - pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeEvents "github.com/matrix-org/pinecone/router/events" - pineconeSessions "github.com/matrix-org/pinecone/sessions" - "github.com/matrix-org/pinecone/types" _ "golang.org/x/mobile/bind" ) @@ -74,24 +71,20 @@ const ( ) type DendriteMonolith struct { - logger logrus.Logger - baseDendrite *base.BaseDendrite - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - PineconeManager *pineconeConnections.ConnectionManager - StorageDirectory string - CacheDirectory string - listener net.Listener - httpServer *http.Server - userAPI userapiAPI.UserInternalAPI - federationAPI api.FederationInternalAPI - relayAPI relayServerAPI.RelayInternalAPI - relayRetriever relay.RelayServerRetriever + logger logrus.Logger + baseDendrite *base.BaseDendrite + p2pMonolith monolith.P2PMonolith + StorageDirectory string + listener net.Listener + httpServer *http.Server + userAPI userapiAPI.UserInternalAPI + federationAPI api.FederationInternalAPI + relayAPI relayServerAPI.RelayInternalAPI + relayRetriever relay.RelayServerRetriever } func (m *DendriteMonolith) PublicKey() string { - return m.PineconeRouter.PublicKey().String() + return m.p2pMonolith.Router.PublicKey().String() } func (m *DendriteMonolith) BaseURL() string { @@ -99,11 +92,11 @@ func (m *DendriteMonolith) BaseURL() string { } func (m *DendriteMonolith) PeerCount(peertype int) int { - return m.PineconeRouter.PeerCount(peertype) + return m.p2pMonolith.Router.PeerCount(peertype) } func (m *DendriteMonolith) SessionCount() int { - return len(m.PineconeQUIC.Protocol("matrix").Sessions()) + return len(m.p2pMonolith.Sessions.Protocol("matrix").Sessions()) } type InterfaceInfo struct { @@ -145,22 +138,22 @@ func (m *DendriteMonolith) RegisterNetworkCallback(intfCallback InterfaceRetriev } return intfs } - m.PineconeMulticast.RegisterNetworkCallback(callback) + m.p2pMonolith.Multicast.RegisterNetworkCallback(callback) } func (m *DendriteMonolith) SetMulticastEnabled(enabled bool) { if enabled { - m.PineconeMulticast.Start() + m.p2pMonolith.Multicast.Start() } else { - m.PineconeMulticast.Stop() + m.p2pMonolith.Multicast.Stop() m.DisconnectType(int(pineconeRouter.PeerTypeMulticast)) } } func (m *DendriteMonolith) SetStaticPeer(uri string) { - m.PineconeManager.RemovePeers() + m.p2pMonolith.ConnManager.RemovePeers() for _, uri := range strings.Split(uri, ",") { - m.PineconeManager.AddPeer(strings.TrimSpace(uri)) + m.p2pMonolith.ConnManager.AddPeer(strings.TrimSpace(uri)) } } @@ -268,23 +261,23 @@ func (m *DendriteMonolith) SetRelayingEnabled(enabled bool) { } func (m *DendriteMonolith) DisconnectType(peertype int) { - for _, p := range m.PineconeRouter.Peers() { + for _, p := range m.p2pMonolith.Router.Peers() { if int(peertype) == p.PeerType { - m.PineconeRouter.Disconnect(types.SwitchPortID(p.Port), nil) + m.p2pMonolith.Router.Disconnect(types.SwitchPortID(p.Port), nil) } } } func (m *DendriteMonolith) DisconnectZone(zone string) { - for _, p := range m.PineconeRouter.Peers() { + for _, p := range m.p2pMonolith.Router.Peers() { if zone == p.Zone { - m.PineconeRouter.Disconnect(types.SwitchPortID(p.Port), nil) + m.p2pMonolith.Router.Disconnect(types.SwitchPortID(p.Port), nil) } } } func (m *DendriteMonolith) DisconnectPort(port int) { - m.PineconeRouter.Disconnect(types.SwitchPortID(port), nil) + m.p2pMonolith.Router.Disconnect(types.SwitchPortID(port), nil) } func (m *DendriteMonolith) Conduit(zone string, peertype int) (*conduit.Conduit, error) { @@ -294,7 +287,7 @@ func (m *DendriteMonolith) Conduit(zone string, peertype int) (*conduit.Conduit, logrus.Errorf("Attempting authenticated connect") var port types.SwitchPortID var err error - if port, err = m.PineconeRouter.Connect( + if port, err = m.p2pMonolith.Router.Connect( l, pineconeRouter.ConnectionZone(zone), pineconeRouter.ConnectionPeerType(peertype), @@ -312,7 +305,7 @@ func (m *DendriteMonolith) Conduit(zone string, peertype int) (*conduit.Conduit, } func (m *DendriteMonolith) RegisterUser(localpart, password string) (string, error) { - pubkey := m.PineconeRouter.PublicKey() + pubkey := m.p2pMonolith.Router.PublicKey() userID := userutil.MakeUserID( localpart, gomatrixserverlib.ServerName(hex.EncodeToString(pubkey[:])), @@ -353,7 +346,7 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e func (m *DendriteMonolith) Start() { keyfile := filepath.Join(m.StorageDirectory, "p2p.pem") oldKeyfile := filepath.Join(m.StorageDirectory, "p2p.key") - sk, pk := keys.GetOrCreateKey(keyfile, oldKeyfile) + sk, pk := monolith.GetOrCreateKey(keyfile, oldKeyfile) var err error m.listener, err = net.Listen("tcp", "localhost:65432") @@ -367,50 +360,20 @@ func (m *DendriteMonolith) Start() { m.logger.SetOutput(BindLogger{}) logrus.SetOutput(BindLogger{}) - pineconeEventChannel := make(chan pineconeEvents.Event) - m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) - m.PineconeRouter.EnableHopLimiting() - m.PineconeRouter.EnableWakeupBroadcasts() - m.PineconeRouter.Subscribe(pineconeEventChannel) - - m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) - m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) - m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil) + m.p2pMonolith = monolith.P2PMonolith{} + m.p2pMonolith.SetupPinecone(sk) prefix := hex.EncodeToString(pk) - cfg := &config.Dendrite{} - cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, - }) + cfg := monolith.GenerateDefaultConfig(sk, m.StorageDirectory, prefix) cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) - cfg.Global.PrivateKey = sk cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.InMemory = false - cfg.Global.JetStream.StoragePath = config.Path(filepath.Join(m.CacheDirectory, prefix)) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.MediaAPI.BasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) - cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) - cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(m.StorageDirectory, prefix))) - cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} - cfg.ClientAPI.RegistrationDisabled = false - cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - cfg.SyncAPI.Fulltext.Enabled = true - cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(m.CacheDirectory, "search")) - if err = cfg.Derive(); err != nil { - panic(err) - } base := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) m.baseDendrite = base base.ConfigureAdminEndpoints() - federation := conn.CreateFederationClient(base, m.PineconeQUIC) + federation := conn.CreateFederationClient(base, m.p2pMonolith.Sessions) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() @@ -431,8 +394,8 @@ func (m *DendriteMonolith) Start() { // This is different to rsAPI which can be the http client which doesn't need this dependency rsAPI.SetFederationAPI(m.federationAPI, keyRing) - userProvider := users.NewPineconeUserProvider(m.PineconeRouter, m.PineconeQUIC, m.userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, m.federationAPI, federation) + userProvider := users.NewPineconeUserProvider(m.p2pMonolith.Router, m.p2pMonolith.Sessions, m.userAPI, federation) + roomProvider := rooms.NewPineconeRoomProvider(m.p2pMonolith.Router, m.p2pMonolith.Sessions, m.federationAPI, federation) js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) producer := &producers.SyncAPIProducer{ @@ -450,7 +413,7 @@ func (m *DendriteMonolith) Start() { monolith := setup.Monolith{ Config: base.Cfg, - Client: conn.CreateClient(base, m.PineconeQUIC), + Client: conn.CreateClient(base, m.p2pMonolith.Sessions), FedClient: federation, KeyRing: keyRing, @@ -471,14 +434,14 @@ func (m *DendriteMonolith) Start() { httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) - httpRouter.HandleFunc("/pinecone", m.PineconeRouter.ManholeHandler) + httpRouter.HandleFunc("/pinecone", m.p2pMonolith.Router.ManholeHandler) pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - pHTTP := m.PineconeQUIC.Protocol("matrix").HTTP() + pHTTP := m.p2pMonolith.Sessions.Protocol("matrix").HTTP() pHTTP.Mux().Handle(users.PublicURL, pMux) pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) @@ -500,7 +463,7 @@ func (m *DendriteMonolith) Start() { go func() { m.logger.Info("Listening on ", cfg.Global.ServerName) - switch m.httpServer.Serve(m.PineconeQUIC.Protocol("matrix")) { + switch m.httpServer.Serve(m.p2pMonolith.Sessions.Protocol("matrix")) { case net.ErrClosed, http.ErrServerClosed: m.logger.Info("Stopped listening on ", cfg.Global.ServerName) default: @@ -522,7 +485,7 @@ func (m *DendriteMonolith) Start() { eLog := logrus.WithField("pinecone", "events") m.relayRetriever = relay.NewRelayServerRetriever( context.Background(), - gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), + gomatrixserverlib.ServerName(m.p2pMonolith.Router.PublicKey().String()), m.federationAPI, monolith.RelayAPI, stopRelayServerSync, @@ -535,7 +498,7 @@ func (m *DendriteMonolith) Start() { case pineconeEvents.PeerAdded: m.relayRetriever.StartSync() case pineconeEvents.PeerRemoved: - if m.relayRetriever.IsRunning() && m.PineconeRouter.TotalPeerCount() == 0 { + if m.relayRetriever.IsRunning() && m.p2pMonolith.Router.TotalPeerCount() == 0 { stopRelayServerSync <- true } case pineconeEvents.BroadcastReceived: @@ -551,14 +514,14 @@ func (m *DendriteMonolith) Start() { default: } } - }(pineconeEventChannel) + }(m.p2pMonolith.EventChannel) } func (m *DendriteMonolith) Stop() { _ = m.baseDendrite.Close() m.baseDendrite.WaitForShutdown() _ = m.listener.Close() - m.PineconeMulticast.Stop() - _ = m.PineconeQUIC.Close() - _ = m.PineconeRouter.Close() + m.p2pMonolith.Multicast.Stop() + _ = m.p2pMonolith.Sessions.Close() + _ = m.p2pMonolith.Router.Close() } diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 9ee792e4a..ffb7ace19 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -33,7 +33,7 @@ import ( "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/keys" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/monolith" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" @@ -53,11 +53,8 @@ import ( "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" - pineconeConnections "github.com/matrix-org/pinecone/connections" - pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" pineconeEvents "github.com/matrix-org/pinecone/router/events" - pineconeSessions "github.com/matrix-org/pinecone/sessions" "github.com/sirupsen/logrus" ) @@ -87,7 +84,7 @@ func main() { } } - cfg := &config.Dendrite{} + var cfg *config.Dendrite // use custom config if config flag is set if configFlagSet { @@ -97,31 +94,8 @@ func main() { } else { keyfile := filepath.Join(*instanceDir, *instanceName) + ".pem" oldKeyfile := *instanceName + ".key" - sk, pk = keys.GetOrCreateKey(keyfile, oldKeyfile) - - cfg.Defaults(config.DefaultOpts{ - Generate: true, - Monolithic: true, - }) - cfg.Global.PrivateKey = sk - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", filepath.Join(*instanceDir, *instanceName))) - cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(*instanceDir, *instanceName))) - cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName))) - cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName))) - cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName))) - cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} - cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) - cfg.ClientAPI.RegistrationDisabled = false - cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - cfg.MediaAPI.BasePath = config.Path(*instanceDir) - cfg.SyncAPI.Fulltext.Enabled = true - cfg.SyncAPI.Fulltext.IndexPath = config.Path(*instanceDir) - if err := cfg.Derive(); err != nil { - panic(err) - } + sk, pk = monolith.GetOrCreateKey(keyfile, oldKeyfile) + cfg = monolith.GenerateDefaultConfig(sk, *instanceDir, *instanceName) } cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) @@ -131,19 +105,13 @@ func main() { base.ConfigureAdminEndpoints() defer base.Close() // nolint: errcheck - pineconeEventChannel := make(chan pineconeEvents.Event) - pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) - pRouter.EnableHopLimiting() - pRouter.EnableWakeupBroadcasts() - pRouter.Subscribe(pineconeEventChannel) + p2pMonolith := monolith.P2PMonolith{} + p2pMonolith.SetupPinecone(sk) + p2pMonolith.Multicast.Start() - pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) - pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) - pManager := pineconeConnections.NewConnectionManager(pRouter, nil) - pMulticast.Start() if instancePeer != nil && *instancePeer != "" { for _, peer := range strings.Split(*instancePeer, ",") { - pManager.AddPeer(strings.Trim(peer, " \t\r\n")) + p2pMonolith.ConnManager.AddPeer(strings.Trim(peer, " \t\r\n")) } } @@ -162,7 +130,7 @@ func main() { continue } - port, err := pRouter.Connect( + port, err := p2pMonolith.Router.Connect( conn, pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), ) @@ -175,7 +143,8 @@ func main() { } }() - federation := conn.CreateFederationClient(base, pQUIC) + // TODO : factor this dendrite setup out to a common place + federation := conn.CreateFederationClient(base, p2pMonolith.Sessions) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() @@ -194,8 +163,8 @@ func main() { rsComponent.SetFederationAPI(fsAPI, keyRing) - userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation) + userProvider := users.NewPineconeUserProvider(p2pMonolith.Router, p2pMonolith.Sessions, userAPI, federation) + roomProvider := rooms.NewPineconeRoomProvider(p2pMonolith.Router, p2pMonolith.Sessions, fsAPI, federation) js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) producer := &producers.SyncAPIProducer{ @@ -212,9 +181,9 @@ func main() { relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer, *instanceRelayingEnabled) logrus.Infof("Relaying enabled: %v", relayAPI.RelayingEnabled()) - monolith := setup.Monolith{ + m := setup.Monolith{ Config: base.Cfg, - Client: conn.CreateClient(base, pQUIC), + Client: conn.CreateClient(base, p2pMonolith.Sessions), FedClient: federation, KeyRing: keyRing, @@ -227,7 +196,7 @@ func main() { ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } - monolith.AddAllPublicRoutes(base) + m.AddAllPublicRoutes(base) wsUpgrader := websocket.Upgrader{ CheckOrigin: func(_ *http.Request) bool { @@ -247,7 +216,7 @@ func main() { return } conn := conn.WrapWebSocketConn(c) - if _, err = pRouter.Connect( + if _, err = p2pMonolith.Router.Connect( conn, pineconeRouter.ConnectionZone("websocket"), pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), @@ -255,7 +224,7 @@ func main() { logrus.WithError(err).Error("Failed to connect WebSocket peer to Pinecone switch") } }) - httpRouter.HandleFunc("/pinecone", pRouter.ManholeHandler) + httpRouter.HandleFunc("/pinecone", p2pMonolith.Router.ManholeHandler) embed.Embed(httpRouter, *instancePort, "Pinecone Demo") pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() @@ -263,7 +232,7 @@ func main() { pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - pHTTP := pQUIC.Protocol("matrix").HTTP() + pHTTP := p2pMonolith.Sessions.Protocol("matrix").HTTP() pHTTP.Mux().Handle(users.PublicURL, pMux) pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) @@ -281,10 +250,11 @@ func main() { Handler: pMux, } + // TODO : factor these funcs out to a common place go func() { - pubkey := pRouter.PublicKey() + pubkey := p2pMonolith.Router.PublicKey() logrus.Info("Listening on ", hex.EncodeToString(pubkey[:])) - logrus.Fatal(httpServer.Serve(pQUIC.Protocol("matrix"))) + logrus.Fatal(httpServer.Serve(p2pMonolith.Sessions.Protocol("matrix"))) }() go func() { httpBindAddr := fmt.Sprintf(":%d", *instancePort) @@ -296,9 +266,9 @@ func main() { eLog := logrus.WithField("pinecone", "events") relayRetriever := relay.NewRelayServerRetriever( context.Background(), - gomatrixserverlib.ServerName(pRouter.PublicKey().String()), - monolith.FederationAPI, - monolith.RelayAPI, + gomatrixserverlib.ServerName(p2pMonolith.Router.PublicKey().String()), + m.FederationAPI, + m.RelayAPI, stopRelayServerSync, ) relayRetriever.InitializeRelayServers(eLog) @@ -309,7 +279,7 @@ func main() { case pineconeEvents.PeerAdded: relayRetriever.StartSync() case pineconeEvents.PeerRemoved: - if relayRetriever.IsRunning() && pRouter.TotalPeerCount() == 0 { + if relayRetriever.IsRunning() && p2pMonolith.Router.TotalPeerCount() == 0 { stopRelayServerSync <- true } case pineconeEvents.BroadcastReceived: @@ -319,13 +289,13 @@ func main() { ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, } res := &api.PerformWakeupServersResponse{} - if err := monolith.FederationAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + if err := m.FederationAPI.PerformWakeupServers(base.Context(), req, res); err != nil { eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } default: } } - }(pineconeEventChannel) + }(p2pMonolith.EventChannel) base.WaitForShutdown() } diff --git a/cmd/dendrite-demo-pinecone/keys/keys.go b/cmd/dendrite-demo-pinecone/monolith/keys.go similarity index 99% rename from cmd/dendrite-demo-pinecone/keys/keys.go rename to cmd/dendrite-demo-pinecone/monolith/keys.go index db84d8ec9..637f24a43 100644 --- a/cmd/dendrite-demo-pinecone/keys/keys.go +++ b/cmd/dendrite-demo-pinecone/monolith/keys.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package keys +package monolith import ( "crypto/ed25519" diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go new file mode 100644 index 000000000..d3f0562bc --- /dev/null +++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go @@ -0,0 +1,80 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package monolith + +import ( + "crypto/ed25519" + "fmt" + "path/filepath" + + "github.com/matrix-org/dendrite/setup/config" + "github.com/sirupsen/logrus" + + pineconeConnections "github.com/matrix-org/pinecone/connections" + pineconeMulticast "github.com/matrix-org/pinecone/multicast" + pineconeRouter "github.com/matrix-org/pinecone/router" + pineconeEvents "github.com/matrix-org/pinecone/router/events" + pineconeSessions "github.com/matrix-org/pinecone/sessions" +) + +type P2PMonolith struct { + Sessions *pineconeSessions.Sessions + Multicast *pineconeMulticast.Multicast + ConnManager *pineconeConnections.ConnectionManager + Router *pineconeRouter.Router + EventChannel chan pineconeEvents.Event +} + +func GenerateDefaultConfig(sk ed25519.PrivateKey, storageDir string, dbPrefix string) *config.Dendrite { + cfg := config.Dendrite{} + cfg.Defaults(config.DefaultOpts{ + Generate: true, + Monolithic: true, + }) + cfg.Global.PrivateKey = sk + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", filepath.Join(storageDir, dbPrefix))) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(storageDir, dbPrefix))) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(storageDir, dbPrefix))) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(storageDir, dbPrefix))) + cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(storageDir, dbPrefix))) + cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(storageDir, dbPrefix))) + cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(storageDir, dbPrefix))) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(storageDir, dbPrefix))) + cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} + cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(storageDir, dbPrefix))) + cfg.ClientAPI.RegistrationDisabled = false + cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true + cfg.MediaAPI.BasePath = config.Path(filepath.Join(storageDir, "media")) + cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(storageDir, "media")) + cfg.SyncAPI.Fulltext.Enabled = true + cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(storageDir, "search")) + if err := cfg.Derive(); err != nil { + panic(err) + } + + return &cfg +} + +func (p *P2PMonolith) SetupPinecone(sk ed25519.PrivateKey) { + p.EventChannel = make(chan pineconeEvents.Event) + p.Router = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) + p.Router.EnableHopLimiting() + p.Router.EnableWakeupBroadcasts() + p.Router.Subscribe(p.EventChannel) + + p.Sessions = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), p.Router, []string{"matrix"}) + p.Multicast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), p.Router) + p.ConnManager = pineconeConnections.NewConnectionManager(p.Router, nil) +} From 048e35026c5396594d26cef6a2412c00ffe6f35f Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 1 Feb 2023 13:41:27 -0700 Subject: [PATCH 117/124] Refactor common pinecone demo code to remove major duplication --- build/gobind-pinecone/monolith.go | 226 ++---------- cmd/dendrite-demo-pinecone/main.go | 196 +---------- .../monolith/monolith.go | 321 +++++++++++++++++- 3 files changed, 347 insertions(+), 396 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index cca92151b..597c91353 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -18,47 +18,25 @@ import ( "context" "crypto/ed25519" "crypto/rand" - "crypto/tls" "encoding/hex" "fmt" "net" - "net/http" "path/filepath" "strings" - "time" - "github.com/gorilla/mux" - "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conduit" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/monolith" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" - "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/federationapi/producers" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" - "github.com/matrix-org/dendrite/relayapi" - relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" - "github.com/matrix-org/dendrite/roomserver" - "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/userapi" userapiAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/pinecone/types" "github.com/sirupsen/logrus" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeEvents "github.com/matrix-org/pinecone/router/events" _ "golang.org/x/mobile/bind" ) @@ -70,17 +48,17 @@ const ( PeerTypeBonjour = pineconeRouter.PeerTypeBonjour ) +// Re-export Conduit in this package for bindings. +type Conduit struct { + conduit.Conduit +} + type DendriteMonolith struct { logger logrus.Logger - baseDendrite *base.BaseDendrite p2pMonolith monolith.P2PMonolith StorageDirectory string + CacheDirectory string listener net.Listener - httpServer *http.Server - userAPI userapiAPI.UserInternalAPI - federationAPI api.FederationInternalAPI - relayAPI relayServerAPI.RelayInternalAPI - relayRetriever relay.RelayServerRetriever } func (m *DendriteMonolith) PublicKey() string { @@ -88,7 +66,7 @@ func (m *DendriteMonolith) PublicKey() string { } func (m *DendriteMonolith) BaseURL() string { - return fmt.Sprintf("http://%s", m.listener.Addr().String()) + return fmt.Sprintf("http://%s", m.p2pMonolith.Addr()) } func (m *DendriteMonolith) PeerCount(peertype int) int { @@ -96,7 +74,7 @@ func (m *DendriteMonolith) PeerCount(peertype int) int { } func (m *DendriteMonolith) SessionCount() int { - return len(m.p2pMonolith.Sessions.Protocol("matrix").Sessions()) + return len(m.p2pMonolith.Sessions.Protocol(monolith.SessionProtocol).Sessions()) } type InterfaceInfo struct { @@ -202,13 +180,13 @@ func (m *DendriteMonolith) SetRelayServers(nodeID string, uris string) { if string(nodeKey) == m.PublicKey() { logrus.Infof("Setting own relay servers to: %v", relays) - m.relayRetriever.SetRelayServers(relays) + m.p2pMonolith.RelayRetriever.SetRelayServers(relays) } else { relay.UpdateNodeRelayServers( gomatrixserverlib.ServerName(nodeKey), relays, - m.baseDendrite.Context(), - m.federationAPI, + m.p2pMonolith.BaseDendrite.Context(), + m.p2pMonolith.GetFederationAPI(), ) } } @@ -222,7 +200,7 @@ func (m *DendriteMonolith) GetRelayServers(nodeID string) string { relaysString := "" if string(nodeKey) == m.PublicKey() { - relays := m.relayRetriever.GetRelayServers() + relays := m.p2pMonolith.RelayRetriever.GetRelayServers() for i, relay := range relays { if i != 0 { @@ -234,7 +212,7 @@ func (m *DendriteMonolith) GetRelayServers(nodeID string) string { } else { request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(nodeKey)} response := api.P2PQueryRelayServersResponse{} - err := m.federationAPI.P2PQueryRelayServers(m.baseDendrite.Context(), &request, &response) + err := m.p2pMonolith.GetFederationAPI().P2PQueryRelayServers(m.p2pMonolith.BaseDendrite.Context(), &request, &response) if err != nil { logrus.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) return "" @@ -253,11 +231,11 @@ func (m *DendriteMonolith) GetRelayServers(nodeID string) string { } func (m *DendriteMonolith) RelayingEnabled() bool { - return m.relayAPI.RelayingEnabled() + return m.p2pMonolith.GetRelayAPI().RelayingEnabled() } func (m *DendriteMonolith) SetRelayingEnabled(enabled bool) { - m.relayAPI.SetRelayingEnabled(enabled) + m.p2pMonolith.GetRelayAPI().SetRelayingEnabled(enabled) } func (m *DendriteMonolith) DisconnectType(peertype int) { @@ -280,9 +258,9 @@ func (m *DendriteMonolith) DisconnectPort(port int) { m.p2pMonolith.Router.Disconnect(types.SwitchPortID(port), nil) } -func (m *DendriteMonolith) Conduit(zone string, peertype int) (*conduit.Conduit, error) { +func (m *DendriteMonolith) Conduit(zone string, peertype int) (*Conduit, error) { l, r := net.Pipe() - newConduit := conduit.NewConduit(r, 0) + newConduit := Conduit{conduit.NewConduit(r, 0)} go func() { logrus.Errorf("Attempting authenticated connect") var port types.SwitchPortID @@ -316,7 +294,7 @@ func (m *DendriteMonolith) RegisterUser(localpart, password string) (string, err Password: password, } userRes := &userapiAPI.PerformAccountCreationResponse{} - if err := m.userAPI.PerformAccountCreation(context.Background(), userReq, userRes); err != nil { + if err := m.p2pMonolith.GetUserAPI().PerformAccountCreation(context.Background(), userReq, userRes); err != nil { return userID, fmt.Errorf("userAPI.PerformAccountCreation: %w", err) } return userID, nil @@ -334,7 +312,7 @@ func (m *DendriteMonolith) RegisterDevice(localpart, deviceID string) (string, e AccessToken: hex.EncodeToString(accessTokenBytes[:n]), } loginRes := &userapiAPI.PerformDeviceCreationResponse{} - if err := m.userAPI.PerformDeviceCreation(context.Background(), loginReq, loginRes); err != nil { + if err := m.p2pMonolith.GetUserAPI().PerformDeviceCreation(context.Background(), loginReq, loginRes); err != nil { return "", fmt.Errorf("userAPI.PerformDeviceCreation: %w", err) } if !loginRes.DeviceCreated { @@ -348,12 +326,6 @@ func (m *DendriteMonolith) Start() { oldKeyfile := filepath.Join(m.StorageDirectory, "p2p.key") sk, pk := monolith.GetOrCreateKey(keyfile, oldKeyfile) - var err error - m.listener, err = net.Listen("tcp", "localhost:65432") - if err != nil { - panic(err) - } - m.logger = logrus.Logger{ Out: BindLogger{}, } @@ -364,164 +336,20 @@ func (m *DendriteMonolith) Start() { m.p2pMonolith.SetupPinecone(sk) prefix := hex.EncodeToString(pk) - cfg := monolith.GenerateDefaultConfig(sk, m.StorageDirectory, prefix) + cfg := monolith.GenerateDefaultConfig(sk, m.StorageDirectory, m.CacheDirectory, prefix) cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.InMemory = false - base := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) - m.baseDendrite = base - base.ConfigureAdminEndpoints() + enableRelaying := false + enableMetrics := true + enableWebsockets := true + m.p2pMonolith.SetupDendrite(cfg, 65432, enableRelaying, enableMetrics, enableWebsockets) - federation := conn.CreateFederationClient(base, m.p2pMonolith.Sessions) - - serverKeyAPI := &signing.YggdrasilKeys{} - keyRing := serverKeyAPI.KeyRing() - - rsAPI := roomserver.NewInternalAPI(base) - - m.federationAPI = federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, keyRing, true, - ) - - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, m.federationAPI, rsAPI) - m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(m.userAPI) - - asAPI := appservice.NewInternalAPI(base, m.userAPI, rsAPI) - - // The underlying roomserver implementation needs to be able to call the fedsender. - // This is different to rsAPI which can be the http client which doesn't need this dependency - rsAPI.SetFederationAPI(m.federationAPI, keyRing) - - userProvider := users.NewPineconeUserProvider(m.p2pMonolith.Router, m.p2pMonolith.Sessions, m.userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(m.p2pMonolith.Router, m.p2pMonolith.Sessions, m.federationAPI, federation) - - js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - producer := &producers.SyncAPIProducer{ - JetStream: js, - TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), - TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), - TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), - TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), - TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - Config: &base.Cfg.FederationAPI, - UserAPI: m.userAPI, - } - m.relayAPI = relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer, false) - - monolith := setup.Monolith{ - Config: base.Cfg, - Client: conn.CreateClient(base, m.p2pMonolith.Sessions), - FedClient: federation, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - FederationAPI: m.federationAPI, - RoomserverAPI: rsAPI, - UserAPI: m.userAPI, - KeyAPI: keyAPI, - RelayAPI: m.relayAPI, - ExtPublicRoomsProvider: roomProvider, - ExtUserDirectoryProvider: userProvider, - } - monolith.AddAllPublicRoutes(base) - - httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) - httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) - httpRouter.HandleFunc("/pinecone", m.p2pMonolith.Router.ManholeHandler) - - pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() - pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) - pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) - pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - - pHTTP := m.p2pMonolith.Sessions.Protocol("matrix").HTTP() - pHTTP.Mux().Handle(users.PublicURL, pMux) - pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) - pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) - - // Build both ends of a HTTP multiplex. - h2s := &http2.Server{} - m.httpServer = &http.Server{ - Addr: ":0", - TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 30 * time.Second, - BaseContext: func(_ net.Listener) context.Context { - return context.Background() - }, - Handler: h2c.NewHandler(pMux, h2s), - } - - go func() { - m.logger.Info("Listening on ", cfg.Global.ServerName) - - switch m.httpServer.Serve(m.p2pMonolith.Sessions.Protocol("matrix")) { - case net.ErrClosed, http.ErrServerClosed: - m.logger.Info("Stopped listening on ", cfg.Global.ServerName) - default: - m.logger.Error("Stopped listening on ", cfg.Global.ServerName) - } - }() - go func() { - logrus.Info("Listening on ", m.listener.Addr()) - - switch http.Serve(m.listener, httpRouter) { - case net.ErrClosed, http.ErrServerClosed: - m.logger.Info("Stopped listening on ", cfg.Global.ServerName) - default: - m.logger.Error("Stopped listening on ", cfg.Global.ServerName) - } - }() - - stopRelayServerSync := make(chan bool) - eLog := logrus.WithField("pinecone", "events") - m.relayRetriever = relay.NewRelayServerRetriever( - context.Background(), - gomatrixserverlib.ServerName(m.p2pMonolith.Router.PublicKey().String()), - m.federationAPI, - monolith.RelayAPI, - stopRelayServerSync, - ) - m.relayRetriever.InitializeRelayServers(eLog) - - go func(ch <-chan pineconeEvents.Event) { - for event := range ch { - switch e := event.(type) { - case pineconeEvents.PeerAdded: - m.relayRetriever.StartSync() - case pineconeEvents.PeerRemoved: - if m.relayRetriever.IsRunning() && m.p2pMonolith.Router.TotalPeerCount() == 0 { - stopRelayServerSync <- true - } - case pineconeEvents.BroadcastReceived: - // eLog.Info("Broadcast received from: ", e.PeerID) - - req := &api.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, - } - res := &api.PerformWakeupServersResponse{} - if err := m.federationAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) - } - default: - } - } - }(m.p2pMonolith.EventChannel) + useTCPListener := false + m.p2pMonolith.StartMonolith(useTCPListener) } func (m *DendriteMonolith) Stop() { - _ = m.baseDendrite.Close() - m.baseDendrite.WaitForShutdown() - _ = m.listener.Close() - m.p2pMonolith.Multicast.Stop() - _ = m.p2pMonolith.Sessions.Close() - _ = m.p2pMonolith.Router.Close() + m.p2pMonolith.Stop() } diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index ffb7ace19..0b76188b8 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -15,48 +15,24 @@ package main import ( - "context" "crypto/ed25519" - "crypto/tls" "encoding/hex" "flag" "fmt" "net" - "net/http" "os" "path/filepath" "strings" - "time" - "github.com/gorilla/mux" - "github.com/gorilla/websocket" - "github.com/matrix-org/dendrite/appservice" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/monolith" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" - "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" - "github.com/matrix-org/dendrite/federationapi" - "github.com/matrix-org/dendrite/federationapi/api" - "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/httputil" - "github.com/matrix-org/dendrite/keyserver" - "github.com/matrix-org/dendrite/relayapi" - "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/jetstream" - "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" pineconeRouter "github.com/matrix-org/pinecone/router" - pineconeEvents "github.com/matrix-org/pinecone/router/events" - - "github.com/sirupsen/logrus" ) var ( @@ -95,16 +71,12 @@ func main() { keyfile := filepath.Join(*instanceDir, *instanceName) + ".pem" oldKeyfile := *instanceName + ".key" sk, pk = monolith.GetOrCreateKey(keyfile, oldKeyfile) - cfg = monolith.GenerateDefaultConfig(sk, *instanceDir, *instanceName) + cfg = monolith.GenerateDefaultConfig(sk, *instanceDir, *instanceDir, *instanceName) } cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) - base := base.NewBaseDendrite(cfg, "Monolith") - base.ConfigureAdminEndpoints() - defer base.Close() // nolint: errcheck - p2pMonolith := monolith.P2PMonolith{} p2pMonolith.SetupPinecone(sk) p2pMonolith.Multicast.Start() @@ -115,6 +87,14 @@ func main() { } } + enableMetrics := true + enableWebsockets := true + p2pMonolith.SetupDendrite(cfg, *instancePort, *instanceRelayingEnabled, enableMetrics, enableWebsockets) + + useTCPListener := false + p2pMonolith.StartMonolith(useTCPListener) + p2pMonolith.WaitForShutdown() + go func() { listener, err := net.Listen("tcp", *instanceListen) if err != nil { @@ -142,160 +122,4 @@ func main() { fmt.Println("Inbound connection", conn.RemoteAddr(), "is connected to port", port) } }() - - // TODO : factor this dendrite setup out to a common place - federation := conn.CreateFederationClient(base, p2pMonolith.Sessions) - - serverKeyAPI := &signing.YggdrasilKeys{} - keyRing := serverKeyAPI.KeyRing() - - rsComponent := roomserver.NewInternalAPI(base) - rsAPI := rsComponent - fsAPI := federationapi.NewInternalAPI( - base, federation, rsAPI, base.Caches, keyRing, true, - ) - - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsComponent) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) - keyAPI.SetUserAPI(userAPI) - - asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - - rsComponent.SetFederationAPI(fsAPI, keyRing) - - userProvider := users.NewPineconeUserProvider(p2pMonolith.Router, p2pMonolith.Sessions, userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(p2pMonolith.Router, p2pMonolith.Sessions, fsAPI, federation) - - js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) - producer := &producers.SyncAPIProducer{ - JetStream: js, - TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), - TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), - TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), - TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), - TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - Config: &base.Cfg.FederationAPI, - UserAPI: userAPI, - } - relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer, *instanceRelayingEnabled) - logrus.Infof("Relaying enabled: %v", relayAPI.RelayingEnabled()) - - m := setup.Monolith{ - Config: base.Cfg, - Client: conn.CreateClient(base, p2pMonolith.Sessions), - FedClient: federation, - KeyRing: keyRing, - - AppserviceAPI: asAPI, - FederationAPI: fsAPI, - RoomserverAPI: rsAPI, - UserAPI: userAPI, - KeyAPI: keyAPI, - RelayAPI: relayAPI, - ExtPublicRoomsProvider: roomProvider, - ExtUserDirectoryProvider: userProvider, - } - m.AddAllPublicRoutes(base) - - wsUpgrader := websocket.Upgrader{ - CheckOrigin: func(_ *http.Request) bool { - return true - }, - } - httpRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() - httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) - httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) - httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) - httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) - httpRouter.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { - c, err := wsUpgrader.Upgrade(w, r, nil) - if err != nil { - logrus.WithError(err).Error("Failed to upgrade WebSocket connection") - return - } - conn := conn.WrapWebSocketConn(c) - if _, err = p2pMonolith.Router.Connect( - conn, - pineconeRouter.ConnectionZone("websocket"), - pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), - ); err != nil { - logrus.WithError(err).Error("Failed to connect WebSocket peer to Pinecone switch") - } - }) - httpRouter.HandleFunc("/pinecone", p2pMonolith.Router.ManholeHandler) - embed.Embed(httpRouter, *instancePort, "Pinecone Demo") - - pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() - pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) - pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) - pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - - pHTTP := p2pMonolith.Sessions.Protocol("matrix").HTTP() - pHTTP.Mux().Handle(users.PublicURL, pMux) - pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, pMux) - pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, pMux) - - // Build both ends of a HTTP multiplex. - httpServer := &http.Server{ - Addr: ":0", - TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - IdleTimeout: 60 * time.Second, - BaseContext: func(_ net.Listener) context.Context { - return context.Background() - }, - Handler: pMux, - } - - // TODO : factor these funcs out to a common place - go func() { - pubkey := p2pMonolith.Router.PublicKey() - logrus.Info("Listening on ", hex.EncodeToString(pubkey[:])) - logrus.Fatal(httpServer.Serve(p2pMonolith.Sessions.Protocol("matrix"))) - }() - go func() { - httpBindAddr := fmt.Sprintf(":%d", *instancePort) - logrus.Info("Listening on ", httpBindAddr) - logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) - }() - - stopRelayServerSync := make(chan bool) - eLog := logrus.WithField("pinecone", "events") - relayRetriever := relay.NewRelayServerRetriever( - context.Background(), - gomatrixserverlib.ServerName(p2pMonolith.Router.PublicKey().String()), - m.FederationAPI, - m.RelayAPI, - stopRelayServerSync, - ) - relayRetriever.InitializeRelayServers(eLog) - - go func(ch <-chan pineconeEvents.Event) { - for event := range ch { - switch e := event.(type) { - case pineconeEvents.PeerAdded: - relayRetriever.StartSync() - case pineconeEvents.PeerRemoved: - if relayRetriever.IsRunning() && p2pMonolith.Router.TotalPeerCount() == 0 { - stopRelayServerSync <- true - } - case pineconeEvents.BroadcastReceived: - // eLog.Info("Broadcast received from: ", e.PeerID) - - req := &api.PerformWakeupServersRequest{ - ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, - } - res := &api.PerformWakeupServersResponse{} - if err := m.FederationAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) - } - default: - } - } - }(p2pMonolith.EventChannel) - - base.WaitForShutdown() } diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go index d3f0562bc..61e71e877 100644 --- a/cmd/dendrite-demo-pinecone/monolith/monolith.go +++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go @@ -15,12 +15,43 @@ package monolith import ( + "context" "crypto/ed25519" + "crypto/tls" + "encoding/hex" "fmt" + "net" + "net/http" "path/filepath" + "time" + "github.com/gorilla/mux" + "github.com/gorilla/websocket" + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/conn" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/embed" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/relay" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/rooms" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/federationapi" + federationAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/relayapi" + relayAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/userapi" + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" @@ -29,22 +60,33 @@ import ( pineconeSessions "github.com/matrix-org/pinecone/sessions" ) +const SessionProtocol = "matrix" + type P2PMonolith struct { - Sessions *pineconeSessions.Sessions - Multicast *pineconeMulticast.Multicast - ConnManager *pineconeConnections.ConnectionManager - Router *pineconeRouter.Router - EventChannel chan pineconeEvents.Event + BaseDendrite *base.BaseDendrite + Sessions *pineconeSessions.Sessions + Multicast *pineconeMulticast.Multicast + ConnManager *pineconeConnections.ConnectionManager + Router *pineconeRouter.Router + EventChannel chan pineconeEvents.Event + RelayRetriever relay.RelayServerRetriever + + dendrite setup.Monolith + port int + httpMux *mux.Router + pineconeMux *mux.Router + listener net.Listener + httpListenAddr string } -func GenerateDefaultConfig(sk ed25519.PrivateKey, storageDir string, dbPrefix string) *config.Dendrite { +func GenerateDefaultConfig(sk ed25519.PrivateKey, storageDir string, cacheDir string, dbPrefix string) *config.Dendrite { cfg := config.Dendrite{} cfg.Defaults(config.DefaultOpts{ Generate: true, Monolithic: true, }) cfg.Global.PrivateKey = sk - cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", filepath.Join(storageDir, dbPrefix))) + cfg.Global.JetStream.StoragePath = config.Path(fmt.Sprintf("%s/", filepath.Join(cacheDir, dbPrefix))) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-account.db", filepath.Join(storageDir, dbPrefix))) cfg.MediaAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mediaapi.db", filepath.Join(storageDir, dbPrefix))) cfg.SyncAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-syncapi.db", filepath.Join(storageDir, dbPrefix))) @@ -56,10 +98,10 @@ func GenerateDefaultConfig(sk ed25519.PrivateKey, storageDir string, dbPrefix st cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(storageDir, dbPrefix))) cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true - cfg.MediaAPI.BasePath = config.Path(filepath.Join(storageDir, "media")) - cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(storageDir, "media")) + cfg.MediaAPI.BasePath = config.Path(filepath.Join(cacheDir, "media")) + cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(cacheDir, "media")) cfg.SyncAPI.Fulltext.Enabled = true - cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(storageDir, "search")) + cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(cacheDir, "search")) if err := cfg.Derive(); err != nil { panic(err) } @@ -74,7 +116,264 @@ func (p *P2PMonolith) SetupPinecone(sk ed25519.PrivateKey) { p.Router.EnableWakeupBroadcasts() p.Router.Subscribe(p.EventChannel) - p.Sessions = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), p.Router, []string{"matrix"}) + p.Sessions = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), p.Router, []string{SessionProtocol}) p.Multicast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), p.Router) p.ConnManager = pineconeConnections.NewConnectionManager(p.Router, nil) } + +func (p *P2PMonolith) SetupDendrite(cfg *config.Dendrite, port int, enableRelaying bool, enableMetrics bool, enableWebsockets bool) { + if enableMetrics { + p.BaseDendrite = base.NewBaseDendrite(cfg, "Monolith") + } else { + p.BaseDendrite = base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) + } + p.port = port + p.BaseDendrite.ConfigureAdminEndpoints() + + federation := conn.CreateFederationClient(p.BaseDendrite, p.Sessions) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + + rsComponent := roomserver.NewInternalAPI(p.BaseDendrite) + rsAPI := rsComponent + fsAPI := federationapi.NewInternalAPI( + p.BaseDendrite, federation, rsAPI, p.BaseDendrite.Caches, keyRing, true, + ) + + keyAPI := keyserver.NewInternalAPI(p.BaseDendrite, &p.BaseDendrite.Cfg.KeyServer, fsAPI, rsComponent) + userAPI := userapi.NewInternalAPI(p.BaseDendrite, &cfg.UserAPI, nil, keyAPI, rsAPI, p.BaseDendrite.PushGatewayHTTPClient()) + keyAPI.SetUserAPI(userAPI) + + asAPI := appservice.NewInternalAPI(p.BaseDendrite, userAPI, rsAPI) + + rsComponent.SetFederationAPI(fsAPI, keyRing) + + userProvider := users.NewPineconeUserProvider(p.Router, p.Sessions, userAPI, federation) + roomProvider := rooms.NewPineconeRoomProvider(p.Router, p.Sessions, fsAPI, federation) + + js, _ := p.BaseDendrite.NATS.Prepare(p.BaseDendrite.ProcessContext, &p.BaseDendrite.Cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: p.BaseDendrite.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &p.BaseDendrite.Cfg.FederationAPI, + UserAPI: userAPI, + } + relayAPI := relayapi.NewRelayInternalAPI(p.BaseDendrite, federation, rsAPI, keyRing, producer, enableRelaying) + logrus.Infof("Relaying enabled: %v", relayAPI.RelayingEnabled()) + + p.dendrite = setup.Monolith{ + Config: p.BaseDendrite.Cfg, + Client: conn.CreateClient(p.BaseDendrite, p.Sessions), + FedClient: federation, + KeyRing: keyRing, + + AppserviceAPI: asAPI, + FederationAPI: fsAPI, + RoomserverAPI: rsAPI, + UserAPI: userAPI, + KeyAPI: keyAPI, + RelayAPI: relayAPI, + ExtPublicRoomsProvider: roomProvider, + ExtUserDirectoryProvider: userProvider, + } + p.dendrite.AddAllPublicRoutes(p.BaseDendrite) + + p.setupHttpServers(userProvider, enableWebsockets) +} + +func (p *P2PMonolith) GetFederationAPI() federationAPI.FederationInternalAPI { + return p.dendrite.FederationAPI +} + +func (p *P2PMonolith) GetRelayAPI() relayAPI.RelayInternalAPI { + return p.dendrite.RelayAPI +} + +func (p *P2PMonolith) GetUserAPI() userAPI.UserInternalAPI { + return p.dendrite.UserAPI +} + +func (p *P2PMonolith) StartMonolith(useTCPListener bool) { + p.startHTTPServers(useTCPListener) + p.startEventHandler() +} + +func (p *P2PMonolith) Stop() { + _ = p.BaseDendrite.Close() + p.WaitForShutdown() +} + +func (p *P2PMonolith) WaitForShutdown() { + p.BaseDendrite.WaitForShutdown() + p.closeAllResources() +} + +func (p *P2PMonolith) closeAllResources() { + if p.listener != nil { + _ = p.listener.Close() + } + + if p.Multicast != nil { + p.Multicast.Stop() + } + + if p.Sessions != nil { + _ = p.Sessions.Close() + } + + if p.Router != nil { + _ = p.Router.Close() + } +} + +func (p *P2PMonolith) Addr() string { + return p.httpListenAddr +} + +func (p *P2PMonolith) setupHttpServers(userProvider *users.PineconeUserProvider, enableWebsockets bool) { + p.httpMux = mux.NewRouter().SkipClean(true).UseEncodedPath() + p.httpMux.PathPrefix(httputil.InternalPathPrefix).Handler(p.BaseDendrite.InternalAPIMux) + p.httpMux.PathPrefix(httputil.PublicClientPathPrefix).Handler(p.BaseDendrite.PublicClientAPIMux) + p.httpMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(p.BaseDendrite.PublicMediaAPIMux) + p.httpMux.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(p.BaseDendrite.DendriteAdminMux) + p.httpMux.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(p.BaseDendrite.SynapseAdminMux) + + if enableWebsockets { + wsUpgrader := websocket.Upgrader{ + CheckOrigin: func(_ *http.Request) bool { + return true + }, + } + p.httpMux.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { + c, err := wsUpgrader.Upgrade(w, r, nil) + if err != nil { + logrus.WithError(err).Error("Failed to upgrade WebSocket connection") + return + } + conn := conn.WrapWebSocketConn(c) + if _, err = p.Router.Connect( + conn, + pineconeRouter.ConnectionZone("websocket"), + pineconeRouter.ConnectionPeerType(pineconeRouter.PeerTypeRemote), + ); err != nil { + logrus.WithError(err).Error("Failed to connect WebSocket peer to Pinecone switch") + } + }) + } + + p.httpMux.HandleFunc("/pinecone", p.Router.ManholeHandler) + + if enableWebsockets { + embed.Embed(p.httpMux, p.port, "Pinecone Demo") + } + + p.pineconeMux = mux.NewRouter().SkipClean(true).UseEncodedPath() + p.pineconeMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) + p.pineconeMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(p.BaseDendrite.PublicFederationAPIMux) + p.pineconeMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(p.BaseDendrite.PublicMediaAPIMux) + + pHTTP := p.Sessions.Protocol(SessionProtocol).HTTP() + pHTTP.Mux().Handle(users.PublicURL, p.pineconeMux) + pHTTP.Mux().Handle(httputil.PublicFederationPathPrefix, p.pineconeMux) + pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, p.pineconeMux) +} + +func (p *P2PMonolith) startHTTPServers(useTCPListener bool) { + var handler http.Handler + var httpServeFunc func() error + if useTCPListener { + var err error + p.httpListenAddr = "localhost:" + fmt.Sprint(p.port) + p.listener, err = net.Listen("tcp", p.httpListenAddr) + if err != nil { + panic(err) + } + + h2s := &http2.Server{} + handler = h2c.NewHandler(p.pineconeMux, h2s) + httpServeFunc = func() error { return http.Serve(p.listener, p.httpMux) } + } else { + handler = p.pineconeMux + p.httpListenAddr = fmt.Sprintf(":%d", p.port) + httpServeFunc = func() error { return http.ListenAndServe(p.httpListenAddr, p.httpMux) } + } + + go func() { + // Build both ends of a HTTP multiplex. + httpServer := &http.Server{ + Addr: ":0", + TLSNextProto: map[string]func(*http.Server, *tls.Conn, http.Handler){}, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 30 * time.Second, + BaseContext: func(_ net.Listener) context.Context { + return context.Background() + }, + Handler: handler, + } + + pubkey := p.Router.PublicKey() + pubkeyString := hex.EncodeToString(pubkey[:]) + logrus.Info("Listening on ", pubkeyString) + + switch httpServer.Serve(p.Sessions.Protocol(SessionProtocol)) { + case net.ErrClosed, http.ErrServerClosed: + logrus.Info("Stopped listening on ", pubkeyString) + default: + logrus.Error("Stopped listening on ", pubkeyString) + } + }() + + go func() { + logrus.Info("Listening on ", p.httpListenAddr) + switch httpServeFunc() { + case net.ErrClosed, http.ErrServerClosed: + logrus.Info("Stopped listening on ", p.httpListenAddr) + default: + logrus.Error("Stopped listening on ", p.httpListenAddr) + } + }() +} + +func (p *P2PMonolith) startEventHandler() { + stopRelayServerSync := make(chan bool) + eLog := logrus.WithField("pinecone", "events") + p.RelayRetriever = relay.NewRelayServerRetriever( + context.Background(), + gomatrixserverlib.ServerName(p.Router.PublicKey().String()), + p.dendrite.FederationAPI, + p.dendrite.RelayAPI, + stopRelayServerSync, + ) + p.RelayRetriever.InitializeRelayServers(eLog) + + go func(ch <-chan pineconeEvents.Event) { + for event := range ch { + switch e := event.(type) { + case pineconeEvents.PeerAdded: + p.RelayRetriever.StartSync() + case pineconeEvents.PeerRemoved: + if p.RelayRetriever.IsRunning() && p.Router.TotalPeerCount() == 0 { + stopRelayServerSync <- true + } + case pineconeEvents.BroadcastReceived: + // eLog.Info("Broadcast received from: ", e.PeerID) + + req := &federationAPI.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + } + res := &federationAPI.PerformWakeupServersResponse{} + if err := p.dendrite.FederationAPI.PerformWakeupServers(p.BaseDendrite.Context(), req, res); err != nil { + eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) + } + default: + } + } + }(p.EventChannel) +} From a666c06da1274ed98bfbcc4862fc58deb8448ec9 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 1 Feb 2023 14:11:48 -0700 Subject: [PATCH 118/124] Consolidate pinecone demo http server variations --- build/gobind-pinecone/monolith.go | 8 ++--- cmd/dendrite-demo-pinecone/main.go | 4 +-- .../monolith/monolith.go | 32 ++++--------------- 3 files changed, 10 insertions(+), 34 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 597c91353..e8b22c067 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -342,12 +342,10 @@ func (m *DendriteMonolith) Start() { cfg.Global.JetStream.InMemory = false enableRelaying := false - enableMetrics := true - enableWebsockets := true + enableMetrics := false + enableWebsockets := false m.p2pMonolith.SetupDendrite(cfg, 65432, enableRelaying, enableMetrics, enableWebsockets) - - useTCPListener := false - m.p2pMonolith.StartMonolith(useTCPListener) + m.p2pMonolith.StartMonolith() } func (m *DendriteMonolith) Stop() { diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 0b76188b8..7706a73f5 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -90,9 +90,7 @@ func main() { enableMetrics := true enableWebsockets := true p2pMonolith.SetupDendrite(cfg, *instancePort, *instanceRelayingEnabled, enableMetrics, enableWebsockets) - - useTCPListener := false - p2pMonolith.StartMonolith(useTCPListener) + p2pMonolith.StartMonolith() p2pMonolith.WaitForShutdown() go func() { diff --git a/cmd/dendrite-demo-pinecone/monolith/monolith.go b/cmd/dendrite-demo-pinecone/monolith/monolith.go index 61e71e877..6f1c69a78 100644 --- a/cmd/dendrite-demo-pinecone/monolith/monolith.go +++ b/cmd/dendrite-demo-pinecone/monolith/monolith.go @@ -50,8 +50,6 @@ import ( userAPI "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" - "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" @@ -199,8 +197,8 @@ func (p *P2PMonolith) GetUserAPI() userAPI.UserInternalAPI { return p.dendrite.UserAPI } -func (p *P2PMonolith) StartMonolith(useTCPListener bool) { - p.startHTTPServers(useTCPListener) +func (p *P2PMonolith) StartMonolith() { + p.startHTTPServers() p.startEventHandler() } @@ -284,26 +282,7 @@ func (p *P2PMonolith) setupHttpServers(userProvider *users.PineconeUserProvider, pHTTP.Mux().Handle(httputil.PublicMediaPathPrefix, p.pineconeMux) } -func (p *P2PMonolith) startHTTPServers(useTCPListener bool) { - var handler http.Handler - var httpServeFunc func() error - if useTCPListener { - var err error - p.httpListenAddr = "localhost:" + fmt.Sprint(p.port) - p.listener, err = net.Listen("tcp", p.httpListenAddr) - if err != nil { - panic(err) - } - - h2s := &http2.Server{} - handler = h2c.NewHandler(p.pineconeMux, h2s) - httpServeFunc = func() error { return http.Serve(p.listener, p.httpMux) } - } else { - handler = p.pineconeMux - p.httpListenAddr = fmt.Sprintf(":%d", p.port) - httpServeFunc = func() error { return http.ListenAndServe(p.httpListenAddr, p.httpMux) } - } - +func (p *P2PMonolith) startHTTPServers() { go func() { // Build both ends of a HTTP multiplex. httpServer := &http.Server{ @@ -315,7 +294,7 @@ func (p *P2PMonolith) startHTTPServers(useTCPListener bool) { BaseContext: func(_ net.Listener) context.Context { return context.Background() }, - Handler: handler, + Handler: p.pineconeMux, } pubkey := p.Router.PublicKey() @@ -330,9 +309,10 @@ func (p *P2PMonolith) startHTTPServers(useTCPListener bool) { } }() + p.httpListenAddr = fmt.Sprintf(":%d", p.port) go func() { logrus.Info("Listening on ", p.httpListenAddr) - switch httpServeFunc() { + switch http.ListenAndServe(p.httpListenAddr, p.httpMux) { case net.ErrClosed, http.ErrServerClosed: logrus.Info("Stopped listening on ", p.httpListenAddr) default: From 9c826d064dacc00d4f4f385a09aa2c8d381e7317 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 2 Feb 2023 10:27:38 +0100 Subject: [PATCH 119/124] Bump activesupport from 6.0.5 to 6.0.6.1 in /docs (#2959) Bumps [activesupport](https://github.com/rails/rails) from 6.0.5 to 6.0.6.1.
Release notes

Sourced from activesupport's releases.

v6.0.6.1

Active Support

  • No changes.

Active Model

  • No changes.

Active Record

  • Make sanitize_as_sql_comment more strict

    Though this method was likely never meant to take user input, it was attempting sanitization. That sanitization could be bypassed with carefully crafted input.

    This commit makes the sanitization more robust by replacing any occurrances of "/" or "/" with "/ " or " /". It also performs a first pass to remove one surrounding comment to avoid compatibility issues for users relying on the existing removal.

    This also clarifies in the documentation of annotate that it should not be provided user input.

    [CVE-2023-22794]

Action View

  • No changes.

Action Pack

  • No changes.

Active Job

  • No changes.

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=activesupport&package-manager=bundler&previous-version=6.0.5&new-version=6.0.6.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/Gemfile.lock | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 5d79365f8..a61786c1d 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -1,7 +1,7 @@ GEM remote: https://rubygems.org/ specs: - activesupport (6.0.5) + activesupport (6.0.6.1) concurrent-ruby (~> 1.0, >= 1.0.2) i18n (>= 0.7, < 2) minitest (~> 5.1) @@ -15,7 +15,7 @@ GEM coffee-script-source (1.11.1) colorator (1.1.0) commonmarker (0.23.7) - concurrent-ruby (1.1.10) + concurrent-ruby (1.2.0) dnsruby (1.61.9) simpleidn (~> 0.1) em-websocket (0.5.3) @@ -229,7 +229,7 @@ GEM jekyll (>= 3.5, < 5.0) jekyll-feed (~> 0.9) jekyll-seo-tag (~> 2.1) - minitest (5.15.0) + minitest (5.17.0) multipart-post (2.1.1) nokogiri (1.13.10-arm64-darwin) racc (~> 1.4) @@ -265,13 +265,13 @@ GEM thread_safe (0.3.6) typhoeus (1.4.0) ethon (>= 0.9.0) - tzinfo (1.2.10) + tzinfo (1.2.11) thread_safe (~> 0.1) unf (0.1.4) unf_ext unf_ext (0.0.8.1) unicode-display_width (1.8.0) - zeitwerk (2.5.4) + zeitwerk (2.6.6) PLATFORMS arm64-darwin-21 From baf118b08cee29cf7435a8a871a7aab423baf779 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 3 Feb 2023 13:42:35 +0100 Subject: [PATCH 120/124] Add Sytest/Complement coverage to scheduled runs (#2962) This adds Sytest and Complement coverage reporting to the nightly scheduled CI runs. Fixes a few API mode related issues as well, since we seemingly never really ran them with Complement. Also fixes a bug related to device list changes: When we pass in an empty `newlyLeftRooms` slice, we got a list of all currently joined rooms with the corresponding members. When we then got the `newlyJoinedRooms`, we wouldn't update the `changed` slice, because we already got the user from the `newlyLeftRooms` query. This is fixed by simply ignoring empty `newlyLeftRooms`. --- .github/workflows/dendrite.yml | 3 +- .github/workflows/schedules.yaml | 176 ++++++++++++++++++- build/scripts/complement-cmd.sh | 2 +- clientapi/routing/leaveroom.go | 10 ++ clientapi/routing/register.go | 12 ++ clientapi/routing/routing.go | 23 ++- roomserver/internal/perform/perform_leave.go | 3 +- syncapi/internal/keychange.go | 64 +++---- 8 files changed, 253 insertions(+), 40 deletions(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 1de39850d..2dc5f74c0 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -460,7 +460,8 @@ jobs: name: Run Complement Tests env: COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} - API: ${{ matrix.api && 1 }} + COMPLEMENT_DENDRITE_API: ${{ matrix.api && 1 }} + COMPLEMENT_SHARE_ENV_PREFIX: COMPLEMENT_DENDRITE_ working-directory: complement integration-tests-done: diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index 5636c4cf9..fa2304c1f 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -19,11 +19,18 @@ jobs: fail-fast: false matrix: include: - - label: SQLite + - label: SQLite native - - label: SQLite, full HTTP APIs + - label: SQLite Cgo + cgo: 1 + + - label: SQLite native, full HTTP APIs api: full-http + - label: SQLite Cgo, full HTTP APIs + api: full-http + cgo: 1 + - label: PostgreSQL postgres: postgres @@ -41,6 +48,7 @@ jobs: API: ${{ matrix.api && 1 }} SYTEST_BRANCH: ${{ github.head_ref }} RACE_DETECTION: 1 + COVER: 1 steps: - uses: actions/checkout@v3 - uses: actions/cache@v3 @@ -73,7 +81,171 @@ jobs: path: | /logs/results.tap /logs/**/*.log* + + sytest-coverage: + timeout-minutes: 5 + name: "Sytest Coverage" + runs-on: ubuntu-latest + needs: sytest # only run once Sytest is done + if: ${{ always() }} + steps: + - uses: actions/checkout@v3 + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: '>=1.19.0' + cache: true + - name: Download all artifacts + uses: actions/download-artifact@v3 + - name: Install gocovmerge + run: go install github.com/wadey/gocovmerge@latest + - name: Run gocovmerge + run: | + find -name 'integrationcover.log' -printf '"%p"\n' | xargs gocovmerge | grep -Ev 'relayapi|setup/mscs|api_trace' > sytest.cov + go tool cover -func=sytest.cov + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./sytest.cov + flags: sytest + fail_ci_if_error: true + + # run Complement + complement: + name: "Complement (${{ matrix.label }})" + timeout-minutes: 60 + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - label: SQLite native + cgo: 0 + + - label: SQLite Cgo + cgo: 1 + + - label: SQLite native, full HTTP APIs + api: full-http + cgo: 0 + + - label: SQLite Cgo, full HTTP APIs + api: full-http + cgo: 1 + + - label: PostgreSQL + postgres: Postgres + cgo: 0 + + - label: PostgreSQL, full HTTP APIs + postgres: Postgres + api: full-http + cgo: 0 + steps: + # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. + # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path + - name: "Set Go Version" + run: | + echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH + echo "~/go/bin" >> $GITHUB_PATH + - name: "Install Complement Dependencies" + # We don't need to install Go because it is included on the Ubuntu 20.04 image: + # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 + run: | + sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev + go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest + - name: Run actions/checkout@v3 for dendrite + uses: actions/checkout@v3 + with: + path: dendrite + + # Attempt to check out the same branch of Complement as the PR. If it + # doesn't exist, fallback to main. + - name: Checkout complement + shell: bash + run: | + mkdir -p complement + # Attempt to use the version of complement which best matches the current + # build. Depending on whether this is a PR or release, etc. we need to + # use different fallbacks. + # + # 1. First check if there's a similarly named branch (GITHUB_HEAD_REF + # for pull requests, otherwise GITHUB_REF). + # 2. Attempt to use the base branch, e.g. when merging into release-vX.Y + # (GITHUB_BASE_REF for pull requests). + # 3. Use the default complement branch ("master"). + for BRANCH_NAME in "$GITHUB_HEAD_REF" "$GITHUB_BASE_REF" "${GITHUB_REF#refs/heads/}" "master"; do + # Skip empty branch names and merge commits. + if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then + continue + fi + (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break + done + # Build initial Dendrite image + - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . + working-directory: dendrite + env: + DOCKER_BUILDKIT: 1 + + - name: Create post test script + run: | + cat < /tmp/posttest.sh + #!/bin/bash + mkdir -p /tmp/Complement/logs/\$2/\$1/ + docker cp \$1:/dendrite/complementcover.log /tmp/Complement/logs/\$2/\$1/ + EOF + chmod +x /tmp/posttest.sh + # Run Complement + - run: | + set -o pipefail && + go test -v -json -tags dendrite_blacklist ./tests/... 2>&1 | gotestfmt + shell: bash + name: Run Complement Tests + env: + COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} + COMPLEMENT_DENDRITE_API: ${{ matrix.api && 1 }} + COMPLEMENT_SHARE_ENV_PREFIX: COMPLEMENT_DENDRITE_ + COMPLEMENT_DENDRITE_COVER: 1 + COMPLEMENT_POST_TEST_SCRIPT: /tmp/posttest.sh + working-directory: complement + + - name: Upload Complement logs + uses: actions/upload-artifact@v2 + if: ${{ always() }} + with: + name: Complement Logs - (Dendrite, ${{ join(matrix.*, ', ') }}) + path: | + /tmp/Complement/**/complementcover.log + + complement-coverage: + timeout-minutes: 5 + name: "Complement Coverage" + runs-on: ubuntu-latest + needs: complement # only run once Complement is done + if: ${{ always() }} + steps: + - uses: actions/checkout@v3 + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: '>=1.19.0' + cache: true + - name: Download all artifacts + uses: actions/download-artifact@v3 + - name: Install gocovmerge + run: go install github.com/wadey/gocovmerge@latest + - name: Run gocovmerge + run: | + find -name 'complementcover.log' -printf '"%p"\n' | xargs gocovmerge | grep -Ev 'relayapi|setup/mscs|api_trace' > complement.cov + go tool cover -func=complement.cov + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + files: ./complement.cov + flags: complement + fail_ci_if_error: true + element_web: timeout-minutes: 120 runs-on: ubuntu-latest diff --git a/build/scripts/complement-cmd.sh b/build/scripts/complement-cmd.sh index 061bd18eb..def091e27 100755 --- a/build/scripts/complement-cmd.sh +++ b/build/scripts/complement-cmd.sh @@ -10,7 +10,7 @@ if [[ "${COVER}" -eq 1 ]]; then --tls-key server.key \ --config dendrite.yaml \ -api=${API:-0} \ - --test.coverprofile=integrationcover.log + --test.coverprofile=complementcover.log else echo "Not running with coverage" exec /dendrite/dendrite-monolith-server \ diff --git a/clientapi/routing/leaveroom.go b/clientapi/routing/leaveroom.go index a71661851..86414afca 100644 --- a/clientapi/routing/leaveroom.go +++ b/clientapi/routing/leaveroom.go @@ -18,6 +18,7 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/httputil" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" @@ -44,6 +45,15 @@ func LeaveRoomByID( JSON: jsonerror.LeaveServerNoticeError(), } } + switch e := err.(type) { + case httputil.InternalAPIError: + if e.Message == jsonerror.LeaveServerNoticeError().Error() { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.LeaveServerNoticeError(), + } + } + } return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown(err.Error()), diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index ff6a0900e..be2b192b2 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -31,6 +31,8 @@ import ( "time" "github.com/matrix-org/dendrite/internal" + internalHTTPUtil "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" @@ -859,6 +861,16 @@ func completeRegistration( JSON: jsonerror.UserInUse("Desired user ID is already taken."), } } + switch e := err.(type) { + case internalHTTPUtil.InternalAPIError: + conflictErr := &userapi.ErrorConflict{Message: sqlutil.ErrUserExists.Error()} + if e.Message == conflictErr.Error() { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UserInUse("Desired user ID is already taken."), + } + } + } return util.JSONResponse{ Code: http.StatusInternalServerError, JSON: jsonerror.Unknown("failed to create account: " + err.Error()), diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 93f6ea901..66610c0a0 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -18,6 +18,7 @@ import ( "context" "net/http" "strings" + "sync" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/setup/base" @@ -198,18 +199,24 @@ func Setup( // server notifications if cfg.Matrix.ServerNotices.Enabled { logrus.Info("Enabling server notices at /_synapse/admin/v1/send_server_notice") - serverNotificationSender, err := getSenderDevice(context.Background(), rsAPI, userAPI, cfg) - if err != nil { - logrus.WithError(err).Fatal("unable to get account for sending sending server notices") - } + var serverNotificationSender *userapi.Device + var err error + notificationSenderOnce := &sync.Once{} synapseAdminRouter.Handle("/admin/v1/send_server_notice/{txnID}", httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + notificationSenderOnce.Do(func() { + serverNotificationSender, err = getSenderDevice(context.Background(), rsAPI, userAPI, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } + }) // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r } - vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + var vars map[string]string + vars, err = httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } @@ -225,6 +232,12 @@ func Setup( synapseAdminRouter.Handle("/admin/v1/send_server_notice", httputil.MakeAuthAPI("send_server_notice", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + notificationSenderOnce.Do(func() { + serverNotificationSender, err = getSenderDevice(context.Background(), rsAPI, userAPI, cfg) + if err != nil { + logrus.WithError(err).Fatal("unable to get account for sending sending server notices") + } + }) // not specced, but ensure we're rate limiting requests to this endpoint if r := rateLimits.Limit(req, device); r != nil { return *r diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index fa998e3e1..86f1dfaee 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -109,7 +110,7 @@ func (r *Leaver) performLeaveRoomByID( // mimic the returned values from Synapse res.Message = "You cannot reject this invite" res.Code = 403 - return nil, fmt.Errorf("You cannot reject this invite") + return nil, jsonerror.LeaveServerNoticeError() } } } diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 3d6b2a7f3..4867f7d9e 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -144,38 +144,42 @@ func TrackChangedUsers( // - Loop set of users and decrement by 1 for each user in newly left room. // - If count=0 then they share no more rooms so inform BOTH parties of this via 'left'=[...] in /sync. var queryRes roomserverAPI.QuerySharedUsersResponse - err = rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ - UserID: userID, - IncludeRoomIDs: newlyLeftRooms, - }, &queryRes) - if err != nil { - return nil, nil, err - } var stateRes roomserverAPI.QueryBulkStateContentResponse - err = rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ - RoomIDs: newlyLeftRooms, - StateTuples: []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomMember, - StateKey: "*", - }, - }, - AllowWildcards: true, - }, &stateRes) - if err != nil { - return nil, nil, err - } - for _, state := range stateRes.Rooms { - for tuple, membership := range state { - if membership != gomatrixserverlib.Join { - continue - } - queryRes.UserIDsToCount[tuple.StateKey]-- + if len(newlyLeftRooms) > 0 { + err = rsAPI.QuerySharedUsers(ctx, &roomserverAPI.QuerySharedUsersRequest{ + UserID: userID, + IncludeRoomIDs: newlyLeftRooms, + }, &queryRes) + if err != nil { + return nil, nil, err } - } - for userID, count := range queryRes.UserIDsToCount { - if count <= 0 { - left = append(left, userID) // left is returned + + err = rsAPI.QueryBulkStateContent(ctx, &roomserverAPI.QueryBulkStateContentRequest{ + RoomIDs: newlyLeftRooms, + StateTuples: []gomatrixserverlib.StateKeyTuple{ + { + EventType: gomatrixserverlib.MRoomMember, + StateKey: "*", + }, + }, + AllowWildcards: true, + }, &stateRes) + if err != nil { + return nil, nil, err + } + for _, state := range stateRes.Rooms { + for tuple, membership := range state { + if membership != gomatrixserverlib.Join { + continue + } + queryRes.UserIDsToCount[tuple.StateKey]-- + } + } + + for userID, count := range queryRes.UserIDsToCount { + if count <= 0 { + left = append(left, userID) // left is returned + } } } From 26f86a76b6720e6c9eaba8aa1088ac40251e937a Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 3 Feb 2023 09:05:57 -0700 Subject: [PATCH 121/124] Update dendrite-pinecone gobindings build script --- build/gobind-pinecone/build.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) mode change 100644 => 100755 build/gobind-pinecone/build.sh diff --git a/build/gobind-pinecone/build.sh b/build/gobind-pinecone/build.sh old mode 100644 new mode 100755 index 0f1b1aab9..51844506c --- a/build/gobind-pinecone/build.sh +++ b/build/gobind-pinecone/build.sh @@ -7,7 +7,7 @@ do case "$option" in a) gomobile bind -v -target android -trimpath -ldflags="-s -w" github.com/matrix-org/dendrite/build/gobind-pinecone ;; - i) gomobile bind -v -target ios -trimpath -ldflags="" github.com/matrix-org/dendrite/build/gobind-pinecone ;; + i) gomobile bind -v -target ios -trimpath -ldflags="" -o ~/DendriteBindings/Gobind.xcframework . ;; *) echo "No target specified, specify -a or -i"; exit 1 ;; esac done \ No newline at end of file From 4ed61740abb69b35ac7fcc1fc138dfb3d87f53c0 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Sat, 4 Feb 2023 05:55:59 -0700 Subject: [PATCH 122/124] Disable fulltext search in pinecone builds --- build/gobind-pinecone/monolith.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index e8b22c067..c4408ba18 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -340,6 +340,9 @@ func (m *DendriteMonolith) Start() { cfg.Global.ServerName = gomatrixserverlib.ServerName(hex.EncodeToString(pk)) cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) cfg.Global.JetStream.InMemory = false + // NOTE : disabled for now since there is a 64 bit alignment panic on 32 bit systems + // This isn't actually fixed: https://github.com/blevesearch/zapx/pull/147 + cfg.SyncAPI.Fulltext.Enabled = false enableRelaying := false enableMetrics := false From cf254ba0445e2509f77f41dbec69f632b126b847 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Sat, 4 Feb 2023 06:05:39 -0700 Subject: [PATCH 123/124] Add max frame size to pinecone bindings --- build/gobind-pinecone/monolith.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index c4408ba18..16797eec0 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -46,6 +46,8 @@ const ( PeerTypeMulticast = pineconeRouter.PeerTypeMulticast PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth PeerTypeBonjour = pineconeRouter.PeerTypeBonjour + + MaxFrameSize = types.MaxFrameSize ) // Re-export Conduit in this package for bindings. From eb29a315507f0075c2c6a495ac59c64a7f45f9fc Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 7 Feb 2023 14:31:23 +0100 Subject: [PATCH 124/124] Optimize `/sync` and history visibility (#2961) Should fix the following issues or make a lot less worse when using Postgres: The main issue behind #2911: The client gives up after a certain time, causing a cascade of context errors, because the response couldn't be built up fast enough. This mostly happens on accounts with many rooms, due to the inefficient way we're getting recent events and current state For #2777: The queries for getting the membership events for history visibility were being executed for each room (I think 185?), resulting in a whooping 2k queries for membership events. (Getting the statesnapshot -> block nids -> actual wanted membership event) Both should now be better by: - Using a LATERAL join to get all recent events for all joined rooms in one go (TODO: maybe do the same for room summary and current state etc) - If we're lazy loading on initial syncs, we're now not getting the whole current state, just to drop the majority of it because we're lazy loading members - we add a filter to exclude membership events on the first call to `CurrentState`. - Using an optimized query to get the membership events needed to calculate history visibility --------- Co-authored-by: kegsay --- roomserver/api/query.go | 9 +- roomserver/internal/query/query.go | 26 ++- roomserver/storage/interface.go | 7 + roomserver/storage/postgres/events_table.go | 3 + .../storage/postgres/state_snapshot_table.go | 69 ++++++- roomserver/storage/shared/storage.go | 6 + .../storage/sqlite3/state_snapshot_table.go | 5 + roomserver/storage/tables/interface.go | 4 + .../tables/state_snapshot_table_test.go | 2 + syncapi/internal/history_visibility.go | 27 ++- syncapi/routing/context.go | 67 ++++--- syncapi/routing/messages.go | 44 +---- syncapi/storage/interface.go | 2 +- .../postgres/current_room_state_table.go | 9 + .../deltas/20230201152200_rename_index.go | 29 +++ .../postgres/output_room_events_table.go | 162 ++++++++++++---- syncapi/storage/shared/storage_sync.go | 54 ++++-- syncapi/storage/shared/storage_sync_test.go | 72 ++++++++ syncapi/storage/sqlite3/account_data_table.go | 4 +- .../sqlite3/current_room_state_table.go | 9 + .../sqlite3/output_room_events_table.go | 92 +++++----- syncapi/storage/storage_test.go | 73 +++++++- .../storage/tables/current_room_state_test.go | 40 ++++ syncapi/storage/tables/interface.go | 2 +- syncapi/streams/stream_pdu.go | 64 ++++--- syncapi/syncapi_test.go | 173 ++++++++++++++++++ syncapi/types/types.go | 5 + 27 files changed, 838 insertions(+), 221 deletions(-) create mode 100644 syncapi/storage/postgres/deltas/20230201152200_rename_index.go create mode 100644 syncapi/storage/shared/storage_sync_test.go diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 76f8298ca..4ef548e19 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -433,7 +433,7 @@ func (r *QueryCurrentStateResponse) UnmarshalJSON(data []byte) error { return nil } -// QueryMembershipAtEventRequest requests the membership events for a user +// QueryMembershipAtEventRequest requests the membership event for a user // for a list of eventIDs. type QueryMembershipAtEventRequest struct { RoomID string @@ -443,9 +443,10 @@ type QueryMembershipAtEventRequest struct { // QueryMembershipAtEventResponse is the response to QueryMembershipAtEventRequest. type QueryMembershipAtEventResponse struct { - // Memberships is a map from eventID to a list of events (if any). Events that - // do not have known state will return an empty array here. - Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` + // Membership is a map from eventID to membership event. Events that + // do not have known state will return a nil event, resulting in a "leave" membership + // when calculating history visibility. + Membership map[string]*gomatrixserverlib.HeaderedEvent `json:"membership"` } // QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 69d841dda..1083bb237 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" @@ -216,7 +217,8 @@ func (r *Queryer) QueryMembershipAtEvent( request *api.QueryMembershipAtEventRequest, response *api.QueryMembershipAtEventResponse, ) error { - response.Memberships = make(map[string][]*gomatrixserverlib.HeaderedEvent) + response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) + info, err := r.DB.RoomInfo(ctx, request.RoomID) if err != nil { return fmt.Errorf("unable to get roomInfo: %w", err) @@ -234,7 +236,17 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("requested stateKeyNID for %s was not found", request.UserID) } - stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, info, request.EventIDs, stateKeyNIDs[request.UserID]) + response.Membership, err = r.DB.GetMembershipForHistoryVisibility(ctx, stateKeyNIDs[request.UserID], info, request.EventIDs...) + switch err { + case nil: + return nil + case tables.OptimisationNotSupportedError: // fallthrough, slow way of getting the membership events for each event + default: + return err + } + + response.Membership = make(map[string]*gomatrixserverlib.HeaderedEvent) + stateEntries, err := helpers.MembershipAtEvent(ctx, r.DB, nil, request.EventIDs, stateKeyNIDs[request.UserID]) if err != nil { return fmt.Errorf("unable to get state before event: %w", err) } @@ -258,7 +270,7 @@ func (r *Queryer) QueryMembershipAtEvent( for _, eventID := range request.EventIDs { stateEntry, ok := stateEntries[eventID] if !ok || len(stateEntry) == 0 { - response.Memberships[eventID] = []*gomatrixserverlib.HeaderedEvent{} + response.Membership[eventID] = nil continue } @@ -275,15 +287,15 @@ func (r *Queryer) QueryMembershipAtEvent( return fmt.Errorf("unable to get memberships at state: %w", err) } - res := make([]*gomatrixserverlib.HeaderedEvent, 0, len(memberships)) - + // Iterate over all membership events we got. Given we only query the membership for + // one user and assuming this user only ever has one membership event associated to + // a given event, overwrite any other existing membership events. for i := range memberships { ev := memberships[i] if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(request.UserID) { - res = append(res, ev.Headered(info.RoomVersion)) + response.Membership[eventID] = ev.Event.Headered(info.RoomVersion) } } - response.Memberships[eventID] = res } return nil diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index e0b9c56b3..998915122 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -175,4 +175,11 @@ type Database interface { GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) PurgeRoom(ctx context.Context, roomID string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error + + // GetMembershipForHistoryVisibility queries the membership events for the given eventIDs. + // Returns a map from (input) eventID -> membership event. If no membership event is found, returns an empty event, resulting in + // a membership of "leave" when calculating history visibility. + GetMembershipForHistoryVisibility( + ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string, + ) (map[string]*gomatrixserverlib.HeaderedEvent, error) } diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 9b5ed6eda..f4a21c8a7 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -69,6 +69,9 @@ CREATE TABLE IF NOT EXISTS roomserver_events ( auth_event_nids BIGINT[] NOT NULL, is_rejected BOOLEAN NOT NULL DEFAULT FALSE ); + +-- Create an index which helps in resolving membership events (event_type_nid = 5) - (used for history visibility) +CREATE INDEX IF NOT EXISTS roomserver_events_memberships_idx ON roomserver_events (room_nid, event_state_key_nid) WHERE (event_type_nid = 5); ` const insertEventSQL = "" + diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index a00c026f4..0e83cfc25 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -21,10 +21,10 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -99,10 +99,26 @@ const bulkSelectStateForHistoryVisibilitySQL = ` AND (event_type_nid = 7 OR event_state_key LIKE '%:' || $2); ` +// bulkSelectMembershipForHistoryVisibilitySQL is an optimization to get membership events for a specific user for defined set of events. +// Returns the event_id of the event we want the membership event for, the event_id of the membership event and the membership event JSON. +const bulkSelectMembershipForHistoryVisibilitySQL = ` +SELECT re.event_id, re2.event_id, rej.event_json +FROM roomserver_events re +LEFT JOIN roomserver_state_snapshots rss on re.state_snapshot_nid = rss.state_snapshot_nid +CROSS JOIN unnest(rss.state_block_nids) AS blocks(block_nid) +LEFT JOIN roomserver_state_block rsb ON rsb.state_block_nid = blocks.block_nid +CROSS JOIN unnest(rsb.event_nids) AS rsb2(event_nid) +JOIN roomserver_events re2 ON re2.room_nid = $3 AND re2.event_type_nid = 5 AND re2.event_nid = rsb2.event_nid AND re2.event_state_key_nid = $1 +LEFT JOIN roomserver_event_json rej ON rej.event_nid = re2.event_nid +WHERE re.event_id = ANY($2) + +` + type stateSnapshotStatements struct { - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt - bulkSelectStateForHistoryVisibilityStmt *sql.Stmt + insertStateStmt *sql.Stmt + bulkSelectStateBlockNIDsStmt *sql.Stmt + bulkSelectStateForHistoryVisibilityStmt *sql.Stmt + bulktSelectMembershipForHistoryVisibilityStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -110,13 +126,14 @@ func CreateStateSnapshotTable(db *sql.DB) error { return err } -func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) { s := &stateSnapshotStatements{} return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, {&s.bulkSelectStateForHistoryVisibilityStmt, bulkSelectStateForHistoryVisibilitySQL}, + {&s.bulktSelectMembershipForHistoryVisibilityStmt, bulkSelectMembershipForHistoryVisibilitySQL}, }.Prepare(db) } @@ -185,3 +202,45 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( } return results, rows.Err() } + +func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility( + ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, +) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + stmt := sqlutil.TxStmt(txn, s.bulktSelectMembershipForHistoryVisibilityStmt) + rows, err := stmt.QueryContext(ctx, userNID, pq.Array(eventIDs), roomInfo.RoomNID) + if err != nil { + return nil, err + } + defer rows.Close() // nolint: errcheck + result := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + var evJson []byte + var eventID string + var membershipEventID string + + knownEvents := make(map[string]*gomatrixserverlib.HeaderedEvent, len(eventIDs)) + + for rows.Next() { + if err = rows.Scan(&eventID, &membershipEventID, &evJson); err != nil { + return nil, err + } + if len(evJson) == 0 { + result[eventID] = &gomatrixserverlib.HeaderedEvent{} + continue + } + // If we already know this event, don't try to marshal the json again + if ev, ok := knownEvents[membershipEventID]; ok { + result[eventID] = ev + continue + } + event, err := gomatrixserverlib.NewEventFromTrustedJSON(evJson, false, roomInfo.RoomVersion) + if err != nil { + result[eventID] = &gomatrixserverlib.HeaderedEvent{} + // not fatal + continue + } + he := event.Headered(roomInfo.RoomVersion) + result[eventID] = he + knownEvents[membershipEventID] = he + } + return result, rows.Err() +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 654b078d2..f8672496b 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool { return true } +func (d *Database) GetMembershipForHistoryVisibility( + ctx context.Context, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, +) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + return d.StateSnapshotTable.BulkSelectMembershipForHistoryVisibility(ctx, nil, userNID, roomInfo, eventIDs...) +} + func (d *Database) EventTypeNIDs( ctx context.Context, eventTypes []string, ) (map[string]types.EventTypeNID, error) { diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 930ad14dd..e57e1a4bf 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -152,6 +153,10 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( return nil, tables.OptimisationNotSupportedError } +func (s *stateSnapshotStatements) BulkSelectMembershipForHistoryVisibility(ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string) (map[string]*gomatrixserverlib.HeaderedEvent, error) { + return nil, tables.OptimisationNotSupportedError +} + func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.StateBlockNID, error) { diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 64145f83d..c7f1064db 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -91,6 +91,10 @@ type StateSnapshot interface { // which users are in a room faster than having to load the entire room state. In the // case of SQLite, this will return tables.OptimisationNotSupportedError. BulkSelectStateForHistoryVisibility(ctx context.Context, txn *sql.Tx, stateSnapshotNID types.StateSnapshotNID, domain string) ([]types.EventNID, error) + + BulkSelectMembershipForHistoryVisibility( + ctx context.Context, txn *sql.Tx, userNID types.EventStateKeyNID, roomInfo *types.RoomInfo, eventIDs ...string, + ) (map[string]*gomatrixserverlib.HeaderedEvent, error) } type StateBlock interface { diff --git a/roomserver/storage/tables/state_snapshot_table_test.go b/roomserver/storage/tables/state_snapshot_table_test.go index b2e59377d..c7c991b20 100644 --- a/roomserver/storage/tables/state_snapshot_table_test.go +++ b/roomserver/storage/tables/state_snapshot_table_test.go @@ -29,6 +29,8 @@ func mustCreateStateSnapshotTable(t *testing.T, dbType test.DBType) (tab tables. assert.NoError(t, err) err = postgres.CreateEventsTable(db) assert.NoError(t, err) + err = postgres.CreateEventJSONTable(db) + assert.NoError(t, err) err = postgres.CreateStateBlockTable(db) assert.NoError(t, err) // ... and then the snapshot table itself diff --git a/syncapi/internal/history_visibility.go b/syncapi/internal/history_visibility.go index 71d7ddd15..87c59e03e 100644 --- a/syncapi/internal/history_visibility.go +++ b/syncapi/internal/history_visibility.go @@ -121,10 +121,7 @@ func ApplyHistoryVisibilityFilter( // Get the mapping from eventID -> eventVisibility eventsFiltered := make([]*gomatrixserverlib.HeaderedEvent, 0, len(events)) - visibilities, err := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) - if err != nil { - return eventsFiltered, err - } + visibilities := visibilityForEvents(ctx, rsAPI, events, userID, events[0].RoomID()) for _, ev := range events { evVis := visibilities[ev.EventID()] evVis.membershipCurrent = membershipCurrent @@ -175,7 +172,7 @@ func visibilityForEvents( rsAPI api.SyncRoomserverAPI, events []*gomatrixserverlib.HeaderedEvent, userID, roomID string, -) (map[string]eventVisibility, error) { +) map[string]eventVisibility { eventIDs := make([]string, len(events)) for i := range events { eventIDs[i] = events[i].EventID() @@ -185,6 +182,7 @@ func visibilityForEvents( // get the membership events for all eventIDs membershipResp := &api.QueryMembershipAtEventResponse{} + err := rsAPI.QueryMembershipAtEvent(ctx, &api.QueryMembershipAtEventRequest{ RoomID: roomID, EventIDs: eventIDs, @@ -201,19 +199,20 @@ func visibilityForEvents( membershipAtEvent: gomatrixserverlib.Leave, // default to leave, to not expose events by accident visibility: event.Visibility, } - membershipEvs, ok := membershipResp.Memberships[eventID] - if !ok { + ev, ok := membershipResp.Membership[eventID] + if !ok || ev == nil { result[eventID] = vis continue } - for _, ev := range membershipEvs { - membership, err := ev.Membership() - if err != nil { - return result, err - } - vis.membershipAtEvent = membership + + membership, err := ev.Membership() + if err != nil { + result[eventID] = vis + continue } + vis.membershipAtEvent = membership + result[eventID] = vis } - return result, nil + return result } diff --git a/syncapi/routing/context.go b/syncapi/routing/context.go index 095a868c7..76f003671 100644 --- a/syncapi/routing/context.go +++ b/syncapi/routing/context.go @@ -67,6 +67,8 @@ func Context( errMsg = "unable to parse filter" case *strconv.NumError: errMsg = "unable to parse limit" + default: + errMsg = err.Error() } return util.JSONResponse{ Code: http.StatusBadRequest, @@ -167,7 +169,18 @@ func Context( eventsBeforeClient := gomatrixserverlib.HeaderedToClientEvents(eventsBeforeFiltered, gomatrixserverlib.FormatAll) eventsAfterClient := gomatrixserverlib.HeaderedToClientEvents(eventsAfterFiltered, gomatrixserverlib.FormatAll) - newState := applyLazyLoadMembers(device, filter, eventsAfterClient, eventsBeforeClient, state, lazyLoadCache) + + newState := state + if filter.LazyLoadMembers { + allEvents := append(eventsBeforeFiltered, eventsAfterFiltered...) + allEvents = append(allEvents, &requestedEvent) + evs := gomatrixserverlib.HeaderedToClientEvents(allEvents, gomatrixserverlib.FormatAll) + newState, err = applyLazyLoadMembers(ctx, device, snapshot, roomID, evs, lazyLoadCache) + if err != nil { + logrus.WithError(err).Error("unable to load membership events") + return jsonerror.InternalServerError() + } + } ev := gomatrixserverlib.HeaderedToClientEvent(&requestedEvent, gomatrixserverlib.FormatAll) response := ContextRespsonse{ @@ -244,41 +257,43 @@ func getStartEnd(ctx context.Context, snapshot storage.DatabaseTransaction, star } func applyLazyLoadMembers( + ctx context.Context, device *userapi.Device, - filter *gomatrixserverlib.RoomEventFilter, - eventsAfter, eventsBefore []gomatrixserverlib.ClientEvent, - state []*gomatrixserverlib.HeaderedEvent, + snapshot storage.DatabaseTransaction, + roomID string, + events []gomatrixserverlib.ClientEvent, lazyLoadCache caching.LazyLoadCache, -) []*gomatrixserverlib.HeaderedEvent { - if filter == nil || !filter.LazyLoadMembers { - return state - } - allEvents := append(eventsBefore, eventsAfter...) - x := make(map[string]struct{}) +) ([]*gomatrixserverlib.HeaderedEvent, error) { + eventSenders := make(map[string]struct{}) // get members who actually send an event - for _, e := range allEvents { + for _, e := range events { // Don't add membership events the client should already know about if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, e.RoomID, e.Sender); cached { continue } - x[e.Sender] = struct{}{} + eventSenders[e.Sender] = struct{}{} } - newState := []*gomatrixserverlib.HeaderedEvent{} - membershipEvents := []*gomatrixserverlib.HeaderedEvent{} - for _, event := range state { - if event.Type() != gomatrixserverlib.MRoomMember { - newState = append(newState, event) - } else { - // did the user send an event? - if _, ok := x[event.Sender()]; ok { - membershipEvents = append(membershipEvents, event) - lazyLoadCache.StoreLazyLoadedUser(device, event.RoomID(), event.Sender(), event.EventID()) - } - } + wantUsers := make([]string, 0, len(eventSenders)) + for userID := range eventSenders { + wantUsers = append(wantUsers, userID) } - // Add the membershipEvents to the end of the list, to make Sytest happy - return append(newState, membershipEvents...) + + // Query missing membership events + filter := gomatrixserverlib.DefaultStateFilter() + filter.Senders = &wantUsers + filter.Types = &[]string{gomatrixserverlib.MRoomMember} + memberships, err := snapshot.GetStateEventsForRoom(ctx, roomID, &filter) + if err != nil { + return nil, err + } + + // cache the membership events + for _, membership := range memberships { + lazyLoadCache.StoreLazyLoadedUser(device, roomID, *membership.StateKey(), membership.EventID()) + } + + return memberships, nil } func parseRoomEventFilter(req *http.Request) (*gomatrixserverlib.RoomEventFilter, error) { diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 4a01ec357..02d8fcc7e 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -64,6 +64,7 @@ type messagesResp struct { // OnIncomingMessagesRequest implements the /messages endpoint from the // client-server API. // See: https://matrix.org/docs/spec/client_server/latest.html#get-matrix-client-r0-rooms-roomid-messages +// nolint:gocyclo func OnIncomingMessagesRequest( req *http.Request, db storage.Database, roomID string, device *userapi.Device, rsAPI api.SyncRoomserverAPI, @@ -246,7 +247,14 @@ func OnIncomingMessagesRequest( Start: start.String(), End: end.String(), } - res.applyLazyLoadMembers(req.Context(), snapshot, roomID, device, filter.LazyLoadMembers, lazyLoadCache) + if filter.LazyLoadMembers { + membershipEvents, err := applyLazyLoadMembers(req.Context(), device, snapshot, roomID, clientEvents, lazyLoadCache) + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("failed to apply lazy loading") + return jsonerror.InternalServerError() + } + res.State = append(res.State, gomatrixserverlib.HeaderedToClientEvents(membershipEvents, gomatrixserverlib.FormatAll)...) + } // If we didn't return any events, set the end to an empty string, so it will be omitted // in the response JSON. @@ -265,40 +273,6 @@ func OnIncomingMessagesRequest( } } -// applyLazyLoadMembers loads membership events for users returned in Chunk, if the filter has -// LazyLoadMembers enabled. -func (m *messagesResp) applyLazyLoadMembers( - ctx context.Context, - db storage.DatabaseTransaction, - roomID string, - device *userapi.Device, - lazyLoad bool, - lazyLoadCache caching.LazyLoadCache, -) { - if !lazyLoad { - return - } - membershipToUser := make(map[string]*gomatrixserverlib.HeaderedEvent) - for _, evt := range m.Chunk { - // Don't add membership events the client should already know about - if _, cached := lazyLoadCache.IsLazyLoadedUserCached(device, roomID, evt.Sender); cached { - continue - } - membership, err := db.GetStateEvent(ctx, roomID, gomatrixserverlib.MRoomMember, evt.Sender) - if err != nil { - util.GetLogger(ctx).WithError(err).Error("failed to get membership event for user") - continue - } - if membership != nil { - membershipToUser[evt.Sender] = membership - lazyLoadCache.StoreLazyLoadedUser(device, roomID, evt.Sender, membership.EventID()) - } - } - for _, evt := range membershipToUser { - m.State = append(m.State, gomatrixserverlib.HeaderedToClientEvent(evt, gomatrixserverlib.FormatAll)) - } -} - func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) { req := api.QueryMembershipForUserRequest{ RoomID: roomID, diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index a7a127e3a..04c2020a0 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -46,7 +46,7 @@ type DatabaseTransaction interface { RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) - RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) + RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) InviteEventsInRange(ctx context.Context, targetUserID string, r types.Range) (map[string]*gomatrixserverlib.HeaderedEvent, map[string]*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 3caafa14b..0d607b7c0 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -275,6 +275,15 @@ func (s *currentRoomStateStatements) SelectCurrentState( ) ([]*gomatrixserverlib.HeaderedEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt) senders, notSenders := getSendersStateFilterFilter(stateFilter) + // We're going to query members later, so remove them from this request + if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { + notTypes := &[]string{gomatrixserverlib.MRoomMember} + if stateFilter.NotTypes != nil { + *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + } else { + stateFilter.NotTypes = notTypes + } + } rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(senders), pq.StringArray(notSenders), diff --git a/syncapi/storage/postgres/deltas/20230201152200_rename_index.go b/syncapi/storage/postgres/deltas/20230201152200_rename_index.go new file mode 100644 index 000000000..5a0ec5050 --- /dev/null +++ b/syncapi/storage/postgres/deltas/20230201152200_rename_index.go @@ -0,0 +1,29 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "context" + "database/sql" + "fmt" +) + +func UpRenameOutputRoomEventsIndex(ctx context.Context, tx *sql.Tx) error { + _, err := tx.ExecContext(ctx, `ALTER TABLE syncapi_output_room_events RENAME CONSTRAINT syncapi_event_id_idx TO syncapi_output_room_event_id_idx;`) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 0075fc8d3..59fb99aa3 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -19,18 +19,17 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "sort" + "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - - "github.com/lib/pq" "github.com/matrix-org/gomatrixserverlib" - - "github.com/matrix-org/dendrite/internal/sqlutil" ) const outputRoomEventsSchema = ` @@ -44,7 +43,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events ( -- This isn't a problem for us since we just want to order by this field. id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_stream_id'), -- The event ID for the event - event_id TEXT NOT NULL CONSTRAINT syncapi_event_id_idx UNIQUE, + event_id TEXT NOT NULL CONSTRAINT syncapi_output_room_event_id_idx UNIQUE, -- The 'room_id' key for the event. room_id TEXT NOT NULL, -- The headered JSON for the event, containing potentially additional metadata such as @@ -79,13 +78,16 @@ CREATE INDEX IF NOT EXISTS syncapi_output_room_events_room_id_idx ON syncapi_out CREATE INDEX IF NOT EXISTS syncapi_output_room_events_exclude_from_sync_idx ON syncapi_output_room_events (exclude_from_sync); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_add_state_ids_idx ON syncapi_output_room_events ((add_state_ids IS NOT NULL)); CREATE INDEX IF NOT EXISTS syncapi_output_room_events_remove_state_ids_idx ON syncapi_output_room_events ((remove_state_ids IS NOT NULL)); +CREATE INDEX IF NOT EXISTS syncapi_output_room_events_recent_events_idx ON syncapi_output_room_events (room_id, exclude_from_sync, id, sender, type); + + ` const insertEventSQL = "" + "INSERT INTO syncapi_output_room_events (" + "room_id, event_id, headered_event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id, exclude_from_sync, history_visibility" + ") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) " + - "ON CONFLICT ON CONSTRAINT syncapi_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " + + "ON CONFLICT ON CONSTRAINT syncapi_output_room_event_id_idx DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $11) " + "RETURNING id" const selectEventsSQL = "" + @@ -109,14 +111,29 @@ const selectRecentEventsSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id DESC LIMIT $8" -const selectRecentEventsForSyncSQL = "" + - "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + - " WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" + - " AND ( $4::text[] IS NULL OR sender = ANY($4) )" + - " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + - " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + - " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + - " ORDER BY id DESC LIMIT $8" +// selectRecentEventsForSyncSQL contains an optimization to get the recent events for a list of rooms, using a LATERAL JOIN +// The sub select inside LATERAL () is executed for all room_ids it gets as a parameter $1 +const selectRecentEventsForSyncSQL = ` +WITH room_ids AS ( + SELECT unnest($1::text[]) AS room_id +) +SELECT x.* +FROM room_ids, + LATERAL ( + SELECT room_id, event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility + FROM syncapi_output_room_events recent_events + WHERE + recent_events.room_id = room_ids.room_id + AND recent_events.exclude_from_sync = FALSE + AND id > $2 AND id <= $3 + AND ( $4::text[] IS NULL OR sender = ANY($4) ) + AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) ) + AND ( $6::text[] IS NULL OR type LIKE ANY($6) ) + AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) ) + ORDER BY recent_events.id DESC + LIMIT $8 + ) AS x +` const selectEarlyEventsSQL = "" + "SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id, history_visibility FROM syncapi_output_room_events" + @@ -207,12 +224,30 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { return nil, err } + migrationName := "syncapi: rename dupe index (output_room_events)" + + var cName string + err = db.QueryRowContext(context.Background(), "select constraint_name from information_schema.table_constraints where table_name = 'syncapi_output_room_events' AND constraint_name = 'syncapi_event_id_idx'").Scan(&cName) + switch err { + case sql.ErrNoRows: // migration was already executed, as the index was renamed + if err = sqlutil.InsertMigration(context.Background(), db, migrationName); err != nil { + return nil, fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err) + } + case nil: + default: + return nil, err + } + m := sqlutil.NewMigrator(db) m.AddMigrations( sqlutil.Migration{ Version: "syncapi: add history visibility column (output_room_events)", Up: deltas.UpAddHistoryVisibilityColumnOutputRoomEvents, }, + sqlutil.Migration{ + Version: migrationName, + Up: deltas.UpRenameOutputRoomEventsIndex, + }, ) err = m.Up(context.Background()) if err != nil { @@ -398,9 +433,9 @@ func (s *outputRoomEventsStatements) InsertEvent( // from sync. func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + roomIDs []string, ra types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool, -) ([]types.StreamEvent, bool, error) { +) (map[string]types.RecentEvents, error) { var stmt *sql.Stmt if onlySyncEvents { stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt) @@ -408,8 +443,9 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt) } senders, notSenders := getSendersRoomEventFilter(eventFilter) + rows, err := stmt.QueryContext( - ctx, roomID, r.Low(), r.High(), + ctx, pq.StringArray(roomIDs), ra.Low(), ra.High(), pq.StringArray(senders), pq.StringArray(notSenders), pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)), @@ -417,34 +453,80 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( eventFilter.Limit+1, ) if err != nil { - return nil, false, err + return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, false, err - } - if chronologicalOrder { - // The events need to be returned from oldest to latest, which isn't - // necessary the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - } - // we queried for 1 more than the limit, so if we returned one more mark limited=true - limited := false - if len(events) > eventFilter.Limit { - limited = true - // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. - if chronologicalOrder { - events = events[1:] - } else { - events = events[:len(events)-1] + + result := make(map[string]types.RecentEvents) + + for rows.Next() { + var ( + roomID string + eventID string + streamPos types.StreamPosition + eventBytes []byte + excludeFromSync bool + sessionID *int64 + txnID *string + transactionID *api.TransactionID + historyVisibility gomatrixserverlib.HistoryVisibility + ) + if err := rows.Scan(&roomID, &eventID, &streamPos, &eventBytes, &sessionID, &excludeFromSync, &txnID, &historyVisibility); err != nil { + return nil, err } + // TODO: Handle redacted events + var ev gomatrixserverlib.HeaderedEvent + if err := ev.UnmarshalJSONWithEventID(eventBytes, eventID); err != nil { + return nil, err + } + + if sessionID != nil && txnID != nil { + transactionID = &api.TransactionID{ + SessionID: *sessionID, + TransactionID: *txnID, + } + } + + r := result[roomID] + + ev.Visibility = historyVisibility + r.Events = append(r.Events, types.StreamEvent{ + HeaderedEvent: &ev, + StreamPosition: streamPos, + TransactionID: transactionID, + ExcludeFromSync: excludeFromSync, + }) + + result[roomID] = r } - return events, limited, nil + if chronologicalOrder { + for roomID, evs := range result { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(evs.Events, func(i int, j int) bool { + return evs.Events[i].StreamPosition < evs.Events[j].StreamPosition + }) + + if len(evs.Events) > eventFilter.Limit { + evs.Limited = true + evs.Events = evs.Events[1:] + } + + result[roomID] = evs + } + } else { + for roomID, evs := range result { + if len(evs.Events) > eventFilter.Limit { + evs.Limited = true + evs.Events = evs.Events[:len(evs.Events)-1] + } + + result[roomID] = evs + } + } + return result, rows.Err() } // selectEarlyEvents returns the earliest events in the given room, starting diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 8385b95a5..931bc9e23 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -151,8 +151,8 @@ func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID return summary, nil } -func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { - return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents) +func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) { + return d.OutputEvents.SelectRecentEvents(ctx, d.txn, roomIDs, r, eventFilter, chronologicalOrder, onlySyncEvents) } func (d *DatabaseTransaction) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) { @@ -370,19 +370,25 @@ func (d *DatabaseTransaction) GetStateDeltas( } // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeededFiltered, eventMapFiltered, err := d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil + stateFiltered := state + // avoid hitting the database if the result would be the same as above + if !isStatefilterEmpty(stateFilter) { + var stateNeededFiltered map[string]map[string]bool + var eventMapFiltered map[string]types.StreamEvent + stateNeededFiltered, eventMapFiltered, err = d.OutputEvents.SelectStateInRange(ctx, d.txn, r, stateFilter, allRoomIDs) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err } - return nil, nil, err - } - stateFiltered, err := d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) - if err != nil { - if err == sql.ErrNoRows { - return nil, nil, nil + stateFiltered, err = d.fetchStateEvents(ctx, d.txn, stateNeededFiltered, eventMapFiltered) + if err != nil { + if err == sql.ErrNoRows { + return nil, nil, nil + } + return nil, nil, err } - return nil, nil, err } // find out which rooms this user is peeking, if any. @@ -701,6 +707,28 @@ func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) return types.StreamPosition(id), err } +func isStatefilterEmpty(filter *gomatrixserverlib.StateFilter) bool { + if filter == nil { + return true + } + switch { + case filter.NotTypes != nil && len(*filter.NotTypes) > 0: + return false + case filter.Types != nil && len(*filter.Types) > 0: + return false + case filter.Senders != nil && len(*filter.Senders) > 0: + return false + case filter.NotSenders != nil && len(*filter.NotSenders) > 0: + return false + case filter.NotRooms != nil && len(*filter.NotRooms) > 0: + return false + case filter.ContainsURL != nil: + return false + default: + return true + } +} + func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) ( events []types.StreamEvent, prevBatch, nextBatch string, err error, ) { diff --git a/syncapi/storage/shared/storage_sync_test.go b/syncapi/storage/shared/storage_sync_test.go new file mode 100644 index 000000000..c56720db7 --- /dev/null +++ b/syncapi/storage/shared/storage_sync_test.go @@ -0,0 +1,72 @@ +package shared + +import ( + "testing" + + "github.com/matrix-org/gomatrixserverlib" +) + +func Test_isStatefilterEmpty(t *testing.T) { + filterSet := []string{"a"} + boolValue := false + + tests := []struct { + name string + filter *gomatrixserverlib.StateFilter + want bool + }{ + { + name: "nil filter is empty", + filter: nil, + want: true, + }, + { + name: "Empty filter is empty", + filter: &gomatrixserverlib.StateFilter{}, + want: true, + }, + { + name: "NotTypes is set", + filter: &gomatrixserverlib.StateFilter{ + NotTypes: &filterSet, + }, + }, + { + name: "Types is set", + filter: &gomatrixserverlib.StateFilter{ + Types: &filterSet, + }, + }, + { + name: "Senders is set", + filter: &gomatrixserverlib.StateFilter{ + Senders: &filterSet, + }, + }, + { + name: "NotSenders is set", + filter: &gomatrixserverlib.StateFilter{ + NotSenders: &filterSet, + }, + }, + { + name: "NotRooms is set", + filter: &gomatrixserverlib.StateFilter{ + NotRooms: &filterSet, + }, + }, + { + name: "ContainsURL is set", + filter: &gomatrixserverlib.StateFilter{ + ContainsURL: &boolValue, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isStatefilterEmpty(tt.filter); got != tt.want { + t.Errorf("isStatefilterEmpty() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/syncapi/storage/sqlite3/account_data_table.go b/syncapi/storage/sqlite3/account_data_table.go index d8967113a..de0e72dbd 100644 --- a/syncapi/storage/sqlite3/account_data_table.go +++ b/syncapi/storage/sqlite3/account_data_table.go @@ -105,7 +105,9 @@ func (s *accountDataStatements) SelectAccountDataInRange( filter.Senders, filter.NotSenders, filter.Types, filter.NotTypes, []string{}, nil, filter.Limit, FilterOrderAsc) - + if err != nil { + return + } rows, err := stmt.QueryContext(ctx, params...) if err != nil { return diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 6bc1b267a..35b746c5c 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -267,6 +267,15 @@ func (s *currentRoomStateStatements) SelectCurrentState( stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string, ) ([]*gomatrixserverlib.HeaderedEvent, error) { + // We're going to query members later, so remove them from this request + if stateFilter.LazyLoadMembers && !stateFilter.IncludeRedundantMembers { + notTypes := &[]string{gomatrixserverlib.MRoomMember} + if stateFilter.NotTypes != nil { + *stateFilter.NotTypes = append(*stateFilter.NotTypes, gomatrixserverlib.MRoomMember) + } else { + stateFilter.NotTypes = notTypes + } + } stmt, params, err := prepareWithFilters( s.db, txn, selectCurrentStateSQL, []interface{}{ diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index db708c083..23bc68a41 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -368,9 +368,9 @@ func (s *outputRoomEventsStatements) InsertEvent( func (s *outputRoomEventsStatements) SelectRecentEvents( ctx context.Context, txn *sql.Tx, - roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, + roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool, -) ([]types.StreamEvent, bool, error) { +) (map[string]types.RecentEvents, error) { var query string if onlySyncEvents { query = selectRecentEventsForSyncSQL @@ -378,49 +378,55 @@ func (s *outputRoomEventsStatements) SelectRecentEvents( query = selectRecentEventsSQL } - stmt, params, err := prepareWithFilters( - s.db, txn, query, - []interface{}{ - roomID, r.Low(), r.High(), - }, - eventFilter.Senders, eventFilter.NotSenders, - eventFilter.Types, eventFilter.NotTypes, - nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, - ) - if err != nil { - return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err) - } - defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") - - rows, err := stmt.QueryContext(ctx, params...) - if err != nil { - return nil, false, err - } - defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") - events, err := rowsToStreamEvents(rows) - if err != nil { - return nil, false, err - } - if chronologicalOrder { - // The events need to be returned from oldest to latest, which isn't - // necessary the way the SQL query returns them, so a sort is necessary to - // ensure the events are in the right order in the slice. - sort.SliceStable(events, func(i int, j int) bool { - return events[i].StreamPosition < events[j].StreamPosition - }) - } - // we queried for 1 more than the limit, so if we returned one more mark limited=true - limited := false - if len(events) > eventFilter.Limit { - limited = true - // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. - if chronologicalOrder { - events = events[1:] - } else { - events = events[:len(events)-1] + result := make(map[string]types.RecentEvents, len(roomIDs)) + for _, roomID := range roomIDs { + stmt, params, err := prepareWithFilters( + s.db, txn, query, + []interface{}{ + roomID, r.Low(), r.High(), + }, + eventFilter.Senders, eventFilter.NotSenders, + eventFilter.Types, eventFilter.NotTypes, + nil, eventFilter.ContainsURL, eventFilter.Limit+1, FilterOrderDesc, + ) + if err != nil { + return nil, fmt.Errorf("s.prepareWithFilters: %w", err) } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRecentEvents: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed") + events, err := rowsToStreamEvents(rows) + if err != nil { + return nil, err + } + if chronologicalOrder { + // The events need to be returned from oldest to latest, which isn't + // necessary the way the SQL query returns them, so a sort is necessary to + // ensure the events are in the right order in the slice. + sort.SliceStable(events, func(i int, j int) bool { + return events[i].StreamPosition < events[j].StreamPosition + }) + } + res := types.RecentEvents{} + // we queried for 1 more than the limit, so if we returned one more mark limited=true + if len(events) > eventFilter.Limit { + res.Limited = true + // re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last. + if chronologicalOrder { + events = events[1:] + } else { + events = events[:len(events)-1] + } + } + res.Events = events + result[roomID] = res } - return events, limited, nil + + return result, nil } func (s *outputRoomEventsStatements) SelectEarlyEvents( diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index e65367d8b..05d498bc2 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -156,12 +156,12 @@ func TestRecentEventsPDU(t *testing.T) { tc := testCases[i] t.Run(tc.Name, func(st *testing.T) { var filter gomatrixserverlib.RoomEventFilter - var gotEvents []types.StreamEvent + var gotEvents map[string]types.RecentEvents var limited bool filter.Limit = tc.Limit WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { var err error - gotEvents, limited, err = snapshot.RecentEvents(ctx, r.ID, types.Range{ + gotEvents, err = snapshot.RecentEvents(ctx, []string{r.ID}, types.Range{ From: tc.From, To: tc.To, }, &filter, !tc.ReverseOrder, true) @@ -169,15 +169,18 @@ func TestRecentEventsPDU(t *testing.T) { st.Fatalf("failed to do sync: %s", err) } }) + streamEvents := gotEvents[r.ID] + limited = streamEvents.Limited if limited != tc.WantLimited { st.Errorf("got limited=%v want %v", limited, tc.WantLimited) } - if len(gotEvents) != len(tc.WantEvents) { + if len(streamEvents.Events) != len(tc.WantEvents) { st.Errorf("got %d events, want %d", len(gotEvents), len(tc.WantEvents)) } - for j := range gotEvents { - if !reflect.DeepEqual(gotEvents[j].JSON(), tc.WantEvents[j].JSON()) { - st.Errorf("event %d got %s want %s", j, string(gotEvents[j].JSON()), string(tc.WantEvents[j].JSON())) + + for j := range streamEvents.Events { + if !reflect.DeepEqual(streamEvents.Events[j].JSON(), tc.WantEvents[j].JSON()) { + st.Errorf("event %d got %s want %s", j, string(streamEvents.Events[j].JSON()), string(tc.WantEvents[j].JSON())) } } }) @@ -923,3 +926,61 @@ func TestRoomSummary(t *testing.T) { } }) } + +func TestRecentEvents(t *testing.T) { + alice := test.NewUser(t) + room1 := test.NewRoom(t, alice) + room2 := test.NewRoom(t, alice) + roomIDs := []string{room1.ID, room2.ID} + rooms := map[string]*test.Room{ + room1.ID: room1, + room2.ID: room2, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + filter := gomatrixserverlib.DefaultRoomEventFilter() + db, close, closeBase := MustCreateDatabase(t, dbType) + t.Cleanup(func() { + close() + closeBase() + }) + + MustWriteEvents(t, db, room1.Events()) + MustWriteEvents(t, db, room2.Events()) + + transaction, err := db.NewDatabaseTransaction(ctx) + assert.NoError(t, err) + defer transaction.Rollback() + + // get all recent events from 0 to 100 (we only created 5 events, so we should get 5 back) + roomEvs, err := transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true) + assert.NoError(t, err) + assert.Equal(t, len(roomEvs), 2, "unexpected recent events response") + for _, recentEvents := range roomEvs { + assert.Equal(t, 5, len(recentEvents.Events), "unexpected recent events for room") + } + + // update the filter to only return one event + filter.Limit = 1 + roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, true, true) + assert.NoError(t, err) + assert.Equal(t, len(roomEvs), 2, "unexpected recent events response") + for roomID, recentEvents := range roomEvs { + origEvents := rooms[roomID].Events() + assert.Equal(t, true, recentEvents.Limited, "expected events to be limited") + assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room") + assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID()) + } + + // not chronologically ordered still returns the events in order (given ORDER BY id DESC) + roomEvs, err = transaction.RecentEvents(ctx, roomIDs, types.Range{From: 0, To: 100}, &filter, false, true) + assert.NoError(t, err) + assert.Equal(t, len(roomEvs), 2, "unexpected recent events response") + for roomID, recentEvents := range roomEvs { + origEvents := rooms[roomID].Events() + assert.Equal(t, true, recentEvents.Limited, "expected events to be limited") + assert.Equal(t, 1, len(recentEvents.Events), "unexpected recent events for room") + assert.Equal(t, origEvents[len(origEvents)-1].EventID(), recentEvents.Events[0].EventID()) + } + }) +} diff --git a/syncapi/storage/tables/current_room_state_test.go b/syncapi/storage/tables/current_room_state_test.go index 23287c500..c7af4f977 100644 --- a/syncapi/storage/tables/current_room_state_test.go +++ b/syncapi/storage/tables/current_room_state_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" ) func newCurrentRoomStateTable(t *testing.T, dbType test.DBType) (tables.CurrentRoomState, *sql.DB, func()) { @@ -79,6 +80,9 @@ func TestCurrentRoomStateTable(t *testing.T) { return fmt.Errorf("SelectEventsWithEventIDs\nexpected id %q not returned", id) } } + + testCurrentState(t, ctx, txn, tab, room) + return nil }) if err != nil { @@ -86,3 +90,39 @@ func TestCurrentRoomStateTable(t *testing.T) { } }) } + +func testCurrentState(t *testing.T, ctx context.Context, txn *sql.Tx, tab tables.CurrentRoomState, room *test.Room) { + t.Run("test currentState", func(t *testing.T) { + // returns the complete state of the room with a default filter + filter := gomatrixserverlib.DefaultStateFilter() + evs, err := tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) + if err != nil { + t.Fatal(err) + } + expectCount := 5 + if gotCount := len(evs); gotCount != expectCount { + t.Fatalf("expected %d state events, got %d", expectCount, gotCount) + } + // When lazy loading, we expect no membership event, so only 4 events + filter.LazyLoadMembers = true + expectCount = 4 + evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) + if err != nil { + t.Fatal(err) + } + if gotCount := len(evs); gotCount != expectCount { + t.Fatalf("expected %d state events, got %d", expectCount, gotCount) + } + // same as above, but with existing NotTypes defined + notTypes := []string{gomatrixserverlib.MRoomMember} + filter.NotTypes = ¬Types + evs, err = tab.SelectCurrentState(ctx, txn, room.ID, &filter, nil) + if err != nil { + t.Fatal(err) + } + if gotCount := len(evs); gotCount != expectCount { + t.Fatalf("expected %d state events, got %d", expectCount, gotCount) + } + }) + +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 145e197cc..727d6bf2c 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -66,7 +66,7 @@ type Events interface { // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`. - SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) + SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomIDs []string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) (map[string]types.RecentEvents, error) // SelectEarlyEvents returns the earliest events in the given room. SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string, filter *gomatrixserverlib.RoomEventFilter, preserveOrder bool) ([]types.StreamEvent, error) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 44013e37c..6af25c028 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -82,19 +82,24 @@ func (p *PDUStreamProvider) CompleteSync( req.Log.WithError(err).Error("unable to update event filter with ignored users") } - // Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars - // TODO: This might be inefficient, when joined to many and/or large rooms. + recentEvents, err := snapshot.RecentEvents(ctx, joinedRoomIDs, r, &eventFilter, true, true) + if err != nil { + return from + } + // Build up a /sync response. Add joined rooms. for _, roomID := range joinedRoomIDs { + events := recentEvents[roomID] + // Invalidate the lazyLoadCache, otherwise we end up with missing displaynames/avatars + // TODO: This might be inefficient, when joined to many and/or large rooms. joinedUsers := p.notifier.JoinedUsers(roomID) for _, sharedUser := range joinedUsers { p.lazyLoadCache.InvalidateLazyLoadedUser(req.Device, roomID, sharedUser) } - } - // Build up a /sync response. Add joined rooms. - for _, roomID := range joinedRoomIDs { + // get the join response for each room jr, jerr := p.getJoinResponseForCompleteSync( - ctx, snapshot, roomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, false, + ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, false, + events.Events, events.Limited, ) if jerr != nil { req.Log.WithError(jerr).Error("p.getJoinResponseForCompleteSync failed") @@ -113,11 +118,25 @@ func (p *PDUStreamProvider) CompleteSync( req.Log.WithError(err).Error("p.DB.PeeksInRange failed") return from } - for _, peek := range peeks { - if !peek.Deleted { + if len(peeks) > 0 { + peekRooms := make([]string, 0, len(peeks)) + for _, peek := range peeks { + if !peek.Deleted { + peekRooms = append(peekRooms, peek.RoomID) + } + } + + recentEvents, err = snapshot.RecentEvents(ctx, peekRooms, r, &eventFilter, true, true) + if err != nil { + return from + } + + for _, roomID := range peekRooms { var jr *types.JoinResponse + events := recentEvents[roomID] jr, err = p.getJoinResponseForCompleteSync( - ctx, snapshot, peek.RoomID, r, &stateFilter, &eventFilter, req.WantFullState, req.Device, true, + ctx, snapshot, roomID, &stateFilter, req.WantFullState, req.Device, true, + events.Events, events.Limited, ) if err != nil { req.Log.WithError(err).Error("p.getJoinResponseForCompleteSync failed") @@ -126,7 +145,7 @@ func (p *PDUStreamProvider) CompleteSync( } continue } - req.Response.Rooms.Peek[peek.RoomID] = jr + req.Response.Rooms.Peek[roomID] = jr } } @@ -227,7 +246,7 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( stateFilter *gomatrixserverlib.StateFilter, req *types.SyncRequest, ) (types.StreamPosition, error) { - + var err error originalLimit := eventFilter.Limit // If we're going backwards, grep at least X events, this is mostly to satisfy Sytest if r.Backwards && originalLimit < recentEventBackwardsLimit { @@ -238,8 +257,8 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } } - recentStreamEvents, limited, err := snapshot.RecentEvents( - ctx, delta.RoomID, r, + dbEvents, err := snapshot.RecentEvents( + ctx, []string{delta.RoomID}, r, eventFilter, true, true, ) if err != nil { @@ -248,6 +267,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( } return r.From, fmt.Errorf("p.DB.RecentEvents: %w", err) } + + recentStreamEvents := dbEvents[delta.RoomID].Events + limited := dbEvents[delta.RoomID].Limited + recentEvents := gomatrixserverlib.HeaderedReverseTopologicalOrdering( snapshot.StreamEventsToEvents(device, recentStreamEvents), gomatrixserverlib.TopologicalOrderByPrevEvents, @@ -420,7 +443,7 @@ func applyHistoryVisibilityFilter( "room_id": roomID, "before": len(recentEvents), "after": len(events), - }).Trace("Applied history visibility (sync)") + }).Debugf("Applied history visibility (sync)") return events, nil } @@ -428,25 +451,16 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, snapshot storage.DatabaseTransaction, roomID string, - r types.Range, stateFilter *gomatrixserverlib.StateFilter, - eventFilter *gomatrixserverlib.RoomEventFilter, wantFullState bool, device *userapi.Device, isPeek bool, + recentStreamEvents []types.StreamEvent, + limited bool, ) (jr *types.JoinResponse, err error) { jr = types.NewJoinResponse() // TODO: When filters are added, we may need to call this multiple times to get enough events. // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 - recentStreamEvents, limited, err := snapshot.RecentEvents( - ctx, roomID, r, eventFilter, true, true, - ) - if err != nil { - if err == sql.ErrNoRows { - return jr, nil - } - return - } // Work our way through the timeline events and pick out the event IDs // of any state events that appear in the timeline. We'll specifically diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 666a872f8..643c30265 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "github.com/matrix-org/dendrite/syncapi/routing" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/tidwall/gjson" @@ -448,6 +449,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/messages", room.ID), test.WithQueryParams(map[string]string{ "access_token": bobDev.AccessToken, "dir": "b", + "filter": `{"lazy_load_members":true}`, // check that lazy loading doesn't break history visibility }))) if w.Code != 200 { t.Logf("%s", w.Body.String()) @@ -905,6 +907,177 @@ func testSendToDevice(t *testing.T, dbType test.DBType) { } } +func TestContext(t *testing.T) { + test.WithAllDatabases(t, testContext) +} + +func testContext(t *testing.T, dbType test.DBType) { + + tests := []struct { + name string + roomID string + eventID string + params map[string]string + wantError bool + wantStateLength int + wantBeforeLength int + wantAfterLength int + }{ + { + name: "invalid filter", + params: map[string]string{ + "filter": "{", + }, + wantError: true, + }, + { + name: "invalid limit", + params: map[string]string{ + "limit": "abc", + }, + wantError: true, + }, + { + name: "high limit", + params: map[string]string{ + "limit": "100000", + }, + }, + { + name: "fine limit", + params: map[string]string{ + "limit": "10", + }, + }, + { + name: "last event without lazy loading", + wantStateLength: 5, + }, + { + name: "last event with lazy loading", + params: map[string]string{ + "filter": `{"lazy_load_members":true}`, + }, + wantStateLength: 1, + }, + { + name: "invalid room", + roomID: "!doesnotexist", + wantError: true, + }, + { + name: "invalid eventID", + eventID: "$doesnotexist", + wantError: true, + }, + { + name: "state is limited", + params: map[string]string{ + "limit": "1", + }, + wantStateLength: 1, + }, + { + name: "events are not limited", + wantBeforeLength: 7, + }, + { + name: "all events are limited", + params: map[string]string{ + "limit": "1", + }, + wantStateLength: 1, + wantBeforeLength: 1, + wantAfterLength: 1, + }, + } + + user := test.NewUser(t) + alice := userapi.Device{ + ID: "ALICEID", + UserID: user.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + // Use an actual roomserver for this + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{alice}}, rsAPI, &syncKeyAPI{}) + + room := test.NewRoom(t, user) + + room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 1!"}) + room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world 2!"}) + thirdMsg := room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world3!"}) + room.CreateAndInsert(t, user, "m.room.message", map[string]interface{}{"body": "hello world4!"}) + + if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + syncUntil(t, base, alice.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, thirdMsg.EventID()) + return gjson.Get(syncBody, path).Exists() + }) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + params := map[string]string{ + "access_token": alice.AccessToken, + } + w := httptest.NewRecorder() + // test overrides + roomID := room.ID + if tc.roomID != "" { + roomID = tc.roomID + } + eventID := thirdMsg.EventID() + if tc.eventID != "" { + eventID = tc.eventID + } + requestPath := fmt.Sprintf("/_matrix/client/v3/rooms/%s/context/%s", roomID, eventID) + if tc.params != nil { + for k, v := range tc.params { + params[k] = v + } + } + base.PublicClientAPIMux.ServeHTTP(w, test.NewRequest(t, "GET", requestPath, test.WithQueryParams(params))) + + if tc.wantError && w.Code == 200 { + t.Fatalf("Expected an error, but got none") + } + t.Log(w.Body.String()) + resp := routing.ContextRespsonse{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + if tc.wantStateLength > 0 && tc.wantStateLength != len(resp.State) { + t.Fatalf("expected %d state events, got %d", tc.wantStateLength, len(resp.State)) + } + if tc.wantBeforeLength > 0 && tc.wantBeforeLength != len(resp.EventsBefore) { + t.Fatalf("expected %d before events, got %d", tc.wantBeforeLength, len(resp.EventsBefore)) + } + if tc.wantAfterLength > 0 && tc.wantAfterLength != len(resp.EventsAfter) { + t.Fatalf("expected %d after events, got %d", tc.wantAfterLength, len(resp.EventsAfter)) + } + + if !tc.wantError && resp.Event.EventID != eventID { + t.Fatalf("unexpected eventID %s, expected %s", resp.Event.EventID, eventID) + } + }) + } +} + func syncUntil(t *testing.T, base *base.BaseDendrite, accessToken string, skip bool, diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 9fbadc06c..6495dd535 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -63,6 +63,11 @@ type StreamEvent struct { ExcludeFromSync bool } +type RecentEvents struct { + Limited bool + Events []StreamEvent +} + // Range represents a range between two stream positions. type Range struct { // From is the position the client has already received.