From b2712cd2b1ca8537d97d7cd6a416120826638e5c Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 4 Nov 2022 20:58:24 +0100 Subject: [PATCH 01/42] Fix GHA release script --- .github/workflows/docker.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index f0500ccc5..846844173 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -32,7 +32,7 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV + echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) [ ${BRANCH} == "main" ] && BRANCH="" echo "BRANCH=${BRANCH}" >> $GITHUB_ENV @@ -112,7 +112,7 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV + echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) [ ${BRANCH} == "main" ] && BRANCH="" echo "BRANCH=${BRANCH}" >> $GITHUB_ENV @@ -191,7 +191,7 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV + echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) [ ${BRANCH} == "main" ] && BRANCH="" echo "BRANCH=${BRANCH}" >> $GITHUB_ENV @@ -258,7 +258,7 @@ jobs: if: github.event_name == 'release' # Only for GitHub releases run: | echo "RELEASE_VERSION=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV - echo "BUILD=$(git rev-parse --short HEAD || "") >> $GITHUB_ENV + echo "BUILD=$(git rev-parse --short HEAD || \"\")" >> $GITHUB_ENV BRANCH=$(git symbolic-ref --short HEAD | tr -d \/) [ ${BRANCH} == "main" ] && BRANCH="" echo "BRANCH=${BRANCH}" >> $GITHUB_ENV From a7b74176e3dbcfd36525388665c64674c707e5c3 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 4 Nov 2022 21:49:18 +0000 Subject: [PATCH 02/42] Revert Docker user change --- Dockerfile | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Dockerfile b/Dockerfile index 499992343..a9bbce925 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,7 +24,7 @@ RUN --mount=target=. \ go build -v -ldflags="${FLAGS}" -trimpath -o /out/ ./cmd/... # -# The dendrite base image; mainly creates a user and switches to it +# The dendrite base image # FROM alpine:latest AS dendrite-base LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" @@ -32,8 +32,6 @@ LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" LABEL org.opencontainers.image.licenses="Apache-2.0" LABEL org.opencontainers.image.documentation="https://matrix-org.github.io/dendrite/" LABEL org.opencontainers.image.vendor="The Matrix.org Foundation C.I.C." -RUN addgroup dendrite && adduser dendrite -G dendrite -u 1337 -D -USER dendrite # # Builds the polylith image and only contains the polylith binary From c125203eb6e5dfeaa4290aeea8bbed9a73caf16c Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 7 Nov 2022 09:47:18 +0100 Subject: [PATCH 03/42] Handle `m.room.tombstone` events in the UserAPI (#2864) Fixes #2863 and makes ``` /upgrade preserves direct room state local user has tags copied to the new room remote user has tags copied to the new room ``` pass. --- sytest-whitelist | 5 +- userapi/consumers/roomserver.go | 117 +++++++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 2 deletions(-) diff --git a/sytest-whitelist b/sytest-whitelist index f4311d339..bb4f0a279 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -758,4 +758,7 @@ Can get rooms/{roomId}/members at a given point Can filter rooms/{roomId}/members Current state appears in timeline in private history with many messages after AS can publish rooms in their own list -AS and main public room lists are separate \ No newline at end of file +AS and main public room lists are separate +/upgrade preserves direct room state +local user has tags copied to the new room +remote user has tags copied to the new room \ No newline at end of file diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 97c17e188..b6b30a095 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -2,12 +2,16 @@ package consumers import ( "context" + "database/sql" "encoding/json" + "errors" "fmt" "strings" "sync" "time" + "github.com/tidwall/gjson" + "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" log "github.com/sirupsen/logrus" @@ -185,13 +189,115 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy } } +func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { + for _, membership := range localMembers { + // Copy any existing push rules from old -> new room + if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + return err + } + + // preserve m.direct room state + if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil { + return err + } + + // copy existing m.tag entries, if any + if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + return err + } + } + return nil +} + +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error { + pushRules, err := s.db.QueryPushRules(ctx, localpart) + if err != nil { + return fmt.Errorf("failed to query pushrules for user: %w", err) + } + if pushRules == nil { + return nil + } + + for _, roomRule := range pushRules.Global.Room { + if roomRule.RuleID != oldRoomID { + continue + } + cpRool := *roomRule + cpRool.RuleID = newRoomID + pushRules.Global.Room = append(pushRules.Global.Room, &cpRool) + rules, err := json.Marshal(pushRules) + if err != nil { + return err + } + if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil { + return fmt.Errorf("failed to update pushrules: %w", err) + } + } + return nil +} + +// updateMDirect copies the "is_direct" flag from oldRoomID to newROomID +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error { + // this is most likely not a DM, so skip updating m.direct state + if roomSize > 2 { + return nil + } + // Get direct message state + directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct") + if err != nil { + return fmt.Errorf("failed to get m.direct from database: %w", err) + } + directChats := gjson.ParseBytes(directChatsRaw) + newDirectChats := make(map[string][]string) + // iterate over all userID -> roomIDs + directChats.ForEach(func(userID, roomIDs gjson.Result) bool { + var found bool + for _, roomID := range roomIDs.Array() { + newDirectChats[userID.Str] = append(newDirectChats[userID.Str], roomID.Str) + // add the new roomID to m.direct + if roomID.Str == oldRoomID { + found = true + newDirectChats[userID.Str] = append(newDirectChats[userID.Str], newRoomID) + } + } + // Only hit the database if we found the old room as a DM for this user + if found { + var data []byte + data, err = json.Marshal(newDirectChats) + if err != nil { + return true + } + if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil { + return true + } + } + return true + }) + if err != nil { + return fmt.Errorf("failed to update m.direct state") + } + return nil +} + +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error { + tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag") + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + if tag == nil { + return nil + } + return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag) +} + func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { members, roomSize, err := s.localRoomMembers(ctx, event.RoomID()) if err != nil { return fmt.Errorf("s.localRoomMembers: %w", err) } - if event.Type() == gomatrixserverlib.MRoomMember { + switch { + case event.Type() == gomatrixserverlib.MRoomMember: cevent := gomatrixserverlib.HeaderedToClientEvent(event, gomatrixserverlib.FormatAll) var member *localMembership member, err = newLocalMembership(&cevent) @@ -203,6 +309,15 @@ func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gom // should also be pushed to the target user. members = append(members, member) } + case event.Type() == "m.room.tombstone" && event.StateKeyEquals(""): + // Handle room upgrades + oldRoomID := event.RoomID() + newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str + if err = s.handleRoomUpgrade(ctx, oldRoomID, newRoomID, members, roomSize); err != nil { + // while inconvenient, this shouldn't stop us from sending push notifications + log.WithError(err).Errorf("UserAPI: failed to handle room upgrade for users") + } + } // TODO: run in parallel with localRoomMembers. From 205a15621a67ca7b851dc8ece8cfbcce94fb8b4d Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 7 Nov 2022 15:07:47 +0100 Subject: [PATCH 04/42] Add custom build flag to satisfy Sytest --- keyserver/internal/device_list_update.go | 5 ++-- .../internal/device_list_update_default.go | 22 ++++++++++++++++ .../internal/device_list_update_sytest.go | 25 +++++++++++++++++++ .../storage/postgres/stale_device_lists.go | 6 ++--- .../storage/sqlite3/stale_device_lists.go | 6 ++--- 5 files changed, 55 insertions(+), 9 deletions(-) create mode 100644 keyserver/internal/device_list_update_default.go create mode 100644 keyserver/internal/device_list_update_sytest.go diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 8b02f3d6c..3f7c0d8b4 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -47,7 +47,6 @@ var ( ) ) -const defaultWaitTime = time.Minute const requestTimeout = time.Second * 30 func init() { @@ -454,7 +453,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go } else if e.Code >= 300 { // We didn't get a real FederationClientError (e.g. in polylith mode, where gomatrix.HTTPError // are "converted" to FederationClientError), but we probably shouldn't hit them every $waitTime seconds. - return time.Hour, err + return hourWaitTime, err } case net.Error: // Use the default waitTime, if it's a timeout. @@ -468,7 +467,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go // This is to avoid spamming remote servers, which may not be Matrix servers anymore. if e.Code >= 300 { logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError") - return time.Hour, err + return hourWaitTime, err } default: // Something else failed diff --git a/keyserver/internal/device_list_update_default.go b/keyserver/internal/device_list_update_default.go new file mode 100644 index 000000000..7d357c951 --- /dev/null +++ b/keyserver/internal/device_list_update_default.go @@ -0,0 +1,22 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !vw + +package internal + +import "time" + +const defaultWaitTime = time.Minute +const hourWaitTime = time.Hour diff --git a/keyserver/internal/device_list_update_sytest.go b/keyserver/internal/device_list_update_sytest.go new file mode 100644 index 000000000..1c60d2eb9 --- /dev/null +++ b/keyserver/internal/device_list_update_sytest.go @@ -0,0 +1,25 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build vw + +package internal + +import "time" + +// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite +// results in a one-hour wait time from a previous device so the test times out. This is fine for +// production, but makes an otherwise passing test fail. +const defaultWaitTime = time.Second +const hourWaitTime = time.Second diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go index 63281adfb..d0fe50d00 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" type staleDeviceListsStatements struct { upsertStaleDeviceListStmt *sql.Stmt @@ -77,7 +77,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) return err } diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index fc2cc37c4..1e08b266c 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -43,10 +43,10 @@ const upsertStaleDeviceListSQL = "" + " DO UPDATE SET is_stale = $3, ts_added_secs = $4" const selectStaleDeviceListsWithDomainsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC" const selectStaleDeviceListsSQL = "" + - "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" type staleDeviceListsStatements struct { db *sql.DB @@ -80,7 +80,7 @@ func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, if err != nil { return err } - _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, time.Now().Unix()) + _, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now())) return err } From a5cabdbac5678171167062fca06e60a93a21ddbf Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 9 Nov 2022 09:24:29 +0000 Subject: [PATCH 05/42] Remove unspecced fields from `Transaction` (update to matrix-org/gomatrixserverlib@715dc88) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 7dd9e0b2c..af315b47f 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e + github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 5d172e86b..ed2d97eed 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e h1:6I34fdyiHMRCxL6GOb/G8ZyI1WWlb6ZxCF2hIGSMSCc= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221101165746-0e4a8bb6db7e/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 h1:Bet5n+//Yh+A2SuPHD67N8jrOhC/EIKvEgisfsKhTss= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From bdaae060ccead4594cfa88f45e9656c3a094535c Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 9 Nov 2022 14:07:29 +0000 Subject: [PATCH 06/42] Update Ristretto --- go.mod | 2 +- go.sum | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index af315b47f..47c43fbc7 100644 --- a/go.mod +++ b/go.mod @@ -8,7 +8,7 @@ require ( github.com/Masterminds/semver/v3 v3.1.1 github.com/blevesearch/bleve/v2 v2.3.4 github.com/codeclysm/extract v2.2.0+incompatible - github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d + github.com/dgraph-io/ristretto v0.1.1 github.com/docker/docker v20.10.19+incompatible github.com/docker/go-connections v0.4.0 github.com/getsentry/sentry-go v0.14.0 diff --git a/go.sum b/go.sum index ed2d97eed..a805a5bf9 100644 --- a/go.sum +++ b/go.sum @@ -141,8 +141,8 @@ github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d h1:Wrc3UKTS+cffkOx0xRGFC+ZesNuTfn0ThvEC72N0krk= -github.com/dgraph-io/ristretto v0.1.1-0.20220403145359-8e850b710d6d/go.mod h1:RAy2GVV4sTWVlNMavv3xhLsk18rxhfhDnombTe6EF5c= +github.com/dgraph-io/ristretto v0.1.1 h1:6CWw5tJNgpegArSHpNHJKldNeq03FQCwYvfMVWajOK8= +github.com/dgraph-io/ristretto v0.1.1/go.mod h1:S1GPSBCYCIhmVNfcth17y2zZtQT6wzkzgwUve0VDWWA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2 h1:tdlZCpZ/P9DhczCTSixgIKmwPv6+wP5DGjqLYw5SUiA= github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/docker/distribution v2.8.1+incompatible h1:Q50tZOPR6T/hjNsyc9g8/syEs6bk8XXApsHjKukMl68= @@ -526,7 +526,6 @@ go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= -go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= golang.org/x/crypto v0.0.0-20180723164146-c126467f60eb/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4= @@ -697,6 +696,7 @@ golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= From 503d9c7586d7c5ca91cdb78631553256628861cd Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 10 Nov 2022 10:07:19 +0000 Subject: [PATCH 07/42] Improve logging in upgrade tests --- cmd/dendrite-upgrade-tests/main.go | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index dce22472d..6630a0620 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -381,18 +381,22 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string) lastErr = nil break } - if lastErr != nil { - logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{ - ShowStdout: true, - ShowStderr: true, - }) - // ignore errors when cannot get logs, it's just for debugging anyways - if err == nil { - logbody, err := io.ReadAll(logs) - if err == nil { - log.Printf("Container logs:\n\n%s\n\n", string(logbody)) + logs, err := dockerClient.ContainerLogs(context.Background(), containerID, types.ContainerLogsOptions{ + ShowStdout: true, + ShowStderr: true, + Follow: true, + }) + // ignore errors when cannot get logs, it's just for debugging anyways + if err == nil { + go func() { + for { + if body, err := io.ReadAll(logs); err == nil && len(body) > 0 { + log.Printf("%s: %s", version, string(body)) + } else { + return + } } - } + }() } return baseURL, containerID, lastErr } From efa50253f68f8ec1c6a9f4a8d82b329512027b00 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 10 Nov 2022 10:16:56 +0000 Subject: [PATCH 08/42] Fix lint error --- cmd/dendrite-upgrade-tests/main.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 6630a0620..131ce4af9 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -367,7 +367,8 @@ func runImage(dockerClient *client.Client, volumeName, version, imageID string) // hit /versions to check it is up var lastErr error for i := 0; i < 500; i++ { - res, err := http.Get(versionsURL) + var res *http.Response + res, err = http.Get(versionsURL) if err != nil { lastErr = fmt.Errorf("GET %s => error: %s", versionsURL, err) time.Sleep(50 * time.Millisecond) From 0193549201299f5dcce919b2aeb3b1c40bdfcefa Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 11 Nov 2022 10:35:17 +0100 Subject: [PATCH 09/42] Send presence to newly added servers (#2869) This should make `New federated private chats get full presence information (SYN-115)` happy. --- federationapi/consumers/roomserver.go | 120 ++++++++++++++++++++++---- federationapi/federationapi.go | 4 +- roomserver/api/api.go | 1 + 3 files changed, 107 insertions(+), 18 deletions(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index a42733628..d16af6626 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -18,6 +18,10 @@ import ( "context" "encoding/json" "fmt" + "strconv" + "time" + + syncAPITypes "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" @@ -34,14 +38,16 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { - ctx context.Context - cfg *config.FederationAPI - rsAPI api.FederationRoomserverAPI - jetstream nats.JetStreamContext - durable string - db storage.Database - queues *queue.OutgoingQueues - topic string + ctx context.Context + cfg *config.FederationAPI + rsAPI api.FederationRoomserverAPI + jetstream nats.JetStreamContext + natsClient *nats.Conn + durable string + db storage.Database + queues *queue.OutgoingQueues + topic string + topicPresence string } // NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers. @@ -49,19 +55,22 @@ func NewOutputRoomEventConsumer( process *process.ProcessContext, cfg *config.FederationAPI, js nats.JetStreamContext, + natsClient *nats.Conn, queues *queue.OutgoingQueues, store storage.Database, rsAPI api.FederationRoomserverAPI, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ - ctx: process.Context(), - cfg: cfg, - jetstream: js, - db: store, - queues: queues, - rsAPI: rsAPI, - durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), + ctx: process.Context(), + cfg: cfg, + jetstream: js, + natsClient: natsClient, + db: store, + queues: queues, + rsAPI: rsAPI, + durable: cfg.Matrix.JetStream.Durable("FederationAPIRoomServerConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputRoomEvent), + topicPresence: cfg.Matrix.JetStream.Prefixed(jetstream.RequestPresence), } } @@ -146,6 +155,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee // processMessage updates the list of currently joined hosts in the room // and then sends the event to the hosts that were joined before the event. func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rewritesState bool) error { + addsStateEvents, missingEventIDs := ore.NeededStateEventIDs() // Ask the roomserver and add in the rest of the results into the set. @@ -184,6 +194,14 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew return err } + // If we added new hosts, inform them about our known presence events for this room + if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + membership, _ := ore.Event.Membership() + if membership == gomatrixserverlib.Join { + s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) + } + } + if oldJoinedHosts == nil { // This means that there is nothing to update as this is a duplicate // message. @@ -213,6 +231,76 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew ) } +func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) { + joined := make([]gomatrixserverlib.ServerName, len(addedJoined)) + for _, added := range addedJoined { + joined = append(joined, added.ServerName) + } + + // get our locally joined users + var queryRes api.QueryMembershipsForRoomResponse + err := s.rsAPI.QueryMembershipsForRoom(s.ctx, &api.QueryMembershipsForRoomRequest{ + JoinedOnly: true, + LocalOnly: true, + RoomID: roomID, + }, &queryRes) + if err != nil { + log.WithError(err).Error("failed to calculate joined rooms for user") + return + } + + // send every presence we know about to the remote server + content := types.Presence{} + for _, ev := range queryRes.JoinEvents { + msg := nats.NewMsg(s.topicPresence) + msg.Header.Set(jetstream.UserID, ev.Sender) + + var presence *nats.Msg + presence, err = s.natsClient.RequestMsg(msg, time.Second*10) + if err != nil { + log.WithError(err).Errorf("unable to get presence") + continue + } + + statusMsg := presence.Header.Get("status_msg") + e := presence.Header.Get("error") + if e != "" { + continue + } + var lastActive int + lastActive, err = strconv.Atoi(presence.Header.Get("last_active_ts")) + if err != nil { + continue + } + + p := syncAPITypes.PresenceInternal{LastActiveTS: gomatrixserverlib.Timestamp(lastActive)} + + content.Push = append(content.Push, types.PresenceContent{ + CurrentlyActive: p.CurrentlyActive(), + LastActiveAgo: p.LastActiveAgo(), + Presence: presence.Header.Get("presence"), + StatusMsg: &statusMsg, + UserID: ev.Sender, + }) + } + + if len(content.Push) == 0 { + return + } + + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MPresence, + Origin: string(s.cfg.Matrix.ServerName), + } + if edu.Content, err = json.Marshal(content); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return + } + if err := s.queues.SendEDU(edu, s.cfg.Matrix.ServerName, joined); err != nil { + log.WithError(err).Error("failed to send EDU") + } +} + // joinedHostsAtEvent works out a list of matrix servers that were joined to // the room at the event (including peeking ones) // It is important to use the state at the event for sending messages because: diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 202da6c51..e35b9c7f0 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -118,7 +118,7 @@ func NewInternalAPI( stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) - js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, @@ -132,7 +132,7 @@ func NewInternalAPI( ) rsConsumer := consumers.NewOutputRoomEventConsumer( - base.ProcessContext, cfg, js, queues, + base.ProcessContext, cfg, js, nats, queues, federationDB, rsAPI, ) if err = rsConsumer.Start(); err != nil { diff --git a/roomserver/api/api.go b/roomserver/api/api.go index a1373a62b..01e87ec8a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -177,6 +177,7 @@ type FederationRoomserverAPI interface { QueryBulkStateContentAPI // QueryServerBannedFromRoom returns whether a server is banned from a room by server ACLs. QueryServerBannedFromRoom(ctx context.Context, req *QueryServerBannedFromRoomRequest, res *QueryServerBannedFromRoomResponse) error + QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error QueryRoomVersionForRoom(ctx context.Context, req *QueryRoomVersionForRoomRequest, res *QueryRoomVersionForRoomResponse) error GetRoomIDForAlias(ctx context.Context, req *GetRoomIDForAliasRequest, res *GetRoomIDForAliasResponse) error QueryEventsByID(ctx context.Context, req *QueryEventsByIDRequest, res *QueryEventsByIDResponse) error From d35a5642e89a2a1b64f1c2ed1cb13e6080987b1c Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 11 Nov 2022 10:52:08 +0100 Subject: [PATCH 10/42] Deny guest access on several endpoints (#2873) Second part for guest access, this adds a `WithAllowGuests()` option to `MakeAuthAPI`, allowing guests to access the specified endpoints. Endpoints taken from the [spec](https://spec.matrix.org/v1.4/client-server-api/#client-behaviour-14) and by checking Synapse endpoints for `allow_guest=true`. --- clientapi/routing/routing.go | 68 +++++++++++++++++------------------ internal/httputil/httpapi.go | 29 +++++++++++++++ setup/mscs/msc2946/msc2946.go | 2 +- syncapi/routing/routing.go | 18 +++++----- 4 files changed, 73 insertions(+), 44 deletions(-) diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index f35aa7e12..1b3ef120a 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -252,7 +252,7 @@ func Setup( return JoinRoomByIDOrAlias( req, device, rsAPI, userAPI, vars["roomIDOrAlias"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) if mscCfg.Enabled("msc2753") { @@ -274,7 +274,7 @@ func Setup( v3mux.Handle("/joined_rooms", httputil.MakeAuthAPI("joined_rooms", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetJoinedRooms(req, device, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/join", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -288,7 +288,7 @@ func Setup( return JoinRoomByIDOrAlias( req, device, rsAPI, userAPI, vars["roomID"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/leave", httputil.MakeAuthAPI("membership", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -302,7 +302,7 @@ func Setup( return LeaveRoomByID( req, device, rsAPI, vars["roomID"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/unpeek", httputil.MakeAuthAPI("unpeek", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -361,7 +361,7 @@ func Setup( return util.ErrorResponse(err) } return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, rsAPI, nil) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -372,7 +372,7 @@ func Setup( txnID := vars["txnID"] return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID, nil, cfg, rsAPI, transactionsCache) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -381,7 +381,7 @@ func Setup( return util.ErrorResponse(err) } return OnIncomingStateRequest(req.Context(), device, rsAPI, vars["roomID"]) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/aliases", httputil.MakeAuthAPI("aliases", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -400,7 +400,7 @@ func Setup( eventType := strings.TrimSuffix(vars["type"], "/") eventFormat := req.URL.Query().Get("format") == "event" return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], eventType, "", eventFormat) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", httputil.MakeAuthAPI("room_state", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -409,7 +409,7 @@ func Setup( } eventFormat := req.URL.Query().Get("format") == "event" return OnIncomingStateTypeRequest(req.Context(), device, rsAPI, vars["roomID"], vars["type"], vars["stateKey"], eventFormat) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}", httputil.MakeAuthAPI("send_message", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -420,7 +420,7 @@ func Setup( emptyString := "" eventType := strings.TrimSuffix(vars["eventType"], "/") return SendEvent(req, device, vars["roomID"], eventType, nil, &emptyString, cfg, rsAPI, nil) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}", @@ -431,7 +431,7 @@ func Setup( } stateKey := vars["stateKey"] return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, rsAPI, nil) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/register", httputil.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse { @@ -575,7 +575,7 @@ func Setup( } txnID := vars["txnID"] return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) // This is only here because sytest refers to /unstable for this endpoint @@ -589,7 +589,7 @@ func Setup( } txnID := vars["txnID"] return SendToDevice(req, device, syncProducer, transactionsCache, vars["eventType"], &txnID) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/account/whoami", @@ -598,7 +598,7 @@ func Setup( return *r } return Whoami(req, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/account/password", @@ -830,7 +830,7 @@ func Setup( return util.ErrorResponse(err) } return SetDisplayName(req, userAPI, device, vars["userID"], cfg, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) // Browsers use the OPTIONS HTTP method to check if the CORS policy allows // PUT requests, so we need to allow this method @@ -871,7 +871,7 @@ func Setup( v3mux.Handle("/thirdparty/protocols", httputil.MakeAuthAPI("thirdparty_protocols", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Protocols(req, asAPI, device, "") - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/protocol/{protocolID}", @@ -881,7 +881,7 @@ func Setup( return util.ErrorResponse(err) } return Protocols(req, asAPI, device, vars["protocolID"]) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user/{protocolID}", @@ -891,13 +891,13 @@ func Setup( return util.ErrorResponse(err) } return User(req, asAPI, device, vars["protocolID"], req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/user", httputil.MakeAuthAPI("thirdparty_user", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return User(req, asAPI, device, "", req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location/{protocolID}", @@ -907,13 +907,13 @@ func Setup( return util.ErrorResponse(err) } return Location(req, asAPI, device, vars["protocolID"], req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/thirdparty/location", httputil.MakeAuthAPI("thirdparty_location", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return Location(req, asAPI, device, "", req.URL.Query()) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/initialSync", @@ -1054,7 +1054,7 @@ func Setup( v3mux.Handle("/devices", httputil.MakeAuthAPI("get_devices", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return GetDevicesByLocalpart(req, userAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", @@ -1064,7 +1064,7 @@ func Setup( return util.ErrorResponse(err) } return GetDeviceByID(req, userAPI, device, vars["deviceID"]) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", @@ -1074,7 +1074,7 @@ func Setup( return util.ErrorResponse(err) } return UpdateDeviceByID(req, userAPI, device, vars["deviceID"]) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPut, http.MethodOptions) v3mux.Handle("/devices/{deviceID}", @@ -1116,21 +1116,21 @@ func Setup( // Stub implementations for sytest v3mux.Handle("/events", - httputil.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse { + httputil.MakeAuthAPI("events", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "chunk": []interface{}{}, "start": "", "end": "", }} - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/initialSync", - httputil.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse { + httputil.MakeAuthAPI("initial_sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{ "end": "", }} - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/rooms/{roomId}/tags", @@ -1169,7 +1169,7 @@ func Setup( return *r } return GetCapabilities(req, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) // Key Backup Versions (Metadata) @@ -1350,7 +1350,7 @@ func Setup( postDeviceSigningSignatures := httputil.MakeAuthAPI("post_device_signing_signatures", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadCrossSigningDeviceSignatures(req, keyAPI, device) - }) + }, httputil.WithAllowGuests()) v3mux.Handle("/keys/device_signing/upload", postDeviceSigningKeys).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/signatures/upload", postDeviceSigningSignatures).Methods(http.MethodPost, http.MethodOptions) @@ -1362,22 +1362,22 @@ func Setup( v3mux.Handle("/keys/upload/{deviceID}", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/upload", httputil.MakeAuthAPI("keys_upload", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return UploadKeys(req, keyAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/query", httputil.MakeAuthAPI("keys_query", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return QueryKeys(req, keyAPI, device) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/keys/claim", httputil.MakeAuthAPI("keys_claim", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return ClaimKeys(req, keyAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodPost, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/receipt/{receiptType}/{eventId}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 36dcaf453..4f33a3f79 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -42,10 +42,26 @@ type BasicAuth struct { Password string `yaml:"password"` } +type AuthAPIOpts struct { + GuestAccessAllowed bool +} + +// AuthAPIOption is an option to MakeAuthAPI to add additional checks (e.g. guest access) to verify +// the user is allowed to do specific things. +type AuthAPIOption func(opts *AuthAPIOpts) + +// WithAllowGuests checks that guest users have access to this endpoint +func WithAllowGuests() AuthAPIOption { + return func(opts *AuthAPIOpts) { + opts.GuestAccessAllowed = true + } +} + // MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request. func MakeAuthAPI( metricsName string, userAPI userapi.QueryAcccessTokenAPI, f func(*http.Request, *userapi.Device) util.JSONResponse, + checks ...AuthAPIOption, ) http.Handler { h := func(req *http.Request) util.JSONResponse { logger := util.GetLogger(req.Context()) @@ -76,6 +92,19 @@ func MakeAuthAPI( } }() + // apply additional checks, if any + opts := AuthAPIOpts{} + for _, opt := range checks { + opt(&opts) + } + + if !opts.GuestAccessAllowed && device.AccountType == userapi.AccountTypeGuest { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.GuestAccessForbidden("Guest access not allowed"), + } + } + jsonRes := f(req, device) // do not log 4xx as errors as they are client fails, not server fails if hub != nil && jsonRes.Code >= 500 { diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index bc9df0f96..d7c022544 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -57,7 +57,7 @@ func Enable( base *base.BaseDendrite, rsAPI roomserver.RoomserverInternalAPI, userAPI userapi.UserInternalAPI, fsAPI fs.FederationInternalAPI, keyRing gomatrixserverlib.JSONVerifier, cache caching.SpaceSummaryRoomsCache, ) error { - clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName)) + clientAPI := httputil.MakeAuthAPI("spaces", userAPI, spacesHandler(rsAPI, fsAPI, cache, base.Cfg.Global.ServerName), httputil.WithAllowGuests()) base.PublicClientAPIMux.Handle("/v1/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) base.PublicClientAPIMux.Handle("/unstable/org.matrix.msc2946/rooms/{roomID}/hierarchy", clientAPI).Methods(http.MethodGet, http.MethodOptions) diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index bc3ad2384..4cc1a6a85 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -51,7 +51,7 @@ func Setup( // TODO: Add AS support for all handlers below. v3mux.Handle("/sync", httputil.MakeAuthAPI("sync", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingSyncRequest(req, device) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) @@ -59,7 +59,7 @@ func Setup( return util.ErrorResponse(err) } return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], device, rsAPI, cfg, srp, lazyLoadCache) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/event/{eventID}", httputil.MakeAuthAPI("rooms_get_event", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -68,7 +68,7 @@ func Setup( return util.ErrorResponse(err) } return GetEvent(req, device, vars["roomID"], vars["eventID"], cfg, syncDB, rsAPI) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/user/{userId}/filter", @@ -93,7 +93,7 @@ func Setup( v3mux.Handle("/keys/changes", httputil.MakeAuthAPI("keys_changes", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return srp.OnIncomingKeyChangeRequest(req, device) - })).Methods(http.MethodGet, http.MethodOptions) + }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomId}/context/{eventId}", httputil.MakeAuthAPI(gomatrixserverlib.Join, userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { @@ -108,7 +108,7 @@ func Setup( vars["roomId"], vars["eventId"], lazyLoadCache, ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}", @@ -122,7 +122,7 @@ func Setup( req, device, syncDB, rsAPI, vars["roomId"], vars["eventId"], "", "", ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}", @@ -136,7 +136,7 @@ func Setup( req, device, syncDB, rsAPI, vars["roomId"], vars["eventId"], vars["relType"], "", ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v1unstablemux.Handle("/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}", @@ -150,7 +150,7 @@ func Setup( req, device, syncDB, rsAPI, vars["roomId"], vars["eventId"], vars["relType"], vars["eventType"], ) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/search", @@ -191,7 +191,7 @@ func Setup( at := req.URL.Query().Get("at") return GetMemberships(req, device, vars["roomID"], syncDB, rsAPI, false, membership, notMembership, at) - }), + }, httputil.WithAllowGuests()), ).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/joined_members", From c648c671a326f2d626cf34db52cbcc9999b95bba Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 11 Nov 2022 10:52:43 +0100 Subject: [PATCH 11/42] Fix issue with missing user NIDs (#2874) This should fix #2696 and possibly other related issues regarding missing user NIDs. (https://github.com/matrix-org/dendrite/issues/2094?) --- roomserver/storage/shared/storage.go | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 4455ec3bf..958787070 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -110,6 +110,19 @@ func (d *Database) eventStateKeyNIDs( for eventStateKey, nid := range nids { result[eventStateKey] = nid } + // We received some nids, but are still missing some, work out which and create them + if len(eventStateKeys) < len(result) { + for _, eventStateKey := range eventStateKeys { + if _, ok := result[eventStateKey]; ok { + continue + } + nid, err := d.assignStateKeyNID(ctx, txn, eventStateKey) + if err != nil { + return result, err + } + result[eventStateKey] = nid + } + } return result, nil } @@ -1243,7 +1256,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu } - eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys) + eventStateKeyNIDMap, err := d.eventStateKeyNIDs(ctx, nil, eventStateKeys) if err != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err) } @@ -1309,7 +1322,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ if err != nil { return nil, err } - userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs) + userNIDsMap, err := d.eventStateKeyNIDs(ctx, nil, userIDs) if err != nil { return nil, err } From 72ce6acf7176ac9e597e4fd6587605199e9e0d7f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 11 Nov 2022 11:21:16 +0000 Subject: [PATCH 12/42] Run upgrade tests for SQLite too (#2875) This should hopefully catch problems with database migrations in SQLite as well as PostgreSQL. --- .github/workflows/dendrite.yml | 10 ++++--- cmd/dendrite-upgrade-tests/main.go | 43 +++++++++++++++++++++++++++--- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index a37e45e46..f96cbadd4 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -231,8 +231,10 @@ jobs: ${{ runner.os }}-go-upgrade - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests - - name: Test upgrade + - name: Test upgrade (PostgreSQL) run: ./dendrite-upgrade-tests --head . + - name: Test upgrade (SQLite) + run: ./dendrite-upgrade-tests --sqlite --head . # run database upgrade tests, skipping over one version upgrade_test_direct: @@ -256,7 +258,9 @@ jobs: ${{ runner.os }}-go-upgrade - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests - - name: Test upgrade + - name: Test upgrade (PostgreSQL) + run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head . + - name: Test upgrade (SQLite) run: ./dendrite-upgrade-tests -direct -from HEAD-2 --head . # run Sytest in different variations @@ -434,7 +438,7 @@ jobs: permissions: packages: write contents: read - security-events: write # To upload Trivy sarif files + security-events: write # To upload Trivy sarif files if: github.repository == 'matrix-org/dendrite' && github.ref_name == 'main' needs: [integration-tests-done] uses: matrix-org/dendrite/.github/workflows/docker.yml@main diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 131ce4af9..75446d18c 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -38,6 +38,7 @@ var ( flagHead = flag.String("head", "", "Location to a dendrite repository to treat as HEAD instead of Github") flagDockerHost = flag.String("docker-host", "localhost", "The hostname of the docker client. 'localhost' if running locally, 'host.docker.internal' if running in Docker.") flagDirect = flag.Bool("direct", false, "If a direct upgrade from the defined FROM version to TO should be done") + flagSqlite = flag.Bool("sqlite", false, "Test SQLite instead of PostgreSQL") alphaNumerics = regexp.MustCompile("[^a-zA-Z0-9]+") ) @@ -49,7 +50,7 @@ const HEAD = "HEAD" // due to the error: // When using COPY with more than one source file, the destination must be a directory and end with a / // We need to run a postgres anyway, so use the dockerfile associated with Complement instead. -const Dockerfile = `FROM golang:1.18-stretch as build +const DockerfilePostgreSQL = `FROM golang:1.18-stretch as build RUN apt-get update && apt-get install -y postgresql WORKDIR /build @@ -92,6 +93,42 @@ ENV SERVER_NAME=localhost EXPOSE 8008 8448 CMD /build/run_dendrite.sh ` +const DockerfileSQLite = `FROM golang:1.18-stretch as build +RUN apt-get update && apt-get install -y postgresql +WORKDIR /build + +# Copy the build context to the repo as this is the right dendrite code. This is different to the +# Complement Dockerfile which wgets a branch. +COPY . . + +RUN go build ./cmd/dendrite-monolith-server +RUN go build ./cmd/generate-keys +RUN go build ./cmd/generate-config +RUN ./generate-config --ci > dendrite.yaml +RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key + +# Make sure the SQLite databases are in a persistent location, we're already mapping +# the postgresql folder so let's just use that for simplicity +RUN sed -i "s%connection_string:.file:%connection_string: file:\/var\/lib\/postgresql\/9.6\/main\/%g" dendrite.yaml + +# This entry script starts postgres, waits for it to be up then starts dendrite +RUN echo '\ +sed -i "s/server_name: localhost/server_name: ${SERVER_NAME}/g" dendrite.yaml \n\ +PARAMS="--tls-cert server.crt --tls-key server.key --config dendrite.yaml" \n\ +./dendrite-monolith-server --really-enable-open-registration ${PARAMS} || ./dendrite-monolith-server ${PARAMS} \n\ +' > run_dendrite.sh && chmod +x run_dendrite.sh + +ENV SERVER_NAME=localhost +EXPOSE 8008 8448 +CMD /build/run_dendrite.sh ` + +func dockerfile() []byte { + if *flagSqlite { + return []byte(DockerfileSQLite) + } + return []byte(DockerfilePostgreSQL) +} + const dendriteUpgradeTestLabel = "dendrite_upgrade_test" // downloadArchive downloads an arbitrary github archive of the form: @@ -150,7 +187,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, if branchOrTagName == HEAD && *flagHead != "" { log.Printf("%s: Using %s as HEAD", branchOrTagName, *flagHead) // add top level Dockerfile - err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), []byte(Dockerfile), os.ModePerm) + err = os.WriteFile(path.Join(*flagHead, "Dockerfile"), dockerfile(), os.ModePerm) if err != nil { return "", fmt.Errorf("custom HEAD: failed to inject /Dockerfile: %w", err) } @@ -166,7 +203,7 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, // pull an archive, this contains a top-level directory which screws with the build context // which we need to fix up post download u := fmt.Sprintf("https://github.com/matrix-org/dendrite/archive/%s.tar.gz", branchOrTagName) - tarball, err = downloadArchive(httpClient, tmpDir, u, []byte(Dockerfile)) + tarball, err = downloadArchive(httpClient, tmpDir, u, dockerfile()) if err != nil { return "", fmt.Errorf("failed to download archive %s: %w", u, err) } From e177e0ae73d7cc34ffb9869681a6bf177f805205 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 11 Nov 2022 16:44:59 +0100 Subject: [PATCH 13/42] Fix oops, add simple UT --- roomserver/internal/helpers/helpers_test.go | 56 +++++++++++++++++++++ roomserver/storage/shared/storage.go | 2 +- 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 roomserver/internal/helpers/helpers_test.go diff --git a/roomserver/internal/helpers/helpers_test.go b/roomserver/internal/helpers/helpers_test.go new file mode 100644 index 000000000..aa5c30e44 --- /dev/null +++ b/roomserver/internal/helpers/helpers_test.go @@ -0,0 +1,56 @@ +package helpers + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/roomserver/types" + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/setup/base" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + + "github.com/matrix-org/dendrite/roomserver/storage" +) + +func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { + base, close := testrig.CreateBaseDendrite(t, dbType) + db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches) + if err != nil { + t.Fatalf("failed to create Database: %v", err) + } + return base, db, close +} + +func TestIsInvitePendingWithoutNID(t *testing.T) { + + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat)) + _ = bob + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + _, db, close := mustCreateDatabase(t, dbType) + defer close() + + // store all events + var authNIDs []types.EventNID + for _, x := range room.Events() { + + evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false) + assert.NoError(t, err) + authNIDs = append(authNIDs, evNID) + } + + // Alice should have no pending invites and should have a NID + pendingInvite, _, _, _, err := IsInvitePending(context.Background(), db, room.ID, alice.ID) + assert.NoError(t, err, "failed to get pending invites") + assert.False(t, pendingInvite, "unexpected pending invite") + + // Bob should have no pending invites and receive a new NID + pendingInvite, _, _, _, err = IsInvitePending(context.Background(), db, room.ID, bob.ID) + assert.NoError(t, err, "failed to get pending invites") + assert.False(t, pendingInvite, "unexpected pending invite") + }) +} diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 958787070..08e912a00 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -111,7 +111,7 @@ func (d *Database) eventStateKeyNIDs( result[eventStateKey] = nid } // We received some nids, but are still missing some, work out which and create them - if len(eventStateKeys) < len(result) { + if len(eventStateKeys) > len(result) { for _, eventStateKey := range eventStateKeys { if _, ok := result[eventStateKey]; ok { continue From 529df30b5649e67a2f98114e6640d259cba53566 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 11 Nov 2022 16:41:37 +0000 Subject: [PATCH 14/42] Virtual hosting schema and logic changes (#2876) Note that virtual users cannot federate correctly yet. --- appservice/appservice.go | 6 +- clientapi/auth/password.go | 21 +- clientapi/routing/admin.go | 5 + clientapi/routing/login.go | 1 + clientapi/routing/notification.go | 11 +- clientapi/routing/password.go | 12 +- clientapi/routing/pusher.go | 8 +- clientapi/routing/register.go | 10 +- clientapi/routing/routing.go | 2 +- clientapi/routing/server_notices.go | 8 +- clientapi/routing/threepid.go | 14 +- federationapi/federationapi.go | 18 +- federationapi/federationapi_keys_test.go | 2 +- federationapi/queue/destinationqueue.go | 2 +- federationapi/queue/queue.go | 26 +-- federationapi/queue/queue_test.go | 10 +- federationapi/routing/keys.go | 24 ++- federationapi/routing/routing.go | 2 +- keyserver/internal/internal.go | 30 +-- keyserver/keyserver.go | 8 +- userapi/api/api.go | 55 +++-- userapi/api/api_trace.go | 4 +- userapi/consumers/clientapi.go | 4 +- userapi/consumers/roomserver.go | 58 ++--- userapi/internal/api.go | 101 +++++---- userapi/internal/api_logintoken.go | 2 +- userapi/inthttp/client.go | 3 +- userapi/inthttp/server.go | 15 +- userapi/producers/syncapi.go | 4 +- userapi/storage/interface.go | 66 +++--- .../storage/postgres/account_data_table.go | 33 +-- userapi/storage/postgres/accounts_table.go | 57 ++--- .../deltas/2022110411000000_server_names.go | 81 +++++++ .../deltas/2022110411000001_server_names.go | 28 +++ userapi/storage/postgres/devices_table.go | 88 ++++---- .../storage/postgres/notifications_table.go | 49 ++--- userapi/storage/postgres/openid_table.go | 15 +- userapi/storage/postgres/profile_table.go | 48 +++-- userapi/storage/postgres/pusher_table.go | 32 +-- userapi/storage/postgres/storage.go | 28 ++- userapi/storage/postgres/threepid_table.go | 26 ++- userapi/storage/shared/storage.go | 204 ++++++++++-------- userapi/storage/sqlite3/account_data_table.go | 33 +-- userapi/storage/sqlite3/accounts_table.go | 53 ++--- .../deltas/20200929203058_is_active.go | 1 + .../deltas/20201001204705_last_seen_ts_ip.go | 1 + .../2022021012490600_add_account_type.go | 1 + .../deltas/2022110411000000_server_names.go | 108 ++++++++++ .../deltas/2022110411000001_server_names.go | 28 +++ userapi/storage/sqlite3/devices_table.go | 92 ++++---- .../storage/sqlite3/notifications_table.go | 49 ++--- userapi/storage/sqlite3/openid_table.go | 15 +- userapi/storage/sqlite3/profile_table.go | 52 +++-- userapi/storage/sqlite3/pusher_table.go | 32 +-- userapi/storage/sqlite3/storage.go | 28 ++- userapi/storage/sqlite3/threepid_table.go | 26 ++- userapi/storage/storage_test.go | 132 ++++++------ userapi/storage/tables/interface.go | 68 +++--- userapi/storage/tables/stats_table_test.go | 17 +- userapi/userapi_test.go | 10 +- userapi/util/devices.go | 8 +- userapi/util/notify.go | 7 +- 62 files changed, 1250 insertions(+), 732 deletions(-) create mode 100644 userapi/storage/postgres/deltas/2022110411000000_server_names.go create mode 100644 userapi/storage/postgres/deltas/2022110411000001_server_names.go create mode 100644 userapi/storage/sqlite3/deltas/2022110411000000_server_names.go create mode 100644 userapi/storage/sqlite3/deltas/2022110411000001_server_names.go diff --git a/appservice/appservice.go b/appservice/appservice.go index 0c778b6ca..b3c28dbde 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -32,6 +32,7 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" ) // AddInternalRoutes registers HTTP handlers for internal API calls @@ -74,7 +75,7 @@ func NewInternalAPI( // events to be sent out. for _, appservice := range base.Cfg.Derived.ApplicationServices { // Create bot account for this AS if it doesn't already exist - if err := generateAppServiceAccount(userAPI, appservice); err != nil { + if err := generateAppServiceAccount(userAPI, appservice, base.Cfg.Global.ServerName); err != nil { logrus.WithFields(logrus.Fields{ "appservice": appservice.ID, }).WithError(err).Panicf("failed to generate bot account for appservice") @@ -101,11 +102,13 @@ func NewInternalAPI( func generateAppServiceAccount( userAPI userapi.AppserviceUserAPI, as config.ApplicationService, + serverName gomatrixserverlib.ServerName, ) error { var accRes userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(context.Background(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeAppService, Localpart: as.SenderLocalpart, + ServerName: serverName, AppServiceID: as.ID, OnConflict: userapi.ConflictUpdate, }, &accRes) @@ -115,6 +118,7 @@ func generateAppServiceAccount( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(context.Background(), &userapi.PerformDeviceCreationRequest{ Localpart: as.SenderLocalpart, + ServerName: serverName, AccessToken: as.ASToken, DeviceID: &as.SenderLocalpart, DeviceDisplayName: &as.SenderLocalpart, diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 700a72f5d..4de2b443c 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -61,7 +61,7 @@ func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { r := req.(*PasswordRequest) - username := strings.ToLower(r.Username()) + username := r.Username() if username == "" { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, @@ -74,32 +74,43 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, JSON: jsonerror.BadJSON("A password must be supplied."), } } - localpart, _, err := userutil.ParseUsernameParam(username, t.Config.Matrix) + localpart, domain, err := userutil.ParseUsernameParam(username, t.Config.Matrix) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusUnauthorized, JSON: jsonerror.InvalidUsername(err.Error()), } } + if !t.Config.Matrix.IsLocalServerName(domain) { + return nil, &util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.InvalidUsername("The server name is not known."), + } + } // Squash username to all lowercase letters res := &api.QueryAccountByPasswordResponse{} - err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{Localpart: strings.ToLower(localpart), PlaintextPassword: r.Password}, res) + err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ + Localpart: strings.ToLower(localpart), + ServerName: domain, + PlaintextPassword: r.Password, + }, res) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("unable to fetch account by password"), + JSON: jsonerror.Unknown("Unable to fetch account by password."), } } if !res.Exists { err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: localpart, + ServerName: domain, PlaintextPassword: r.Password, }, res) if err != nil { return nil, &util.JSONResponse{ Code: http.StatusInternalServerError, - JSON: jsonerror.Unknown("unable to fetch account by password"), + JSON: jsonerror.Unknown("Unable to fetch account by password."), } } // Technically we could tell them if the user does not exist by checking if err == sql.ErrNoRows diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 9088f7716..9ed1f0ca2 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -102,6 +102,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap if err != nil { return util.ErrorResponse(err) } + serverName := cfg.Matrix.ServerName localpart, ok := vars["localpart"] if !ok { return util.JSONResponse{ @@ -109,6 +110,9 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.MissingArgument("Expecting user localpart."), } } + if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil { + localpart, serverName = l, s + } request := struct { Password string `json:"password"` }{} @@ -126,6 +130,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } updateReq := &userapi.PerformPasswordUpdateRequest{ Localpart: localpart, + ServerName: serverName, Password: request.Password, LogoutDevices: true, } diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 7f5a8c4f8..0de324da1 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -100,6 +100,7 @@ func completeAuth( DeviceID: login.DeviceID, AccessToken: token, Localpart: localpart, + ServerName: serverName, IPAddr: ipAddr, UserAgent: userAgent, }, &performRes) diff --git a/clientapi/routing/notification.go b/clientapi/routing/notification.go index 8a424a141..f593e27db 100644 --- a/clientapi/routing/notification.go +++ b/clientapi/routing/notification.go @@ -40,16 +40,17 @@ func GetNotifications( } var queryRes userapi.QueryNotificationsResponse - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") return jsonerror.InternalServerError() } err = userAPI.QueryNotifications(req.Context(), &userapi.QueryNotificationsRequest{ - Localpart: localpart, - From: req.URL.Query().Get("from"), - Limit: int(limit), - Only: req.URL.Query().Get("only"), + Localpart: localpart, + ServerName: domain, + From: req.URL.Query().Get("from"), + Limit: int(limit), + Only: req.URL.Query().Get("only"), }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryNotifications failed") diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 6dc9af508..9772f669a 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -86,7 +86,7 @@ func Password( } // Get the local part. - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() @@ -94,8 +94,9 @@ func Password( // Ask the user API to perform the password change. passwordReq := &api.PerformPasswordUpdateRequest{ - Localpart: localpart, - Password: r.NewPassword, + Localpart: localpart, + ServerName: domain, + Password: r.NewPassword, } passwordRes := &api.PerformPasswordUpdateResponse{} if err := userAPI.PerformPasswordUpdate(req.Context(), passwordReq, passwordRes); err != nil { @@ -122,8 +123,9 @@ func Password( } pushersReq := &api.PerformPusherDeletionRequest{ - Localpart: localpart, - SessionID: device.SessionID, + Localpart: localpart, + ServerName: domain, + SessionID: device.SessionID, } if err := userAPI.PerformPusherDeletion(req.Context(), pushersReq, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("PerformPusherDeletion failed") diff --git a/clientapi/routing/pusher.go b/clientapi/routing/pusher.go index d6a6eb936..89ec824bf 100644 --- a/clientapi/routing/pusher.go +++ b/clientapi/routing/pusher.go @@ -31,13 +31,14 @@ func GetPushers( userAPI userapi.ClientUserAPI, ) util.JSONResponse { var queryRes userapi.QueryPushersResponse - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") return jsonerror.InternalServerError() } err = userAPI.QueryPushers(req.Context(), &userapi.QueryPushersRequest{ - Localpart: localpart, + Localpart: localpart, + ServerName: domain, }, &queryRes) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("QueryPushers failed") @@ -59,7 +60,7 @@ func SetPusher( req *http.Request, device *userapi.Device, userAPI userapi.ClientUserAPI, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("SplitID failed") return jsonerror.InternalServerError() @@ -93,6 +94,7 @@ func SetPusher( } body.Localpart = localpart + body.ServerName = domain body.SessionID = device.SessionID err = userAPI.PerformPusherSet(req.Context(), &body, &struct{}{}) if err != nil { diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index b9ebb0518..9dc63af84 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -588,12 +588,15 @@ func Register( } // Auto generate a numeric username if r.Username is empty if r.Username == "" { - res := &userapi.QueryNumericLocalpartResponse{} - if err := userAPI.QueryNumericLocalpart(req.Context(), res); err != nil { + nreq := &userapi.QueryNumericLocalpartRequest{ + ServerName: cfg.Matrix.ServerName, // TODO: might not be right + } + nres := &userapi.QueryNumericLocalpartResponse{} + if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") return jsonerror.InternalServerError() } - r.Username = strconv.FormatInt(res.ID, 10) + r.Username = strconv.FormatInt(nres.ID, 10) } // Is this an appservice registration? It will be if the access @@ -676,6 +679,7 @@ func handleGuestRegistration( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(req.Context(), &userapi.PerformDeviceCreationRequest{ Localpart: res.Account.Localpart, + ServerName: res.Account.ServerName, DeviceDisplayName: r.InitialDisplayName, AccessToken: token, IPAddr: req.RemoteAddr, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 1b3ef120a..a510761eb 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -157,7 +157,7 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) - dendriteAdminRouter.Handle("/admin/resetPassword/{localpart}", + dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) }), diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index a6a78061d..a7acee326 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -286,6 +286,7 @@ func getSenderDevice( err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeUser, Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, OnConflict: userapi.ConflictUpdate, }, &accRes) if err != nil { @@ -295,8 +296,9 @@ func getSenderDevice( // Set the avatarurl for the user avatarRes := &userapi.PerformSetAvatarURLResponse{} if err = userAPI.SetAvatarURL(ctx, &userapi.PerformSetAvatarURLRequest{ - Localpart: cfg.Matrix.ServerNotices.LocalPart, - AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, + Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, + AvatarURL: cfg.Matrix.ServerNotices.AvatarURL, }, avatarRes); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetAvatarURL failed") return nil, err @@ -308,6 +310,7 @@ func getSenderDevice( displayNameRes := &userapi.PerformUpdateDisplayNameResponse{} if err = userAPI.SetDisplayName(ctx, &userapi.PerformUpdateDisplayNameRequest{ Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, DisplayName: cfg.Matrix.ServerNotices.DisplayName, }, displayNameRes); err != nil { util.GetLogger(ctx).WithError(err).Error("userAPI.SetDisplayName failed") @@ -353,6 +356,7 @@ func getSenderDevice( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: cfg.Matrix.ServerNotices.LocalPart, + ServerName: cfg.Matrix.ServerName, DeviceDisplayName: &cfg.Matrix.ServerNotices.LocalPart, AccessToken: token, NoDeviceListUpdate: true, diff --git a/clientapi/routing/threepid.go b/clientapi/routing/threepid.go index 4b7989ecb..971bfcad3 100644 --- a/clientapi/routing/threepid.go +++ b/clientapi/routing/threepid.go @@ -136,16 +136,17 @@ func CheckAndSave3PIDAssociation( } // Save the association in the database - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() } if err = threePIDAPI.PerformSaveThreePIDAssociation(req.Context(), &api.PerformSaveThreePIDAssociationRequest{ - ThreePID: address, - Localpart: localpart, - Medium: medium, + ThreePID: address, + Localpart: localpart, + ServerName: domain, + Medium: medium, }, &struct{}{}); err != nil { util.GetLogger(req.Context()).WithError(err).Error("threePIDAPI.PerformSaveThreePIDAssociation failed") return jsonerror.InternalServerError() @@ -161,7 +162,7 @@ func CheckAndSave3PIDAssociation( func GetAssociated3PIDs( req *http.Request, threepidAPI api.ClientUserAPI, device *api.Device, ) util.JSONResponse { - localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', device.UserID) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") return jsonerror.InternalServerError() @@ -169,7 +170,8 @@ func GetAssociated3PIDs( res := &api.QueryThreePIDsForLocalpartResponse{} err = threepidAPI.QueryThreePIDsForLocalpart(req.Context(), &api.QueryThreePIDsForLocalpartRequest{ - Localpart: localpart, + Localpart: localpart, + ServerName: domain, }, res) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("threepidAPI.QueryThreePIDsForLocalpart failed") diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index e35b9c7f0..4578e33aa 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -120,15 +120,23 @@ func NewInternalAPI( js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) + signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{} + for _, serverName := range append( + []gomatrixserverlib.ServerName{base.Cfg.Global.ServerName}, + base.Cfg.Global.SecondaryServerNames..., + ) { + signingInfo[serverName] = &queue.SigningInfo{ + KeyID: cfg.Matrix.KeyID, + PrivateKey: cfg.Matrix.PrivateKey, + ServerName: serverName, + } + } + queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, cfg.Matrix.DisableFederation, cfg.Matrix.ServerName, federation, rsAPI, &stats, - &queue.SigningInfo{ - KeyID: cfg.Matrix.KeyID, - PrivateKey: cfg.Matrix.PrivateKey, - ServerName: cfg.Matrix.ServerName, - }, + signingInfo, ) rsConsumer := consumers.NewOutputRoomEventConsumer( diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 7ccc02f76..3acaa70dc 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -137,7 +137,7 @@ func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err } // Get the keys and JSON-ify them. - keys := routing.LocalKeys(s.config) + keys := routing.LocalKeys(s.config, gomatrixserverlib.ServerName(req.Host)) body, err := json.MarshalIndent(keys.JSON, "", " ") if err != nil { return nil, err diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a638a5742..bf04ee99a 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -50,7 +50,7 @@ type destinationQueue struct { queues *OutgoingQueues db storage.Database process *process.ProcessContext - signing *SigningInfo + signing map[gomatrixserverlib.ServerName]*SigningInfo rsAPI api.FederationRoomserverAPI client fedapi.FederationClient // federation client origin gomatrixserverlib.ServerName // origin of requests diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index b5d0552c6..68f354993 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -46,7 +46,7 @@ type OutgoingQueues struct { origin gomatrixserverlib.ServerName client fedapi.FederationClient statistics *statistics.Statistics - signing *SigningInfo + signing map[gomatrixserverlib.ServerName]*SigningInfo queuesMutex sync.Mutex // protects the below queues map[gomatrixserverlib.ServerName]*destinationQueue } @@ -91,7 +91,7 @@ func NewOutgoingQueues( client fedapi.FederationClient, rsAPI api.FederationRoomserverAPI, statistics *statistics.Statistics, - signing *SigningInfo, + signing map[gomatrixserverlib.ServerName]*SigningInfo, ) *OutgoingQueues { queues := &OutgoingQueues{ disabled: disabled, @@ -199,11 +199,10 @@ func (oqs *OutgoingQueues) SendEvent( log.Trace("Federation is disabled, not sending event") return nil } - if origin != oqs.origin { - // TODO: Support virtual hosting; gh issue #577. + if _, ok := oqs.signing[origin]; !ok { return fmt.Errorf( - "sendevent: unexpected server to send as: got %q expected %q", - origin, oqs.origin, + "sendevent: unexpected server to send as %q", + origin, ) } @@ -214,7 +213,9 @@ func (oqs *OutgoingQueues) SendEvent( destmap[d] = struct{}{} } delete(destmap, oqs.origin) - delete(destmap, oqs.signing.ServerName) + for local := range oqs.signing { + delete(destmap, local) + } // Check if any of the destinations are prohibited by server ACLs. for destination := range destmap { @@ -288,11 +289,10 @@ func (oqs *OutgoingQueues) SendEDU( log.Trace("Federation is disabled, not sending EDU") return nil } - if origin != oqs.origin { - // TODO: Support virtual hosting; gh issue #577. + if _, ok := oqs.signing[origin]; !ok { return fmt.Errorf( - "sendevent: unexpected server to send as: got %q expected %q", - origin, oqs.origin, + "sendevent: unexpected server to send as %q", + origin, ) } @@ -303,7 +303,9 @@ func (oqs *OutgoingQueues) SendEDU( destmap[d] = struct{}{} } delete(destmap, oqs.origin) - delete(destmap, oqs.signing.ServerName) + for local := range oqs.signing { + delete(destmap, local) + } // There is absolutely no guarantee that the EDU will have a room_id // field, as it is not required by the spec. However, if it *does* diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 7ef4646f7..58745c607 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -350,10 +350,12 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T } rs := &stubFederationRoomServerAPI{} stats := statistics.NewStatistics(db, failuresUntilBlacklist) - signingInfo := &SigningInfo{ - KeyID: "ed21019:auto", - PrivateKey: test.PrivateKeyA, - ServerName: "localhost", + signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{ + "localhost": { + KeyID: "ed21019:auto", + PrivateKey: test.PrivateKeyA, + ServerName: "localhost", + }, } queues := NewOutgoingQueues(db, processContext, false, "localhost", fc, rs, &stats, signingInfo) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 8931830f3..5650e3d53 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -16,6 +16,7 @@ package routing import ( "encoding/json" + "fmt" "net/http" "time" @@ -134,18 +135,21 @@ func ClaimOneTimeKeys( // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys -func LocalKeys(cfg *config.FederationAPI) util.JSONResponse { - keys, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) +func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse { + keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) if err != nil { return util.ErrorResponse(err) } return util.JSONResponse{Code: http.StatusOK, JSON: keys} } -func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { +func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys + if !cfg.Matrix.IsLocalServerName(serverName) { + return nil, fmt.Errorf("server name not known") + } - keys.ServerName = cfg.Matrix.ServerName + keys.ServerName = serverName keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) @@ -172,7 +176,7 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver } keys.Raw, err = gomatrixserverlib.SignJSON( - string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, + string(serverName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, ) if err != nil { return nil, err @@ -186,6 +190,14 @@ func NotaryKeys( fsAPI federationAPI.FederationInternalAPI, req *gomatrixserverlib.PublicKeyNotaryLookupRequest, ) util.JSONResponse { + serverName := gomatrixserverlib.ServerName(httpReq.Host) // TODO: this is not ideal + if !cfg.Matrix.IsLocalServerName(serverName) { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("Server name not known"), + } + } + if req == nil { req = &gomatrixserverlib.PublicKeyNotaryLookupRequest{} if reqErr := clienthttputil.UnmarshalJSONRequest(httpReq, &req); reqErr != nil { @@ -201,7 +213,7 @@ func NotaryKeys( for serverName, kidToCriteria := range req.ServerKeys { var keyList []gomatrixserverlib.ServerKeys if serverName == cfg.Matrix.ServerName { - if k, err := localKeys(cfg, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { + if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { keyList = append(keyList, *k) } else { return util.ErrorResponse(err) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 9f16e5093..0a3ab7a88 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -74,7 +74,7 @@ func Setup( } localKeys := httputil.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse { - return LocalKeys(cfg) + return LocalKeys(cfg, gomatrixserverlib.ServerName(req.Host)) }) notaryKeys := httputil.MakeExternalAPI("notarykeys", func(req *http.Request) util.JSONResponse { diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 92ee80d81..37c55c8f8 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -33,16 +33,17 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/producers" "github.com/matrix-org/dendrite/keyserver/storage" + "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" ) type KeyInternalAPI struct { - DB storage.Database - ThisServer gomatrixserverlib.ServerName - FedClient fedsenderapi.KeyserverFederationAPI - UserAPI userapi.KeyserverUserAPI - Producer *producers.KeyChange - Updater *DeviceListUpdater + DB storage.Database + Cfg *config.KeyServer + FedClient fedsenderapi.KeyserverFederationAPI + UserAPI userapi.KeyserverUserAPI + Producer *producers.KeyChange + Updater *DeviceListUpdater } func (a *KeyInternalAPI) SetUserAPI(i userapi.KeyserverUserAPI) { @@ -95,8 +96,11 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC nested[userID] = val domainToDeviceKeys[string(serverName)] = nested } - // claim local keys - if local, ok := domainToDeviceKeys[string(a.ThisServer)]; ok { + for domain, local := range domainToDeviceKeys { + if !a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { + continue + } + // claim local keys keys, err := a.DB.ClaimKeys(ctx, local) if err != nil { res.Error = &api.KeyError{ @@ -117,7 +121,7 @@ func (a *KeyInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformC res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON } } - delete(domainToDeviceKeys, string(a.ThisServer)) + delete(domainToDeviceKeys, domain) } if len(domainToDeviceKeys) > 0 { a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) @@ -258,7 +262,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } domain := string(serverName) // query local devices - if serverName == a.ThisServer { + if a.Cfg.Matrix.IsLocalServerName(serverName) { deviceKeys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs, false) if err != nil { res.Error = &api.KeyError{ @@ -437,13 +441,13 @@ func (a *KeyInternalAPI) queryRemoteKeys( domains := map[string]struct{}{} for domain := range domainToDeviceKeys { - if domain == string(a.ThisServer) { + if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} } for domain := range domainToCrossSigningKeys { - if domain == string(a.ThisServer) { + if a.Cfg.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) { continue } domains[domain] = struct{}{} @@ -689,7 +693,7 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per if err != nil { continue // ignore invalid users } - if serverName != a.ThisServer { + if !a.Cfg.Matrix.IsLocalServerName(serverName) { continue // ignore remote users } if len(key.KeyJSON) == 0 { diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 9ae4f9ca3..0a4b8fde6 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -53,10 +53,10 @@ func NewInternalAPI( DB: db, } ap := &internal.KeyInternalAPI{ - DB: db, - ThisServer: cfg.Matrix.ServerName, - FedClient: fedClient, - Producer: keyChangeProducer, + DB: db, + Cfg: cfg, + FedClient: fedClient, + Producer: keyChangeProducer, } updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable ap.Updater = updater diff --git a/userapi/api/api.go b/userapi/api/api.go index 8d7f783de..d3f5aefc8 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -78,7 +78,7 @@ type ClientUserAPI interface { QueryAcccessTokenAPI LoginTokenInternalAPI UserLoginAPI - QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error + QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error @@ -335,9 +335,10 @@ type PerformAccountCreationResponse struct { // PerformAccountCreationRequest is the request for PerformAccountCreation type PerformPasswordUpdateRequest struct { - Localpart string // Required: The localpart for this account. - Password string // Required: The new password to set. - LogoutDevices bool // Optional: Whether to log out all user devices. + Localpart string // Required: The localpart for this account. + ServerName gomatrixserverlib.ServerName // Required: The domain for this account. + Password string // Required: The new password to set. + LogoutDevices bool // Optional: Whether to log out all user devices. } // PerformAccountCreationResponse is the response for PerformAccountCreation @@ -518,7 +519,8 @@ const ( ) type QueryPushersRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryPushersResponse struct { @@ -526,14 +528,16 @@ type QueryPushersResponse struct { } type PerformPusherSetRequest struct { - Pusher // Anonymous field because that's how clientapi unmarshals it. - Localpart string - Append bool `json:"append"` + Pusher // Anonymous field because that's how clientapi unmarshals it. + Localpart string + ServerName gomatrixserverlib.ServerName + Append bool `json:"append"` } type PerformPusherDeletionRequest struct { - Localpart string - SessionID int64 + Localpart string + ServerName gomatrixserverlib.ServerName + SessionID int64 } // Pusher represents a push notification subscriber @@ -571,10 +575,11 @@ type QueryPushRulesResponse struct { } type QueryNotificationsRequest struct { - Localpart string `json:"localpart"` // Required. - From string `json:"from,omitempty"` - Limit int `json:"limit,omitempty"` - Only string `json:"only,omitempty"` + Localpart string `json:"localpart"` // Required. + ServerName gomatrixserverlib.ServerName `json:"server_name"` // Required. + From string `json:"from,omitempty"` + Limit int `json:"limit,omitempty"` + Only string `json:"only,omitempty"` } type QueryNotificationsResponse struct { @@ -601,12 +606,17 @@ type PerformSetAvatarURLResponse struct { Changed bool `json:"changed"` } +type QueryNumericLocalpartRequest struct { + ServerName gomatrixserverlib.ServerName +} + type QueryNumericLocalpartResponse struct { ID int64 } type QueryAccountAvailabilityRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryAccountAvailabilityResponse struct { @@ -614,7 +624,9 @@ type QueryAccountAvailabilityResponse struct { } type QueryAccountByPasswordRequest struct { - Localpart, PlaintextPassword string + Localpart string + ServerName gomatrixserverlib.ServerName + PlaintextPassword string } type QueryAccountByPasswordResponse struct { @@ -638,11 +650,13 @@ type QueryLocalpartForThreePIDRequest struct { } type QueryLocalpartForThreePIDResponse struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryThreePIDsForLocalpartRequest struct { - Localpart string + Localpart string + ServerName gomatrixserverlib.ServerName } type QueryThreePIDsForLocalpartResponse struct { @@ -652,5 +666,8 @@ type QueryThreePIDsForLocalpartResponse struct { type PerformForgetThreePIDRequest QueryLocalpartForThreePIDRequest type PerformSaveThreePIDAssociationRequest struct { - ThreePID, Localpart, Medium string + ThreePID string + Localpart string + ServerName gomatrixserverlib.ServerName + Medium string } diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index 90834f7e3..ce661770f 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -156,8 +156,8 @@ func (t *UserInternalAPITrace) SetAvatarURL(ctx context.Context, req *PerformSet return err } -func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, res *QueryNumericLocalpartResponse) error { - err := t.Impl.QueryNumericLocalpart(ctx, res) +func (t *UserInternalAPITrace) QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error { + err := t.Impl.QueryNumericLocalpart(ctx, req, res) util.GetLogger(ctx).Infof("QueryNumericLocalpart req= res=%+v", js(res)) return err } diff --git a/userapi/consumers/clientapi.go b/userapi/consumers/clientapi.go index 79f1bf06f..42ae72e77 100644 --- a/userapi/consumers/clientapi.go +++ b/userapi/consumers/clientapi.go @@ -104,7 +104,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats return false } - updated, err := s.db.SetNotificationsRead(ctx, localpart, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) + updated, err := s.db.SetNotificationsRead(ctx, localpart, domain, roomID, uint64(gomatrixserverlib.AsTimestamp(metadata.Timestamp)), true) if err != nil { log.WithError(err).Error("userapi EDU consumer") return false @@ -118,7 +118,7 @@ func (s *OutputReceiptEventConsumer) onMessage(ctx context.Context, msgs []*nats if !updated { return true } - if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, s.db); err != nil { + if err = util.NotifyUserCountsAsync(ctx, s.pgClient, localpart, domain, s.db); err != nil { log.WithError(err).Error("userapi EDU consumer: NotifyUserCounts failed") return false } diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index b6b30a095..5d8924dda 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -192,25 +192,25 @@ func (s *OutputRoomEventConsumer) storeMessageStats(ctx context.Context, eventTy func (s *OutputRoomEventConsumer) handleRoomUpgrade(ctx context.Context, oldRoomID, newRoomID string, localMembers []*localMembership, roomSize int) error { for _, membership := range localMembers { // Copy any existing push rules from old -> new room - if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + if err := s.copyPushrules(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { return err } // preserve m.direct room state - if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, roomSize); err != nil { + if err := s.updateMDirect(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain, roomSize); err != nil { return err } // copy existing m.tag entries, if any - if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart); err != nil { + if err := s.copyTags(ctx, oldRoomID, newRoomID, membership.Localpart, membership.Domain); err != nil { return err } } return nil } -func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string) error { - pushRules, err := s.db.QueryPushRules(ctx, localpart) +func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, newRoomID string, localpart string, serverName gomatrixserverlib.ServerName) error { + pushRules, err := s.db.QueryPushRules(ctx, localpart, serverName) if err != nil { return fmt.Errorf("failed to query pushrules for user: %w", err) } @@ -229,7 +229,7 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, if err != nil { return err } - if err = s.db.SaveAccountData(ctx, localpart, "", "m.push_rules", rules); err != nil { + if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.push_rules", rules); err != nil { return fmt.Errorf("failed to update pushrules: %w", err) } } @@ -237,13 +237,13 @@ func (s *OutputRoomEventConsumer) copyPushrules(ctx context.Context, oldRoomID, } // updateMDirect copies the "is_direct" flag from oldRoomID to newROomID -func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, roomSize int) error { +func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName, roomSize int) error { // this is most likely not a DM, so skip updating m.direct state if roomSize > 2 { return nil } // Get direct message state - directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, "", "m.direct") + directChatsRaw, err := s.db.GetAccountDataByType(ctx, localpart, serverName, "", "m.direct") if err != nil { return fmt.Errorf("failed to get m.direct from database: %w", err) } @@ -267,7 +267,7 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, if err != nil { return true } - if err = s.db.SaveAccountData(ctx, localpart, "", "m.direct", data); err != nil { + if err = s.db.SaveAccountData(ctx, localpart, serverName, "", "m.direct", data); err != nil { return true } } @@ -279,15 +279,15 @@ func (s *OutputRoomEventConsumer) updateMDirect(ctx context.Context, oldRoomID, return nil } -func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string) error { - tag, err := s.db.GetAccountDataByType(ctx, localpart, oldRoomID, "m.tag") +func (s *OutputRoomEventConsumer) copyTags(ctx context.Context, oldRoomID, newRoomID, localpart string, serverName gomatrixserverlib.ServerName) error { + tag, err := s.db.GetAccountDataByType(ctx, localpart, serverName, oldRoomID, "m.tag") if err != nil && !errors.Is(err, sql.ErrNoRows) { return err } if tag == nil { return nil } - return s.db.SaveAccountData(ctx, localpart, newRoomID, "m.tag", tag) + return s.db.SaveAccountData(ctx, localpart, serverName, newRoomID, "m.tag", tag) } func (s *OutputRoomEventConsumer) processMessage(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, streamPos uint64) error { @@ -492,11 +492,11 @@ func unmarshalCanonicalAlias(event *gomatrixserverlib.HeaderedEvent) (string, er func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatrixserverlib.HeaderedEvent, mem *localMembership, roomSize int, roomName string, streamPos uint64) error { actions, err := s.evaluatePushRules(ctx, event, mem, roomSize) if err != nil { - return err + return fmt.Errorf("s.evaluatePushRules: %w", err) } a, tweaks, err := pushrules.ActionsToTweaks(actions) if err != nil { - return err + return fmt.Errorf("pushrules.ActionsToTweaks: %w", err) } // TODO: support coalescing. if a != pushrules.NotifyAction && a != pushrules.CoalesceAction { @@ -508,9 +508,9 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr return nil } - devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, tweaks) + devicesByURLAndFormat, profileTag, err := s.localPushDevices(ctx, mem.Localpart, mem.Domain, tweaks) if err != nil { - return err + return fmt.Errorf("s.localPushDevices: %w", err) } n := &api.Notification{ @@ -527,18 +527,18 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr RoomID: event.RoomID(), TS: gomatrixserverlib.AsTimestamp(time.Now()), } - if err = s.db.InsertNotification(ctx, mem.Localpart, event.EventID(), streamPos, tweaks, n); err != nil { - return err + if err = s.db.InsertNotification(ctx, mem.Localpart, mem.Domain, event.EventID(), streamPos, tweaks, n); err != nil { + return fmt.Errorf("s.db.InsertNotification: %w", err) } if err = s.syncProducer.GetAndSendNotificationData(ctx, mem.UserID, event.RoomID()); err != nil { - return err + return fmt.Errorf("s.syncProducer.GetAndSendNotificationData: %w", err) } // We do this after InsertNotification. Thus, this should always return >=1. - userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, tables.AllNotifications) + userNumUnreadNotifs, err := s.db.GetNotificationCount(ctx, mem.Localpart, mem.Domain, tables.AllNotifications) if err != nil { - return err + return fmt.Errorf("s.db.GetNotificationCount: %w", err) } log.WithFields(log.Fields{ @@ -589,7 +589,7 @@ func (s *OutputRoomEventConsumer) notifyLocal(ctx context.Context, event *gomatr } if len(rejected) > 0 { - s.deleteRejectedPushers(ctx, rejected, mem.Localpart) + s.deleteRejectedPushers(ctx, rejected, mem.Localpart, mem.Domain) } }() @@ -606,7 +606,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * } // Get accountdata to check if the event.Sender() is ignored by mem.LocalPart - data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, "", "m.ignored_user_list") + data, err := s.db.GetAccountDataByType(ctx, mem.Localpart, mem.Domain, "", "m.ignored_user_list") if err != nil { return nil, err } @@ -621,7 +621,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * return nil, fmt.Errorf("user %s is ignored", sender) } } - ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart) + ruleSets, err := s.db.QueryPushRules(ctx, mem.Localpart, mem.Domain) if err != nil { return nil, err } @@ -693,10 +693,10 @@ func (rse *ruleSetEvalContext) HasPowerLevel(userID, levelKey string) (bool, err // localPushDevices pushes to the configured devices of a local // user. The map keys are [url][format]. -func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { - pusherDevices, err := util.GetPushDevices(ctx, localpart, tweaks, s.db) +func (s *OutputRoomEventConsumer) localPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}) (map[string]map[string][]*pushgateway.Device, string, error) { + pusherDevices, err := util.GetPushDevices(ctx, localpart, serverName, tweaks, s.db) if err != nil { - return nil, "", err + return nil, "", fmt.Errorf("util.GetPushDevices: %w", err) } var profileTag string @@ -791,7 +791,7 @@ func (s *OutputRoomEventConsumer) notifyHTTP(ctx context.Context, event *gomatri } // deleteRejectedPushers deletes the pushers associated with the given devices. -func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string) { +func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, devices []*pushgateway.Device, localpart string, serverName gomatrixserverlib.ServerName) { log.WithFields(log.Fields{ "localpart": localpart, "app_id0": devices[0].AppID, @@ -799,7 +799,7 @@ func (s *OutputRoomEventConsumer) deleteRejectedPushers(ctx context.Context, dev }).Warnf("Deleting pushers rejected by the HTTP push gateway") for _, d := range devices { - if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart); err != nil { + if err := s.db.RemovePusher(ctx, d.AppID, d.PushKey, localpart, serverName); err != nil { log.WithFields(log.Fields{ "localpart": localpart, }).WithError(err).Errorf("Unable to delete rejected pusher") diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 9ca76965d..3f256457e 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc if req.DataType == "" { return fmt.Errorf("data type must not be empty") } - if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil { + if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil { util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed") return fmt.Errorf("failed to save account data: %w", err) } @@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } - deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) + deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now()))) if err != nil { logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed") return err @@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun return nil } - if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil { + if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil { logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed") return err } @@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P if serverName == "" { serverName = a.Config.Matrix.ServerName } - // XXXX: Use the server name here - acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType) + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %s is not local", serverName) + } + acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType) if err != nil { if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists switch req.OnConflict { @@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P return nil } - if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil { - return err + if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil { + return fmt.Errorf("a.DB.SetDisplayName: %w", err) } postRegisterJoinRooms(a.Cfg, acc, a.RSAPI) @@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P } func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error { - if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil { + if !a.Config.Matrix.IsLocalServerName(req.ServerName) { + return fmt.Errorf("server name %s is not local", req.ServerName) + } + if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil { return err } if req.LogoutDevices { - if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil { + if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil { return err } } @@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe if serverName == "" { serverName = a.Config.Matrix.ServerName } - _ = serverName - // XXXX: Use the server name here + if !a.Config.Matrix.IsLocalServerName(serverName) { + return fmt.Errorf("server name %s is not local", serverName) + } util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, "display_name": req.DeviceDisplayName, }).Info("PerformDeviceCreation") - dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) + dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent) if err != nil { return err } @@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe deletedDeviceIDs := req.DeviceIDs if len(req.DeviceIDs) == 0 { var devices []api.Device - devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID) + devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID) for _, d := range devices { deletedDeviceIDs = append(deletedDeviceIDs, d.ID) } } else { - err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs) + err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs) } if err != nil { return err @@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate( req *api.PerformLastSeenUpdateRequest, res *api.PerformLastSeenUpdateResponse, ) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } - if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } + if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil { return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err) } return nil } func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID) if err != nil { util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed") return err } - dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID) + if !a.Config.Matrix.IsLocalServerName(domain) { + return fmt.Errorf("server name %s is not local", domain) + } + dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID) if err == sql.ErrNoRows { res.DeviceExists = false return nil @@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf return nil } - err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName) + err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName) if err != nil { util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed") return err @@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query profile of remote users (server name %s)", domain) } - prof, err := a.DB.GetProfileByLocalpart(ctx, local) + prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain) if err != nil { if err == sql.ErrNoRows { return nil @@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot query devices of remote users (server name %s)", domain) } - devs, err := a.DB.GetDevicesByLocalpart(ctx, local) + devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain) if err != nil { return err } @@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } if req.DataType != "" { var data json.RawMessage - data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType) + data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType) if err != nil { return err } @@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc } return nil } - global, rooms, err := a.DB.GetAccountData(ctx, local) + global, rooms, err := a.DB.GetAccountData(ctx, local, domain) if err != nil { return err } @@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc if !a.Config.Matrix.IsLocalServerName(domain) { return nil } - acc, err := a.DB.GetAccountByLocalpart(ctx, localPart) + acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain) if err != nil { return err } @@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe AccountType: api.AccountTypeAppService, } - localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) + localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix) if err != nil { return nil, err } if localpart != "" { // AS is masquerading as another user // Verify that the user is registered - account, err := a.DB.GetAccountByLocalpart(ctx, localpart) + account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain) // Verify that the account exists and either appServiceID matches or // it belongs to the appservice user namespaces if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) { @@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a return err } - err := a.DB.DeactivateAccount(ctx, req.Localpart) + err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName) res.AccountDeactivated = err == nil return err } @@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query if req.Only == "highlight" { filter = tables.HighlightNotifications } - notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter) + notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter) if err != nil { return err } @@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform } } if req.Pusher.Kind == "" { - return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart) + return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName) } if req.Pusher.PushKeyTS == 0 { req.Pusher.PushKeyTS = int64(time.Now().Unix()) } - return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart) + return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName) } func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error { - pushers, err := a.DB.GetPushers(ctx, req.Localpart) + pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName) if err != nil { return err } for i := range pushers { logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID) if pushers[i].SessionID != req.SessionID { - err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart) + err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName) if err != nil { return err } @@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error { var err error - res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart) + res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName) return err } @@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut( } func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error { - localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID) + localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID) if err != nil { return fmt.Errorf("failed to split user ID %q for push rules", req.UserID) } - pushRules, err := a.DB.QueryPushRules(ctx, localpart) + pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain) if err != nil { return fmt.Errorf("failed to query push rules: %w", err) } @@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush } func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error { - profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL) + profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL) res.Profile = profile res.Changed = changed return err } -func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error { - id, err := a.DB.GetNewNumericLocalpart(ctx) +func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error { + id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName) if err != nil { return err } @@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error { var err error - res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart) + res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName) return err } func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error { - acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword) + acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword) switch err { case sql.ErrNoRows: // user does not exist return nil @@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q } func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error { - profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName) + profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName) res.Profile = profile res.Changed = changed return err } func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error { - localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) + localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium) if err != nil { return err } res.Localpart = localpart + res.ServerName = domain return nil } func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error { - r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart) + r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName) if err != nil { return err } @@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error { - return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium) + return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium) } const pushRulesAccountDataType = "m.push_rules" diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go index 87f25e5e2..3b211db5b 100644 --- a/userapi/internal/api_logintoken.go +++ b/userapi/internal/api_logintoken.go @@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog if !a.Config.Matrix.IsLocalServerName(domain) { return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain) } - if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil { + if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil { res.Data = nil if err == sql.ErrNoRows { return nil diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index aa5d46d9f..87ae058c2 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -355,11 +355,12 @@ func (h *httpUserInternalAPI) SetAvatarURL( func (h *httpUserInternalAPI) QueryNumericLocalpart( ctx context.Context, + request *api.QueryNumericLocalpartRequest, response *api.QueryNumericLocalpartResponse, ) error { return httputil.CallInternalRPCAPI( "QueryNumericLocalpart", h.apiURL+QueryNumericLocalpartPath, - h.httpClient, ctx, &struct{}{}, response, + h.httpClient, ctx, request, response, ) } diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 99148b760..661fecfae 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -15,12 +15,9 @@ package inthttp import ( - "net/http" - "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/util" ) // nolint: gocyclo @@ -152,15 +149,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), ) - // TODO: Look at the shape of this - internalAPIMux.Handle(QueryNumericLocalpartPath, - httputil.MakeInternalAPI("UserAPIQueryNumericLocalpart", func(req *http.Request) util.JSONResponse { - response := api.QueryNumericLocalpartResponse{} - if err := s.QueryNumericLocalpart(req.Context(), &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), + internalAPIMux.Handle( + QueryNumericLocalpartPath, + httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart), ) internalAPIMux.Handle( diff --git a/userapi/producers/syncapi.go b/userapi/producers/syncapi.go index f556ea352..51eaa9856 100644 --- a/userapi/producers/syncapi.go +++ b/userapi/producers/syncapi.go @@ -61,12 +61,12 @@ func (p *SyncAPI) SendAccountData(userID string, data eventutil.AccountData) err // GetAndSendNotificationData reads the database and sends data about unread // notifications to the Sync API server. func (p *SyncAPI) GetAndSendNotificationData(ctx context.Context, userID, roomID string) error { - localpart, _, err := gomatrixserverlib.SplitID('@', userID) + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) if err != nil { return err } - ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, roomID) + ntotal, nhighlight, err := p.db.GetRoomNotificationCounts(ctx, localpart, domain, roomID) if err != nil { return err } diff --git a/userapi/storage/interface.go b/userapi/storage/interface.go index 28ef26559..c22b7658f 100644 --- a/userapi/storage/interface.go +++ b/userapi/storage/interface.go @@ -29,40 +29,40 @@ import ( ) type Profile interface { - GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) + GetProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) - SetAvatarURL(ctx context.Context, localpart string, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, localpart string, displayName string) (*authtypes.Profile, bool, error) + SetAvatarURL(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) } type Account interface { // CreateAccount makes a new account with the given login name and password, and creates an empty profile // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. - CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) - GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) - GetNewNumericLocalpart(ctx context.Context) (int64, error) - CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) - GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) - DeactivateAccount(ctx context.Context, localpart string) (err error) - SetPassword(ctx context.Context, localpart string, plaintextPassword string) error + CreateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) (*api.Account, error) + GetNewNumericLocalpart(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) + CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) + GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) + DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SetPassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, plaintextPassword string) error } type AccountData interface { - SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error - GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) + SaveAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error + GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) // GetAccountDataByType returns account data matching a given // localpart, room ID and type. // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval - GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) - QueryPushRules(ctx context.Context, localpart string) (*pushrules.AccountRuleSets, error) + GetAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) + QueryPushRules(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*pushrules.AccountRuleSets, error) } type Device interface { GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) - GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) + GetDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) + GetDevicesByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Device, error) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) // CreateDevice makes a new device associated with the given user ID localpart. // If there is already a device with the same device ID for this user, that access token will be revoked @@ -70,12 +70,12 @@ type Device interface { // an error will be returned. // If no device ID is given one is generated. // Returns the device on success. - CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) - UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error - UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error - RemoveDevices(ctx context.Context, localpart string, devices []string) error + CreateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) + UpdateDevice(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error + UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error + RemoveDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. - RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + RemoveAllDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) (devices []api.Device, err error) } type KeyBackup interface { @@ -107,26 +107,26 @@ type OpenID interface { } type Pusher interface { - UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error - GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) - RemovePusher(ctx context.Context, appid, pushkey, localpart string) error + UpsertPusher(ctx context.Context, p api.Pusher, localpart string, serverName gomatrixserverlib.ServerName) error + GetPushers(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) + RemovePusher(ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error RemovePushers(ctx context.Context, appid, pushkey string) error } type ThreePID interface { - SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error) + SaveThreePIDAssociation(ctx context.Context, threepid, localpart string, serverName gomatrixserverlib.ServerName, medium string) (err error) RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error) - GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error) - GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) + GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) + GetThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) } type Notification interface { - InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error - DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) - SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, read bool) (affected bool, err error) - GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) - GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) - GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) + InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error + DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) + SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, read bool) (affected bool, err error) + GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) + GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) + GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) DeleteOldNotifications(ctx context.Context) error } diff --git a/userapi/storage/postgres/account_data_table.go b/userapi/storage/postgres/account_data_table.go index 0b6a3af6d..2a4777d74 100644 --- a/userapi/storage/postgres/account_data_table.go +++ b/userapi/storage/postgres/account_data_table.go @@ -22,6 +22,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -29,27 +30,28 @@ const accountDataSchema = ` CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) room_id TEXT, -- The account data type type TEXT NOT NULL, -- The account data content - content TEXT NOT NULL, - - PRIMARY KEY(localpart, room_id, type) + content TEXT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type); ` const insertAccountDataSQL = ` - INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = EXCLUDED.content + INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5) + ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = EXCLUDED.content ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4" type accountDataStatements struct { insertAccountDataStmt *sql.Stmt @@ -71,21 +73,24 @@ func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { } func (s *accountDataStatements) InsertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) - _, err = stmt.ExecContext(ctx, localpart, roomID, dataType, content) + _, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content) return } func (s *accountDataStatements) SelectAccountData( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, error, ) { - rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName) if err != nil { return nil, nil, err } @@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData( } func (s *accountDataStatements) SelectAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } diff --git a/userapi/storage/postgres/accounts_table.go b/userapi/storage/postgres/accounts_table.go index 7c309eb4f..31a996527 100644 --- a/userapi/storage/postgres/accounts_table.go +++ b/userapi/storage/postgres/accounts_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/matrix-org/gomatrixserverlib" @@ -34,7 +35,8 @@ const accountsSchema = ` -- Stores data about accounts. CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this account was first created, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The password hash for this account. Can be NULL if this is a passwordless account. @@ -48,25 +50,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name); ` const insertAccountSQL = "" + - "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + - "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3" const deactivateAccountSQL = "" + - "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = TRUE WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" + "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = FALSE" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = FALSE" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$'" + "SELECT COALESCE(MAX(localpart::bigint), 0) FROM userapi_accounts WHERE localpart ~ '^[0-9]{1,}$' AND server_name = $1" type accountsStatements struct { insertAccountStmt *sql.Stmt @@ -117,59 +121,62 @@ func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerNam // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.insertAccountStmt) var err error if accountType != api.AccountTypeAppService { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType) } else { - _, err = stmt.ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = stmt.ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { - return nil, err + return nil, fmt.Errorf("insertAccountStmt: %w", err) } return &api.Account{ Localpart: localpart, - UserID: userutil.MakeUserID(localpart, s.serverName), - ServerName: s.serverName, + UserID: userutil.MakeUserID(localpart, serverName), + ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, }, nil } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart, passwordHash string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) return } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -180,19 +187,17 @@ func (s *accountsStatements) SelectAccountByLocalpart( acc.AppServiceID = appserviceIDPtr.String } - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName - + acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName) return &acc, nil } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - err = stmt.QueryRowContext(ctx).Scan(&id) + err = stmt.QueryRowContext(ctx, serverName).Scan(&id) return id + 1, err } diff --git a/userapi/storage/postgres/deltas/2022110411000000_server_names.go b/userapi/storage/postgres/deltas/2022110411000000_server_names.go new file mode 100644 index 000000000..279e1e5f1 --- /dev/null +++ b/userapi/storage/postgres/deltas/2022110411000000_server_names.go @@ -0,0 +1,81 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +var serverNamesTables = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_devices", + "userapi_notifications", + "userapi_openid_tokens", + "userapi_profiles", + "userapi_pushers", + "userapi_threepids", +} + +// These tables have a PRIMARY KEY constraint which we need to drop so +// that we can recreate a new unique index that contains the server name. +// If the new key doesn't exist (i.e. the database was created before the +// table rename migration) we'll try to drop the old one instead. +var serverNamesDropPK = map[string]string{ + "userapi_accounts": "account_accounts", + "userapi_account_datas": "account_data", + "userapi_profiles": "account_profiles", +} + +// These indices are out of date so let's drop them. They will get recreated +// automatically. +var serverNamesDropIndex = []string{ + "userapi_pusher_localpart_idx", + "userapi_pusher_app_id_pushkey_localpart_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s ADD COLUMN IF NOT EXISTS server_name TEXT NOT NULL DEFAULT '';", + pq.QuoteIdentifier(table), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("add server name to %q error: %w", table, err) + } + } + for newTable, oldTable := range serverNamesDropPK { + q := fmt.Sprintf( + "ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;", + pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(newTable+"_pkey"), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop new PK from %q error: %w", newTable, err) + } + q = fmt.Sprintf( + "ALTER TABLE IF EXISTS %s DROP CONSTRAINT IF EXISTS %s;", + pq.QuoteIdentifier(newTable), pq.QuoteIdentifier(oldTable+"_pkey"), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop old PK from %q error: %w", newTable, err) + } + } + for _, index := range serverNamesDropIndex { + q := fmt.Sprintf( + "DROP INDEX IF EXISTS %s;", + pq.QuoteIdentifier(index), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop index %q error: %w", index, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/deltas/2022110411000001_server_names.go b/userapi/storage/postgres/deltas/2022110411000001_server_names.go new file mode 100644 index 000000000..04a47fa7b --- /dev/null +++ b/userapi/storage/postgres/deltas/2022110411000001_server_names.go @@ -0,0 +1,28 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "UPDATE %s SET server_name = %s WHERE server_name = '';", + pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("write server names to %q error: %w", table, err) + } + } + return nil +} diff --git a/userapi/storage/postgres/devices_table.go b/userapi/storage/postgres/devices_table.go index 8b7fbd6cf..2dd216189 100644 --- a/userapi/storage/postgres/devices_table.go +++ b/userapi/storage/postgres/devices_table.go @@ -17,6 +17,7 @@ package postgres import ( "context" "database/sql" + "fmt" "time" "github.com/lib/pq" @@ -50,6 +51,7 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( -- as it is smaller, makes it clearer that we only manage devices for our own users, and may make -- migration to different domain names easier. localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this devices was first recognised on the network, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The display name, human friendlier than device_id and updatable @@ -65,39 +67,39 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( ); -- Device IDs must be unique for a given user. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, device_id); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_device_localpart_id_idx ON userapi_devices(localpart, server_name, device_id); ` const insertDeviceSQL = "" + - "INSERT INTO userapi_devices(device_id, localpart, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" + + "INSERT INTO userapi_devices(device_id, localpart, server_name, access_token, created_ts, display_name, last_seen_ts, ip, user_agent) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + " RETURNING session_id" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" const deleteDeviceSQL = "" + - "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3" const deleteDevicesSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id = ANY($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = ANY($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" type devicesStatements struct { insertDeviceStmt *sql.Stmt @@ -148,18 +150,19 @@ func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) InsertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt) - if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { - return nil, err + if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil { + return nil, fmt.Errorf("insertDeviceStmt: %w", err) } return &api.Device{ ID: id, - UserID: userutil.MakeUserID(localpart, s.serverName), + UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, @@ -170,38 +173,45 @@ func (s *devicesStatements) InsertDevice( // deleteDevice removes a single device by id and user localpart. func (s *devicesStatements) DeleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + _, err := stmt.ExecContext(ctx, id, localpart, serverName) return err } // deleteDevices removes a single or multiple devices by ids and user localpart. // Returns an error if the execution failed. func (s *devicesStatements) DeleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) - _, err := stmt.ExecContext(ctx, localpart, pq.Array(devices)) + _, err := stmt.ExecContext(ctx, localpart, serverName, pq.Array(devices)) return err } // deleteDevicesByLocalpart removes all devices for the // given user localpart. func (s *devicesStatements) DeleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID) return err } func (s *devicesStatements) UpdateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID) return err } @@ -210,10 +220,11 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) dev.AccessToken = accessToken } return &dev, err @@ -222,16 +233,18 @@ func (s *devicesStatements) SelectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { var dev api.Device var displayName, ip sql.NullString var lastseenTS sql.NullInt64 stmt := s.selectDeviceByIDStmt - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) + err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) if displayName.Valid { dev.DisplayName = displayName.String } @@ -254,10 +267,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName var lastseents sql.NullInt64 var displayName sql.NullString for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { @@ -266,17 +280,19 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s if lastseents.Valid { dev.LastSeenTS = lastseents.Int64 } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } func (s *devicesStatements) SelectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID) if err != nil { return devices, err @@ -307,16 +323,16 @@ func (s *devicesStatements) SelectDevicesByLocalpart( dev.UserAgent = useragent.String } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) return err } diff --git a/userapi/storage/postgres/notifications_table.go b/userapi/storage/postgres/notifications_table.go index 24a30b2f5..dc64b1e79 100644 --- a/userapi/storage/postgres/notifications_table.go +++ b/userapi/storage/postgres/notifications_table.go @@ -43,6 +43,7 @@ const notificationSchema = ` CREATE TABLE IF NOT EXISTS userapi_notifications ( id BIGSERIAL PRIMARY KEY, localpart TEXT NOT NULL, + server_name TEXT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, stream_pos BIGINT NOT NULL, @@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications ( read BOOLEAN NOT NULL DEFAULT FALSE ); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id); ` const insertNotificationSQL = "" + - "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const deleteNotificationsUpToSQL = "" + - "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4" const updateNotificationReadSQL = "" + - "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1" const selectNotificationSQL = "" + - "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + - "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + - ") AND NOT read ORDER BY localpart, id LIMIT $4" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" + + "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $5" const selectNotificationCountSQL = "" + - "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + - "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + ") AND NOT read" const selectRoomNotificationCountsSQL = "" + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + - "WHERE localpart = $1 AND room_id = $2 AND NOT read" + "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read" const cleanNotificationsSQL = "" + "DELETE FROM userapi_notifications WHERE" + @@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local if err != nil { return err } - _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs)) return err } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { return nil, 0, err @@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { - err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { - err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/postgres/openid_table.go b/userapi/storage/postgres/openid_table.go index 06ae30d08..68d87f007 100644 --- a/userapi/storage/postgres/openid_table.go +++ b/userapi/storage/postgres/openid_table.go @@ -3,6 +3,7 @@ package postgres import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When the token expires, as a unix timestamp (ms resolution). token_expires_at_ms BIGINT NOT NULL ); ` const insertOpenIDTokenSQL = "" + - "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" + "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { insertTokenStmt *sql.Stmt @@ -54,11 +56,11 @@ func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, + token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS) return } @@ -69,10 +71,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes + var localpart string + var serverName gomatrixserverlib.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, + &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, ) + openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve token from the db") diff --git a/userapi/storage/postgres/profile_table.go b/userapi/storage/postgres/profile_table.go index 2753b23d9..df4e0db63 100644 --- a/userapi/storage/postgres/profile_table.go +++ b/userapi/storage/postgres/profile_table.go @@ -23,42 +23,46 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account avatar_url TEXT ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name); ` const insertProfileSQL = "" + - "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2" const setAvatarURLSQL = "" + "UPDATE userapi_profiles AS new" + " SET avatar_url = $1" + " FROM userapi_profiles AS old" + - " WHERE new.localpart = $2" + + " WHERE new.localpart = $2 AND new.server_name = $3" + " RETURNING new.display_name, old.avatar_url <> new.avatar_url" const setDisplayNameSQL = "" + "UPDATE userapi_profiles AS new" + " SET display_name = $1" + " FROM userapi_profiles AS old" + - " WHERE new.localpart = $2" + + " WHERE new.localpart = $2 AND new.server_name = $3" + " RETURNING new.avatar_url, old.display_name <> new.display_name" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { serverNoticesLocalpart string @@ -87,18 +91,20 @@ func NewPostgresProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( + &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { return nil, err @@ -107,28 +113,34 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ - Localpart: localpart, - AvatarURL: avatarURL, + Localpart: localpart, + ServerName: string(serverName), + AvatarURL: avatarURL, } var changed bool stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err := stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName, &changed) + err := stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName, &changed) return profile, changed, err } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, + ServerName: string(serverName), DisplayName: displayName, } var changed bool stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err := stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL, &changed) + err := stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL, &changed) return profile, changed, err } @@ -146,7 +158,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/postgres/pusher_table.go b/userapi/storage/postgres/pusher_table.go index 6fb714fba..707b3bd2b 100644 --- a/userapi/storage/postgres/pusher_table.go +++ b/userapi/storage/postgres/pusher_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( id BIGSERIAL PRIMARY KEY, -- The Matrix user ID localpart for this pusher localpart TEXT NOT NULL, + server_name TEXT NOT NULL, session_id BIGINT DEFAULT NULL, profile_tag TEXT, kind TEXT NOT NULL, @@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); -- For faster retrieving by localpart. -CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name); -- Pushkey must be unique for a given user and app. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name); ` const insertPusherSQL = "" + - "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + - "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + + "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12" const selectPushersSQL = "" + - "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" const deletePusherSQL = "" + - "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4" const deletePushersByAppIdAndPushKeySQL = "" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" @@ -95,18 +97,19 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) - logrus.Debugf("Created pusher %d", session_id) + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err } func (s *pushersStatements) SelectPushers( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} - rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart) + rows, err := sqlutil.TxStmt(txn, s.selectPushersStmt).QueryContext(ctx, localpart, serverName) if err != nil { return pushers, err @@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( - ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, + ctx context.Context, txn *sql.Tx, appid, pushkey, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err } diff --git a/userapi/storage/postgres/storage.go b/userapi/storage/postgres/storage.go index c059e3e60..92dc48081 100644 --- a/userapi/storage/postgres/storage.go +++ b/userapi/storage/postgres/storage.go @@ -15,6 +15,8 @@ package postgres import ( + "context" + "database/sql" "fmt" "time" @@ -43,18 +45,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + }) if err = m.Up(base.Context()); err != nil { return nil, err } - accountDataTable, err := NewPostgresAccountDataTable(db) - if err != nil { - return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) - } accountsTable, err := NewPostgresAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err) } + accountDataTable, err := NewPostgresAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err) + } devicesTable, err := NewPostgresDevicesTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err) @@ -95,6 +103,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, fmt.Errorf("NewPostgresStatsTable: %w", err) } + + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names populate", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNamesPopulate(ctx, txn, serverName) + }, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, diff --git a/userapi/storage/postgres/threepid_table.go b/userapi/storage/postgres/threepid_table.go index 11af76161..f41c43122 100644 --- a/userapi/storage/postgres/threepid_table.go +++ b/userapi/storage/postgres/threepid_table.go @@ -20,6 +20,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -33,21 +34,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids ( medium TEXT NOT NULL DEFAULT 'email', -- The localpart of the Matrix user ID associated to this 3PID localpart TEXT NOT NULL, + server_name TEXT NOT NULL, PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart); +CREATE INDEX IF NOT EXISTS userapi_threepid_idx ON userapi_threepids(localpart, server_name); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" + "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2" const insertThreePIDSQL = "" + - "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)" const deleteThreePIDSQL = "" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" @@ -75,19 +77,20 @@ func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { - return "", nil + return "", "", nil } return } func (s *threepidStatements) SelectThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { return } @@ -109,10 +112,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( } func (s *threepidStatements) InsertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, txn *sql.Tx, threepid, medium, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) return } diff --git a/userapi/storage/shared/storage.go b/userapi/storage/shared/storage.go index f8b8d02c9..f549dcef9 100644 --- a/userapi/storage/shared/storage.go +++ b/userapi/storage/shared/storage.go @@ -68,9 +68,10 @@ const ( // GetAccountByPassword returns the account associated with the given localpart and password. // Returns sql.ErrNoRows if no account exists which matches the given localpart. func (d *Database) GetAccountByPassword( - ctx context.Context, localpart, plaintextPassword string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword string, ) (*api.Account, error) { - hash, err := d.Accounts.SelectPasswordHash(ctx, localpart) + hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName) if err != nil { return nil, err } @@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword( if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { return nil, err } - return d.Accounts.SelectAccountByLocalpart(ctx, localpart) + return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) } // GetProfileByLocalpart returns the profile associated with the given localpart. // Returns sql.ErrNoRows if no profile exists which matches the given localpart. func (d *Database) GetProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { - return d.Profiles.SelectProfileByLocalpart(ctx, localpart) + return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName) } // SetAvatarURL updates the avatar URL of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetAvatarURL( - ctx context.Context, localpart string, avatarURL string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL) + profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL) return err }) return @@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL( // SetDisplayName updates the display name of the profile associated with the given // localpart. Returns an error if something went wrong with the SQL query func (d *Database) SetDisplayName( - ctx context.Context, localpart string, displayName string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (profile *authtypes.Profile, changed bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName) + profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName) return err }) return @@ -117,14 +123,15 @@ func (d *Database) SetDisplayName( // SetPassword sets the account password to the given hash. func (d *Database) SetPassword( - ctx context.Context, localpart, plaintextPassword string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword string, ) error { hash, err := d.hashPassword(plaintextPassword) if err != nil { return err } return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.Accounts.UpdatePassword(ctx, localpart, hash) + return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash) }) } @@ -132,21 +139,22 @@ func (d *Database) SetPassword( // for this account. If no password is supplied, the account will be a passwordless account. If the // account already exists, it will return nil, ErrUserExists. func (d *Database) CreateAccount( - ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword, appserviceID string, accountType api.AccountType, ) (acc *api.Account, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { // For guest accounts, we create a new numeric local part if accountType == api.AccountTypeGuest { var numLocalpart int64 - numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn) + numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName) if err != nil { - return err + return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err) } localpart = strconv.FormatInt(numLocalpart, 10) plaintextPassword = "" appserviceID = "" } - acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType) + acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType) return err }) return @@ -155,7 +163,9 @@ func (d *Database) CreateAccount( // WARNING! This function assumes that the relevant mutexes have already // been taken out by the caller (e.g. CreateAccount or CreateGuestAccount). func (d *Database) createAccount( - ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + plaintextPassword, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { var err error var account *api.Account @@ -167,28 +177,28 @@ func (d *Database) createAccount( return nil, err } } - if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil { + if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil { return nil, sqlutil.ErrUserExists } - if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil { - return nil, err + if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil { + return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err) } - pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) prbs, err := json.Marshal(pushRuleSets) if err != nil { - return nil, err + return nil, fmt.Errorf("json.Marshal: %w", err) } - if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil { - return nil, err + if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil { + return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err) } return account, nil } func (d *Database) QueryPushRules( ctx context.Context, - localpart string, + localpart string, serverName gomatrixserverlib.ServerName, ) (*pushrules.AccountRuleSets, error) { - data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules") + data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules") if err != nil { return nil, err } @@ -196,13 +206,13 @@ func (d *Database) QueryPushRules( // If we didn't find any default push rules then we should just generate some // fresh ones. if len(data) == 0 { - pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName) + pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName) prbs, err := json.Marshal(pushRuleSets) if err != nil { return nil, fmt.Errorf("failed to marshal default push rules: %w", err) } err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil { + if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil { return fmt.Errorf("failed to save default push rules: %w", dbErr) } return nil @@ -225,22 +235,23 @@ func (d *Database) QueryPushRules( // update the corresponding row with the new content // Returns a SQL error if there was an issue with the insertion/update func (d *Database) SaveAccountData( - ctx context.Context, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content) + return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content) }) } // GetAccountData returns account data related to a given localpart // If no account data could be found, returns an empty arrays // Returns an error if there was an issue with the retrieval -func (d *Database) GetAccountData(ctx context.Context, localpart string) ( +func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) ( global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error, ) { - return d.AccountDatas.SelectAccountData(ctx, localpart) + return d.AccountDatas.SelectAccountData(ctx, localpart, serverName) } // GetAccountDataByType returns account data matching a given @@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) ( // If no account data could be found, returns nil // Returns an error if there was an issue with the retrieval func (d *Database) GetAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { return d.AccountDatas.SelectAccountDataByType( - ctx, localpart, roomID, dataType, + ctx, localpart, serverName, roomID, dataType, ) } // GetNewNumericLocalpart generates and returns a new unused numeric localpart func (d *Database) GetNewNumericLocalpart( - ctx context.Context, + ctx context.Context, serverName gomatrixserverlib.ServerName, ) (int64, error) { - return d.Accounts.SelectNewNumericLocalpart(ctx, nil) + return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName) } func (d *Database) hashPassword(plaintext string) (hash string, err error) { @@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use") // If the third-party identifier is already part of an association, returns Err3PIDInUse. // Returns an error if there was a problem talking to the database. func (d *Database) SaveThreePIDAssociation( - ctx context.Context, threepid, localpart, medium string, + ctx context.Context, threepid string, + localpart string, serverName gomatrixserverlib.ServerName, + medium string, ) (err error) { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - user, err := d.ThreePIDs.SelectLocalpartForThreePID( + user, _, err := d.ThreePIDs.SelectLocalpartForThreePID( ctx, txn, threepid, medium, ) if err != nil { @@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation( return Err3PIDInUse } - return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart) + return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName) }) } @@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation( // Returns an error if there was a problem talking to the database. func (d *Database) GetLocalpartForThreePID( ctx context.Context, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium) } @@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID( // If no association is known for this user, returns an empty slice. // Returns an error if there was an issue talking to the database. func (d *Database) GetThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart) + return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName) } // CheckAccountAvailability checks if the username/localpart is already present // in the database. // If the DB returns sql.ErrNoRows the Localpart isn't taken. -func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) { - _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart) +func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) { + _, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) if err == sql.ErrNoRows { return true, nil } @@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin // GetAccountByLocalpart returns the account associated with the given localpart. // This function assumes the request is authenticated or the account data is used only internally. // Returns sql.ErrNoRows if no account exists which matches the given localpart. -func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, +func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { // try to get the account with lowercase localpart (majority) - acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart)) + acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName) if err == sql.ErrNoRows { - acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request + acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request } return acc, err } @@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi } // DeactivateAccount deactivates the user's account, removing all ability for the user to login again. -func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) { +func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) { return d.Writer.Do(nil, nil, func(txn *sql.Tx) error { - return d.Accounts.DeactivateAccount(ctx, localpart) + return d.Accounts.DeactivateAccount(ctx, localpart, serverName) }) } // CreateOpenIDToken persists a new token that was issued for OpenID Connect func (d *Database) CreateOpenIDToken( ctx context.Context, - token, localpart string, + token, userID string, ) (int64, error) { + localpart, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + return 0, nil + } expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS - err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS) + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS) }) return expiresAtMS, err } @@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken( // GetDeviceByID returns the device matching the given ID. // Returns sql.ErrNoRows if no matching device was found. func (d *Database) GetDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { - return d.Devices.SelectDeviceByID(ctx, localpart, deviceID) + return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID) } // GetDevicesByLocalpart returns the devices matching the given localpart. func (d *Database) GetDevicesByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Device, error) { - return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "") + return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "") } func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { @@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap // If no device ID is given one is generated. // Returns the device on success. func (d *Database) CreateDevice( - ctx context.Context, localpart string, deviceID *string, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string, ) (dev *api.Device, returnErr error) { if deviceID != nil { returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error // Revoke existing tokens for this device - if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil { + if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil { return err } - dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent) + dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) return err }) } else { @@ -588,7 +610,7 @@ func (d *Database) CreateDevice( returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error - dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent) + dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent) return err }) if returnErr == nil { @@ -614,10 +636,12 @@ func generateDeviceID() (string, error) { // UpdateDevice updates the given device with the display name. // Returns SQL error if there are problems and nil on success. func (d *Database) UpdateDevice( - ctx context.Context, localpart, deviceID string, displayName *string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName) + return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName) }) } @@ -626,10 +650,12 @@ func (d *Database) UpdateDevice( // If the devices don't exist, it will not return an error // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveDevices( - ctx context.Context, localpart string, devices []string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows { + if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows { return err } return nil @@ -640,14 +666,16 @@ func (d *Database) RemoveDevices( // database matching the given user ID localpart. // If something went wrong during the deletion, it will return the SQL error. func (d *Database) RemoveAllDevices( - ctx context.Context, localpart, exceptDeviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) (devices []api.Device, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID) + devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID) if err != nil { return err } - if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows { + if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows { return err } return nil @@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices( } // UpdateDeviceLastSeen updates a last seen timestamp and the ip address. -func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error { +func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent) + return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent) }) } @@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) ( return d.LoginTokens.SelectLoginToken(ctx, token) } -func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { +func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) + return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n) }) } -func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) { +func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos) + affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos) return err }) return } -func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) { +func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b) + affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b) return err }) return } -func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter) +func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter) } -func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) { - return d.Notifications.SelectCount(ctx, nil, localpart, filter) +func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) { + return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter) } -func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) { - return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID) +func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) { + return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID) } func (d *Database) DeleteOldNotifications(ctx context.Context) error { @@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error { } func (d *Database) UpsertPusher( - ctx context.Context, p api.Pusher, localpart string, + ctx context.Context, p api.Pusher, + localpart string, serverName gomatrixserverlib.ServerName, ) error { data, err := json.Marshal(p.Data) if err != nil { @@ -766,25 +795,26 @@ func (d *Database) UpsertPusher( p.ProfileTag, p.Language, string(data), - localpart) + localpart, + serverName) }) } // GetPushers returns the pushers matching the given localpart. func (d *Database) GetPushers( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { - return d.Pushers.SelectPushers(ctx, nil, localpart) + return d.Pushers.SelectPushers(ctx, nil, localpart, serverName) } // RemovePusher deletes one pusher // Invoked when `append` is true and `kind` is null in // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set func (d *Database) RemovePusher( - ctx context.Context, appid, pushkey, localpart string, + ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart) + err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName) if err == sql.ErrNoRows { return nil } diff --git a/userapi/storage/sqlite3/account_data_table.go b/userapi/storage/sqlite3/account_data_table.go index af12decb3..2fbdc5732 100644 --- a/userapi/storage/sqlite3/account_data_table.go +++ b/userapi/storage/sqlite3/account_data_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const accountDataSchema = ` @@ -28,27 +29,28 @@ const accountDataSchema = ` CREATE TABLE IF NOT EXISTS userapi_account_datas ( -- The Matrix user ID localpart for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The room ID for this data (empty string if not specific to a room) room_id TEXT, -- The account data type type TEXT NOT NULL, -- The account data content - content TEXT NOT NULL, - - PRIMARY KEY(localpart, room_id, type) + content TEXT NOT NULL ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type); ` const insertAccountDataSQL = ` - INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4) - ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4 + INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5) + ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5 ` const selectAccountDataSQL = "" + - "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1" + "SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2" const selectAccountDataByTypeSQL = "" + - "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3" + "SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4" type accountDataStatements struct { db *sql.DB @@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) { } func (s *accountDataStatements) InsertAccountData( - ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, content json.RawMessage, ) error { - _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) + _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content) return err } func (s *accountDataStatements) SelectAccountData( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) ( /* global */ map[string]json.RawMessage, /* rooms */ map[string]map[string]json.RawMessage, error, ) { - rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart) + rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName) if err != nil { return nil, nil, err } @@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData( } func (s *accountDataStatements) SelectAccountDataByType( - ctx context.Context, localpart, roomID, dataType string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + roomID, dataType string, ) (data json.RawMessage, err error) { var bytes []byte stmt := s.selectAccountDataByTypeStmt - if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil { + if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil { if err == sql.ErrNoRows { return nil, nil } diff --git a/userapi/storage/sqlite3/accounts_table.go b/userapi/storage/sqlite3/accounts_table.go index 671c1aa04..f4ebe2158 100644 --- a/userapi/storage/sqlite3/accounts_table.go +++ b/userapi/storage/sqlite3/accounts_table.go @@ -34,7 +34,8 @@ const accountsSchema = ` -- Stores data about accounts. CREATE TABLE IF NOT EXISTS userapi_accounts ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When this account was first created, as a unix timestamp (ms resolution). created_ts BIGINT NOT NULL, -- The password hash for this account. Can be NULL if this is a passwordless account. @@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts ( -- TODO: -- upgraded_ts, devices, any email reset stuff? ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name); ` const insertAccountSQL = "" + - "INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)" + "INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)" const updatePasswordSQL = "" + - "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2" + "UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3" const deactivateAccountSQL = "" + - "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1" + "UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2" const selectAccountByLocalpartSQL = "" + - "SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1" + "SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2" const selectPasswordHashSQL = "" + - "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0" + "SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0" const selectNewNumericLocalpartSQL = "" + - "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0" + "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1" type accountsStatements struct { db *sql.DB @@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // this account will be passwordless. Returns an error if this account already exists. Returns the account // on success. func (s *accountsStatements) InsertAccount( - ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, + ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, + hash, appserviceID string, accountType api.AccountType, ) (*api.Account, error) { createdTimeMS := time.Now().UnixNano() / 1000000 stmt := s.insertAccountStmt var err error if accountType != api.AccountTypeAppService { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType) } else { - _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType) + _, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType) } if err != nil { return nil, err @@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount( return &api.Account{ Localpart: localpart, - UserID: userutil.MakeUserID(localpart, s.serverName), - ServerName: s.serverName, + UserID: userutil.MakeUserID(localpart, serverName), + ServerName: serverName, AppServiceID: appserviceID, AccountType: accountType, }, nil } func (s *accountsStatements) UpdatePassword( - ctx context.Context, localpart, passwordHash string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, + passwordHash string, ) (err error) { - _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) + _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName) return } func (s *accountsStatements) DeactivateAccount( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { - _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) + _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName) return } func (s *accountsStatements) SelectPasswordHash( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (hash string, err error) { - err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) + err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash) return } func (s *accountsStatements) SelectAccountByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, ) (*api.Account, error) { var appserviceIDPtr sql.NullString var acc api.Account stmt := s.selectAccountByLocalpartStmt - err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType) + err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve user from the db") @@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart( acc.AppServiceID = appserviceIDPtr.String } - acc.UserID = userutil.MakeUserID(localpart, s.serverName) - acc.ServerName = s.serverName - + acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName) return &acc, nil } func (s *accountsStatements) SelectNewNumericLocalpart( - ctx context.Context, txn *sql.Tx, + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, ) (id int64, err error) { stmt := s.selectNewNumericLocalpartStmt if txn != nil { stmt = sqlutil.TxStmt(txn, stmt) } - err = stmt.QueryRowContext(ctx).Scan(&id) + err = stmt.QueryRowContext(ctx, serverName).Scan(&id) if err == sql.ErrNoRows { return 1, nil } diff --git a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go index 9158cb365..2de85005f 100644 --- a/userapi/storage/sqlite3/deltas/20200929203058_is_active.go +++ b/userapi/storage/sqlite3/deltas/20200929203058_is_active.go @@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error { ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go index a9224db6b..636ce4efc 100644 --- a/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go +++ b/userapi/storage/sqlite3/deltas/20201001204705_last_seen_ts_ip.go @@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error { session_id INTEGER, device_id TEXT , localpart TEXT , + server_name TEXT NOT NULL, created_ts BIGINT, display_name TEXT, last_seen_ts BIGINT, diff --git a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go index 230bc1433..471e496cd 100644 --- a/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go +++ b/userapi/storage/sqlite3/deltas/2022021012490600_add_account_type.go @@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error { _, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp; CREATE TABLE userapi_accounts ( localpart TEXT NOT NULL PRIMARY KEY, + server_name TEXT NOT NULL, created_ts BIGINT NOT NULL, password_hash TEXT, appservice_id TEXT, diff --git a/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go new file mode 100644 index 000000000..c11ea6844 --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022110411000000_server_names.go @@ -0,0 +1,108 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +var serverNamesTables = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_devices", + "userapi_notifications", + "userapi_openid_tokens", + "userapi_profiles", + "userapi_pushers", + "userapi_threepids", +} + +// These tables have a PRIMARY KEY constraint which we need to drop so +// that we can recreate a new unique index that contains the server name. +var serverNamesDropPK = []string{ + "userapi_accounts", + "userapi_account_datas", + "userapi_profiles", +} + +// These indices are out of date so let's drop them. They will get recreated +// automatically. +var serverNamesDropIndex = []string{ + "userapi_pusher_localpart_idx", + "userapi_pusher_app_id_pushkey_localpart_idx", +} + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 { + continue + } + q = fmt.Sprintf( + "SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'", + pq.QuoteIdentifier(table), + ) + if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 { + logrus.Infof("Table %s already has column, skipping", table) + continue + } + if c == 0 { + q = fmt.Sprintf( + "ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';", + pq.QuoteIdentifier(table), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("add server name to %q error: %w", table, err) + } + } + } + for _, table := range serverNamesDropPK { + q := fmt.Sprintf( + "SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;", + pq.QuoteIdentifier(table), + ) + var c int + var sql string + if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 { + continue + } + q = fmt.Sprintf(` + %s; -- create temporary table + INSERT INTO %s SELECT * FROM %s; -- copy data + DROP TABLE %s; -- drop original table + ALTER TABLE %s RENAME TO %s; -- rename new table + `, + strings.Replace(sql, table, table+"_tmp", 1), // create temporary table + table+"_tmp", table, // copy data + table, // drop original table + table+"_tmp", table, // rename new table + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop PK from %q error: %w", table, err) + } + } + for _, index := range serverNamesDropIndex { + q := fmt.Sprintf( + "DROP INDEX IF EXISTS %s;", + pq.QuoteIdentifier(index), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("drop index %q error: %w", index, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go new file mode 100644 index 000000000..04a47fa7b --- /dev/null +++ b/userapi/storage/sqlite3/deltas/2022110411000001_server_names.go @@ -0,0 +1,28 @@ +package deltas + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" +) + +// I know what you're thinking: you're wondering "why doesn't this use $1 +// and pass variadic parameters to ExecContext?" — the answer is because +// PostgreSQL doesn't expect the table name to be specified as a substituted +// argument in that way so it results in a syntax error in the query. + +func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error { + for _, table := range serverNamesTables { + q := fmt.Sprintf( + "UPDATE %s SET server_name = %s WHERE server_name = '';", + pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)), + ) + if _, err := tx.ExecContext(ctx, q); err != nil { + return fmt.Errorf("write server names to %q error: %w", table, err) + } + } + return nil +} diff --git a/userapi/storage/sqlite3/devices_table.go b/userapi/storage/sqlite3/devices_table.go index e53a08062..c5db34bd7 100644 --- a/userapi/storage/sqlite3/devices_table.go +++ b/userapi/storage/sqlite3/devices_table.go @@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices ( session_id INTEGER, device_id TEXT , localpart TEXT , + server_name TEXT NOT NULL, created_ts BIGINT, display_name TEXT, last_seen_ts BIGINT, ip TEXT, user_agent TEXT, - UNIQUE (localpart, device_id) + UNIQUE (localpart, server_name, device_id) ); ` const insertDeviceSQL = "" + - "INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + - " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)" + "INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" + + " VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)" const selectDevicesCountSQL = "" + "SELECT COUNT(access_token) FROM userapi_devices" const selectDeviceByTokenSQL = "" + - "SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1" + "SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1" const selectDeviceByIDSQL = "" + - "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2" + "SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3" const selectDevicesByLocalpartSQL = "" + - "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC" + "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC" const updateDeviceNameSQL = "" + - "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" + "UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4" const deleteDeviceSQL = "" + - "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2" + "DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3" const deleteDevicesByLocalpartSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3" const deleteDevicesSQL = "" + - "DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)" + "DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)" const selectDevicesByIDSQL = "" + - "SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" + "SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC" const updateDeviceLastSeen = "" + - "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5" + "UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6" type devicesStatements struct { db *sql.DB @@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) // Returns an error if the user already has a device with the given device ID. // Returns the device on success. func (s *devicesStatements) InsertDevice( - ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, - displayName *string, ipAddr, userAgent string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, + accessToken string, displayName *string, ipAddr, userAgent string, ) (*api.Device, error) { createdTimeMS := time.Now().UnixNano() / 1000000 var sessionID int64 @@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice( return nil, err } sessionID++ - if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { + if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil { return nil, err } return &api.Device{ ID: id, - UserID: userutil.MakeUserID(localpart, s.serverName), + UserID: userutil.MakeUserID(localpart, serverName), AccessToken: accessToken, SessionID: sessionID, LastSeenTS: createdTimeMS, @@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice( } func (s *devicesStatements) DeleteDevice( - ctx context.Context, txn *sql.Tx, id, localpart string, + ctx context.Context, txn *sql.Tx, id string, + localpart string, serverName gomatrixserverlib.ServerName, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) - _, err := stmt.ExecContext(ctx, id, localpart) + _, err := stmt.ExecContext(ctx, id, localpart, serverName) return err } func (s *devicesStatements) DeleteDevices( - ctx context.Context, txn *sql.Tx, localpart string, devices []string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + devices []string, ) error { - orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) + orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1) prep, err := s.db.Prepare(orig) if err != nil { return err } stmt := sqlutil.TxStmt(txn, prep) - params := make([]interface{}, len(devices)+1) + params := make([]interface{}, len(devices)+2) params[0] = localpart + params[1] = serverName for i, v := range devices { - params[i+1] = v + params[i+2] = v } _, err = stmt.ExecContext(ctx, params...) return err } func (s *devicesStatements) DeleteDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) error { stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) - _, err := stmt.ExecContext(ctx, localpart, exceptDeviceID) + _, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID) return err } func (s *devicesStatements) UpdateDeviceName( - ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, displayName *string, ) error { stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) - _, err := stmt.ExecContext(ctx, displayName, localpart, deviceID) + _, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID) return err } @@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken( ) (*api.Device, error) { var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName stmt := s.selectDeviceByTokenStmt - err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart) + err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName) if err == nil { - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) dev.AccessToken = accessToken } return &dev, err @@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken( // selectDeviceByID retrieves a device from the database with the given user // localpart and deviceID func (s *devicesStatements) SelectDeviceByID( - ctx context.Context, localpart, deviceID string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, + deviceID string, ) (*api.Device, error) { var dev api.Device var displayName, ip sql.NullString stmt := s.selectDeviceByIDStmt var lastseenTS sql.NullInt64 - err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip) + err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip) if err == nil { dev.ID = deviceID - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) if displayName.Valid { dev.DisplayName = displayName.String } @@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID( } func (s *devicesStatements) SelectDevicesByLocalpart( - ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + exceptDeviceID string, ) ([]api.Device, error) { devices := []api.Device{} - rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID) + rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID) if err != nil { return devices, err @@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart( dev.UserAgent = useragent.String } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } @@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s var devices []api.Device var dev api.Device var localpart string + var serverName gomatrixserverlib.ServerName var displayName sql.NullString var lastseents sql.NullInt64 for rows.Next() { - if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil { + if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil { return nil, err } if displayName.Valid { @@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s if lastseents.Valid { dev.LastSeenTS = lastseents.Int64 } - dev.UserID = userutil.MakeUserID(localpart, s.serverName) + dev.UserID = userutil.MakeUserID(localpart, serverName) devices = append(devices, dev) } return devices, rows.Err() } -func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error { +func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error { lastSeenTs := time.Now().UnixNano() / 1000000 stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) - _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID) + _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID) return err } diff --git a/userapi/storage/sqlite3/notifications_table.go b/userapi/storage/sqlite3/notifications_table.go index a35ec7be5..ef39d027c 100644 --- a/userapi/storage/sqlite3/notifications_table.go +++ b/userapi/storage/sqlite3/notifications_table.go @@ -43,6 +43,7 @@ const notificationSchema = ` CREATE TABLE IF NOT EXISTS userapi_notifications ( id INTEGER PRIMARY KEY AUTOINCREMENT, localpart TEXT NOT NULL, + server_name TEXT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL, stream_pos BIGINT NOT NULL, @@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications ( read BOOLEAN NOT NULL DEFAULT FALSE ); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id); -CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id); +CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id); ` const insertNotificationSQL = "" + - "INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)" + "INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)" const deleteNotificationsUpToSQL = "" + - "DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3" + "DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4" const updateNotificationReadSQL = "" + - "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1" + "UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1" const selectNotificationSQL = "" + - "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" + - "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + - ") AND NOT read ORDER BY localpart, id LIMIT $4" + "SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" + + "(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" + + ") AND NOT read ORDER BY localpart, id LIMIT $5" const selectNotificationCountSQL = "" + - "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" + - "(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" + + "SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" + + "(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" + ") AND NOT read" const selectRoomNotificationCountsSQL = "" + "SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " + - "WHERE localpart = $1 AND room_id = $2 AND NOT read" + "WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read" const cleanNotificationsSQL = "" + "DELETE FROM userapi_notifications WHERE" + @@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error } // Insert inserts a notification into the database. -func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error { +func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error { roomID, tsMS := n.RoomID, n.TS nn := *n // Clears out fields that have their own columns to (1) shrink the @@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local if err != nil { return err } - _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs)) + _, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs)) return err } // DeleteUpTo deletes all previous notifications, up to and including the event. -func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos) +func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l } // UpdateRead updates the "read" value for an event. -func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) { - res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos) +func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) { + res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos) if err != nil { return false, err } @@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l return nrows > 0, nil } -func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { - rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit) +func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) { + rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit) if err != nil { return nil, 0, err @@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local return notifs, maxID, rows.Err() } -func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) { - err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count) +func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) { + err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count) return } -func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) { - err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight) +func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) { + err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight) return } diff --git a/userapi/storage/sqlite3/openid_table.go b/userapi/storage/sqlite3/openid_table.go index 875f1a9a5..f06429741 100644 --- a/userapi/storage/sqlite3/openid_table.go +++ b/userapi/storage/sqlite3/openid_table.go @@ -3,6 +3,7 @@ package sqlite3 import ( "context" "database/sql" + "fmt" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" @@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens ( token TEXT NOT NULL PRIMARY KEY, -- The Matrix user ID for this account localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- When the token expires, as a unix timestamp (ms resolution). token_expires_at_ms BIGINT NOT NULL ); ` const insertOpenIDTokenSQL = "" + - "INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" + "INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)" const selectOpenIDTokenSQL = "" + - "SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" + "SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1" type openIDTokenStatements struct { db *sql.DB @@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) ( func (s *openIDTokenStatements) InsertOpenIDToken( ctx context.Context, txn *sql.Tx, - token, localpart string, + token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertTokenStmt) - _, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS) + _, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS) return } @@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes( token string, ) (*api.OpenIDTokenAttributes, error) { var openIDTokenAttrs api.OpenIDTokenAttributes + var localpart string + var serverName gomatrixserverlib.ServerName err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan( - &openIDTokenAttrs.UserID, + &localpart, &serverName, &openIDTokenAttrs.ExpiresAtMS, ) + openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName) if err != nil { if err != sql.ErrNoRows { log.WithError(err).Error("Unable to retrieve token from the db") diff --git a/userapi/storage/sqlite3/profile_table.go b/userapi/storage/sqlite3/profile_table.go index b6130a1e3..867026d7a 100644 --- a/userapi/storage/sqlite3/profile_table.go +++ b/userapi/storage/sqlite3/profile_table.go @@ -23,36 +23,40 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) const profilesSchema = ` -- Stores data about accounts profiles. CREATE TABLE IF NOT EXISTS userapi_profiles ( -- The Matrix user ID localpart for this account - localpart TEXT NOT NULL PRIMARY KEY, + localpart TEXT NOT NULL, + server_name TEXT NOT NULL, -- The display name for this account display_name TEXT, -- The URL of the avatar for this account avatar_url TEXT ); + +CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name); ` const insertProfileSQL = "" + - "INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)" + "INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)" const selectProfileByLocalpartSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2" const setAvatarURLSQL = "" + - "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING display_name" const setDisplayNameSQL = "" + - "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" + + "UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" + " RETURNING avatar_url" const selectProfilesBySearchSQL = "" + - "SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" + "SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2" type profilesStatements struct { db *sql.DB @@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P } func (s *profilesStatements) InsertProfile( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") + _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "") return err } func (s *profilesStatements) SelectProfileByLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (*authtypes.Profile, error) { var profile authtypes.Profile - err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan( - &profile.Localpart, &profile.DisplayName, &profile.AvatarURL, + err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan( + &profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL, ) if err != nil { return nil, err @@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart( } func (s *profilesStatements) SetAvatarURL( - ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + avatarURL string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ - Localpart: localpart, - AvatarURL: avatarURL, + Localpart: localpart, + ServerName: string(serverName), + AvatarURL: avatarURL, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) - err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName) + err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName) return profile, true, err } func (s *profilesStatements) SetDisplayName( - ctx context.Context, txn *sql.Tx, localpart string, displayName string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, + displayName string, ) (*authtypes.Profile, bool, error) { profile := &authtypes.Profile{ Localpart: localpart, + ServerName: string(serverName), DisplayName: displayName, } - old, err := s.SelectProfileByLocalpart(ctx, localpart) + old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName) if err != nil { return old, false, err } @@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName( return old, false, nil } stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) - err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL) + err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL) return profile, true, err } @@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch( defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed") for rows.Next() { var profile authtypes.Profile - if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil { + if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil { return nil, err } if profile.Localpart != s.serverNoticesLocalpart { diff --git a/userapi/storage/sqlite3/pusher_table.go b/userapi/storage/sqlite3/pusher_table.go index 4de0a9f06..c9d451dc5 100644 --- a/userapi/storage/sqlite3/pusher_table.go +++ b/userapi/storage/sqlite3/pusher_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) // See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers @@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( id INTEGER PRIMARY KEY AUTOINCREMENT, -- The Matrix user ID localpart for this pusher localpart TEXT NOT NULL, + server_name TEXT NOT NULL, session_id BIGINT DEFAULT NULL, profile_tag TEXT, kind TEXT NOT NULL, @@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers ( CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey); -- For faster retrieving by localpart. -CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart); +CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name); -- Pushkey must be unique for a given user and app. -CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart); +CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name); ` const insertPusherSQL = "" + - "INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + - "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" + - "ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11" + "INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" + + "VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" + + "ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12" const selectPushersSQL = "" + - "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1" + "SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2" const deletePusherSQL = "" + - "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4" const deletePushersByAppIdAndPushKeySQL = "" + "DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2" @@ -95,18 +97,19 @@ type pushersStatements struct { // Returns nil error success. func (s *pushersStatements) InsertPusher( ctx context.Context, txn *sql.Tx, session_id int64, - pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, + pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) - logrus.Debugf("Created pusher %d", session_id) + _, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data) return err } func (s *pushersStatements) SelectPushers( - ctx context.Context, txn *sql.Tx, localpart string, + ctx context.Context, txn *sql.Tx, + localpart string, serverName gomatrixserverlib.ServerName, ) ([]api.Pusher, error) { pushers := []api.Pusher{} - rows, err := s.selectPushersStmt.QueryContext(ctx, localpart) + rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName) if err != nil { return pushers, err @@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers( // deletePusher removes a single pusher by pushkey and user localpart. func (s *pushersStatements) DeletePusher( - ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, + ctx context.Context, txn *sql.Tx, appid, pushkey, + localpart string, serverName gomatrixserverlib.ServerName, ) error { - _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart) + _, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName) return err } diff --git a/userapi/storage/sqlite3/storage.go b/userapi/storage/sqlite3/storage.go index dd33dc0cf..85a1f7063 100644 --- a/userapi/storage/sqlite3/storage.go +++ b/userapi/storage/sqlite3/storage.go @@ -15,6 +15,8 @@ package sqlite3 import ( + "context" + "database/sql" "fmt" "time" @@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, Up: deltas.UpRenameTables, Down: deltas.DownRenameTables, }) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNames(ctx, txn, serverName) + }, + }) if err = m.Up(base.Context()); err != nil { return nil, err } - accountDataTable, err := NewSQLiteAccountDataTable(db) - if err != nil { - return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) - } accountsTable, err := NewSQLiteAccountsTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err) } + accountDataTable, err := NewSQLiteAccountDataTable(db) + if err != nil { + return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err) + } devicesTable, err := NewSQLiteDevicesTable(db, serverName) if err != nil { return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err) @@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err) } + + m = sqlutil.NewMigrator(db) + m.AddMigrations(sqlutil.Migration{ + Version: "userapi: server names populate", + Up: func(ctx context.Context, txn *sql.Tx) error { + return deltas.UpServerNamesPopulate(ctx, txn, serverName) + }, + }) + if err = m.Up(base.Context()); err != nil { + return nil, err + } + return &shared.Database{ AccountDatas: accountDataTable, Accounts: accountsTable, diff --git a/userapi/storage/sqlite3/threepid_table.go b/userapi/storage/sqlite3/threepid_table.go index 73af139db..2db7d5887 100644 --- a/userapi/storage/sqlite3/threepid_table.go +++ b/userapi/storage/sqlite3/threepid_table.go @@ -21,6 +21,7 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" ) @@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids ( medium TEXT NOT NULL DEFAULT 'email', -- The localpart of the Matrix user ID associated to this 3PID localpart TEXT NOT NULL, + server_name TEXT NOT NULL, PRIMARY KEY(threepid, medium) ); -CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart); +CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name); ` const selectLocalpartForThreePIDSQL = "" + - "SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2" + "SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2" const selectThreePIDsForLocalpartSQL = "" + - "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1" + "SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2" const insertThreePIDSQL = "" + - "INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)" + "INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)" const deleteThreePIDSQL = "" + "DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2" @@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) { func (s *threepidStatements) SelectLocalpartForThreePID( ctx context.Context, txn *sql.Tx, threepid string, medium string, -) (localpart string, err error) { +) (localpart string, serverName gomatrixserverlib.ServerName, err error) { stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) - err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart) + err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName) if err == sql.ErrNoRows { - return "", nil + return "", "", nil } return } func (s *threepidStatements) SelectThreePIDsForLocalpart( - ctx context.Context, localpart string, + ctx context.Context, + localpart string, serverName gomatrixserverlib.ServerName, ) (threepids []authtypes.ThreePID, err error) { - rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) + rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName) if err != nil { return } @@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart( } func (s *threepidStatements) InsertThreePID( - ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, + ctx context.Context, txn *sql.Tx, threepid, medium, + localpart string, serverName gomatrixserverlib.ServerName, ) (err error) { stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) - _, err = stmt.ExecContext(ctx, threepid, medium, localpart) + _, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName) return err } diff --git a/userapi/storage/storage_test.go b/userapi/storage/storage_test.go index 354f085fc..23aafff03 100644 --- a/userapi/storage/storage_test.go +++ b/userapi/storage/storage_test.go @@ -50,25 +50,25 @@ func Test_AccountData(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() alice := test.NewUser(t) - localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) events := room.Events() contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID())) - err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom) + err = db.SaveAccountData(ctx, localpart, domain, room.ID, "m.fully_read", contentRoom) assert.NoError(t, err, "unable to save account data") contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID)) - err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal) + err = db.SaveAccountData(ctx, localpart, domain, "", "im.vector.setting.breadcrumbs", contentGlobal) assert.NoError(t, err, "unable to save account data") - accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read") + accountData, err := db.GetAccountDataByType(ctx, localpart, domain, room.ID, "m.fully_read") assert.NoError(t, err, "unable to get account data by type") assert.Equal(t, contentRoom, accountData) - globalData, roomData, err := db.GetAccountData(ctx, localpart) + globalData, roomData, err := db.GetAccountData(ctx, localpart, domain) assert.NoError(t, err) assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"]) assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"]) @@ -81,78 +81,78 @@ func Test_Accounts(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) - accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + accAlice, err := db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") // verify the newly create account is the same as returned by CreateAccount var accGet *api.Account - accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing") + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "testing") assert.NoError(t, err, "failed to get account by password") assert.Equal(t, accAlice, accGet) - accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart) + accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to get account by localpart") assert.Equal(t, accAlice, accGet) // check account availability - available, err := db.CheckAccountAvailability(ctx, aliceLocalpart) + available, err := db.CheckAccountAvailability(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, false, available) - available, err = db.CheckAccountAvailability(ctx, "unusedname") + available, err = db.CheckAccountAvailability(ctx, "unusedname", aliceDomain) assert.NoError(t, err, "failed to checkout account availability") assert.Equal(t, true, available) // get guest account numeric aliceLocalpart - first, err := db.GetNewNumericLocalpart(ctx) + first, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err, "failed to get new numeric localpart") // Create a new account to verify the numeric localpart is updated - _, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", aliceDomain, "testing", "", api.AccountTypeGuest) assert.NoError(t, err, "failed to create account") - second, err := db.GetNewNumericLocalpart(ctx) + second, err := db.GetNewNumericLocalpart(ctx, aliceDomain) assert.NoError(t, err) assert.Greater(t, second, first) // update password for alice - err = db.SetPassword(ctx, aliceLocalpart, "newPassword") + err = db.SetPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to update password") - accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.NoError(t, err, "failed to get account by new password") assert.Equal(t, accAlice, accGet) // deactivate account - err = db.DeactivateAccount(ctx, aliceLocalpart) + err = db.DeactivateAccount(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "failed to deactivate account") // This should fail now, as the account is deactivated - _, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword") + _, err = db.GetAccountByPassword(ctx, aliceLocalpart, aliceDomain, "newPassword") assert.Error(t, err, "expected an error, got none") - _, err = db.GetAccountByLocalpart(ctx, "unusename") + _, err = db.GetAccountByLocalpart(ctx, "unusename", aliceDomain) assert.Error(t, err, "expected an error for non existent localpart") // create an empty localpart; this should never happen, but is required to test getting a numeric localpart // if there's already a user without a localpart in the database - _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeUser) + _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeUser) assert.NoError(t, err) // test getting a numeric localpart, with an existing user without a localpart - _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest) assert.NoError(t, err) // Create a user with a high numeric localpart, out of range for the Postgres integer (2147483647) type - _, err = db.CreateAccount(ctx, "2147483650", "", "", api.AccountTypeUser) + _, err = db.CreateAccount(ctx, "2147483650", aliceDomain, "", "", api.AccountTypeUser) assert.NoError(t, err) // Now try to create a new guest user - _, err = db.CreateAccount(ctx, "", "", "", api.AccountTypeGuest) + _, err = db.CreateAccount(ctx, "", aliceDomain, "", "", api.AccountTypeGuest) assert.NoError(t, err) }) } func Test_Devices(t *testing.T) { alice := test.NewUser(t) - localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) deviceID := util.RandomString(8) accessToken := util.RandomString(16) @@ -161,10 +161,10 @@ func Test_Devices(t *testing.T) { db, close := mustCreateDatabase(t, dbType) defer close() - deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "") + deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") - gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID) + gotDevice, err := db.GetDeviceByID(ctx, localpart, domain, deviceID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields @@ -174,14 +174,14 @@ func Test_Devices(t *testing.T) { // create a device without existing device ID accessToken = util.RandomString(16) - deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "") + deviceWithoutID, err := db.CreateDevice(ctx, localpart, domain, nil, accessToken, nil, "", "") assert.NoError(t, err, "unable to create deviceWithoutID") - gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID) + gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, domain, deviceWithoutID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields // Get devices - devices, err := db.GetDevicesByLocalpart(ctx, localpart) + devices, err := db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get devices by localpart") assert.Equal(t, 2, len(devices)) deviceIDs := make([]string, 0, len(devices)) @@ -195,15 +195,15 @@ func Test_Devices(t *testing.T) { // Update device newName := "new display name" - err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName) + err = db.UpdateDevice(ctx, localpart, domain, deviceWithID.ID, &newName) assert.NoError(t, err, "unable to update device displayname") updatedAfterTimestamp := time.Now().Unix() - err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1", "Element Web") + err = db.UpdateDeviceLastSeen(ctx, localpart, domain, deviceWithID.ID, "127.0.0.1", "Element Web") assert.NoError(t, err, "unable to update device last seen") deviceWithID.DisplayName = newName deviceWithID.LastSeenIP = "127.0.0.1" - gotDevice, err = db.GetDeviceByID(ctx, localpart, deviceWithID.ID) + gotDevice, err = db.GetDeviceByID(ctx, localpart, domain, deviceWithID.ID) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 2, len(devices)) assert.Equal(t, deviceWithID.DisplayName, gotDevice.DisplayName) @@ -213,20 +213,20 @@ func Test_Devices(t *testing.T) { // create one more device and remove the devices step by step newDeviceID := util.RandomString(16) accessToken = util.RandomString(16) - _, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "") + _, err = db.CreateDevice(ctx, localpart, domain, &newDeviceID, accessToken, nil, "", "") assert.NoError(t, err, "unable to create new device") - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 3, len(devices)) - err = db.RemoveDevices(ctx, localpart, deviceIDs) + err = db.RemoveDevices(ctx, localpart, domain, deviceIDs) assert.NoError(t, err, "unable to remove devices") - devices, err = db.GetDevicesByLocalpart(ctx, localpart) + devices, err = db.GetDevicesByLocalpart(ctx, localpart, domain) assert.NoError(t, err, "unable to get device by id") assert.Equal(t, 1, len(devices)) - deleted, err := db.RemoveAllDevices(ctx, localpart, "") + deleted, err := db.RemoveAllDevices(ctx, localpart, domain, "") assert.NoError(t, err, "unable to remove all devices") assert.Equal(t, 1, len(deleted)) assert.Equal(t, newDeviceID, deleted[0].ID) @@ -364,7 +364,7 @@ func Test_OpenID(t *testing.T) { func Test_Profile(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -372,30 +372,33 @@ func Test_Profile(t *testing.T) { defer close() // create account, which also creates a profile - _, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin) + _, err = db.CreateAccount(ctx, aliceLocalpart, aliceDomain, "testing", "", api.AccountTypeAdmin) assert.NoError(t, err, "failed to create account") - gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart) + gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get profile by localpart") - wantProfile := &authtypes.Profile{Localpart: aliceLocalpart} + wantProfile := &authtypes.Profile{ + Localpart: aliceLocalpart, + ServerName: string(aliceDomain), + } assert.Equal(t, wantProfile, gotProfile) // set avatar & displayname wantProfile.DisplayName = "Alice" - gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, "Alice") + gotProfile, changed, err := db.SetDisplayName(ctx, aliceLocalpart, aliceDomain, "Alice") assert.Equal(t, wantProfile, gotProfile) assert.NoError(t, err, "unable to set displayname") assert.True(t, changed) wantProfile.AvatarURL = "mxc://aliceAvatar" - gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.True(t, changed) // Setting the same avatar again doesn't change anything wantProfile.AvatarURL = "mxc://aliceAvatar" - gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar") + gotProfile, changed, err = db.SetAvatarURL(ctx, aliceLocalpart, aliceDomain, "mxc://aliceAvatar") assert.NoError(t, err, "unable to set avatar url") assert.Equal(t, wantProfile, gotProfile) assert.False(t, changed) @@ -410,7 +413,7 @@ func Test_Profile(t *testing.T) { func Test_Pusher(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -432,11 +435,11 @@ func Test_Pusher(t *testing.T) { ProfileTag: util.RandomString(8), Language: util.RandomString(2), } - err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart) + err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to upsert pusher") // check it was actually persisted - gotPushers, err = db.GetPushers(ctx, aliceLocalpart) + gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, i+1, len(gotPushers)) assert.Equal(t, wantPusher, gotPushers[i]) @@ -444,16 +447,16 @@ func Test_Pusher(t *testing.T) { } // remove single pusher - err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart) + err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to remove pusher") - gotPushers, err := db.GetPushers(ctx, aliceLocalpart) + gotPushers, err := db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 1, len(gotPushers)) // remove last pusher err = db.RemovePushers(ctx, appID, pushKeys[1]) assert.NoError(t, err, "unable to remove pusher") - gotPushers, err = db.GetPushers(ctx, aliceLocalpart) + gotPushers, err = db.GetPushers(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get pushers") assert.Equal(t, 0, len(gotPushers)) }) @@ -461,7 +464,7 @@ func Test_Pusher(t *testing.T) { func Test_ThreePID(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { @@ -469,15 +472,16 @@ func Test_ThreePID(t *testing.T) { defer close() threePID := util.RandomString(8) medium := util.RandomString(8) - err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium) + err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, aliceDomain, medium) assert.NoError(t, err, "unable to save threepid association") // get the stored threepid - gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium) + gotLocalpart, gotDomain, err := db.GetLocalpartForThreePID(ctx, threePID, medium) assert.NoError(t, err, "unable to get localpart for threepid") assert.Equal(t, aliceLocalpart, gotLocalpart) + assert.Equal(t, aliceDomain, gotDomain) - threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 1, len(threepids)) assert.Equal(t, authtypes.ThreePID{ @@ -490,7 +494,7 @@ func Test_ThreePID(t *testing.T) { assert.NoError(t, err, "unexpected error") // verify it was deleted - threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart) + threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart, aliceDomain) assert.NoError(t, err, "unable to get threepids for localpart") assert.Equal(t, 0, len(threepids)) }) @@ -498,7 +502,7 @@ func Test_ThreePID(t *testing.T) { func Test_Notification(t *testing.T) { alice := test.NewUser(t) - aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID) + aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID) assert.NoError(t, err) room := test.NewRoom(t, alice) room2 := test.NewRoom(t, alice) @@ -526,34 +530,34 @@ func Test_Notification(t *testing.T) { RoomID: roomID, TS: gomatrixserverlib.AsTimestamp(ts), } - err = db.InsertNotification(ctx, aliceLocalpart, eventID, uint64(i+1), nil, notification) + err = db.InsertNotification(ctx, aliceLocalpart, aliceDomain, eventID, uint64(i+1), nil, notification) assert.NoError(t, err, "unable to insert notification") } // get notifications - count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications) + count, err := db.GetNotificationCount(ctx, aliceLocalpart, aliceDomain, tables.AllNotifications) assert.NoError(t, err, "unable to get notification count") assert.Equal(t, int64(10), count) - notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications) + notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, aliceDomain, 0, 15, tables.AllNotifications) assert.NoError(t, err, "unable to get notifications") assert.Equal(t, int64(10), count) assert.Equal(t, 10, len(notifs)) // ... for a specific room - total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(4), total) // mark notification as read - affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true) + affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, aliceDomain, room2.ID, 7, true) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) // this should delete 2 notifications - affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8) + affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, aliceDomain, room2.ID, 8) assert.NoError(t, err, "unable to set notifications read") assert.True(t, affected) - total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(2), total) @@ -562,7 +566,7 @@ func Test_Notification(t *testing.T) { assert.NoError(t, err) // this should now return 0 notifications - total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID) + total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, aliceDomain, room2.ID) assert.NoError(t, err, "unable to get notifications for room") assert.Equal(t, int64(0), total) }) diff --git a/userapi/storage/tables/interface.go b/userapi/storage/tables/interface.go index 5e1dd0971..e14776cf3 100644 --- a/userapi/storage/tables/interface.go +++ b/userapi/storage/tables/interface.go @@ -28,31 +28,31 @@ import ( ) type AccountDataTable interface { - InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error - SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) - SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) + InsertAccountData(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string, content json.RawMessage) error + SelectAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error) + SelectAccountDataByType(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID, dataType string) (data json.RawMessage, err error) } type AccountsTable interface { - InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) - UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error) - DeactivateAccount(ctx context.Context, localpart string) (err error) - SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error) - SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) - SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error) + InsertAccount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, hash, appserviceID string, accountType api.AccountType) (*api.Account, error) + UpdatePassword(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, passwordHash string) (err error) + DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) + SelectPasswordHash(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (hash string, err error) + SelectAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*api.Account, error) + SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (id int64, err error) } type DevicesTable interface { - InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) - DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error - DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error - DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error - UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error + InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error) + DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string, serverName gomatrixserverlib.ServerName) error + DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, devices []string) error + DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) error + UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID string, displayName *string) error SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error) - SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) - SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error) + SelectDeviceByID(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID string) (*api.Device, error) + SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, exceptDeviceID string) ([]api.Device, error) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) - UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error + UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error } type KeyBackupTable interface { @@ -79,40 +79,40 @@ type LoginTokenTable interface { } type OpenIDTable interface { - InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error) + InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, serverName gomatrixserverlib.ServerName, expiresAtMS int64) (err error) SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error) } type ProfileTable interface { - InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error - SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) - SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (*authtypes.Profile, bool, error) - SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (*authtypes.Profile, bool, error) + InsertProfile(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) error + SelectProfileByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (*authtypes.Profile, error) + SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, avatarURL string) (*authtypes.Profile, bool, error) + SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, displayName string) (*authtypes.Profile, bool, error) SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) } type ThreePIDTable interface { - SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error) - SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error) - InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error) + SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, serverName gomatrixserverlib.ServerName, err error) + SelectThreePIDsForLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (threepids []authtypes.ThreePID, err error) + InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, serverName gomatrixserverlib.ServerName) (err error) DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) } type PusherTable interface { - InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string) error - SelectPushers(ctx context.Context, txn *sql.Tx, localpart string) ([]api.Pusher, error) - DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string) error + InsertPusher(ctx context.Context, txn *sql.Tx, session_id int64, pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string, serverName gomatrixserverlib.ServerName) error + SelectPushers(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName) ([]api.Pusher, error) + DeletePusher(ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName) error DeletePushers(ctx context.Context, txn *sql.Tx, appid, pushkey string) error } type NotificationTable interface { Clean(ctx context.Context, txn *sql.Tx) error - Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error - DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) - UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) - Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) - SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter NotificationFilter) (int64, error) - SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, _ error) + Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error + DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) + UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) + Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter NotificationFilter) ([]*api.Notification, int64, error) + SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter NotificationFilter) (int64, error) + SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) } type StatsTable interface { diff --git a/userapi/storage/tables/stats_table_test.go b/userapi/storage/tables/stats_table_test.go index a547423bc..b088d15cd 100644 --- a/userapi/storage/tables/stats_table_test.go +++ b/userapi/storage/tables/stats_table_test.go @@ -79,6 +79,7 @@ func mustMakeAccountAndDevice( accDB tables.AccountsTable, devDB tables.DevicesTable, localpart string, + serverName gomatrixserverlib.ServerName, // nolint:unparam accType api.AccountType, userAgent string, ) { @@ -89,11 +90,11 @@ func mustMakeAccountAndDevice( appServiceID = util.RandomString(16) } - _, err := accDB.InsertAccount(ctx, nil, localpart, "", appServiceID, accType) + _, err := accDB.InsertAccount(ctx, nil, localpart, serverName, "", appServiceID, accType) if err != nil { t.Fatalf("unable to create account: %v", err) } - _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, util.RandomString(16), nil, "", userAgent) + _, err = devDB.InsertDevice(ctx, nil, "deviceID", localpart, serverName, util.RandomString(16), nil, "", userAgent) if err != nil { t.Fatalf("unable to create device: %v", err) } @@ -150,12 +151,12 @@ func Test_UserStatistics(t *testing.T) { }) t.Run("Want Users", func(t *testing.T) { - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", api.AccountTypeUser, "Element Android") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", api.AccountTypeUser, "Element iOS") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", api.AccountTypeUser, "Element web") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", api.AccountTypeGuest, "Element Electron") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", api.AccountTypeAdmin, "gecko") - mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", api.AccountTypeAppService, "gecko") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user1", "localhost", api.AccountTypeUser, "Element Android") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user2", "localhost", api.AccountTypeUser, "Element iOS") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user3", "localhost", api.AccountTypeUser, "Element web") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user4", "localhost", api.AccountTypeGuest, "Element Electron") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user5", "localhost", api.AccountTypeAdmin, "gecko") + mustMakeAccountAndDevice(t, ctx, accDB, devDB, "user6", "localhost", api.AccountTypeAppService, "gecko") gotStats, _, err := statsDB.UserStatistics(ctx, nil) if err != nil { t.Fatalf("unexpected error: %v", err) diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 2a43c0bd4..23acfa6f3 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -80,14 +80,14 @@ func TestQueryProfile(t *testing.T) { // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) defer close() - _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "", api.AccountTypeUser) + _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } - if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", aliceAvatarURL); err != nil { + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { t.Fatalf("failed to set avatar url: %s", err) } - if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", aliceDisplayName); err != nil { + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { t.Fatalf("failed to set display name: %s", err) } @@ -164,7 +164,7 @@ func TestPasswordlessLoginFails(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) defer close() - _, err := accountDB.CreateAccount(ctx, "auser", "", "", api.AccountTypeAppService) + _, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService) if err != nil { t.Fatalf("failed to make account: %s", err) } @@ -190,7 +190,7 @@ func TestLoginToken(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) defer close() - _, err := accountDB.CreateAccount(ctx, "auser", "apassword", "", api.AccountTypeUser) + _, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser) if err != nil { t.Fatalf("failed to make account: %s", err) } diff --git a/userapi/util/devices.go b/userapi/util/devices.go index cbf3bd28f..c55fc7999 100644 --- a/userapi/util/devices.go +++ b/userapi/util/devices.go @@ -2,10 +2,12 @@ package util import ( "context" + "fmt" "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage" + "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -17,10 +19,10 @@ type PusherDevice struct { } // GetPushDevices pushes to the configured devices of a local user. -func GetPushDevices(ctx context.Context, localpart string, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { - pushers, err := db.GetPushers(ctx, localpart) +func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) { + pushers, err := db.GetPushers(ctx, localpart, serverName) if err != nil { - return nil, err + return nil, fmt.Errorf("db.GetPushers: %w", err) } devices := make([]*PusherDevice, 0, len(pushers)) diff --git a/userapi/util/notify.go b/userapi/util/notify.go index ff206bd3c..fc0ab39bf 100644 --- a/userapi/util/notify.go +++ b/userapi/util/notify.go @@ -8,6 +8,7 @@ import ( "github.com/matrix-org/dendrite/internal/pushgateway" "github.com/matrix-org/dendrite/userapi/storage" "github.com/matrix-org/dendrite/userapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" ) @@ -16,8 +17,8 @@ import ( // a single goroutine is started when talking to the Push // gateways. There is no way to know when the background goroutine has // finished. -func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, db storage.Database) error { - pusherDevices, err := GetPushDevices(ctx, localpart, nil, db) +func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error { + pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db) if err != nil { return err } @@ -26,7 +27,7 @@ func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, loc return nil } - userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, tables.AllNotifications) + userNumUnreadNotifs, err := db.GetNotificationCount(ctx, localpart, serverName, tables.AllNotifications) if err != nil { return err } From 1e79b0557e177986262f365f6601fcdf56480a07 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 14 Nov 2022 13:06:27 +0100 Subject: [PATCH 15/42] Use a writer to assign state key NIDs (#2877) --- roomserver/storage/shared/storage.go | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 08e912a00..9bf438a23 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -103,6 +103,7 @@ func (d *Database) eventStateKeyNIDs( ctx context.Context, txn *sql.Tx, eventStateKeys []string, ) (map[string]types.EventStateKeyNID, error) { result := make(map[string]types.EventStateKeyNID) + eventStateKeys = util.UniqueStrings(eventStateKeys) nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys) if err != nil { return nil, err @@ -112,15 +113,23 @@ func (d *Database) eventStateKeyNIDs( } // We received some nids, but are still missing some, work out which and create them if len(eventStateKeys) > len(result) { - for _, eventStateKey := range eventStateKeys { - if _, ok := result[eventStateKey]; ok { - continue + var nid types.EventStateKeyNID + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + for _, eventStateKey := range eventStateKeys { + if _, ok := result[eventStateKey]; ok { + continue + } + + nid, err = d.assignStateKeyNID(ctx, txn, eventStateKey) + if err != nil { + return err + } + result[eventStateKey] = nid } - nid, err := d.assignStateKeyNID(ctx, txn, eventStateKey) - if err != nil { - return result, err - } - result[eventStateKey] = nid + return nil + }) + if err != nil { + return nil, err } } return result, nil From 858a4af2244986356576b1cf97572275a8bc001f Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 14 Nov 2022 13:06:41 +0100 Subject: [PATCH 16/42] Try to optimize CI (#2867) Try to optimize CI by using caches --- .github/workflows/dendrite.yml | 80 ++++++++------------- build/scripts/Complement.Dockerfile | 7 +- build/scripts/ComplementPostgres.Dockerfile | 7 +- 3 files changed, 38 insertions(+), 56 deletions(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index f96cbadd4..fa4282384 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -26,22 +26,14 @@ jobs: uses: actions/setup-go@v3 with: go-version: 1.18 - - - uses: actions/cache@v2 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-wasm-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go-wasm + cache: true - name: Install Node uses: actions/setup-node@v2 with: node-version: 14 - - uses: actions/cache@v2 + - uses: actions/cache@v3 with: path: ~/.npm key: ${{ runner.os }}-node-${{ hashFiles('**/package-lock.json') }} @@ -109,19 +101,12 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} + cache: true - name: Set up gotestfmt uses: gotesttools/gotestfmt-action@v2 with: # Optional: pass GITHUB_TOKEN to avoid rate limiting. token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-test-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-test- - run: go test -json -v ./... 2>&1 | gotestfmt env: POSTGRES_HOST: localhost @@ -146,17 +131,17 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - - name: Install dependencies x86 - if: ${{ matrix.goarch == '386' }} - run: sudo apt update && sudo apt-get install -y gcc-multilib - uses: actions/cache@v3 with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goarch }}- + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + - name: Install dependencies x86 + if: ${{ matrix.goarch == '386' }} + run: sudo apt update && sudo apt-get install -y gcc-multilib - env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} @@ -180,16 +165,16 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - - name: Install dependencies - run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc - uses: actions/cache@v3 with: path: | ~/.cache/go-build ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }}-${{ hashFiles('**/go.sum') }} + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}-${{ hashFiles('**/go.sum') }} restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-${{ matrix.goos }} + key: ${{ runner.os }}-go${{ matrix.go }}${{ matrix.goos }}-${{ matrix.goarch }}- + - name: Install dependencies + run: sudo apt update && sudo apt install -y gcc-mingw-w64-x86-64 # install required gcc - env: GOOS: ${{ matrix.goos }} GOARCH: ${{ matrix.goarch }} @@ -221,14 +206,7 @@ jobs: uses: actions/setup-go@v3 with: go-version: "1.18" - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go-upgrade + cache: true - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests - name: Test upgrade (PostgreSQL) @@ -248,14 +226,7 @@ jobs: uses: actions/setup-go@v3 with: go-version: "1.18" - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go-upgrade-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go-upgrade + cache: true - name: Build upgrade-tests run: go build ./cmd/dendrite-upgrade-tests - name: Test upgrade (PostgreSQL) @@ -295,6 +266,8 @@ jobs: image: matrixdotorg/sytest-dendrite:latest volumes: - ${{ github.workspace }}:/src + - /root/.cache/go-build:/github/home/.cache/go-build + - /root/.cache/go-mod:/gopath/pkg/mod env: POSTGRES: ${{ matrix.postgres && 1}} API: ${{ matrix.api && 1 }} @@ -302,6 +275,14 @@ jobs: CGO_ENABLED: ${{ matrix.cgo && 1 }} steps: - uses: actions/checkout@v3 + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + /gopath/pkg/mod + key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go-sytest- - name: Run Sytest run: /bootstrap.sh dendrite working-directory: /src @@ -336,12 +317,14 @@ jobs: matrix: include: - label: SQLite native + cgo: 0 - label: SQLite Cgo cgo: 1 - label: SQLite native, full HTTP APIs api: full-http + cgo: 0 - label: SQLite Cgo, full HTTP APIs api: full-http @@ -349,10 +332,12 @@ jobs: - label: PostgreSQL postgres: Postgres + cgo: 0 - label: PostgreSQL, full HTTP APIs postgres: Postgres api: full-http + cgo: 0 steps: # Env vars are set file a file given by $GITHUB_PATH. We need both Go 1.17 and GOPATH on env to run Complement. # See https://docs.github.com/en/actions/using-workflows/workflow-commands-for-github-actions#adding-a-system-path @@ -360,14 +345,12 @@ jobs: run: | echo "$GOROOT_1_17_X64/bin" >> $GITHUB_PATH echo "~/go/bin" >> $GITHUB_PATH - - name: "Install Complement Dependencies" # We don't need to install Go because it is included on the Ubuntu 20.04 image: # See https://github.com/actions/virtual-environments/blob/main/images/linux/Ubuntu2004-Readme.md specifically GOROOT_1_17_X64 run: | sudo apt-get update && sudo apt-get install -y libolm3 libolm-dev go get -v github.com/gotesttools/gotestfmt/v2/cmd/gotestfmt@latest - - name: Run actions/checkout@v3 for dendrite uses: actions/checkout@v3 with: @@ -393,12 +376,10 @@ jobs: if [[ -z "$BRANCH_NAME" || $BRANCH_NAME =~ ^refs/pull/.* ]]; then continue fi - (wget -O - "https://github.com/matrix-org/complement/archive/$BRANCH_NAME.tar.gz" | tar -xz --strip-components=1 -C complement) && break done - # Build initial Dendrite image - - run: docker build -t complement-dendrite -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . + - run: docker build --build-arg=CGO=${{ matrix.cgo }} -t complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} -f build/scripts/Complement${{ matrix.postgres }}.Dockerfile . working-directory: dendrite env: DOCKER_BUILDKIT: 1 @@ -410,9 +391,8 @@ jobs: shell: bash name: Run Complement Tests env: - COMPLEMENT_BASE_IMAGE: complement-dendrite:latest + COMPLEMENT_BASE_IMAGE: complement-dendrite:${{ matrix.postgres }}${{ matrix.api }}${{ matrix.cgo }} API: ${{ matrix.api && 1 }} - CGO_ENABLED: ${{ matrix.cgo && 1 }} working-directory: complement integration-tests-done: diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 14b28498b..79422e645 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -10,12 +10,13 @@ RUN mkdir /dendrite # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. +ARG CGO RUN --mount=target=. \ --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - go build -o /dendrite ./cmd/generate-config && \ - go build -o /dendrite ./cmd/generate-keys && \ - go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 785090b0b..3faf43cc7 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -28,12 +28,13 @@ RUN mkdir /dendrite # Utilise Docker caching when downloading dependencies, this stops us needlessly # downloading dependencies every time. +ARG CGO RUN --mount=target=. \ --mount=type=cache,target=/go/pkg/mod \ --mount=type=cache,target=/root/.cache/go-build \ - go build -o /dendrite ./cmd/generate-config && \ - go build -o /dendrite ./cmd/generate-keys && \ - go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem From 2a77a910eb16333e56b85a7fdb3d6254acc9b6d7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 14 Nov 2022 13:07:13 +0100 Subject: [PATCH 17/42] Handle remote room upgrades (#2866) Makes the following tests pass ``` /upgrade moves remote aliases to the new room Local and remote users' homeservers remove a room from their public directory on upgrade ``` --- roomserver/internal/input/input_events.go | 16 ++++++++++++ roomserver/storage/interface.go | 1 + roomserver/storage/shared/storage.go | 30 +++++++++++++++++++++++ sytest-whitelist | 4 ++- 4 files changed, 50 insertions(+), 1 deletion(-) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 60160e8e5..682aa2b12 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -23,6 +23,8 @@ import ( "fmt" "time" + "github.com/tidwall/gjson" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/opentracing/opentracing-go" @@ -409,6 +411,13 @@ func (r *Inputer) processRoomEvent( } } + // Handle remote room upgrades, e.g. remove published room + if event.Type() == "m.room.tombstone" && event.StateKeyEquals("") && !r.Cfg.Matrix.IsLocalServerName(senderDomain) { + if err = r.handleRemoteRoomUpgrade(ctx, event); err != nil { + return fmt.Errorf("failed to handle remote room upgrade: %w", err) + } + } + // processing this event resulted in an event (which may not be the one we're processing) // being redacted. We are guaranteed to have both sides (the redaction/redacted event), // so notify downstream components to redact this event - they should have it if they've @@ -434,6 +443,13 @@ func (r *Inputer) processRoomEvent( return nil } +// handleRemoteRoomUpgrade updates published rooms and room aliases +func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixserverlib.Event) error { + oldRoomID := event.RoomID() + newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str + return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender()) +} + // processStateBefore works out what the state is before the event and // then checks the event auths against the state at the time. It also // tries to determine what the history visibility was of the event at diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index c39a8cbba..06db4b2d8 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -172,4 +172,5 @@ type Database interface { ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 9bf438a23..16898bcb1 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1408,6 +1408,36 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { + + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // un-publish old room + if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil { + return fmt.Errorf("failed to unpublish room: %w", err) + } + // publish new room + if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil { + return fmt.Errorf("failed to publish room: %w", err) + } + + // Migrate any existing room aliases + aliases, err := d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, txn, oldRoomID) + if err != nil { + return fmt.Errorf("failed to get room aliases: %w", err) + } + + for _, alias := range aliases { + if err = d.RoomAliasesTable.DeleteRoomAlias(ctx, txn, alias); err != nil { + return fmt.Errorf("failed to remove room alias: %w", err) + } + if err = d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, newRoomID, eventSender); err != nil { + return fmt.Errorf("failed to set room alias: %w", err) + } + } + return nil + }) +} + // FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops // it should live in this package! diff --git a/sytest-whitelist b/sytest-whitelist index bb4f0a279..49ffb8fe8 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -761,4 +761,6 @@ AS can publish rooms in their own list AS and main public room lists are separate /upgrade preserves direct room state local user has tags copied to the new room -remote user has tags copied to the new room \ No newline at end of file +remote user has tags copied to the new room +/upgrade moves remote aliases to the new room +Local and remote users' homeservers remove a room from their public directory on upgrade \ No newline at end of file From f4ee3977340c84d321767d347795b1dcd05ac459 Mon Sep 17 00:00:00 2001 From: Omar Kotb Date: Mon, 14 Nov 2022 19:15:39 +0200 Subject: [PATCH 18/42] Fix Caddy config well-known delegation example (#2879) Signed-off-by: Omar Kotb Signed-off-by: Omar Kotb --- docs/installation/2_domainname.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/installation/2_domainname.md b/docs/installation/2_domainname.md index e7b3495f7..545a2daf6 100644 --- a/docs/installation/2_domainname.md +++ b/docs/installation/2_domainname.md @@ -90,7 +90,7 @@ For example, this can be done with the following Caddy config: handle /.well-known/matrix/server { header Content-Type application/json header Access-Control-Allow-Origin * - respond `"m.server": "matrix.example.com:8448"` + respond `{"m.server": "matrix.example.com:8448"}` } handle /.well-known/matrix/client { From 6650712a1c0dec282b47b7ba14bc8c2e06a385d8 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Nov 2022 15:05:23 +0000 Subject: [PATCH 19/42] Federation fixes for virtual hosting --- clientapi/routing/admin.go | 2 +- clientapi/routing/createroom.go | 2 +- clientapi/routing/directory.go | 2 +- clientapi/routing/directory_public.go | 2 +- clientapi/routing/membership.go | 8 +- clientapi/routing/profile.go | 14 +++- clientapi/routing/redaction.go | 9 +- clientapi/routing/register.go | 39 ++++++--- clientapi/routing/sendevent.go | 9 +- clientapi/routing/server_notices.go | 1 + clientapi/routing/userdirectory.go | 2 +- clientapi/threepid/invites.go | 8 +- cmd/dendrite-demo-pinecone/conn/client.go | 4 +- cmd/dendrite-demo-pinecone/rooms/rooms.go | 8 +- cmd/dendrite-demo-yggdrasil/yggconn/client.go | 3 +- .../yggrooms/yggrooms.go | 9 +- cmd/furl/main.go | 12 ++- federationapi/api/api.go | 54 ++++++------ federationapi/federationapi.go | 12 +-- federationapi/federationapi_keys_test.go | 2 +- federationapi/federationapi_test.go | 8 +- federationapi/internal/federationclient.go | 44 +++++----- federationapi/internal/perform.go | 38 +++++++-- federationapi/inthttp/client.go | 52 ++++++++---- federationapi/inthttp/server.go | 22 ++--- federationapi/queue/destinationqueue.go | 2 +- federationapi/queue/queue.go | 18 ++-- federationapi/queue/queue_test.go | 4 +- federationapi/routing/backfill.go | 5 +- federationapi/routing/invite.go | 17 +++- federationapi/routing/join.go | 12 ++- federationapi/routing/keys.go | 67 +++++++++------ federationapi/routing/leave.go | 13 ++- federationapi/routing/profile.go | 10 +-- federationapi/routing/query.go | 2 +- federationapi/routing/send.go | 9 +- federationapi/routing/send_test.go | 8 +- federationapi/routing/threepid.go | 27 +++++- go.mod | 2 +- go.sum | 4 +- internal/eventutil/events.go | 15 ++-- keyserver/consumers/devicelistupdate.go | 26 +++--- keyserver/consumers/signingkeyupdate.go | 28 ++++--- keyserver/internal/device_list_update.go | 6 +- keyserver/internal/device_list_update_test.go | 14 +++- keyserver/internal/internal.go | 4 +- keyserver/keyserver.go | 2 +- roomserver/api/input.go | 5 +- roomserver/api/perform.go | 2 + roomserver/api/wrapper.go | 11 ++- roomserver/internal/alias.go | 19 ++++- roomserver/internal/api.go | 29 ++++--- roomserver/internal/input/input.go | 7 +- roomserver/internal/input/input_events.go | 27 +++--- roomserver/internal/input/input_missing.go | 17 ++-- roomserver/internal/perform/perform_admin.go | 25 +++++- .../internal/perform/perform_backfill.go | 50 ++++++----- roomserver/internal/perform/perform_join.go | 24 ++++-- roomserver/internal/perform/perform_leave.go | 14 ++-- .../internal/perform/perform_upgrade.go | 21 ++++- roomserver/internal/query/query.go | 10 +-- roomserver/roomserver_test.go | 2 +- setup/base/base.go | 7 +- setup/config/config.go | 15 ++++ setup/config/config_global.go | 82 ++++++++++++++++++- setup/mscs/msc2836/msc2836.go | 6 +- setup/mscs/msc2946/msc2946.go | 2 +- syncapi/consumers/keychange.go | 35 ++++---- syncapi/consumers/receipts.go | 30 ++++--- syncapi/consumers/roomserver.go | 10 +-- syncapi/consumers/sendtodevice.go | 40 ++++----- syncapi/routing/messages.go | 1 + syncapi/syncapi_test.go | 4 +- 73 files changed, 736 insertions(+), 420 deletions(-) diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 9ed1f0ca2..be8073c33 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -110,7 +110,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.MissingArgument("Expecting user localpart."), } } - if l, s, err := gomatrixserverlib.SplitID('@', localpart); err == nil { + if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil { localpart, serverName = l, s } request := struct { diff --git a/clientapi/routing/createroom.go b/clientapi/routing/createroom.go index eefe8e24b..a0d80903d 100644 --- a/clientapi/routing/createroom.go +++ b/clientapi/routing/createroom.go @@ -477,7 +477,7 @@ func createRoom( SendAsServer: roomserverAPI.DoNotSendToOtherServers, }) } - if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, inputs, false); err != nil { + if err = roomserverAPI.SendInputRoomEvents(ctx, rsAPI, device.UserDomain(), inputs, false); err != nil { util.GetLogger(ctx).WithError(err).Error("roomserverAPI.SendInputRoomEvents failed") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/directory.go b/clientapi/routing/directory.go index ce14745aa..b3c5aae45 100644 --- a/clientapi/routing/directory.go +++ b/clientapi/routing/directory.go @@ -77,7 +77,7 @@ func DirectoryRoom( // If we don't know it locally, do a federation query. // But don't send the query to ourselves. if !cfg.Matrix.IsLocalServerName(domain) { - fedRes, fedErr := federation.LookupRoomAlias(req.Context(), domain, roomAlias) + fedRes, fedErr := federation.LookupRoomAlias(req.Context(), cfg.Matrix.ServerName, domain, roomAlias) if fedErr != nil { // TODO: Return 502 if the remote server errored. // TODO: Return 504 if the remote server timed out. diff --git a/clientapi/routing/directory_public.go b/clientapi/routing/directory_public.go index b1043e994..606744767 100644 --- a/clientapi/routing/directory_public.go +++ b/clientapi/routing/directory_public.go @@ -74,7 +74,7 @@ func GetPostPublicRooms( serverName := gomatrixserverlib.ServerName(request.Server) if serverName != "" && !cfg.Matrix.IsLocalServerName(serverName) { res, err := federation.GetPublicRoomsFiltered( - req.Context(), serverName, + req.Context(), cfg.Matrix.ServerName, serverName, int(request.Limit), request.Since, request.Filter.SearchTerms, false, "", diff --git a/clientapi/routing/membership.go b/clientapi/routing/membership.go index 94ba17a02..482c1f5f7 100644 --- a/clientapi/routing/membership.go +++ b/clientapi/routing/membership.go @@ -110,6 +110,7 @@ func sendMembership(ctx context.Context, profileAPI userapi.ClientUserAPI, devic ctx, rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{event.Event.Headered(roomVer)}, + device.UserDomain(), serverName, serverName, nil, @@ -322,7 +323,12 @@ func buildMembershipEvent( return nil, err } - return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return nil, err + } + + return eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil) } // loadProfile lookups the profile of a given user from the database and returns diff --git a/clientapi/routing/profile.go b/clientapi/routing/profile.go index 4d9e1f8a5..92a75fc78 100644 --- a/clientapi/routing/profile.go +++ b/clientapi/routing/profile.go @@ -284,7 +284,7 @@ func updateProfile( } events, err := buildMembershipEvents( - ctx, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, + ctx, device, res.RoomIDs, *profile, userID, cfg, evTime, rsAPI, ) switch e := err.(type) { case nil: @@ -298,7 +298,7 @@ func updateProfile( return jsonerror.InternalServerError(), e } - if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, domain, domain, nil, true); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, events, device.UserDomain(), domain, domain, nil, true); err != nil { util.GetLogger(ctx).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError(), err } @@ -321,7 +321,7 @@ func getProfile( } if !cfg.Matrix.IsLocalServerName(domain) { - profile, fedErr := federation.LookupProfile(ctx, domain, userID, "") + profile, fedErr := federation.LookupProfile(ctx, cfg.Matrix.ServerName, domain, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { @@ -349,6 +349,7 @@ func getProfile( func buildMembershipEvents( ctx context.Context, + device *userapi.Device, roomIDs []string, newProfile authtypes.Profile, userID string, cfg *config.ClientAPI, evTime time.Time, rsAPI api.ClientRoomserverAPI, @@ -380,7 +381,12 @@ func buildMembershipEvents( return nil, err } - event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, nil) + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return nil, err + } + + event, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, nil) if err != nil { return nil, err } diff --git a/clientapi/routing/redaction.go b/clientapi/routing/redaction.go index 778a02fd4..7841b3b07 100644 --- a/clientapi/routing/redaction.go +++ b/clientapi/routing/redaction.go @@ -123,8 +123,13 @@ func SendRedaction( return jsonerror.InternalServerError() } + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return jsonerror.InternalServerError() + } + var queryRes roomserverAPI.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(req.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, @@ -132,7 +137,7 @@ func SendRedaction( } } domain := device.UserDomain() - if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, domain, domain, nil, false); err != nil { + if err = roomserverAPI.SendEvents(context.Background(), rsAPI, roomserverAPI.KindNew, []*gomatrixserverlib.HeaderedEvent{e}, device.UserDomain(), domain, domain, nil, false); err != nil { util.GetLogger(req.Context()).WithError(err).Errorf("failed to SendEvents") return jsonerror.InternalServerError() } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 9dc63af84..a92513b8b 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -211,9 +211,10 @@ var ( // previous parameters with the ones supplied. This mean you cannot "build up" request params. type registerRequest struct { // registration parameters - Password string `json:"password"` - Username string `json:"username"` - Admin bool `json:"admin"` + Password string `json:"password"` + Username string `json:"username"` + ServerName gomatrixserverlib.ServerName `json:"-"` + Admin bool `json:"admin"` // user-interactive auth params Auth authDict `json:"auth"` @@ -570,11 +571,14 @@ func Register( JSON: response, } } - } if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } + r.ServerName = cfg.Matrix.ServerName + if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { + r.Username, r.ServerName = l, d + } if req.URL.Query().Get("kind") == "guest" { return handleGuestRegistration(req, r, cfg, userAPI) } @@ -589,7 +593,7 @@ func Register( // Auto generate a numeric username if r.Username is empty if r.Username == "" { nreq := &userapi.QueryNumericLocalpartRequest{ - ServerName: cfg.Matrix.ServerName, // TODO: might not be right + ServerName: r.ServerName, } nres := &userapi.QueryNumericLocalpartResponse{} if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { @@ -609,7 +613,7 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { + if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { return *resErr } case accessTokenErr == nil: @@ -622,7 +626,7 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username, cfg.Matrix.ServerName); resErr != nil { + if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { return *resErr } } @@ -1023,13 +1027,27 @@ func RegisterAvailable( // Squash username to all lowercase letters username = strings.ToLower(username) + domain := cfg.Matrix.ServerName + if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil { + username, domain = u, l + } + for _, v := range cfg.Matrix.VirtualHosts { + if v.ServerName == domain && !v.AllowRegistration { + return util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden( + fmt.Sprintf("Registration is not allowed on %q", string(v.ServerName)), + ), + } + } + } - if err := validateUsername(username, cfg.Matrix.ServerName); err != nil { + if err := validateUsername(username, domain); err != nil { return *err } // Check if this username is reserved by an application service - userID := userutil.MakeUserID(username, cfg.Matrix.ServerName) + userID := userutil.MakeUserID(username, domain) for _, appservice := range cfg.Derived.ApplicationServices { if appservice.OwnsNamespaceCoveringUserId(userID) { return util.JSONResponse{ @@ -1041,7 +1059,8 @@ func RegisterAvailable( res := &userapi.QueryAccountAvailabilityResponse{} err := registerAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ - Localpart: username, + Localpart: username, + ServerName: domain, }, res) if err != nil { return util.JSONResponse{ diff --git a/clientapi/routing/sendevent.go b/clientapi/routing/sendevent.go index bb66cf6fc..90af9ac4d 100644 --- a/clientapi/routing/sendevent.go +++ b/clientapi/routing/sendevent.go @@ -186,6 +186,7 @@ func SendEvent( []*gomatrixserverlib.HeaderedEvent{ e.Headered(verRes.RoomVersion), }, + device.UserDomain(), domain, domain, txnAndSessionID, @@ -275,8 +276,14 @@ func generateSendEvent( return nil, &resErr } + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + resErr := jsonerror.InternalServerError() + return nil, &resErr + } + var queryRes api.QueryLatestEventsAndStateResponse - e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, evTime, rsAPI, &queryRes) + e, err := eventutil.QueryAndBuildEvent(ctx, &builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &util.JSONResponse{ Code: http.StatusNotFound, diff --git a/clientapi/routing/server_notices.go b/clientapi/routing/server_notices.go index a7acee326..fb93d8783 100644 --- a/clientapi/routing/server_notices.go +++ b/clientapi/routing/server_notices.go @@ -231,6 +231,7 @@ func SendServerNotice( []*gomatrixserverlib.HeaderedEvent{ e.Headered(roomVersion), }, + device.UserDomain(), cfgClient.Matrix.ServerName, cfgClient.Matrix.ServerName, txnAndSessionID, diff --git a/clientapi/routing/userdirectory.go b/clientapi/routing/userdirectory.go index d3d1c22e4..62af9efa4 100644 --- a/clientapi/routing/userdirectory.go +++ b/clientapi/routing/userdirectory.go @@ -106,7 +106,7 @@ knownUsersLoop: continue } // TODO: We should probably cache/store this - fedProfile, fedErr := federation.LookupProfile(ctx, serverName, userID, "") + fedProfile, fedErr := federation.LookupProfile(ctx, localServerName, serverName, userID, "") if fedErr != nil { if x, ok := fedErr.(gomatrix.HTTPError); ok { if x.Code == http.StatusNotFound { diff --git a/clientapi/threepid/invites.go b/clientapi/threepid/invites.go index 99fb8171d..1f294a032 100644 --- a/clientapi/threepid/invites.go +++ b/clientapi/threepid/invites.go @@ -359,8 +359,13 @@ func emit3PIDInviteEvent( return err } + identity, err := cfg.Matrix.SigningIdentityFor(device.UserDomain()) + if err != nil { + return err + } + queryRes := api.QueryLatestEventsAndStateResponse{} - event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, evTime, rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(ctx, builder, cfg.Matrix, identity, evTime, rsAPI, &queryRes) if err != nil { return err } @@ -371,6 +376,7 @@ func emit3PIDInviteEvent( []*gomatrixserverlib.HeaderedEvent{ event.Headered(queryRes.RoomVersion), }, + device.UserDomain(), cfg.Matrix.ServerName, cfg.Matrix.ServerName, nil, diff --git a/cmd/dendrite-demo-pinecone/conn/client.go b/cmd/dendrite-demo-pinecone/conn/client.go index 27e18c2a3..a91434f62 100644 --- a/cmd/dendrite-demo-pinecone/conn/client.go +++ b/cmd/dendrite-demo-pinecone/conn/client.go @@ -101,9 +101,7 @@ func CreateFederationClient( base *base.BaseDendrite, s *pineconeSessions.Sessions, ) *gomatrixserverlib.FederationClient { return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.ServerName, - base.Cfg.Global.KeyID, - base.Cfg.Global.PrivateKey, + base.Cfg.Global.SigningIdentities(), gomatrixserverlib.WithTransport(createTransport(s)), ) } diff --git a/cmd/dendrite-demo-pinecone/rooms/rooms.go b/cmd/dendrite-demo-pinecone/rooms/rooms.go index 0fafbedc3..0ac705cc1 100644 --- a/cmd/dendrite-demo-pinecone/rooms/rooms.go +++ b/cmd/dendrite-demo-pinecone/rooms/rooms.go @@ -58,13 +58,17 @@ func (p *PineconeRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { for _, k := range p.r.Peers() { list[gomatrixserverlib.ServerName(k.PublicKey)] = struct{}{} } - return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, list) + return bulkFetchPublicRoomsFromServers( + context.Background(), p.fedClient, + gomatrixserverlib.ServerName(p.r.PublicKey().String()), list, + ) } // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( ctx context.Context, fedClient *gomatrixserverlib.FederationClient, + origin gomatrixserverlib.ServerName, homeservers map[gomatrixserverlib.ServerName]struct{}, ) (publicRooms []gomatrixserverlib.PublicRoom) { limit := 200 @@ -82,7 +86,7 @@ func bulkFetchPublicRoomsFromServers( go func(homeserverDomain gomatrixserverlib.ServerName) { defer wg.Done() util.GetLogger(reqctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") - fres, err := fedClient.GetPublicRooms(reqctx, homeserverDomain, int(limit), "", false, "") + fres, err := fedClient.GetPublicRooms(reqctx, origin, homeserverDomain, int(limit), "", false, "") if err != nil { util.GetLogger(reqctx).WithError(err).WithField("hs", homeserverDomain).Warn( "bulkFetchPublicRoomsFromServers: failed to query hs", diff --git a/cmd/dendrite-demo-yggdrasil/yggconn/client.go b/cmd/dendrite-demo-yggdrasil/yggconn/client.go index 358d3725e..41a9ec123 100644 --- a/cmd/dendrite-demo-yggdrasil/yggconn/client.go +++ b/cmd/dendrite-demo-yggdrasil/yggconn/client.go @@ -55,8 +55,7 @@ func (n *Node) CreateFederationClient( }, ) return gomatrixserverlib.NewFederationClient( - base.Cfg.Global.ServerName, base.Cfg.Global.KeyID, - base.Cfg.Global.PrivateKey, + base.Cfg.Global.SigningIdentities(), gomatrixserverlib.WithTransport(tr), ) } diff --git a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go index 402b86ed3..0de64755e 100644 --- a/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go +++ b/cmd/dendrite-demo-yggdrasil/yggrooms/yggrooms.go @@ -43,13 +43,18 @@ func NewYggdrasilRoomProvider( } func (p *YggdrasilRoomProvider) Rooms() []gomatrixserverlib.PublicRoom { - return bulkFetchPublicRoomsFromServers(context.Background(), p.fedClient, p.node.KnownNodes()) + return bulkFetchPublicRoomsFromServers( + context.Background(), p.fedClient, + gomatrixserverlib.ServerName(p.node.DerivedServerName()), + p.node.KnownNodes(), + ) } // bulkFetchPublicRoomsFromServers fetches public rooms from the list of homeservers. // Returns a list of public rooms. func bulkFetchPublicRoomsFromServers( ctx context.Context, fedClient *gomatrixserverlib.FederationClient, + origin gomatrixserverlib.ServerName, homeservers []gomatrixserverlib.ServerName, ) (publicRooms []gomatrixserverlib.PublicRoom) { limit := 200 @@ -66,7 +71,7 @@ func bulkFetchPublicRoomsFromServers( go func(homeserverDomain gomatrixserverlib.ServerName) { defer wg.Done() util.GetLogger(ctx).WithField("hs", homeserverDomain).Info("Querying HS for public rooms") - fres, err := fedClient.GetPublicRooms(ctx, homeserverDomain, int(limit), "", false, "") + fres, err := fedClient.GetPublicRooms(ctx, origin, homeserverDomain, int(limit), "", false, "") if err != nil { util.GetLogger(ctx).WithError(err).WithField("hs", homeserverDomain).Warn( "bulkFetchPublicRoomsFromServers: failed to query hs", diff --git a/cmd/furl/main.go b/cmd/furl/main.go index f59f9c8ce..b208ba868 100644 --- a/cmd/furl/main.go +++ b/cmd/furl/main.go @@ -48,10 +48,15 @@ func main() { panic("unexpected key block") } + serverName := gomatrixserverlib.ServerName(*requestFrom) client := gomatrixserverlib.NewFederationClient( - gomatrixserverlib.ServerName(*requestFrom), - gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), - privateKey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: serverName, + KeyID: gomatrixserverlib.KeyID(keyBlock.Headers["Key-ID"]), + PrivateKey: privateKey, + }, + }, ) u, err := url.Parse(flag.Arg(0)) @@ -79,6 +84,7 @@ func main() { req := gomatrixserverlib.NewFederationRequest( method, + serverName, gomatrixserverlib.ServerName(u.Host), u.RequestURI(), ) diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 362333fc9..e34c9e8b6 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -21,8 +21,8 @@ type FederationInternalAPI interface { QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) - MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) // Broadcasts an EDU to all servers in rooms we are joined to. Used in the yggdrasil demos. PerformBroadcastEDU( @@ -60,18 +60,18 @@ type RoomserverFederationAPI interface { // containing only the server names (without information for membership events). // The response will include this server if they are joined to the room. QueryJoinedHostServerNamesInRoom(ctx context.Context, request *QueryJoinedHostServerNamesInRoomRequest, response *QueryJoinedHostServerNamesInRoomResponse) error - GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError type KeyserverFederationAPI interface { - GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) - ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) - QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) + GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error) + ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error) + QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error) } // an interface for gmsl.FederationClient - contains functions called by federationapi only. @@ -80,28 +80,28 @@ type FederationClient interface { SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) // Perform operations - LookupRoomAlias(ctx context.Context, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) - Peek(ctx context.Context, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) - MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) - SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) - MakeLeave(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) - SendLeave(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) - SendInviteV2(ctx context.Context, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) + LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) + Peek(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, peekID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespPeek, err error) + MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) + SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) + MakeLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string) (res gomatrixserverlib.RespMakeLeave, err error) + SendLeave(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (err error) + SendInviteV2(ctx context.Context, origin, s gomatrixserverlib.ServerName, request gomatrixserverlib.InviteV2Request) (res gomatrixserverlib.RespInviteV2, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - GetEventAuth(ctx context.Context, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) - GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) - ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) - QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) - Backfill(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) - MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) - MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) + GetEventAuth(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) (res gomatrixserverlib.RespEventAuth, err error) + GetUserDevices(ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string) (gomatrixserverlib.RespUserDevices, error) + ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) + QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) + Backfill(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string) (res gomatrixserverlib.Transaction, err error) + MSC2836EventRelationships(ctx context.Context, origin, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) + MSC2946Spaces(ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool) (res gomatrixserverlib.MSC2946SpacesResponse, err error) - ExchangeThirdPartyInvite(ctx context.Context, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) - LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) - LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) + ExchangeThirdPartyInvite(ctx context.Context, origin, s gomatrixserverlib.ServerName, builder gomatrixserverlib.EventBuilder) (err error) + LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespState, err error) + LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } // FederationClientError is returned from FederationClient methods in the event of a problem. diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 4578e33aa..854251220 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -120,17 +120,7 @@ func NewInternalAPI( js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) - signingInfo := map[gomatrixserverlib.ServerName]*queue.SigningInfo{} - for _, serverName := range append( - []gomatrixserverlib.ServerName{base.Cfg.Global.ServerName}, - base.Cfg.Global.SecondaryServerNames..., - ) { - signingInfo[serverName] = &queue.SigningInfo{ - KeyID: cfg.Matrix.KeyID, - PrivateKey: cfg.Matrix.PrivateKey, - ServerName: serverName, - } - } + signingInfo := base.Cfg.Global.SigningIdentities() queues := queue.NewOutgoingQueues( federationDB, base.ProcessContext, diff --git a/federationapi/federationapi_keys_test.go b/federationapi/federationapi_keys_test.go index 3acaa70dc..cc03cdece 100644 --- a/federationapi/federationapi_keys_test.go +++ b/federationapi/federationapi_keys_test.go @@ -104,7 +104,7 @@ func TestMain(m *testing.M) { // Create the federation client. s.fedclient = gomatrixserverlib.NewFederationClient( - s.config.Matrix.ServerName, serverKeyID, testPriv, + s.config.Matrix.SigningIdentities(), gomatrixserverlib.WithTransport(transport), ) diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index c37bc87c2..68a06a033 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -103,7 +103,7 @@ func (f *fedClient) GetServerKeys(ctx context.Context, matrixServer gomatrixserv return keys, nil } -func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { +func (f *fedClient) MakeJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, userID string, roomVersions []gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMakeJoin, err error) { for _, r := range f.allowJoins { if r.ID == roomID { res.RoomVersion = r.Version @@ -127,7 +127,7 @@ func (f *fedClient) MakeJoin(ctx context.Context, s gomatrixserverlib.ServerName } return } -func (f *fedClient) SendJoin(ctx context.Context, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { +func (f *fedClient) SendJoin(ctx context.Context, origin, s gomatrixserverlib.ServerName, event *gomatrixserverlib.Event) (res gomatrixserverlib.RespSendJoin, err error) { f.fedClientMutex.Lock() defer f.fedClientMutex.Unlock() for _, r := range f.allowJoins { @@ -283,7 +283,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) fedCli := gomatrixserverlib.NewFederationClient( - serverName, cfg.Global.KeyID, cfg.Global.PrivateKey, + cfg.Global.SigningIdentities(), gomatrixserverlib.WithSkipVerify(true), ) @@ -326,7 +326,7 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { t.Errorf("failed to create invite v2 request: %s", err) continue } - _, err = fedCli.SendInviteV2(context.Background(), serverName, invReq) + _, err = fedCli.SendInviteV2(context.Background(), cfg.Global.ServerName, serverName, invReq) if err == nil { t.Errorf("expected an error, got none") continue diff --git a/federationapi/internal/federationclient.go b/federationapi/internal/federationclient.go index 2636b7fa0..db6348ec1 100644 --- a/federationapi/internal/federationclient.go +++ b/federationapi/internal/federationclient.go @@ -11,13 +11,13 @@ import ( // client. func (a *FederationInternalAPI) GetEventAuth( - ctx context.Context, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (res gomatrixserverlib.RespEventAuth, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetEventAuth(ctx, s, roomVersion, roomID, eventID) + return a.federation.GetEventAuth(ctx, origin, s, roomVersion, roomID, eventID) }) if err != nil { return gomatrixserverlib.RespEventAuth{}, err @@ -26,12 +26,12 @@ func (a *FederationInternalAPI) GetEventAuth( } func (a *FederationInternalAPI) GetUserDevices( - ctx context.Context, s gomatrixserverlib.ServerName, userID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetUserDevices(ctx, s, userID) + return a.federation.GetUserDevices(ctx, origin, s, userID) }) if err != nil { return gomatrixserverlib.RespUserDevices{}, err @@ -40,12 +40,12 @@ func (a *FederationInternalAPI) GetUserDevices( } func (a *FederationInternalAPI) ClaimKeys( - ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.ClaimKeys(ctx, s, oneTimeKeys) + return a.federation.ClaimKeys(ctx, origin, s, oneTimeKeys) }) if err != nil { return gomatrixserverlib.RespClaimKeys{}, err @@ -54,10 +54,10 @@ func (a *FederationInternalAPI) ClaimKeys( } func (a *FederationInternalAPI) QueryKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { ires, err := a.doRequestIfNotBackingOffOrBlacklisted(s, func() (interface{}, error) { - return a.federation.QueryKeys(ctx, s, keys) + return a.federation.QueryKeys(ctx, origin, s, keys) }) if err != nil { return gomatrixserverlib.RespQueryKeys{}, err @@ -66,12 +66,12 @@ func (a *FederationInternalAPI) QueryKeys( } func (a *FederationInternalAPI) Backfill( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.Backfill(ctx, s, roomID, limit, eventIDs) + return a.federation.Backfill(ctx, origin, s, roomID, limit, eventIDs) }) if err != nil { return gomatrixserverlib.Transaction{}, err @@ -80,12 +80,12 @@ func (a *FederationInternalAPI) Backfill( } func (a *FederationInternalAPI) LookupState( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespState, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion) + return a.federation.LookupState(ctx, origin, s, roomID, eventID, roomVersion) }) if err != nil { return gomatrixserverlib.RespState{}, err @@ -94,12 +94,12 @@ func (a *FederationInternalAPI) LookupState( } func (a *FederationInternalAPI) LookupStateIDs( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, ) (res gomatrixserverlib.RespStateIDs, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupStateIDs(ctx, s, roomID, eventID) + return a.federation.LookupStateIDs(ctx, origin, s, roomID, eventID) }) if err != nil { return gomatrixserverlib.RespStateIDs{}, err @@ -108,13 +108,13 @@ func (a *FederationInternalAPI) LookupStateIDs( } func (a *FederationInternalAPI) LookupMissingEvents( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.LookupMissingEvents(ctx, s, roomID, missing, roomVersion) + return a.federation.LookupMissingEvents(ctx, origin, s, roomID, missing, roomVersion) }) if err != nil { return gomatrixserverlib.RespMissingEvents{}, err @@ -123,12 +123,12 @@ func (a *FederationInternalAPI) LookupMissingEvents( } func (a *FederationInternalAPI) GetEvent( - ctx context.Context, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, ) (res gomatrixserverlib.Transaction, err error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.GetEvent(ctx, s, eventID) + return a.federation.GetEvent(ctx, origin, s, eventID) }) if err != nil { return gomatrixserverlib.Transaction{}, err @@ -151,13 +151,13 @@ func (a *FederationInternalAPI) LookupServerKeys( } func (a *FederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion) + return a.federation.MSC2836EventRelationships(ctx, origin, s, r, roomVersion) }) if err != nil { return res, err @@ -166,12 +166,12 @@ func (a *FederationInternalAPI) MSC2836EventRelationships( } func (a *FederationInternalAPI) MSC2946Spaces( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() ires, err := a.doRequestIfNotBlacklisted(s, func() (interface{}, error) { - return a.federation.MSC2946Spaces(ctx, s, roomID, suggestedOnly) + return a.federation.MSC2946Spaces(ctx, origin, s, roomID, suggestedOnly) }) if err != nil { return res, err diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 1b61ec711..603e2a9c9 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -26,6 +26,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( ) (err error) { dir, err := r.federation.LookupRoomAlias( ctx, + r.cfg.Matrix.ServerName, request.ServerName, request.RoomAlias, ) @@ -143,10 +144,16 @@ func (r *FederationInternalAPI) performJoinUsingServer( supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) + if err != nil { + return err + } + // Try to perform a make_join using the information supplied in the // request. respMakeJoin, err := r.federation.MakeJoin( ctx, + origin, serverName, roomID, userID, @@ -192,7 +199,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // Build the join event. event, err := respMakeJoin.JoinEvent.Build( time.Now(), - r.cfg.Matrix.ServerName, + origin, r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, respMakeJoin.RoomVersion, @@ -204,6 +211,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( // Try to perform a send_join using the newly built event. respSendJoin, err := r.federation.SendJoin( context.Background(), + origin, serverName, event, ) @@ -246,7 +254,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( respMakeJoin.RoomVersion, r.keyRing, event, - federatedAuthProvider(ctx, r.federation, r.keyRing, serverName), + federatedAuthProvider(ctx, r.federation, r.keyRing, origin, serverName), ) if err != nil { return fmt.Errorf("respSendJoin.Check: %w", err) @@ -281,6 +289,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( if err = roomserverAPI.SendEventWithState( context.Background(), r.rsAPI, + origin, roomserverAPI.KindNew, respState, event.Headered(respMakeJoin.RoomVersion), @@ -427,6 +436,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // request. respPeek, err := r.federation.Peek( ctx, + r.cfg.Matrix.ServerName, serverName, roomID, peekID, @@ -453,7 +463,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // authenticate the state returned (check its auth events etc) // the equivalent of CheckSendJoinResponse() - authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)) + authEvents, _, err := respState.Check(ctx, respPeek.RoomVersion, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, r.cfg.Matrix.ServerName, serverName)) if err != nil { return fmt.Errorf("error checking state returned from peeking: %w", err) } @@ -475,7 +485,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( // logrus.Warnf("got respPeek %#v", respPeek) // Send the newly returned state to the roomserver to update our local view. if err = roomserverAPI.SendEventWithState( - ctx, r.rsAPI, + ctx, r.rsAPI, r.cfg.Matrix.ServerName, roomserverAPI.KindNew, &respState, respPeek.LatestEvent.Headered(respPeek.RoomVersion), @@ -495,6 +505,11 @@ func (r *FederationInternalAPI) PerformLeave( request *api.PerformLeaveRequest, response *api.PerformLeaveResponse, ) (err error) { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.UserID) + if err != nil { + return err + } + // Deduplicate the server names we were provided. util.SortAndUnique(request.ServerNames) @@ -505,6 +520,7 @@ func (r *FederationInternalAPI) PerformLeave( // request. respMakeLeave, err := r.federation.MakeLeave( ctx, + origin, serverName, request.RoomID, request.UserID, @@ -546,7 +562,7 @@ func (r *FederationInternalAPI) PerformLeave( // Build the leave event. event, err := respMakeLeave.LeaveEvent.Build( time.Now(), - r.cfg.Matrix.ServerName, + origin, r.cfg.Matrix.KeyID, r.cfg.Matrix.PrivateKey, respMakeLeave.RoomVersion, @@ -559,6 +575,7 @@ func (r *FederationInternalAPI) PerformLeave( // Try to perform a send_leave using the newly built event. err = r.federation.SendLeave( ctx, + origin, serverName, event, ) @@ -585,6 +602,11 @@ func (r *FederationInternalAPI) PerformInvite( request *api.PerformInviteRequest, response *api.PerformInviteResponse, ) (err error) { + _, origin, err := r.cfg.Matrix.SplitLocalID('@', request.Event.Sender()) + if err != nil { + return err + } + if request.Event.StateKey() == nil { return errors.New("invite must be a state event") } @@ -607,7 +629,7 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err) } - inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq) + inviteRes, err := r.federation.SendInviteV2(ctx, origin, destination, inviteReq) if err != nil { return fmt.Errorf("r.federation.SendInviteV2: failed to send invite: %w", err) } @@ -708,7 +730,7 @@ func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder // FederatedAuthProvider is an auth chain provider which fetches events from the server provided func federatedAuthProvider( ctx context.Context, federation api.FederationClient, - keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName, + keyRing gomatrixserverlib.JSONVerifier, origin, server gomatrixserverlib.ServerName, ) gomatrixserverlib.AuthChainProvider { // A list of events that we have retried, if they were not included in // the auth events supplied in the send_join. @@ -738,7 +760,7 @@ func federatedAuthProvider( // Try to retrieve the event from the server that sent us the send // join response. - tx, txerr := federation.GetEvent(ctx, server, eventID) + tx, txerr := federation.GetEvent(ctx, origin, server, eventID) if txerr != nil { return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr) } diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 812d3c6da..6c37a1f5a 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -152,16 +152,18 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU( type getUserDevices struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName UserID string } func (h *httpFederationInternalAPI) GetUserDevices( - ctx context.Context, s gomatrixserverlib.ServerName, userID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, userID string, ) (gomatrixserverlib.RespUserDevices, error) { return httputil.CallInternalProxyAPI[getUserDevices, gomatrixserverlib.RespUserDevices, *api.FederationClientError]( "GetUserDevices", h.federationAPIURL+FederationAPIGetUserDevicesPath, h.httpClient, ctx, &getUserDevices{ S: s, + Origin: origin, UserID: userID, }, ) @@ -169,52 +171,58 @@ func (h *httpFederationInternalAPI) GetUserDevices( type claimKeys struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName OneTimeKeys map[string]map[string]string } func (h *httpFederationInternalAPI) ClaimKeys( - ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string, ) (gomatrixserverlib.RespClaimKeys, error) { return httputil.CallInternalProxyAPI[claimKeys, gomatrixserverlib.RespClaimKeys, *api.FederationClientError]( "ClaimKeys", h.federationAPIURL+FederationAPIClaimKeysPath, h.httpClient, ctx, &claimKeys{ S: s, + Origin: origin, OneTimeKeys: oneTimeKeys, }, ) } type queryKeys struct { - S gomatrixserverlib.ServerName - Keys map[string][]string + S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName + Keys map[string][]string } func (h *httpFederationInternalAPI) QueryKeys( - ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string, ) (gomatrixserverlib.RespQueryKeys, error) { return httputil.CallInternalProxyAPI[queryKeys, gomatrixserverlib.RespQueryKeys, *api.FederationClientError]( "QueryKeys", h.federationAPIURL+FederationAPIQueryKeysPath, h.httpClient, ctx, &queryKeys{ - S: s, - Keys: keys, + S: s, + Origin: origin, + Keys: keys, }, ) } type backfill struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string Limit int EventIDs []string } func (h *httpFederationInternalAPI) Backfill( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string, ) (gomatrixserverlib.Transaction, error) { return httputil.CallInternalProxyAPI[backfill, gomatrixserverlib.Transaction, *api.FederationClientError]( "Backfill", h.federationAPIURL+FederationAPIBackfillPath, h.httpClient, ctx, &backfill{ S: s, + Origin: origin, RoomID: roomID, Limit: limit, EventIDs: eventIDs, @@ -224,18 +232,20 @@ func (h *httpFederationInternalAPI) Backfill( type lookupState struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string EventID string RoomVersion gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) LookupState( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion, ) (gomatrixserverlib.RespState, error) { return httputil.CallInternalProxyAPI[lookupState, gomatrixserverlib.RespState, *api.FederationClientError]( "LookupState", h.federationAPIURL+FederationAPILookupStatePath, h.httpClient, ctx, &lookupState{ S: s, + Origin: origin, RoomID: roomID, EventID: eventID, RoomVersion: roomVersion, @@ -245,17 +255,19 @@ func (h *httpFederationInternalAPI) LookupState( type lookupStateIDs struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string EventID string } func (h *httpFederationInternalAPI) LookupStateIDs( - ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID, eventID string, ) (gomatrixserverlib.RespStateIDs, error) { return httputil.CallInternalProxyAPI[lookupStateIDs, gomatrixserverlib.RespStateIDs, *api.FederationClientError]( "LookupStateIDs", h.federationAPIURL+FederationAPILookupStateIDsPath, h.httpClient, ctx, &lookupStateIDs{ S: s, + Origin: origin, RoomID: roomID, EventID: eventID, }, @@ -264,19 +276,21 @@ func (h *httpFederationInternalAPI) LookupStateIDs( type lookupMissingEvents struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomID string Missing gomatrixserverlib.MissingEvents RoomVersion gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) LookupMissingEvents( - ctx context.Context, s gomatrixserverlib.ServerName, roomID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.RespMissingEvents, err error) { return httputil.CallInternalProxyAPI[lookupMissingEvents, gomatrixserverlib.RespMissingEvents, *api.FederationClientError]( "LookupMissingEvents", h.federationAPIURL+FederationAPILookupMissingEventsPath, h.httpClient, ctx, &lookupMissingEvents{ S: s, + Origin: origin, RoomID: roomID, Missing: missing, RoomVersion: roomVersion, @@ -286,16 +300,18 @@ func (h *httpFederationInternalAPI) LookupMissingEvents( type getEvent struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName EventID string } func (h *httpFederationInternalAPI) GetEvent( - ctx context.Context, s gomatrixserverlib.ServerName, eventID string, + ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string, ) (gomatrixserverlib.Transaction, error) { return httputil.CallInternalProxyAPI[getEvent, gomatrixserverlib.Transaction, *api.FederationClientError]( "GetEvent", h.federationAPIURL+FederationAPIGetEventPath, h.httpClient, ctx, &getEvent{ S: s, + Origin: origin, EventID: eventID, }, ) @@ -303,19 +319,21 @@ func (h *httpFederationInternalAPI) GetEvent( type getEventAuth struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName RoomVersion gomatrixserverlib.RoomVersion RoomID string EventID string } func (h *httpFederationInternalAPI) GetEventAuth( - ctx context.Context, s gomatrixserverlib.ServerName, + ctx context.Context, origin, s gomatrixserverlib.ServerName, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string, ) (gomatrixserverlib.RespEventAuth, error) { return httputil.CallInternalProxyAPI[getEventAuth, gomatrixserverlib.RespEventAuth, *api.FederationClientError]( "GetEventAuth", h.federationAPIURL+FederationAPIGetEventAuthPath, h.httpClient, ctx, &getEventAuth{ S: s, + Origin: origin, RoomVersion: roomVersion, RoomID: roomID, EventID: eventID, @@ -351,18 +369,20 @@ func (h *httpFederationInternalAPI) LookupServerKeys( type eventRelationships struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName Req gomatrixserverlib.MSC2836EventRelationshipsRequest RoomVer gomatrixserverlib.RoomVersion } func (h *httpFederationInternalAPI) MSC2836EventRelationships( - ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, + ctx context.Context, origin, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion, ) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) { return httputil.CallInternalProxyAPI[eventRelationships, gomatrixserverlib.MSC2836EventRelationshipsResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPIEventRelationshipsPath, h.httpClient, ctx, &eventRelationships{ S: s, + Origin: origin, Req: r, RoomVer: roomVersion, }, @@ -371,17 +391,19 @@ func (h *httpFederationInternalAPI) MSC2836EventRelationships( type spacesReq struct { S gomatrixserverlib.ServerName + Origin gomatrixserverlib.ServerName SuggestedOnly bool RoomID string } func (h *httpFederationInternalAPI) MSC2946Spaces( - ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, + ctx context.Context, origin, dst gomatrixserverlib.ServerName, roomID string, suggestedOnly bool, ) (res gomatrixserverlib.MSC2946SpacesResponse, err error) { return httputil.CallInternalProxyAPI[spacesReq, gomatrixserverlib.MSC2946SpacesResponse, *api.FederationClientError]( "MSC2836EventRelationships", h.federationAPIURL+FederationAPISpacesSummaryPath, h.httpClient, ctx, &spacesReq{ S: dst, + Origin: origin, SuggestedOnly: suggestedOnly, RoomID: roomID, }, diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 7aa0e4801..7b3edb2a7 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -59,7 +59,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetUserDevices", func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { - res, err := intAPI.GetUserDevices(ctx, req.S, req.UserID) + res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID) return &res, federationClientError(err) }, ), @@ -70,7 +70,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIClaimKeys", func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { - res, err := intAPI.ClaimKeys(ctx, req.S, req.OneTimeKeys) + res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys) return &res, federationClientError(err) }, ), @@ -81,7 +81,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIQueryKeys", func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { - res, err := intAPI.QueryKeys(ctx, req.S, req.Keys) + res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys) return &res, federationClientError(err) }, ), @@ -92,7 +92,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIBackfill", func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.Backfill(ctx, req.S, req.RoomID, req.Limit, req.EventIDs) + res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs) return &res, federationClientError(err) }, ), @@ -103,7 +103,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupState", func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { - res, err := intAPI.LookupState(ctx, req.S, req.RoomID, req.EventID, req.RoomVersion) + res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion) return &res, federationClientError(err) }, ), @@ -114,7 +114,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupStateIDs", func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { - res, err := intAPI.LookupStateIDs(ctx, req.S, req.RoomID, req.EventID) + res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID) return &res, federationClientError(err) }, ), @@ -125,7 +125,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPILookupMissingEvents", func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { - res, err := intAPI.LookupMissingEvents(ctx, req.S, req.RoomID, req.Missing, req.RoomVersion) + res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion) return &res, federationClientError(err) }, ), @@ -136,7 +136,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetEvent", func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { - res, err := intAPI.GetEvent(ctx, req.S, req.EventID) + res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID) return &res, federationClientError(err) }, ), @@ -147,7 +147,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIGetEventAuth", func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { - res, err := intAPI.GetEventAuth(ctx, req.S, req.RoomVersion, req.RoomID, req.EventID) + res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID) return &res, federationClientError(err) }, ), @@ -174,7 +174,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIMSC2836EventRelationships", func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { - res, err := intAPI.MSC2836EventRelationships(ctx, req.S, req.Req, req.RoomVer) + res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer) return &res, federationClientError(err) }, ), @@ -185,7 +185,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalProxyAPI( "FederationAPIMSC2946SpacesSummary", func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { - res, err := intAPI.MSC2946Spaces(ctx, req.S, req.RoomID, req.SuggestedOnly) + res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly) return &res, federationClientError(err) }, ), diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index bf04ee99a..63fc59c89 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -50,7 +50,7 @@ type destinationQueue struct { queues *OutgoingQueues db storage.Database process *process.ProcessContext - signing map[gomatrixserverlib.ServerName]*SigningInfo + signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity rsAPI api.FederationRoomserverAPI client fedapi.FederationClient // federation client origin gomatrixserverlib.ServerName // origin of requests diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 68f354993..31124e0c1 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -15,7 +15,6 @@ package queue import ( - "crypto/ed25519" "encoding/json" "fmt" "sync" @@ -46,7 +45,7 @@ type OutgoingQueues struct { origin gomatrixserverlib.ServerName client fedapi.FederationClient statistics *statistics.Statistics - signing map[gomatrixserverlib.ServerName]*SigningInfo + signing map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity queuesMutex sync.Mutex // protects the below queues map[gomatrixserverlib.ServerName]*destinationQueue } @@ -91,7 +90,7 @@ func NewOutgoingQueues( client fedapi.FederationClient, rsAPI api.FederationRoomserverAPI, statistics *statistics.Statistics, - signing map[gomatrixserverlib.ServerName]*SigningInfo, + signing []*gomatrixserverlib.SigningIdentity, ) *OutgoingQueues { queues := &OutgoingQueues{ disabled: disabled, @@ -101,9 +100,12 @@ func NewOutgoingQueues( origin: origin, client: client, statistics: statistics, - signing: signing, + signing: map[gomatrixserverlib.ServerName]*gomatrixserverlib.SigningIdentity{}, queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, } + for _, identity := range signing { + queues.signing[identity.ServerName] = identity + } // Look up which servers we have pending items for and then rehydrate those queues. if !disabled { serverNames := map[gomatrixserverlib.ServerName]struct{}{} @@ -135,14 +137,6 @@ func NewOutgoingQueues( return queues } -// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables -// around together -type SigningInfo struct { - ServerName gomatrixserverlib.ServerName - KeyID gomatrixserverlib.KeyID - PrivateKey ed25519.PrivateKey -} - type queuedPDU struct { receipt *shared.Receipt pdu *gomatrixserverlib.HeaderedEvent diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index 58745c607..b2ec4b836 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -350,8 +350,8 @@ func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T } rs := &stubFederationRoomServerAPI{} stats := statistics.NewStatistics(db, failuresUntilBlacklist) - signingInfo := map[gomatrixserverlib.ServerName]*SigningInfo{ - "localhost": { + signingInfo := []*gomatrixserverlib.SigningIdentity{ + { KeyID: "ed21019:auto", PrivateKey: test.PrivateKeyA, ServerName: "localhost", diff --git a/federationapi/routing/backfill.go b/federationapi/routing/backfill.go index 7b9ca66f6..272f5e9d8 100644 --- a/federationapi/routing/backfill.go +++ b/federationapi/routing/backfill.go @@ -82,7 +82,8 @@ func Backfill( BackwardsExtremities: map[string][]string{ "": eIDs, }, - ServerName: request.Origin(), + ServerName: request.Origin(), + VirtualHost: request.Destination(), } if req.Limit, err = strconv.Atoi(limit); err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("strconv.Atoi failed") @@ -123,7 +124,7 @@ func Backfill( } txn := gomatrixserverlib.Transaction{ - Origin: cfg.Matrix.ServerName, + Origin: request.Destination(), PDUs: eventJSONs, OriginServerTS: gomatrixserverlib.AsTimestamp(time.Now()), } diff --git a/federationapi/routing/invite.go b/federationapi/routing/invite.go index 504204504..f424fcacd 100644 --- a/federationapi/routing/invite.go +++ b/federationapi/routing/invite.go @@ -140,6 +140,21 @@ func processInvite( } } + if event.StateKey() == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("The invite event has no state key"), + } + } + + _, domain, err := cfg.Matrix.SplitLocalID('@', *event.StateKey()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("The user ID is invalid or domain %q does not belong to this server", domain)), + } + } + // Check that the event is signed by the server sending the request. redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) if err != nil { @@ -175,7 +190,7 @@ func processInvite( // Sign the event so that other servers will know that we have received the invite. signedEvent := event.Sign( - string(cfg.Matrix.ServerName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, + string(domain), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, ) // Add the invite event to the roomserver. diff --git a/federationapi/routing/join.go b/federationapi/routing/join.go index 74d065e59..03809df75 100644 --- a/federationapi/routing/join.go +++ b/federationapi/routing/join.go @@ -131,10 +131,20 @@ func MakeJoin( return jsonerror.InternalServerError() } + identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound( + fmt.Sprintf("Server name %q does not exist", request.Destination()), + ), + } + } + queryRes := api.QueryLatestEventsAndStateResponse{ RoomVersion: verRes.RoomVersion, } - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 5650e3d53..4fd3720f1 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -16,7 +16,6 @@ package routing import ( "encoding/json" - "fmt" "net/http" "time" @@ -136,38 +135,52 @@ func ClaimOneTimeKeys( // LocalKeys returns the local keys for the server. // See https://matrix.org/docs/spec/server_server/unstable.html#publishing-keys func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) util.JSONResponse { - keys, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)) + keys, err := localKeys(cfg, serverName) if err != nil { - return util.ErrorResponse(err) + return util.MessageResponse(http.StatusNotFound, err.Error()) } return util.JSONResponse{Code: http.StatusOK, JSON: keys} } -func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName, validUntil time.Time) (*gomatrixserverlib.ServerKeys, error) { +func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys - if !cfg.Matrix.IsLocalServerName(serverName) { - return nil, fmt.Errorf("server name not known") - } - - keys.ServerName = serverName - keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(validUntil) - - publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) - - keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ - cfg.Matrix.KeyID: { - Key: gomatrixserverlib.Base64Bytes(publicKey), - }, - } - - keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} - for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys { - keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{ - VerifyKey: gomatrixserverlib.VerifyKey{ - Key: oldVerifyKey.PublicKey, - }, - ExpiredTS: oldVerifyKey.ExpiredAt, + var virtualHost *config.VirtualHost + for _, v := range cfg.Matrix.VirtualHosts { + if v.ServerName != serverName { + continue } + virtualHost = v + } + + if virtualHost == nil { + publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) + keys.ServerName = cfg.Matrix.ServerName + keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) + keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ + cfg.Matrix.KeyID: { + Key: gomatrixserverlib.Base64Bytes(publicKey), + }, + } + keys.OldVerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.OldVerifyKey{} + for _, oldVerifyKey := range cfg.Matrix.OldVerifyKeys { + keys.OldVerifyKeys[oldVerifyKey.KeyID] = gomatrixserverlib.OldVerifyKey{ + VerifyKey: gomatrixserverlib.VerifyKey{ + Key: oldVerifyKey.PublicKey, + }, + ExpiredTS: oldVerifyKey.ExpiredAt, + } + } + } else { + publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey) + keys.ServerName = virtualHost.ServerName + keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod)) + keys.VerifyKeys = map[gomatrixserverlib.KeyID]gomatrixserverlib.VerifyKey{ + virtualHost.KeyID: { + Key: gomatrixserverlib.Base64Bytes(publicKey), + }, + } + // TODO: Virtual hosts probably want to be able to specify old signing + // keys too, just in case } toSign, err := json.Marshal(keys.ServerKeyFields) @@ -213,7 +226,7 @@ func NotaryKeys( for serverName, kidToCriteria := range req.ServerKeys { var keyList []gomatrixserverlib.ServerKeys if serverName == cfg.Matrix.ServerName { - if k, err := localKeys(cfg, serverName, time.Now().Add(cfg.Matrix.KeyValidityPeriod)); err == nil { + if k, err := localKeys(cfg, serverName); err == nil { keyList = append(keyList, *k) } else { return util.ErrorResponse(err) diff --git a/federationapi/routing/leave.go b/federationapi/routing/leave.go index a67e4e28b..f1e9f49ba 100644 --- a/federationapi/routing/leave.go +++ b/federationapi/routing/leave.go @@ -13,6 +13,7 @@ package routing import ( + "fmt" "net/http" "time" @@ -60,8 +61,18 @@ func MakeLeave( return jsonerror.InternalServerError() } + identity, err := cfg.Matrix.SigningIdentityFor(request.Destination()) + if err != nil { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound( + fmt.Sprintf("Server name %q does not exist", request.Destination()), + ), + } + } + var queryRes api.QueryLatestEventsAndStateResponse - event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, time.Now(), rsAPI, &queryRes) + event, err := eventutil.QueryAndBuildEvent(httpReq.Context(), &builder, cfg.Matrix, identity, time.Now(), rsAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return util.JSONResponse{ Code: http.StatusNotFound, diff --git a/federationapi/routing/profile.go b/federationapi/routing/profile.go index f672811af..e4d2230ad 100644 --- a/federationapi/routing/profile.go +++ b/federationapi/routing/profile.go @@ -22,7 +22,6 @@ import ( "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -42,16 +41,9 @@ func GetProfile( } } - _, domain, err := gomatrixserverlib.SplitID('@', userID) + _, domain, err := cfg.Matrix.SplitLocalID('@', userID) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed") - return util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument(fmt.Sprintf("Format of user ID %q is invalid", userID)), - } - } - - if domain != cfg.Matrix.ServerName { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidArgumentValue(fmt.Sprintf("Domain %q does not match this server", domain)), diff --git a/federationapi/routing/query.go b/federationapi/routing/query.go index 316c61a14..e6dc52601 100644 --- a/federationapi/routing/query.go +++ b/federationapi/routing/query.go @@ -83,7 +83,7 @@ func RoomAliasToID( } } } else { - resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, roomAlias) + resp, err = federation.LookupRoomAlias(httpReq.Context(), domain, cfg.Matrix.ServerName, roomAlias) if err != nil { switch x := err.(type) { case gomatrix.HTTPError: diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index b3bbaa394..a146d85bd 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -197,12 +197,12 @@ type txnReq struct { // A subset of FederationClient functionality that txn requires. Useful for testing. type txnFederationClient interface { - LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( + LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) - LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, + LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) + GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) + LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } @@ -287,6 +287,7 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res []*gomatrixserverlib.HeaderedEvent{ event.Headered(roomVersion), }, + t.Destination, t.Origin, api.DoNotSendToOtherServers, nil, diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index 1c796f542..b8bfe0221 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -147,7 +147,7 @@ type txnFedClient struct { getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) } -func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( +func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( res gomatrixserverlib.RespState, err error, ) { fmt.Println("testFederationClient.LookupState", eventID) @@ -159,7 +159,7 @@ func (c *txnFedClient) LookupState(ctx context.Context, s gomatrixserverlib.Serv res = r return } -func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { +func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { fmt.Println("testFederationClient.LookupStateIDs", eventID) r, ok := c.stateIDs[eventID] if !ok { @@ -169,7 +169,7 @@ func (c *txnFedClient) LookupStateIDs(ctx context.Context, s gomatrixserverlib.S res = r return } -func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { +func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { fmt.Println("testFederationClient.GetEvent", eventID) r, ok := c.getEvent[eventID] if !ok { @@ -179,7 +179,7 @@ func (c *txnFedClient) GetEvent(ctx context.Context, s gomatrixserverlib.ServerN res = r return } -func (c *txnFedClient) LookupMissingEvents(ctx context.Context, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, +func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { return c.getMissingEvents(missing) } diff --git a/federationapi/routing/threepid.go b/federationapi/routing/threepid.go index ccde9168e..d07faef39 100644 --- a/federationapi/routing/threepid.go +++ b/federationapi/routing/threepid.go @@ -90,7 +90,17 @@ func CreateInvitesFrom3PIDInvites( } // Send all the events - if err := api.SendEvents(req.Context(), rsAPI, api.KindNew, evs, "TODO", cfg.Matrix.ServerName, nil, false); err != nil { + if err := api.SendEvents( + req.Context(), + rsAPI, + api.KindNew, + evs, + cfg.Matrix.ServerName, // TODO: which virtual host? + "TODO", + cfg.Matrix.ServerName, + nil, + false, + ); err != nil { util.GetLogger(req.Context()).WithError(err).Error("SendEvents failed") return jsonerror.InternalServerError() } @@ -126,6 +136,14 @@ func ExchangeThirdPartyInvite( } } + _, senderDomain, err := cfg.Matrix.SplitLocalID('@', builder.Sender) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Invalid sender ID: " + err.Error()), + } + } + // Check that the state key is correct. _, targetDomain, err := gomatrixserverlib.SplitID('@', *builder.StateKey) if err != nil { @@ -171,7 +189,7 @@ func ExchangeThirdPartyInvite( util.GetLogger(httpReq.Context()).WithError(err).Error("failed to make invite v2 request") return jsonerror.InternalServerError() } - signedEvent, err := federation.SendInviteV2(httpReq.Context(), request.Origin(), inviteReq) + signedEvent, err := federation.SendInviteV2(httpReq.Context(), senderDomain, request.Origin(), inviteReq) if err != nil { util.GetLogger(httpReq.Context()).WithError(err).Error("federation.SendInvite failed") return jsonerror.InternalServerError() @@ -189,6 +207,7 @@ func ExchangeThirdPartyInvite( []*gomatrixserverlib.HeaderedEvent{ inviteEvent.Headered(verRes.RoomVersion), }, + request.Destination(), request.Origin(), cfg.Matrix.ServerName, nil, @@ -341,7 +360,7 @@ func buildMembershipEvent( // them responded with an error. func sendToRemoteServer( ctx context.Context, inv invite, - federation federationAPI.FederationClient, _ *config.FederationAPI, + federation federationAPI.FederationClient, cfg *config.FederationAPI, builder gomatrixserverlib.EventBuilder, ) (err error) { remoteServers := make([]gomatrixserverlib.ServerName, 2) @@ -357,7 +376,7 @@ func sendToRemoteServer( } for _, server := range remoteServers { - err = federation.ExchangeThirdPartyInvite(ctx, server, builder) + err = federation.ExchangeThirdPartyInvite(ctx, cfg.Matrix.ServerName, server, builder) if err == nil { return } diff --git a/go.mod b/go.mod index 47c43fbc7..fce538b4c 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index a805a5bf9..224d5aa86 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2 h1:Bet5n+//Yh+A2SuPHD67N8jrOhC/EIKvEgisfsKhTss= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221109092408-715dc88e62e2/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf h1:OPM8urHlGC8m1MfusZg6Oi9zxJiDNuSVh3GQuvZeOHw= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/internal/eventutil/events.go b/internal/eventutil/events.go index d96231963..c572d8830 100644 --- a/internal/eventutil/events.go +++ b/internal/eventutil/events.go @@ -38,7 +38,8 @@ var ErrRoomNoExists = errors.New("room does not exist") // Returns an error if something else went wrong func QueryAndBuildEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, + builder *gomatrixserverlib.EventBuilder, cfg *config.Global, + identity *gomatrixserverlib.SigningIdentity, evTime time.Time, rsAPI api.QueryLatestEventsAndStateAPI, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.HeaderedEvent, error) { if queryRes == nil { @@ -50,24 +51,24 @@ func QueryAndBuildEvent( // This can pass through a ErrRoomNoExists to the caller return nil, err } - return BuildEvent(ctx, builder, cfg, evTime, eventsNeeded, queryRes) + return BuildEvent(ctx, builder, cfg, identity, evTime, eventsNeeded, queryRes) } // BuildEvent builds a Matrix event from the builder and QueryLatestEventsAndStateResponse // provided. func BuildEvent( ctx context.Context, - builder *gomatrixserverlib.EventBuilder, cfg *config.Global, evTime time.Time, + builder *gomatrixserverlib.EventBuilder, cfg *config.Global, + identity *gomatrixserverlib.SigningIdentity, evTime time.Time, eventsNeeded *gomatrixserverlib.StateNeeded, queryRes *api.QueryLatestEventsAndStateResponse, ) (*gomatrixserverlib.HeaderedEvent, error) { - err := addPrevEventsToEvent(builder, eventsNeeded, queryRes) - if err != nil { + if err := addPrevEventsToEvent(builder, eventsNeeded, queryRes); err != nil { return nil, err } event, err := builder.Build( - evTime, cfg.ServerName, cfg.KeyID, - cfg.PrivateKey, queryRes.RoomVersion, + evTime, identity.ServerName, identity.KeyID, + identity.PrivateKey, queryRes.RoomVersion, ) if err != nil { return nil, err diff --git a/keyserver/consumers/devicelistupdate.go b/keyserver/consumers/devicelistupdate.go index 575e41281..cd911f8c6 100644 --- a/keyserver/consumers/devicelistupdate.go +++ b/keyserver/consumers/devicelistupdate.go @@ -30,12 +30,12 @@ import ( // DeviceListUpdateConsumer consumes device list updates that came in over federation. type DeviceListUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - updater *internal.DeviceListUpdater - serverName gomatrixserverlib.ServerName + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + updater *internal.DeviceListUpdater + isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers. @@ -46,12 +46,12 @@ func NewDeviceListUpdateConsumer( updater *internal.DeviceListUpdater, ) *DeviceListUpdateConsumer { return &DeviceListUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), - updater: updater, - serverName: cfg.Matrix.ServerName, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + updater: updater, + isLocalServerName: cfg.Matrix.IsLocalServerName, } } @@ -75,7 +75,7 @@ func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M origin := gomatrixserverlib.ServerName(msg.Header.Get("origin")) if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil { return true - } else if serverName == t.serverName { + } else if t.isLocalServerName(serverName) { return true } else if serverName != origin { return true diff --git a/keyserver/consumers/signingkeyupdate.go b/keyserver/consumers/signingkeyupdate.go index 366e259b4..bcceaad15 100644 --- a/keyserver/consumers/signingkeyupdate.go +++ b/keyserver/consumers/signingkeyupdate.go @@ -31,12 +31,13 @@ import ( // SigningKeyUpdateConsumer consumes signing key updates that came in over federation. type SigningKeyUpdateConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - keyAPI *internal.KeyInternalAPI - cfg *config.KeyServer + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + keyAPI *internal.KeyInternalAPI + cfg *config.KeyServer + isLocalServerName func(gomatrixserverlib.ServerName) bool } // NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers. @@ -47,12 +48,13 @@ func NewSigningKeyUpdateConsumer( keyAPI *internal.KeyInternalAPI, ) *SigningKeyUpdateConsumer { return &SigningKeyUpdateConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), - topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), - keyAPI: keyAPI, - cfg: cfg, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"), + topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + keyAPI: keyAPI, + cfg: cfg, + isLocalServerName: cfg.Matrix.IsLocalServerName, } } @@ -77,7 +79,7 @@ func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.M if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil { logrus.WithError(err).Error("failed to split user id") return true - } else if serverName == t.cfg.Matrix.ServerName { + } else if t.isLocalServerName(serverName) { logrus.Warn("dropping device key update from ourself") return true } else if serverName != origin { diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 3f7c0d8b4..8ff9dfc31 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -96,6 +96,7 @@ type DeviceListUpdater struct { producer KeyChangeProducer fedClient fedsenderapi.KeyserverFederationAPI workerChans []chan gomatrixserverlib.ServerName + thisServer gomatrixserverlib.ServerName // When device lists are stale for a user, they get inserted into this map with a channel which `Update` will // block on or timeout via a select. @@ -139,6 +140,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, + thisServer gomatrixserverlib.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -148,6 +150,7 @@ func NewDeviceListUpdater( api: api, producer: producer, fedClient: fedClient, + thisServer: thisServer, workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, @@ -435,8 +438,7 @@ func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName go "server_name": serverName, "user_id": userID, }) - - res, err := u.fedClient.GetUserDevices(ctx, serverName, userID) + res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID) if err != nil { if errors.Is(err, context.DeadlineExceeded) { return time.Minute * 10, err diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 28a13a0a0..a374c9516 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -129,7 +129,13 @@ func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) { func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient { _, pkey, _ := ed25519.GenerateKey(nil) fedClient := gomatrixserverlib.NewFederationClient( - gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey, + []*gomatrixserverlib.SigningIdentity{ + { + ServerName: gomatrixserverlib.ServerName("example.test"), + KeyID: gomatrixserverlib.KeyID("ed25519:test"), + PrivateKey: pkey, + }, + }, ) fedClient.Client = *gomatrixserverlib.NewClient( gomatrixserverlib.WithTransport(&roundTripper{tripper}), @@ -147,7 +153,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost") event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -219,7 +225,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -288,7 +294,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1) + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 37c55c8f8..9a08a0bb7 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -146,7 +146,7 @@ func (a *KeyInternalAPI) claimRemoteKeys( defer cancel() defer wg.Done() - claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, gomatrixserverlib.ServerName(domain), keysToClaim) + claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim) mu.Lock() defer mu.Unlock() @@ -559,7 +559,7 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer( if len(devKeys) == 0 { return } - queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, gomatrixserverlib.ServerName(serverName), devKeys) + queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Cfg.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys) if err == nil { resultCh <- &queryKeysResp return diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 0a4b8fde6..a86c2da4e 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -58,7 +58,7 @@ func NewInternalAPI( FedClient: fedClient, Producer: keyChangeProducer, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable ap.Updater = updater go func() { if err := updater.Start(); err != nil { diff --git a/roomserver/api/input.go b/roomserver/api/input.go index 45a9ef497..88d523270 100644 --- a/roomserver/api/input.go +++ b/roomserver/api/input.go @@ -94,8 +94,9 @@ type TransactionID struct { // InputRoomEventsRequest is a request to InputRoomEvents type InputRoomEventsRequest struct { - InputRoomEvents []InputRoomEvent `json:"input_room_events"` - Asynchronous bool `json:"async"` + InputRoomEvents []InputRoomEvent `json:"input_room_events"` + Asynchronous bool `json:"async"` + VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` } // InputRoomEventsResponse is a response to InputRoomEvents diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index 2536a4bb5..e70e5ea9c 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -148,6 +148,8 @@ type PerformBackfillRequest struct { Limit int `json:"limit"` // The server interested in the events. ServerName gomatrixserverlib.ServerName `json:"server_name"` + // Which virtual host are we doing this for? + VirtualHost gomatrixserverlib.ServerName `json:"virtual_host"` } // PrevEventIDs returns the prev_event IDs of all backwards extremities, de-duplicated in a lexicographically sorted order. diff --git a/roomserver/api/wrapper.go b/roomserver/api/wrapper.go index 8b031982c..252be557f 100644 --- a/roomserver/api/wrapper.go +++ b/roomserver/api/wrapper.go @@ -26,7 +26,7 @@ import ( func SendEvents( ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, events []*gomatrixserverlib.HeaderedEvent, - origin gomatrixserverlib.ServerName, + virtualHost, origin gomatrixserverlib.ServerName, sendAsServer gomatrixserverlib.ServerName, txnID *TransactionID, async bool, ) error { @@ -40,14 +40,15 @@ func SendEvents( TransactionID: txnID, } } - return SendInputRoomEvents(ctx, rsAPI, ires, async) + return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async) } // SendEventWithState writes an event with the specified kind to the roomserver // with the state at the event as KindOutlier before it. Will not send any event that is // marked as `true` in haveEventIDs. func SendEventWithState( - ctx context.Context, rsAPI InputRoomEventsAPI, kind Kind, + ctx context.Context, rsAPI InputRoomEventsAPI, + virtualHost gomatrixserverlib.ServerName, kind Kind, state *gomatrixserverlib.RespState, event *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName, haveEventIDs map[string]bool, async bool, ) error { @@ -85,17 +86,19 @@ func SendEventWithState( StateEventIDs: stateEventIDs, }) - return SendInputRoomEvents(ctx, rsAPI, ires, async) + return SendInputRoomEvents(ctx, rsAPI, virtualHost, ires, async) } // SendInputRoomEvents to the roomserver. func SendInputRoomEvents( ctx context.Context, rsAPI InputRoomEventsAPI, + virtualHost gomatrixserverlib.ServerName, ires []InputRoomEvent, async bool, ) error { request := InputRoomEventsRequest{ InputRoomEvents: ires, Asynchronous: async, + VirtualHost: virtualHost, } var response InputRoomEventsResponse if err := rsAPI.InputRoomEvents(ctx, &request, &response); err != nil { diff --git a/roomserver/internal/alias.go b/roomserver/internal/alias.go index 175bb9310..329e6af7f 100644 --- a/roomserver/internal/alias.go +++ b/roomserver/internal/alias.go @@ -137,6 +137,11 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( request *api.RemoveRoomAliasRequest, response *api.RemoveRoomAliasResponse, ) error { + _, virtualHost, err := r.Cfg.Matrix.SplitLocalID('@', request.UserID) + if err != nil { + return err + } + roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias) if err != nil { return fmt.Errorf("r.DB.GetRoomIDForAlias: %w", err) @@ -190,6 +195,16 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( sender = ev.Sender() } + _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', sender) + if err != nil { + return err + } + + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + return err + } + builder := &gomatrixserverlib.EventBuilder{ Sender: sender, RoomID: ev.RoomID(), @@ -211,12 +226,12 @@ func (r *RoomserverInternalAPI) RemoveRoomAlias( return err } - newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, stateRes) + newEvent, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, stateRes) if err != nil { return err } - err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, r.ServerName, r.ServerName, nil, false) + err = api.SendEvents(ctx, r, api.KindNew, []*gomatrixserverlib.HeaderedEvent{newEvent}, virtualHost, r.ServerName, r.ServerName, nil, false) if err != nil { return err } diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a11586a5..1a3626609 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -87,10 +87,10 @@ func NewRoomserverAPI( Durable: base.Cfg.Global.JetStream.Durable("RoomserverInputConsumer"), ServerACLs: serverACLs, Queryer: &query.Queryer{ - DB: roomserverDB, - Cache: base.Caches, - ServerName: base.Cfg.Global.ServerName, - ServerACLs: serverACLs, + DB: roomserverDB, + Cache: base.Caches, + IsLocalServerName: base.Cfg.Global.IsLocalServerName, + ServerACLs: serverACLs, }, // perform-er structs get initialised when we have a federation sender to use } @@ -127,13 +127,12 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Inputer: r.Inputer, } r.Joiner = &perform.Joiner{ - ServerName: r.Cfg.Matrix.ServerName, - Cfg: r.Cfg, - DB: r.DB, - FSAPI: r.fsAPI, - RSAPI: r, - Inputer: r.Inputer, - Queryer: r.Queryer, + Cfg: r.Cfg, + DB: r.DB, + FSAPI: r.fsAPI, + RSAPI: r, + Inputer: r.Inputer, + Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ ServerName: r.Cfg.Matrix.ServerName, @@ -163,10 +162,10 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio DB: r.DB, } r.Backfiller = &perform.Backfiller{ - ServerName: r.ServerName, - DB: r.DB, - FSAPI: r.fsAPI, - KeyRing: r.KeyRing, + IsLocalServerName: r.Cfg.Matrix.IsLocalServerName, + DB: r.DB, + FSAPI: r.fsAPI, + KeyRing: r.KeyRing, // Perspective servers are trusted to not lie about server keys, so we will also // prefer these servers when backfilling (assuming they are in the room) rather // than trying random servers diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index f5099ca11..e965691c9 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -278,7 +278,11 @@ func (w *worker) _next() { // a string, because we might want to return that to the caller if // it was a synchronous request. var errString string - if err = w.r.processRoomEvent(w.r.ProcessContext.Context(), &inputRoomEvent); err != nil { + if err = w.r.processRoomEvent( + w.r.ProcessContext.Context(), + gomatrixserverlib.ServerName(msg.Header.Get("virtual_host")), + &inputRoomEvent, + ); err != nil { switch err.(type) { case types.RejectedError: // Don't send events that were rejected to Sentry @@ -358,6 +362,7 @@ func (r *Inputer) queueInputRoomEvents( if replyTo != "" { msg.Header.Set("sync", replyTo) } + msg.Header.Set("virtual_host", string(request.VirtualHost)) msg.Data, err = json.Marshal(e) if err != nil { return nil, fmt.Errorf("json.Marshal: %w", err) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 682aa2b12..72c285f8e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -69,6 +69,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec( // nolint:gocyclo func (r *Inputer) processRoomEvent( ctx context.Context, + virtualHost gomatrixserverlib.ServerName, input *api.InputRoomEvent, ) error { select { @@ -200,7 +201,7 @@ func (r *Inputer) processRoomEvent( isRejected := false authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { return fmt.Errorf("r.fetchAuthEvents: %w", err) } @@ -265,16 +266,17 @@ func (r *Inputer) processRoomEvent( // processRoomEvent. if len(serverRes.ServerNames) > 0 { missingState := missingStateReq{ - origin: input.Origin, - inputer: r, - db: r.DB, - roomInfo: roomInfo, - federation: r.FSAPI, - keys: r.KeyRing, - roomsMu: internal.NewMutexByRoom(), - servers: serverRes.ServerNames, - hadEvents: map[string]bool{}, - haveEvents: map[string]*gomatrixserverlib.Event{}, + origin: input.Origin, + virtualHost: virtualHost, + inputer: r, + db: r.DB, + roomInfo: roomInfo, + federation: r.FSAPI, + keys: r.KeyRing, + roomsMu: internal.NewMutexByRoom(), + servers: serverRes.ServerNames, + hadEvents: map[string]bool{}, + haveEvents: map[string]*gomatrixserverlib.Event{}, } var stateSnapshot *parsedRespState if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { @@ -555,6 +557,7 @@ func (r *Inputer) fetchAuthEvents( ctx context.Context, logger *logrus.Entry, roomInfo *types.RoomInfo, + virtualHost gomatrixserverlib.ServerName, event *gomatrixserverlib.HeaderedEvent, auth *gomatrixserverlib.AuthEvents, known map[string]*types.Event, @@ -605,7 +608,7 @@ func (r *Inputer) fetchAuthEvents( // Request the entire auth chain for the event in question. This should // contain all of the auth events — including ones that we already know — // so we'll need to filter through those in the next section. - res, err = r.FSAPI.GetEventAuth(ctx, serverName, event.RoomVersion, event.RoomID(), event.EventID()) + res, err = r.FSAPI.GetEventAuth(ctx, virtualHost, serverName, event.RoomVersion, event.RoomID(), event.EventID()) if err != nil { logger.WithError(err).Warnf("Failed to get event auth from federation for %q: %s", event.EventID(), err) continue diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index f788c5b3a..03ac2b38d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -41,6 +41,7 @@ func (p *parsedRespState) Events() []*gomatrixserverlib.Event { type missingStateReq struct { log *logrus.Entry + virtualHost gomatrixserverlib.ServerName origin gomatrixserverlib.ServerName db storage.Database roomInfo *types.RoomInfo @@ -101,7 +102,7 @@ func (t *missingStateReq) processEventWithMissingState( // we can just inject all the newEvents as new as we may have only missed 1 or 2 events and have filled // in the gap in the DAG for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -157,7 +158,7 @@ func (t *missingStateReq) processEventWithMissingState( }) } for _, ire := range outlierRoomEvents { - if err = t.inputer.processRoomEvent(ctx, &ire); err != nil { + if err = t.inputer.processRoomEvent(ctx, t.virtualHost, &ire); err != nil { if _, ok := err.(types.RejectedError); !ok { return fmt.Errorf("t.inputer.processRoomEvent (outlier): %w", err) } @@ -180,7 +181,7 @@ func (t *missingStateReq) processEventWithMissingState( stateIDs = append(stateIDs, event.EventID()) } - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: backwardsExtremity.Headered(roomVersion), Origin: t.origin, @@ -199,7 +200,7 @@ func (t *missingStateReq) processEventWithMissingState( // they will automatically fast-forward based on the room state at the // extremity in the last step. for _, newEvent := range newEvents { - err = t.inputer.processRoomEvent(ctx, &api.InputRoomEvent{ + err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{ Kind: api.KindOld, Event: newEvent.Headered(roomVersion), Origin: t.origin, @@ -519,7 +520,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve var missingResp *gomatrixserverlib.RespMissingEvents for _, server := range t.servers { var m gomatrixserverlib.RespMissingEvents - if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ + if m, err = t.federation.LookupMissingEvents(ctx, t.virtualHost, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. EarliestEvents: latestEvents, @@ -635,7 +636,7 @@ func (t *missingStateReq) lookupMissingStateViaState( span, ctx := opentracing.StartSpanFromContext(ctx, "lookupMissingStateViaState") defer span.Finish() - state, err := t.federation.LookupState(ctx, t.origin, roomID, eventID, roomVersion) + state, err := t.federation.LookupState(ctx, t.virtualHost, t.origin, roomID, eventID, roomVersion) if err != nil { return nil, err } @@ -675,7 +676,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo totalctx, totalcancel := context.WithTimeout(ctx, time.Minute*5) for _, serverName := range t.servers { reqctx, reqcancel := context.WithTimeout(totalctx, time.Second*20) - stateIDs, err = t.federation.LookupStateIDs(reqctx, serverName, roomID, eventID) + stateIDs, err = t.federation.LookupStateIDs(reqctx, t.virtualHost, serverName, roomID, eventID) reqcancel() if err == nil { break @@ -855,7 +856,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs for _, serverName := range t.servers { reqctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) + txn, err := t.federation.GetEvent(reqctx, t.virtualHost, serverName, missingEventID) if err != nil || len(txn.PDUs) == 0 { t.log.WithError(err).WithField("missing_event_id", missingEventID).Warn("Failed to get missing /event for event ID") if errors.Is(err, context.DeadlineExceeded) { diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index 0406fc8f4..d42f4e45d 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -139,7 +139,12 @@ func (r *Admin) PerformAdminEvacuateRoom( return nil } - event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, time.Now(), &eventsNeeded, latestRes) + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + continue + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, latestRes) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -242,6 +247,15 @@ func (r *Admin) PerformAdminDownloadState( req *api.PerformAdminDownloadStateRequest, res *api.PerformAdminDownloadStateResponse, ) error { + _, senderDomain, err := r.Cfg.Matrix.SplitLocalID('@', req.UserID) + if err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("r.Cfg.Matrix.SplitLocalID: %s", err), + } + return nil + } + roomInfo, err := r.DB.RoomInfo(ctx, req.RoomID) if err != nil { res.Error = &api.PerformError{ @@ -273,7 +287,7 @@ func (r *Admin) PerformAdminDownloadState( for _, fwdExtremity := range fwdExtremities { var state gomatrixserverlib.RespState - state, err = r.Inputer.FSAPI.LookupState(ctx, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) + state, err = r.Inputer.FSAPI.LookupState(ctx, r.Inputer.ServerName, req.ServerName, req.RoomID, fwdExtremity.EventID, roomInfo.RoomVersion) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, @@ -331,7 +345,12 @@ func (r *Admin) PerformAdminDownloadState( Depth: depth, } - ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, time.Now(), &eventsNeeded, queryRes) + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + return err + } + + ev, err := eventutil.BuildEvent(ctx, builder, r.Cfg.Matrix, identity, time.Now(), &eventsNeeded, queryRes) if err != nil { res.Error = &api.PerformError{ Code: api.PerformErrorBadRequest, diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 57e121ea2..069f017a9 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -37,10 +37,10 @@ import ( const maxBackfillServers = 5 type Backfiller struct { - ServerName gomatrixserverlib.ServerName - DB storage.Database - FSAPI federationAPI.RoomserverFederationAPI - KeyRing gomatrixserverlib.JSONVerifier + IsLocalServerName func(gomatrixserverlib.ServerName) bool + DB storage.Database + FSAPI federationAPI.RoomserverFederationAPI + KeyRing gomatrixserverlib.JSONVerifier // The servers which should be preferred above other servers when backfilling PreferServers []gomatrixserverlib.ServerName @@ -55,7 +55,7 @@ func (r *Backfiller) PerformBackfill( // if we are requesting the backfill then we need to do a federation hit // TODO: we could be more sensible and fetch as many events we already have then request the rest // which is what the syncapi does already. - if request.ServerName == r.ServerName { + if r.IsLocalServerName(request.ServerName) { return r.backfillViaFederation(ctx, request, response) } // someone else is requesting the backfill, try to service their request. @@ -112,16 +112,18 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform if info == nil || info.IsStub() { return fmt.Errorf("backfillViaFederation: missing room info for room %s", req.RoomID) } - requester := newBackfillRequester(r.DB, r.FSAPI, r.ServerName, req.BackwardsExtremities, r.PreferServers) + requester := newBackfillRequester(r.DB, r.FSAPI, req.VirtualHost, r.IsLocalServerName, req.BackwardsExtremities, r.PreferServers) // Request 100 items regardless of what the query asks for. // We don't want to go much higher than this. // We can't honour exactly the limit as some sytests rely on requesting more for tests to pass // (so we don't need to hit /state_ids which the test has no listener for) // Specifically the test "Outbound federation can backfill events" events, err := gomatrixserverlib.RequestBackfill( - ctx, requester, - r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100) + ctx, req.VirtualHost, requester, + r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, + ) if err != nil { + logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed") return err } logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) @@ -144,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform var entries []types.StateEntry if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true); err != nil { // attempt to fetch the missing events - r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs) + r.fetchAndStoreMissingEvents(ctx, info.RoomVersion, requester, stateIDs, req.VirtualHost) // try again entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs, true) if err != nil { @@ -173,7 +175,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform // fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just // best effort. func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion, - backfillRequester *backfillRequester, stateIDs []string) { + backfillRequester *backfillRequester, stateIDs []string, virtualHost gomatrixserverlib.ServerName) { servers := backfillRequester.servers @@ -198,7 +200,7 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom continue // already found } logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id) - res, err := r.FSAPI.GetEvent(ctx, srv, id) + res, err := r.FSAPI.GetEvent(ctx, virtualHost, srv, id) if err != nil { logger.WithError(err).Warn("failed to get event from server") continue @@ -241,11 +243,12 @@ func (r *Backfiller) fetchAndStoreMissingEvents(ctx context.Context, roomVer gom // backfillRequester implements gomatrixserverlib.BackfillRequester type backfillRequester struct { - db storage.Database - fsAPI federationAPI.RoomserverFederationAPI - thisServer gomatrixserverlib.ServerName - preferServer map[gomatrixserverlib.ServerName]bool - bwExtrems map[string][]string + db storage.Database + fsAPI federationAPI.RoomserverFederationAPI + virtualHost gomatrixserverlib.ServerName + isLocalServerName func(gomatrixserverlib.ServerName) bool + preferServer map[gomatrixserverlib.ServerName]bool + bwExtrems map[string][]string // per-request state servers []gomatrixserverlib.ServerName @@ -255,7 +258,9 @@ type backfillRequester struct { } func newBackfillRequester( - db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, thisServer gomatrixserverlib.ServerName, + db storage.Database, fsAPI federationAPI.RoomserverFederationAPI, + virtualHost gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, bwExtrems map[string][]string, preferServers []gomatrixserverlib.ServerName, ) *backfillRequester { preferServer := make(map[gomatrixserverlib.ServerName]bool) @@ -265,7 +270,8 @@ func newBackfillRequester( return &backfillRequester{ db: db, fsAPI: fsAPI, - thisServer: thisServer, + virtualHost: virtualHost, + isLocalServerName: isLocalServerName, eventIDToBeforeStateIDs: make(map[string][]string), eventIDMap: make(map[string]*gomatrixserverlib.Event), bwExtrems: bwExtrems, @@ -450,7 +456,7 @@ FindSuccessor: } // possibly return all joined servers depending on history visiblity - memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.thisServer) + memberEventsFromVis, visibility, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries, b.virtualHost) b.historyVisiblity = visibility if err != nil { logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules") @@ -477,7 +483,7 @@ FindSuccessor: } var servers []gomatrixserverlib.ServerName for server := range serverSet { - if server == b.thisServer { + if b.isLocalServerName(server) { continue } if b.preferServer[server] { // insert at the front @@ -496,10 +502,10 @@ FindSuccessor: // Backfill performs a backfill request to the given server. // https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid -func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string, +func (b *backfillRequester) Backfill(ctx context.Context, origin, server gomatrixserverlib.ServerName, roomID string, limit int, fromEventIDs []string) (gomatrixserverlib.Transaction, error) { - tx, err := b.fsAPI.Backfill(ctx, server, roomID, limit, fromEventIDs) + tx, err := b.fsAPI.Backfill(ctx, origin, server, roomID, limit, fromEventIDs) return tx, err } diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 9d596ab30..4de008c66 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -39,11 +39,10 @@ import ( ) type Joiner struct { - ServerName gomatrixserverlib.ServerName - Cfg *config.RoomServer - FSAPI fsAPI.RoomserverFederationAPI - RSAPI rsAPI.RoomserverInternalAPI - DB storage.Database + Cfg *config.RoomServer + FSAPI fsAPI.RoomserverFederationAPI + RSAPI rsAPI.RoomserverInternalAPI + DB storage.Database Inputer *input.Inputer Queryer *query.Queryer @@ -197,7 +196,7 @@ func (r *Joiner) performJoinRoomByID( // Prepare the template for the join event. userID := req.UserID - _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) if err != nil { return "", "", &rsAPI.PerformError{ Code: rsAPI.PerformErrorBadRequest, @@ -283,7 +282,7 @@ func (r *Joiner) performJoinRoomByID( // locally on the homeserver. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, userDomain, &eb) switch err { case nil: @@ -410,7 +409,9 @@ func (r *Joiner) populateAuthorisedViaUserForRestrictedJoin( } func buildEvent( - ctx context.Context, db storage.Database, cfg *config.Global, builder *gomatrixserverlib.EventBuilder, + ctx context.Context, db storage.Database, cfg *config.Global, + senderDomain gomatrixserverlib.ServerName, + builder *gomatrixserverlib.EventBuilder, ) (*gomatrixserverlib.HeaderedEvent, *rsAPI.QueryLatestEventsAndStateResponse, error) { eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(builder) if err != nil { @@ -438,7 +439,12 @@ func buildEvent( } } - ev, err := eventutil.BuildEvent(ctx, builder, cfg, time.Now(), &eventsNeeded, &queryRes) + identity, err := cfg.SigningIdentityFor(senderDomain) + if err != nil { + return nil, nil, err + } + + ev, err := eventutil.BuildEvent(ctx, builder, cfg, identity, time.Now(), &eventsNeeded, &queryRes) if err != nil { return nil, nil, err } diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index 49e4b479a..fa998e3e1 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -162,21 +162,21 @@ func (r *Leaver) performLeaveRoomByID( return nil, fmt.Errorf("eb.SetUnsigned: %w", err) } + // Get the sender domain. + _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', eb.Sender) + if serr != nil { + return nil, fmt.Errorf("sender %q is invalid", eb.Sender) + } + // We know that the user is in the room at this point so let's build // a leave event. // TODO: Check what happens if the room exists on the server // but everyone has since left. I suspect it does the wrong thing. - event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, &eb) + event, buildRes, err := buildEvent(ctx, r.DB, r.Cfg.Matrix, senderDomain, &eb) if err != nil { return nil, fmt.Errorf("eventutil.BuildEvent: %w", err) } - // Get the sender domain. - _, senderDomain, serr := gomatrixserverlib.SplitID('@', event.Sender()) - if serr != nil { - return nil, fmt.Errorf("sender %q is invalid", event.Sender()) - } - // Give our leave event to the roomserver input stream. The // roomserver will process the membership change and notify // downstream automatically. diff --git a/roomserver/internal/perform/perform_upgrade.go b/roomserver/internal/perform/perform_upgrade.go index 38abe323c..02a19911c 100644 --- a/roomserver/internal/perform/perform_upgrade.go +++ b/roomserver/internal/perform/perform_upgrade.go @@ -60,7 +60,7 @@ func (r *Upgrader) performRoomUpgrade( ) (string, *api.PerformError) { roomID := req.RoomID userID := req.UserID - _, userDomain, err := gomatrixserverlib.SplitID('@', userID) + _, userDomain, err := r.Cfg.Matrix.SplitLocalID('@', userID) if err != nil { return "", &api.PerformError{ Code: api.PerformErrorNotAllowed, @@ -558,7 +558,7 @@ func (r *Upgrader) sendInitialEvents(ctx context.Context, evTime time.Time, user SendAsServer: api.DoNotSendToOtherServers, }) } - if err = api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { + if err = api.SendInputRoomEvents(ctx, r.URSAPI, userDomain, inputs, false); err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to send new room %q to roomserver: %s", newRoomID, err), } @@ -595,8 +595,21 @@ func (r *Upgrader) makeHeaderedEvent(ctx context.Context, evTime time.Time, user Msg: fmt.Sprintf("Failed to set new %q event content: %s", builder.Type, err), } } + // Get the sender domain. + _, senderDomain, serr := r.Cfg.Matrix.SplitLocalID('@', builder.Sender) + if serr != nil { + return nil, &api.PerformError{ + Msg: fmt.Sprintf("Failed to split user ID %q: %s", builder.Sender, err), + } + } + identity, err := r.Cfg.Matrix.SigningIdentityFor(senderDomain) + if err != nil { + return nil, &api.PerformError{ + Msg: fmt.Sprintf("Failed to get signing identity for %q: %s", senderDomain, err), + } + } var queryRes api.QueryLatestEventsAndStateResponse - headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, evTime, r.URSAPI, &queryRes) + headeredEvent, err := eventutil.QueryAndBuildEvent(ctx, &builder, r.Cfg.Matrix, identity, evTime, r.URSAPI, &queryRes) if err == eventutil.ErrRoomNoExists { return nil, &api.PerformError{ Code: api.PerformErrorNoRoom, @@ -686,7 +699,7 @@ func (r *Upgrader) sendHeaderedEvent( Origin: serverName, SendAsServer: sendAsServer, }) - if err := api.SendInputRoomEvents(ctx, r.URSAPI, inputs, false); err != nil { + if err := api.SendInputRoomEvents(ctx, r.URSAPI, serverName, inputs, false); err != nil { return &api.PerformError{ Msg: fmt.Sprintf("Failed to send new %q event to roomserver: %s", headeredEvent.Type(), err), } diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 8850e5c46..d8456fb43 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -37,10 +37,10 @@ import ( ) type Queryer struct { - DB storage.Database - Cache caching.RoomServerCaches - ServerName gomatrixserverlib.ServerName - ServerACLs *acls.ServerACLs + DB storage.Database + Cache caching.RoomServerCaches + IsLocalServerName func(gomatrixserverlib.ServerName) bool + ServerACLs *acls.ServerACLs } // QueryLatestEventsAndState implements api.RoomserverInternalAPI @@ -392,7 +392,7 @@ func (r *Queryer) QueryServerJoinedToRoom( } response.RoomExists = true - if request.ServerName == r.ServerName || request.ServerName == "" { + if r.IsLocalServerName(request.ServerName) || request.ServerName == "" { response.IsInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID) if err != nil { return fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err) diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 4e98af853..24b5515e5 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -44,7 +44,7 @@ func Test_SharedUsers(t *testing.T) { // SetFederationAPI starts the room event input consumer rsAPI.SetFederationAPI(nil, nil) // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } diff --git a/setup/base/base.go b/setup/base/base.go index 2e3a3a195..14edadd96 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -364,10 +364,10 @@ func (b *BaseDendrite) CreateClient() *gomatrixserverlib.Client { // CreateFederationClient creates a new federation client. Should only be called // once per component. func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationClient { + identities := b.Cfg.Global.SigningIdentities() if b.Cfg.Global.DisableFederation { return gomatrixserverlib.NewFederationClient( - b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, b.Cfg.Global.PrivateKey, - gomatrixserverlib.WithTransport(noOpHTTPTransport), + identities, gomatrixserverlib.WithTransport(noOpHTTPTransport), ) } opts := []gomatrixserverlib.ClientOption{ @@ -379,8 +379,7 @@ func (b *BaseDendrite) CreateFederationClient() *gomatrixserverlib.FederationCli opts = append(opts, gomatrixserverlib.WithDNSCache(b.DNSCache)) } client := gomatrixserverlib.NewFederationClient( - b.Cfg.Global.ServerName, b.Cfg.Global.KeyID, - b.Cfg.Global.PrivateKey, opts..., + identities, opts..., ) client.SetUserAgent(fmt.Sprintf("Dendrite/%s", internal.VersionString())) return client diff --git a/setup/config/config.go b/setup/config/config.go index e99852ec9..918bcbe3b 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -231,6 +231,21 @@ func loadConfig( return nil, err } + for _, v := range c.Global.VirtualHosts { + if v.KeyValidityPeriod == 0 { + v.KeyValidityPeriod = c.Global.KeyValidityPeriod + } + if v.PrivateKeyPath == "" { + v.KeyID = c.Global.KeyID + v.PrivateKey = c.Global.PrivateKey + continue + } + privateKeyPath := absPath(basePath, v.PrivateKeyPath) + if v.KeyID, v.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { + return nil, err + } + } + for _, key := range c.Global.OldVerifyKeys { switch { case key.PrivateKeyPath != "": diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 825772827..f2fdd021e 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -1,6 +1,7 @@ package config import ( + "fmt" "math/rand" "strconv" "strings" @@ -15,7 +16,7 @@ type Global struct { ServerName gomatrixserverlib.ServerName `yaml:"server_name"` // The secondary server names, used for virtual hosting. - SecondaryServerNames []gomatrixserverlib.ServerName `yaml:"-"` + VirtualHosts []*VirtualHost `yaml:"virtual_hosts"` // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` @@ -114,6 +115,10 @@ func (c *Global) Verify(configErrs *ConfigErrors, isMonolith bool) { checkNotEmpty(configErrs, "global.server_name", string(c.ServerName)) checkNotEmpty(configErrs, "global.private_key", string(c.PrivateKeyPath)) + for _, v := range c.VirtualHosts { + v.Verify(configErrs) + } + c.JetStream.Verify(configErrs, isMonolith) c.Metrics.Verify(configErrs, isMonolith) c.Sentry.Verify(configErrs, isMonolith) @@ -127,14 +132,85 @@ func (c *Global) IsLocalServerName(serverName gomatrixserverlib.ServerName) bool if c.ServerName == serverName { return true } - for _, secondaryName := range c.SecondaryServerNames { - if secondaryName == serverName { + for _, v := range c.VirtualHosts { + if v.ServerName == serverName { return true } } return false } +func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib.ServerName, error) { + u, s, err := gomatrixserverlib.SplitID(sigil, id) + if err != nil { + return u, s, err + } + if !c.IsLocalServerName(s) { + return u, s, fmt.Errorf("server name %q not known", s) + } + return u, s, nil +} + +func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.SigningIdentity, error) { + for _, id := range c.SigningIdentities() { + if id.ServerName == serverName { + return id, nil + } + } + return nil, fmt.Errorf("no signing identity %q", serverName) +} + +func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { + identities := make([]*gomatrixserverlib.SigningIdentity, 0, len(c.VirtualHosts)+1) + identities = append(identities, &gomatrixserverlib.SigningIdentity{ + ServerName: c.ServerName, + KeyID: c.KeyID, + PrivateKey: c.PrivateKey, + }) + for _, v := range c.VirtualHosts { + identities = append(identities, v.SigningIdentity()) + } + return identities +} + +type VirtualHost struct { + // The server name of the virtual host. + ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + + // The key ID of the private key. If not specified, the default global key ID + // will be used instead. + KeyID gomatrixserverlib.KeyID `yaml:"key_id"` + + // Path to the private key. If not specified, the default global private key + // will be used instead. + PrivateKeyPath Path `yaml:"private_key"` + + // The private key itself. + PrivateKey ed25519.PrivateKey `yaml:"-"` + + // How long a remote server can cache our server key for before requesting it again. + // Increasing this number will reduce the number of requests made by remote servers + // for our key, but increases the period a compromised key will be considered valid + // by remote servers. + // Defaults to 24 hours. + KeyValidityPeriod time.Duration `yaml:"key_validity_period"` + + // Is registration enabled on this virtual host? + AllowRegistration bool `json:"allow_registration"` +} + +func (v *VirtualHost) Verify(configErrs *ConfigErrors) { + checkNotEmpty(configErrs, "virtual_host.*.server_name", string(v.ServerName)) +} + +func (v *VirtualHost) SigningIdentity() *gomatrixserverlib.SigningIdentity { + return &gomatrixserverlib.SigningIdentity{ + ServerName: v.ServerName, + KeyID: v.KeyID, + PrivateKey: v.PrivateKey, + } +} + type OldVerifyKeys struct { // Path to the private key. PrivateKeyPath Path `yaml:"private_key"` diff --git a/setup/mscs/msc2836/msc2836.go b/setup/mscs/msc2836/msc2836.go index 98502f5cb..bc369c166 100644 --- a/setup/mscs/msc2836/msc2836.go +++ b/setup/mscs/msc2836/msc2836.go @@ -397,7 +397,7 @@ func (rc *reqCtx) includeChildren(db Database, parentID string, limit int, recen serversToQuery := rc.getServersForEventID(parentID) var result *MSC2836EventRelationshipsResponse for _, srv := range serversToQuery { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: parentID, Direction: "down", Limit: 100, @@ -484,7 +484,7 @@ func walkThread( // MSC2836EventRelationships performs an /event_relationships request to a remote server func (rc *reqCtx) MSC2836EventRelationships(eventID string, srv gomatrixserverlib.ServerName, ver gomatrixserverlib.RoomVersion) (*MSC2836EventRelationshipsResponse, error) { - res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ + res, err := rc.fsAPI.MSC2836EventRelationships(rc.ctx, rc.serverName, srv, gomatrixserverlib.MSC2836EventRelationshipsRequest{ EventID: eventID, DepthFirst: rc.req.DepthFirst, Direction: rc.req.Direction, @@ -665,7 +665,7 @@ func (rc *reqCtx) injectResponseToRoomserver(res *MSC2836EventRelationshipsRespo }) } // we've got the data by this point so use a background context - err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, ires, false) + err := roomserver.SendInputRoomEvents(context.Background(), rc.rsAPI, rc.serverName, ires, false) if err != nil { util.GetLogger(rc.ctx).WithError(err).Error("failed to inject MSC2836EventRelationshipsResponse into the roomserver") } diff --git a/setup/mscs/msc2946/msc2946.go b/setup/mscs/msc2946/msc2946.go index d7c022544..56c063598 100644 --- a/setup/mscs/msc2946/msc2946.go +++ b/setup/mscs/msc2946/msc2946.go @@ -433,7 +433,7 @@ func (w *walker) federatedRoomInfo(roomID string, vias []string) *gomatrixserver if serverName == string(w.thisServer) { continue } - res, err := w.fsAPI.MSC2946Spaces(ctx, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) + res, err := w.fsAPI.MSC2946Spaces(ctx, w.thisServer, gomatrixserverlib.ServerName(serverName), roomID, w.suggestedOnly) if err != nil { util.GetLogger(w.ctx).WithError(err).Warnf("failed to call MSC2946Spaces on server %s", serverName) continue diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 96ebba7ef..92f081500 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -28,22 +28,20 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/streams" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - notifier *notifier.Notifier - stream streams.StreamProvider - serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.SyncRoomserverAPI + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + notifier *notifier.Notifier + stream streams.StreamProvider + rsAPI roomserverAPI.SyncRoomserverAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. @@ -59,15 +57,14 @@ func NewOutputKeyChangeEventConsumer( stream streams.StreamProvider, ) *OutputKeyChangeEventConsumer { s := &OutputKeyChangeEventConsumer{ - ctx: process.Context(), - jetstream: js, - durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), - topic: topic, - db: store, - serverName: cfg.Matrix.ServerName, - rsAPI: rsAPI, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + durable: cfg.Matrix.JetStream.Durable("SyncAPIKeyChangeConsumer"), + topic: topic, + db: store, + rsAPI: rsAPI, + notifier: notifier, + stream: stream, } return s diff --git a/syncapi/consumers/receipts.go b/syncapi/consumers/receipts.go index 8aaa65730..e39d43f94 100644 --- a/syncapi/consumers/receipts.go +++ b/syncapi/consumers/receipts.go @@ -34,14 +34,13 @@ import ( // OutputReceiptEventConsumer consumes events that originated in the EDU server. type OutputReceiptEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - stream streams.StreamProvider - notifier *notifier.Notifier - serverName gomatrixserverlib.ServerName + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + stream streams.StreamProvider + notifier *notifier.Notifier } // NewOutputReceiptEventConsumer creates a new OutputReceiptEventConsumer. @@ -55,14 +54,13 @@ func NewOutputReceiptEventConsumer( stream streams.StreamProvider, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), - durable: cfg.Matrix.JetStream.Durable("SyncAPIReceiptConsumer"), - db: store, - notifier: notifier, - stream: stream, - serverName: cfg.Matrix.ServerName, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputReceiptEvent), + durable: cfg.Matrix.JetStream.Durable("SyncAPIReceiptConsumer"), + db: store, + notifier: notifier, + stream: stream, } } diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index f767615c8..1b67f5684 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -364,11 +364,7 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom // TODO: check that it's a join and not a profile change (means unmarshalling prev_content) if membership == gomatrixserverlib.Join { // check it's a local join - _, domain, err := gomatrixserverlib.SplitID('@', *ev.StateKey()) - if err != nil { - return sp, fmt.Errorf("gomatrixserverlib.SplitID: %w", err) - } - if domain != s.cfg.Matrix.ServerName { + if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil { return sp, nil } @@ -390,9 +386,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent( if msg.Event.StateKey() == nil { return } - if _, serverName, err := gomatrixserverlib.SplitID('@', *msg.Event.StateKey()); err != nil { - return - } else if serverName != s.cfg.Matrix.ServerName { + if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil { return } pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) diff --git a/syncapi/consumers/sendtodevice.go b/syncapi/consumers/sendtodevice.go index 49d84cca3..356e83263 100644 --- a/syncapi/consumers/sendtodevice.go +++ b/syncapi/consumers/sendtodevice.go @@ -37,15 +37,15 @@ import ( // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. type OutputSendToDeviceEventConsumer struct { - ctx context.Context - jetstream nats.JetStreamContext - durable string - topic string - db storage.Database - keyAPI keyapi.SyncKeyAPI - serverName gomatrixserverlib.ServerName // our server name - stream streams.StreamProvider - notifier *notifier.Notifier + ctx context.Context + jetstream nats.JetStreamContext + durable string + topic string + db storage.Database + keyAPI keyapi.SyncKeyAPI + isLocalServerName func(gomatrixserverlib.ServerName) bool + stream streams.StreamProvider + notifier *notifier.Notifier } // NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer. @@ -60,15 +60,15 @@ func NewOutputSendToDeviceEventConsumer( stream streams.StreamProvider, ) *OutputSendToDeviceEventConsumer { return &OutputSendToDeviceEventConsumer{ - ctx: process.Context(), - jetstream: js, - topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), - durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), - db: store, - keyAPI: keyAPI, - serverName: cfg.Matrix.ServerName, - notifier: notifier, - stream: stream, + ctx: process.Context(), + jetstream: js, + topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + durable: cfg.Matrix.JetStream.Durable("SyncAPISendToDeviceConsumer"), + db: store, + keyAPI: keyAPI, + isLocalServerName: cfg.Matrix.IsLocalServerName, + notifier: notifier, + stream: stream, } } @@ -89,7 +89,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] log.WithError(err).Errorf("send-to-device: failed to split user id, dropping message") return true } - if domain != s.serverName { + if !s.isLocalServerName(domain) { log.Tracef("ignoring send-to-device event with destination %s", domain) return true } @@ -114,7 +114,7 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(ctx context.Context, msgs [] if output.Type == "m.room_key_request" { requestingDeviceID := gjson.GetBytes(output.SendToDeviceEvent.Content, "requesting_device_id").Str _, senderDomain, _ := gomatrixserverlib.SplitID('@', output.Sender) - if requestingDeviceID != "" && senderDomain != s.serverName { + if requestingDeviceID != "" && !s.isLocalServerName(senderDomain) { // Mark the requesting device as stale, if we don't know about it. if err = s.keyAPI.PerformMarkAsStaleIfNeeded(ctx, &keyapi.PerformMarkAsStaleRequest{ UserID: output.Sender, Domain: senderDomain, DeviceID: requestingDeviceID, diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 86cf8e736..0d740ebfc 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -528,6 +528,7 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] BackwardsExtremities: backwardsExtremities, Limit: limit, ServerName: r.cfg.Matrix.ServerName, + VirtualHost: r.device.UserDomain(), }, &res) if err != nil { return nil, fmt.Errorf("PerformBackfill failed: %w", err) diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index a4985dbf4..483274481 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -433,7 +433,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { beforeJoinBody := fmt.Sprintf("Before invite in a %s room", tc.historyVisibility) beforeJoinEv := room.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": beforeJoinBody}) eventsToSend := append(room.Events(), beforeJoinEv) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } syncUntil(t, base, aliceDev.AccessToken, false, @@ -472,7 +472,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { eventsToSend = append([]*gomatrixserverlib.HeaderedEvent{}, inviteEv, afterInviteEv, joinEv, msgEv) - if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", nil, false); err != nil { + if err := api.SendEvents(ctx, rsAPI, api.KindNew, eventsToSend, "test", "test", "test", nil, false); err != nil { t.Fatalf("failed to send events: %v", err) } syncUntil(t, base, aliceDev.AccessToken, false, From 5c9aed6af90dc77a7d82a4d16a165715cdd560ca Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Nov 2022 15:11:08 +0000 Subject: [PATCH 20/42] Update to matrix-org/gomatrixserverlib@900369e --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index fce538b4c..1889bf891 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf + github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 224d5aa86..321f00d4a 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf h1:OPM8urHlGC8m1MfusZg6Oi9zxJiDNuSVh3GQuvZeOHw= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221111172904-e8fa860b1cdf/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 h1:VapUpY3oSbEGhfSpnnTKh7bz6AA5R4tVB5lwdlcA6jE= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From 9b8bb55430e658c8752c0b093416eed6d977cb0d Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 15 Nov 2022 17:21:16 +0000 Subject: [PATCH 21/42] Don't get blacklisted hosts when querying joined servers (#2880) Otherwise we just waste time/CPU. --- federationapi/api/api.go | 5 +-- federationapi/consumers/keychange.go | 4 +-- federationapi/consumers/presence.go | 2 +- federationapi/internal/query.go | 2 +- federationapi/storage/interface.go | 2 +- .../storage/postgres/joined_hosts_table.go | 31 +++++++++++++------ federationapi/storage/postgres/storage.go | 8 ++--- federationapi/storage/shared/storage.go | 4 +-- .../storage/sqlite3/joined_hosts_table.go | 15 +++++++-- federationapi/storage/sqlite3/storage.go | 8 ++--- federationapi/storage/tables/interface.go | 2 +- roomserver/internal/input/input_events.go | 5 +-- 12 files changed, 56 insertions(+), 32 deletions(-) diff --git a/federationapi/api/api.go b/federationapi/api/api.go index e34c9e8b6..f4be53b9a 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -198,8 +198,9 @@ type PerformInviteResponse struct { // QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames type QueryJoinedHostServerNamesInRoomRequest struct { - RoomID string `json:"room_id"` - ExcludeSelf bool `json:"exclude_self"` + RoomID string `json:"room_id"` + ExcludeSelf bool `json:"exclude_self"` + ExcludeBlacklisted bool `json:"exclude_blacklisted"` } // QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 7d1ae0f81..601257d4b 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool { } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { sentry.CaptureException(err) logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") @@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool { return true } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { sentry.CaptureException(err) logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") diff --git a/federationapi/consumers/presence.go b/federationapi/consumers/presence.go index 153fc40b5..29b16f373 100644 --- a/federationapi/consumers/presence.go +++ b/federationapi/consumers/presence.go @@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg } // send this presence to all servers who share rooms with this user. - joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true) + joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true) if err != nil { log.WithError(err).Error("failed to get joined hosts") return true diff --git a/federationapi/internal/query.go b/federationapi/internal/query.go index b0a76eeb7..688afa8ea 100644 --- a/federationapi/internal/query.go +++ b/federationapi/internal/query.go @@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom( request *api.QueryJoinedHostServerNamesInRoomRequest, response *api.QueryJoinedHostServerNamesInRoomResponse, ) (err error) { - joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf) + joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted) if err != nil { return } diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 09098cd1e..b15b8bfae 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -32,7 +32,7 @@ type Database interface { GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. - GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) + GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationapi/storage/postgres/joined_hosts_table.go b/federationapi/storage/postgres/joined_hosts_table.go index 5c95b72a8..9a3977560 100644 --- a/federationapi/storage/postgres/joined_hosts_table.go +++ b/federationapi/storage/postgres/joined_hosts_table.go @@ -66,14 +66,20 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" +const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" + + " SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" + + ");" + type joinedHostsStatements struct { - db *sql.DB - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - deleteJoinedHostsForRoomStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt - selectAllJoinedHostsStmt *sql.Stmt - selectJoinedHostsForRoomsStmt *sql.Stmt + db *sql.DB + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + deleteJoinedHostsForRoomStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt + selectJoinedHostsForRoomsStmt *sql.Stmt + selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt } func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { return } + if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil { + return + } return } @@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( } func (s *joinedHostsStatements) SelectJoinedHostsForRooms( - ctx context.Context, roomIDs []string, + ctx context.Context, roomIDs []string, excludingBlacklisted bool, ) ([]gomatrixserverlib.ServerName, error) { - rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) + stmt := s.selectJoinedHostsForRoomsStmt + if excludingBlacklisted { + stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt + } + rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs)) if err != nil { return nil, err } diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index a33fa4a43..fe84e932e 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { return nil, err } + blacklist, err := NewPostgresBlacklistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewPostgresJoinedHostsTable(d.db) if err != nil { return nil, err @@ -58,10 +62,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - blacklist, err := NewPostgresBlacklistTable(d.db) - if err != nil { - return nil, err - } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 4fabff7d4..7d6e8a684 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -117,8 +117,8 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) { - servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) +func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { + servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err } diff --git a/federationapi/storage/sqlite3/joined_hosts_table.go b/federationapi/storage/sqlite3/joined_hosts_table.go index e0e0f2873..83112c150 100644 --- a/federationapi/storage/sqlite3/joined_hosts_table.go +++ b/federationapi/storage/sqlite3/joined_hosts_table.go @@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" +const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" + + " SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" + + ");" + type joinedHostsStatements struct { db *sql.DB insertJoinedHostsStmt *sql.Stmt @@ -74,6 +79,7 @@ type joinedHostsStatements struct { selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic + // selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic } func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts( } func (s *joinedHostsStatements) SelectJoinedHostsForRooms( - ctx context.Context, roomIDs []string, + ctx context.Context, roomIDs []string, excludingBlacklisted bool, ) ([]gomatrixserverlib.ServerName, error) { iRoomIDs := make([]interface{}, len(roomIDs)) for i := range roomIDs { iRoomIDs[i] = roomIDs[i] } - - sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) + query := selectJoinedHostsForRoomsSQL + if excludingBlacklisted { + query = selectJoinedHostsForRoomsExcludingBlacklistedSQL + } + sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1) rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...) if err != nil { return nil, err diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index e86ac817b..d13b5defc 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -41,6 +41,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { return nil, err } + blacklist, err := NewSQLiteBlacklistTable(d.db) + if err != nil { + return nil, err + } joinedHosts, err := NewSQLiteJoinedHostsTable(d.db) if err != nil { return nil, err @@ -57,10 +61,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } - blacklist, err := NewSQLiteBlacklistTable(d.db) - if err != nil { - return nil, err - } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 3c116a1d0..9f4e86a6e 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -58,7 +58,7 @@ type FederationJoinedHosts interface { SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error) SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) - SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) + SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error) } type FederationBlacklist interface { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 72c285f8e..10b8ee27f 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -165,8 +165,9 @@ func (r *Inputer) processRoomEvent( if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ - RoomID: event.RoomID(), - ExcludeSelf: true, + RoomID: event.RoomID(), + ExcludeSelf: true, + ExcludeBlacklisted: true, } if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) From deddf686b9e9be236ca4e3cc265c7cfba38ed4ee Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 09:16:07 +0000 Subject: [PATCH 22/42] Tweak `/key/v2/server` --- federationapi/routing/keys.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 4fd3720f1..693f5b9fe 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -146,10 +146,10 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam var keys gomatrixserverlib.ServerKeys var virtualHost *config.VirtualHost for _, v := range cfg.Matrix.VirtualHosts { - if v.ServerName != serverName { - continue + if v.ServerName == serverName { + virtualHost = v + break } - virtualHost = v } if virtualHost == nil { @@ -188,14 +188,15 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam return nil, err } - keys.Raw, err = gomatrixserverlib.SignJSON( - string(serverName), cfg.Matrix.KeyID, cfg.Matrix.PrivateKey, toSign, - ) + identity, err := cfg.Matrix.SigningIdentityFor(serverName) if err != nil { return nil, err } - return &keys, nil + keys.Raw, err = gomatrixserverlib.SignJSON( + string(identity.ServerName), identity.KeyID, identity.PrivateKey, toSign, + ) + return &keys, err } func NotaryKeys( From d558da1c875704b101dbb26d56e08cd0f6369d35 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 09:34:09 +0000 Subject: [PATCH 23/42] Virtual host server name workaround --- federationapi/routing/keys.go | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 693f5b9fe..8194c9905 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -16,6 +16,7 @@ package routing import ( "encoding/json" + "net" "net/http" "time" @@ -190,7 +191,16 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam identity, err := cfg.Matrix.SigningIdentityFor(serverName) if err != nil { - return nil, err + // TODO: This is a bit of a hack because the Host header can contain a port + // number if it's specified in the well-known file. Try getting a signing + // identity without it to see if that helps. + var h string + if h, _, err = net.SplitHostPort(string(serverName)); err == nil { + identity, err = cfg.Matrix.SigningIdentityFor(gomatrixserverlib.ServerName(h)) + } + if err != nil { + return nil, err + } } keys.Raw, err = gomatrixserverlib.SignJSON( From a2f72dd966fb53b265ed0df0d792027af704a08b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 09:39:19 +0000 Subject: [PATCH 24/42] Fix slice out of bounds in federation API --- federationapi/storage/shared/storage.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 7d6e8a684..f86ae2e0d 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -125,7 +125,9 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, if excludeSelf { for i, server := range servers { if d.IsLocalServerName(server) { - servers = append(servers[:i], servers[i+1:]...) + copy(servers[:i], servers[i+1:]) + servers = servers[:len(servers)-1] + break } } } From 1e714bc3b641c09d958e78060ff1a7b90c79b500 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 10:05:59 +0000 Subject: [PATCH 25/42] Update to NATS Server 2.9.6 and nats.go 1.20.0 --- go.mod | 4 ++-- go.sum | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/go.mod b/go.mod index 1889bf891..97dcdb293 100644 --- a/go.mod +++ b/go.mod @@ -26,8 +26,8 @@ require ( github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 - github.com/nats-io/nats-server/v2 v2.9.4 - github.com/nats-io/nats.go v1.19.0 + github.com/nats-io/nats-server/v2 v2.9.6 + github.com/nats-io/nats.go v1.20.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79 diff --git a/go.sum b/go.sum index 321f00d4a..25331b2f6 100644 --- a/go.sum +++ b/go.sum @@ -385,10 +385,10 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.4 h1:GvRgv1936J/zYUwMg/cqtYaJ6L+bgeIOIvPslbesdow= -github.com/nats-io/nats-server/v2 v2.9.4/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= -github.com/nats-io/nats.go v1.19.0 h1:H6j8aBnTQFoVrTGB6Xjd903UMdE7jz6DS4YkmAqgZ9Q= -github.com/nats-io/nats.go v1.19.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= +github.com/nats-io/nats-server/v2 v2.9.6 h1:RTtK+rv/4CcliOuqGsy58g7MuWkBaWmF5TUNwuUo9Uw= +github.com/nats-io/nats-server/v2 v2.9.6/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= +github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM= +github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= github.com/nats-io/nkeys v0.3.0/go.mod h1:gvUNGjVcM2IPr5rCsRsC6Wb3Hr2CQAm08dsxtV6A5y4= github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= From a916b041b16e3369c4784b648343c241f4fa6494 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 10:28:22 +0000 Subject: [PATCH 26/42] Detect consumer being deleted in `JetStreamConsumer` --- setup/jetstream/helpers.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index 1ec860b04..c1ce9583f 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -2,6 +2,7 @@ package jetstream import ( "context" + "errors" "fmt" "github.com/getsentry/sentry-go" @@ -72,6 +73,9 @@ func JetStreamConsumer( // just timed out and we should try again. continue } + } else if errors.Is(err, nats.ErrConsumerDeleted) { + // The consumer was deleted so stop. + return } else { // Something else went wrong, so we'll panic. sentry.CaptureException(err) From 163dabc498a2b45e9b3f8fee6e24114247192bfa Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 16 Nov 2022 15:10:33 +0000 Subject: [PATCH 27/42] Fix bug in a2f72dd9 --- federationapi/storage/shared/storage.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index f86ae2e0d..1e1ea9e17 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -125,7 +125,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, if excludeSelf { for i, server := range servers { if d.IsLocalServerName(server) { - copy(servers[:i], servers[i+1:]) + copy(servers[i:], servers[i+1:]) servers = servers[:len(servers)-1] break } From df76a172344facfa2d03a910fc4d5b1a7b02dd20 Mon Sep 17 00:00:00 2001 From: devonh Date: Wed, 16 Nov 2022 22:02:25 +0000 Subject: [PATCH 28/42] Add test code coverage reporting (#2871) --- .github/workflows/schedules.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index c07917248..e304c2a59 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -53,12 +53,14 @@ jobs: key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} restore-keys: | ${{ runner.os }}-go${{ matrix.go }}-test-race- - - run: go test -race ./... + - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt env: POSTGRES_HOST: localhost POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres POSTGRES_DB: dendrite + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 # Dummy step to gate other tests on without repeating the whole list initial-tests-done: From 607819f42507d6a3b18ef7c44f98ed8f862a7f78 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Nov 2022 09:26:56 +0000 Subject: [PATCH 29/42] Fix `/key/v2/server`, add HTTP `Host` matching --- federationapi/routing/keys.go | 31 ++++++++++++++----------------- setup/config/config_global.go | 5 +++++ 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index 8194c9905..b2ef1dba4 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -16,7 +16,6 @@ package routing import ( "encoding/json" - "net" "net/http" "time" @@ -146,14 +145,26 @@ func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys var virtualHost *config.VirtualHost +loop: for _, v := range cfg.Matrix.VirtualHosts { if v.ServerName == serverName { virtualHost = v - break + break loop + } + for _, httpHost := range v.MatchHTTPHosts { + if httpHost == serverName { + virtualHost = v + break loop + } } } - if virtualHost == nil { + identity, err := cfg.Matrix.SigningIdentityFor(serverName) + if err != nil { + identity, _ = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName) + } + + if identity.ServerName == serverName { publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = cfg.Matrix.ServerName keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) @@ -189,20 +200,6 @@ func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam return nil, err } - identity, err := cfg.Matrix.SigningIdentityFor(serverName) - if err != nil { - // TODO: This is a bit of a hack because the Host header can contain a port - // number if it's specified in the well-known file. Try getting a signing - // identity without it to see if that helps. - var h string - if h, _, err = net.SplitHostPort(string(serverName)); err == nil { - identity, err = cfg.Matrix.SigningIdentityFor(gomatrixserverlib.ServerName(h)) - } - if err != nil { - return nil, err - } - } - keys.Raw, err = gomatrixserverlib.SignJSON( string(identity.ServerName), identity.KeyID, identity.PrivateKey, toSign, ) diff --git a/setup/config/config_global.go b/setup/config/config_global.go index f2fdd021e..722230d9a 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -195,6 +195,11 @@ type VirtualHost struct { // Defaults to 24 hours. KeyValidityPeriod time.Duration `yaml:"key_validity_period"` + // Match these HTTP Host headers on the `/key/v2/server` endpoint, this needs + // to match all delegated names, likely including the port number too if + // the well-known delegation includes that also. + MatchHTTPHosts []gomatrixserverlib.ServerName `yaml:"match_http_hosts"` + // Is registration enabled on this virtual host? AllowRegistration bool `json:"allow_registration"` } From 16325203af267685e2d664524a3e2976ca1a4bbf Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 17 Nov 2022 09:32:19 +0000 Subject: [PATCH 30/42] Try that again --- federationapi/routing/keys.go | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index b2ef1dba4..ee25ffbb2 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -159,12 +159,12 @@ loop: } } - identity, err := cfg.Matrix.SigningIdentityFor(serverName) - if err != nil { - identity, _ = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName) - } - - if identity.ServerName == serverName { + var identity *gomatrixserverlib.SigningIdentity + var err error + if virtualHost == nil { + if identity, err = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName); err != nil { + return nil, err + } publicKey := cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = cfg.Matrix.ServerName keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(cfg.Matrix.KeyValidityPeriod)) @@ -183,6 +183,9 @@ loop: } } } else { + if identity, err = cfg.Matrix.SigningIdentityFor(virtualHost.ServerName); err != nil { + return nil, err + } publicKey := virtualHost.PrivateKey.Public().(ed25519.PublicKey) keys.ServerName = virtualHost.ServerName keys.ValidUntilTS = gomatrixserverlib.AsTimestamp(time.Now().Add(virtualHost.KeyValidityPeriod)) From ffd8e21ce52bbf542e0e2ed74f032af6a163c56c Mon Sep 17 00:00:00 2001 From: devonh Date: Thu, 17 Nov 2022 15:30:23 +0000 Subject: [PATCH 31/42] Fix nightly code coverage (#2881) --- .github/workflows/schedules.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index e304c2a59..ff4d47187 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -45,6 +45,11 @@ jobs: uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} + - name: Set up gotestfmt + uses: gotesttools/gotestfmt-action@v2 + with: + # Optional: pass GITHUB_TOKEN to avoid rate limiting. + token: ${{ secrets.GITHUB_TOKEN }} - uses: actions/cache@v3 with: path: | From a8e7ffc7ab147ebced766da8e0e1ebb1d75f846a Mon Sep 17 00:00:00 2001 From: devonh Date: Fri, 18 Nov 2022 00:29:23 +0000 Subject: [PATCH 32/42] Add p2p wakeup broadcast handling to pinecone demos (#2841) Adds wakeup broadcast handling to the pinecone demos. This will reset their blacklist status and interrupt any ongoing federation queue backoffs currently in progress for this peer. The end result is that any queued events will quickly be sent to the peer if they had disconnected while attempting to send events to them. --- build/gobind-pinecone/monolith.go | 35 +++++++++++++++++++++++++ cmd/dendrite-demo-pinecone/main.go | 35 +++++++++++++++++++++++++ federationapi/api/api.go | 13 +++++++++ federationapi/internal/perform.go | 16 ++++++++++- federationapi/inthttp/client.go | 13 +++++++++ federationapi/inthttp/server.go | 5 ++++ federationapi/queue/destinationqueue.go | 31 ++++++++++++++++++---- federationapi/queue/queue.go | 16 ++++++++--- go.mod | 4 ++- go.sum | 16 +++++++++-- 10 files changed, 172 insertions(+), 12 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index adb4e40a6..9100ebf0f 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -40,6 +40,7 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" @@ -58,6 +59,7 @@ import ( pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" + pineconeEvents "github.com/matrix-org/pinecone/router/events" pineconeSessions "github.com/matrix-org/pinecone/sessions" "github.com/matrix-org/pinecone/types" @@ -295,7 +297,12 @@ func (m *DendriteMonolith) Start() { m.logger.SetOutput(BindLogger{}) logrus.SetOutput(BindLogger{}) + pineconeEventChannel := make(chan pineconeEvents.Event) m.PineconeRouter = pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) + m.PineconeRouter.EnableHopLimiting() + m.PineconeRouter.EnableWakeupBroadcasts() + m.PineconeRouter.Subscribe(pineconeEventChannel) + m.PineconeQUIC = pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), m.PineconeRouter, []string{"matrix"}) m.PineconeMulticast = pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), m.PineconeRouter) m.PineconeManager = pineconeConnections.NewConnectionManager(m.PineconeRouter, nil) @@ -423,6 +430,34 @@ func (m *DendriteMonolith) Start() { m.logger.Fatal(err) } }() + + go func(ch <-chan pineconeEvents.Event) { + eLog := logrus.WithField("pinecone", "events") + + for event := range ch { + switch e := event.(type) { + case pineconeEvents.PeerAdded: + case pineconeEvents.PeerRemoved: + case pineconeEvents.TreeParentUpdate: + case pineconeEvents.SnakeDescUpdate: + case pineconeEvents.TreeRootAnnUpdate: + case pineconeEvents.SnakeEntryAdded: + case pineconeEvents.SnakeEntryRemoved: + case pineconeEvents.BroadcastReceived: + eLog.Info("Broadcast received from: ", e.PeerID) + + req := &api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + } + res := &api.PerformWakeupServersResponse{} + if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + } + case pineconeEvents.BandwidthReport: + default: + } + } + }(pineconeEventChannel) } func (m *DendriteMonolith) Stop() { diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 6c719a1ee..421b17d56 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -37,6 +37,7 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-pinecone/users" "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" @@ -51,6 +52,7 @@ import ( pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" pineconeRouter "github.com/matrix-org/pinecone/router" + pineconeEvents "github.com/matrix-org/pinecone/router/events" pineconeSessions "github.com/matrix-org/pinecone/sessions" "github.com/sirupsen/logrus" @@ -155,7 +157,12 @@ func main() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck + pineconeEventChannel := make(chan pineconeEvents.Event) pRouter := pineconeRouter.NewRouter(logrus.WithField("pinecone", "router"), sk) + pRouter.EnableHopLimiting() + pRouter.EnableWakeupBroadcasts() + pRouter.Subscribe(pineconeEventChannel) + pQUIC := pineconeSessions.NewSessions(logrus.WithField("pinecone", "sessions"), pRouter, []string{"matrix"}) pMulticast := pineconeMulticast.NewMulticast(logrus.WithField("pinecone", "multicast"), pRouter) pManager := pineconeConnections.NewConnectionManager(pRouter, nil) @@ -293,5 +300,33 @@ func main() { logrus.Fatal(http.ListenAndServe(httpBindAddr, httpRouter)) }() + go func(ch <-chan pineconeEvents.Event) { + eLog := logrus.WithField("pinecone", "events") + + for event := range ch { + switch e := event.(type) { + case pineconeEvents.PeerAdded: + case pineconeEvents.PeerRemoved: + case pineconeEvents.TreeParentUpdate: + case pineconeEvents.SnakeDescUpdate: + case pineconeEvents.TreeRootAnnUpdate: + case pineconeEvents.SnakeEntryAdded: + case pineconeEvents.SnakeEntryRemoved: + case pineconeEvents.BroadcastReceived: + eLog.Info("Broadcast received from: ", e.PeerID) + + req := &api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, + } + res := &api.PerformWakeupServersResponse{} + if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + } + case pineconeEvents.BandwidthReport: + default: + } + } + }(pineconeEventChannel) + base.WaitForShutdown() } diff --git a/federationapi/api/api.go b/federationapi/api/api.go index f4be53b9a..50d0339e4 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -30,6 +30,12 @@ type FederationInternalAPI interface { request *PerformBroadcastEDURequest, response *PerformBroadcastEDUResponse, ) error + + PerformWakeupServers( + ctx context.Context, + request *PerformWakeupServersRequest, + response *PerformWakeupServersResponse, + ) error } type ClientFederationAPI interface { @@ -214,6 +220,13 @@ type PerformBroadcastEDURequest struct { type PerformBroadcastEDUResponse struct { } +type PerformWakeupServersRequest struct { + ServerNames []gomatrixserverlib.ServerName `json:"server_names"` +} + +type PerformWakeupServersResponse struct { +} + type InputPublicKeysRequest struct { Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"` } diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index 603e2a9c9..d86d07e03 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -670,9 +670,23 @@ func (r *FederationInternalAPI) PerformBroadcastEDU( return nil } +// PerformWakeupServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) PerformWakeupServers( + ctx context.Context, + request *api.PerformWakeupServersRequest, + response *api.PerformWakeupServersResponse, +) (err error) { + r.MarkServersAlive(request.ServerNames) + return nil +} + func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { for _, srv := range destinations { - _ = r.db.RemoveServerFromBlacklist(srv) + // Check the statistics cache for the blacklist status to prevent hitting + // the database unnecessarily. + if r.queues.IsServerBlacklisted(srv) { + _ = r.db.RemoveServerFromBlacklist(srv) + } r.queues.RetryServer(srv) } } diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 6c37a1f5a..6eefdc7cd 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -23,6 +23,7 @@ const ( FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest" FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" + FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -150,6 +151,18 @@ func (h *httpFederationInternalAPI) PerformBroadcastEDU( ) } +// Handle an instruction to remove the respective servers from being blacklisted. +func (h *httpFederationInternalAPI) PerformWakeupServers( + ctx context.Context, + request *api.PerformWakeupServersRequest, + response *api.PerformWakeupServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformWakeupServers", h.federationAPIURL+FederationAPIPerformWakeupServers, + h.httpClient, ctx, request, response, + ) +} + type getUserDevices struct { S gomatrixserverlib.ServerName Origin gomatrixserverlib.ServerName diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 7b3edb2a7..21a070392 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -43,6 +43,11 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), ) + internalAPIMux.Handle( + FederationAPIPerformWakeupServers, + httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers), + ) + internalAPIMux.Handle( FederationAPIPerformJoinRequestPath, httputil.MakeInternalRPCAPI( diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index 63fc59c89..a4a87fe99 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -141,23 +141,44 @@ func (oq *destinationQueue) handleBackoffNotifier() { } } +// wakeQueueIfEventsPending calls wakeQueueAndNotify only if there are +// pending events or if forceWakeup is true. This prevents starting the +// queue unnecessarily. +func (oq *destinationQueue) wakeQueueIfEventsPending(forceWakeup bool) { + eventsPending := func() bool { + oq.pendingMutex.Lock() + defer oq.pendingMutex.Unlock() + return len(oq.pendingPDUs) > 0 || len(oq.pendingEDUs) > 0 + } + + // NOTE : Only wakeup and notify the queue if there are pending events + // or if forceWakeup is true. Otherwise there is no reason to start the + // queue goroutine and waste resources. + if forceWakeup || eventsPending() { + logrus.Info("Starting queue due to pending events or forceWakeup") + oq.wakeQueueAndNotify() + } +} + // wakeQueueAndNotify ensures the destination queue is running and notifies it // that there is pending work. func (oq *destinationQueue) wakeQueueAndNotify() { - // Wake up the queue if it's asleep. - oq.wakeQueueIfNeeded() + // NOTE : Send notification before waking queue to prevent a race + // where the queue was running and stops due to a timeout in between + // checking it and sending the notification. // Notify the queue that there are events ready to send. select { case oq.notify <- struct{}{}: default: } + + // Wake up the queue if it's asleep. + oq.wakeQueueIfNeeded() } // wakeQueueIfNeeded will wake up the destination queue if it is -// not already running. If it is running but it is backing off -// then we will interrupt the backoff, causing any federation -// requests to retry. +// not already running. func (oq *destinationQueue) wakeQueueIfNeeded() { // Clear the backingOff flag and update the backoff metrics if it was set. if oq.backingOff.CompareAndSwap(true, false) { diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 31124e0c1..75b1b36be 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -374,14 +374,24 @@ func (oqs *OutgoingQueues) SendEDU( return nil } +// IsServerBlacklisted returns whether or not the provided server is currently +// blacklisted. +func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool { + return oqs.statistics.ForServer(srv).Blacklisted() +} + // RetryServer attempts to resend events to the given server if we had given up. func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { if oqs.disabled { return } - oqs.statistics.ForServer(srv).RemoveBlacklist() + + serverStatistics := oqs.statistics.ForServer(srv) + forceWakeup := serverStatistics.Blacklisted() + serverStatistics.RemoveBlacklist() + serverStatistics.ClearBackoff() + if queue := oqs.getQueue(srv); queue != nil { - queue.statistics.ClearBackoff() - queue.wakeQueueIfNeeded() + queue.wakeQueueIfEventsPending(forceWakeup) } } diff --git a/go.mod b/go.mod index 97dcdb293..adf319127 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 - github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 + github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 github.com/nats-io/nats-server/v2 v2.9.6 @@ -109,6 +109,8 @@ require ( github.com/nats-io/jwt/v2 v2.3.0 // indirect github.com/nats-io/nkeys v0.3.0 // indirect github.com/nats-io/nuid v1.0.1 // indirect + github.com/nxadm/tail v1.4.8 // indirect + github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/ginkgo/v2 v2.3.0 // indirect github.com/onsi/gomega v1.22.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect diff --git a/go.sum b/go.sum index 25331b2f6..73741136d 100644 --- a/go.sum +++ b/go.sum @@ -335,8 +335,12 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= +github.com/lucas-clemente/quic-go v0.29.2 h1:O8Mt0O6LpvEW+wfC40vZdcw0DngwYzoxq5xULZNzSI8= +github.com/lucas-clemente/quic-go v0.29.2/go.mod h1:g6/h9YMmLuU54tL1gW25uIi3VlBp3uv+sBihplIuskE= github.com/lucas-clemente/quic-go v0.30.0 h1:nwLW0h8ahVQ5EPTIM7uhl/stHqQDea15oRlYKZmw2O0= github.com/lucas-clemente/quic-go v0.30.0/go.mod h1:ssOrRsOmdxa768Wr78vnh2B8JozgLsMzG/g+0qEC7uk= +github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= @@ -350,8 +354,10 @@ github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 h1:VapUpY3oSbEGhfSpnnTKh7bz6AA5R4tVB5lwdlcA6jE= github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37 h1:CQWFrgH9TJOU2f2qCDhGwaSdAnmgSu3/f+2xcf/Fse4= -github.com/matrix-org/pinecone v0.0.0-20221103125849-37f2e9b9ba37/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= +github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6 h1:nAT5w41Q9uWTSnpKW55/hBwP91j2IFYPDRs0jJ8TyFI= +github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= +github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d h1:RRJg8TP3bYYtVyFIMv/ebJR+ZTWv7Xtci5zk1D3GD10= +github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= @@ -402,6 +408,12 @@ github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79/go.mod h1:E26fwEtRNigB github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= +github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= +github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= +github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.3.0 h1:kUMoxMoQG3ogk/QWyKh3zibV7BKZ+xBpWil1cTylVqc= github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= From 8299da590542a982437ad9dd30115d23c3d9d075 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 18 Nov 2022 13:24:02 +0000 Subject: [PATCH 33/42] Fix registration for virtual hosting --- clientapi/auth/login_test.go | 9 +++- clientapi/auth/user_interactive_test.go | 4 +- clientapi/routing/register.go | 52 ++++++++++++++---- clientapi/userutil/userutil_test.go | 16 ++++-- federationapi/routing/keys.go | 17 +----- go.mod | 4 +- go.sum | 16 +----- setup/config/config.go | 2 +- setup/config/config_global.go | 72 ++++++++++++++----------- userapi/userapi_test.go | 4 +- 10 files changed, 113 insertions(+), 83 deletions(-) diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index 5085f0170..b79c573aa 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) @@ -66,7 +67,9 @@ func TestLoginFromJSONReader(t *testing.T) { var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } login, cleanup, err := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) @@ -144,7 +147,9 @@ func TestBadLoginFromJSONReader(t *testing.T) { var userAPI fakeUserInternalAPI cfg := &config.ClientAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } _, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &userAPI, &userAPI, cfg) diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 001b1a6d4..5d97b31ce 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -47,7 +47,9 @@ func (d *fakeAccountDatabase) QueryAccountByPassword(ctx context.Context, req *a func setup() *UserInteractive { cfg := &config.ClientAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } return NewUserInteractive(&fakeAccountDatabase{}, cfg) diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index a92513b8b..801000f61 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -551,6 +551,12 @@ func Register( } var r registerRequest + host := gomatrixserverlib.ServerName(req.Host) + if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { + r.ServerName = v.ServerName + } else { + r.ServerName = cfg.Matrix.ServerName + } sessionID := gjson.GetBytes(reqBody, "auth.session").String() if sessionID == "" { // Generate a new, random session ID @@ -560,6 +566,7 @@ func Register( // Some of these might end up being overwritten if the // values are specified again in the request body. r.Username = data.Username + r.ServerName = data.ServerName r.Password = data.Password r.DeviceID = data.DeviceID r.InitialDisplayName = data.InitialDisplayName @@ -575,7 +582,6 @@ func Register( if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } - r.ServerName = cfg.Matrix.ServerName if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { r.Username, r.ServerName = l, d } @@ -650,16 +656,25 @@ func handleGuestRegistration( cfg *config.ClientAPI, userAPI userapi.ClientUserAPI, ) util.JSONResponse { - if cfg.RegistrationDisabled || cfg.GuestsDisabled { + registrationEnabled := !cfg.RegistrationDisabled + guestsEnabled := !cfg.GuestsDisabled + if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil { + registrationEnabled, guestsEnabled = v.RegistrationAllowed() + } + + if !registrationEnabled || !guestsEnabled { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Guest registration is disabled"), + JSON: jsonerror.Forbidden( + fmt.Sprintf("Guest registration is disabled on %q", r.ServerName), + ), } } var res userapi.PerformAccountCreationResponse err := userAPI.PerformAccountCreation(req.Context(), &userapi.PerformAccountCreationRequest{ AccountType: userapi.AccountTypeGuest, + ServerName: r.ServerName, }, &res) if err != nil { return util.JSONResponse{ @@ -736,10 +751,16 @@ func handleRegistrationFlow( ) } - if cfg.RegistrationDisabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { + registrationEnabled := !cfg.RegistrationDisabled + if v := cfg.Matrix.VirtualHost(r.ServerName); v != nil { + registrationEnabled, _ = v.RegistrationAllowed() + } + if !registrationEnabled && r.Auth.Type != authtypes.LoginTypeSharedSecret { return util.JSONResponse{ Code: http.StatusForbidden, - JSON: jsonerror.Forbidden("Registration is disabled"), + JSON: jsonerror.Forbidden( + fmt.Sprintf("Registration is disabled on %q", r.ServerName), + ), } } @@ -827,8 +848,9 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, + req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr, + req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + userapi.AccountTypeAppService, ) } @@ -846,8 +868,9 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, - r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, + req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr, + req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, + userapi.AccountTypeUser, ) } sessions.addParams(sessionID, r) @@ -869,7 +892,8 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username, password, appserviceID, ipAddr, userAgent, sessionID string, + username string, serverName gomatrixserverlib.ServerName, + password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, displayName, deviceID *string, accType userapi.AccountType, @@ -891,6 +915,7 @@ func completeRegistration( err := userAPI.PerformAccountCreation(ctx, &userapi.PerformAccountCreationRequest{ AppServiceID: appserviceID, Localpart: username, + ServerName: serverName, Password: password, AccountType: accType, OnConflict: userapi.ConflictAbort, @@ -934,6 +959,7 @@ func completeRegistration( var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: username, + ServerName: serverName, AccessToken: token, DeviceDisplayName: displayName, DeviceID: deviceID, @@ -1028,6 +1054,10 @@ func RegisterAvailable( // Squash username to all lowercase letters username = strings.ToLower(username) domain := cfg.Matrix.ServerName + host := gomatrixserverlib.ServerName(req.Host) + if v := cfg.Matrix.VirtualHostForHTTPHost(host); v != nil { + domain = v.ServerName + } if u, l, err := cfg.Matrix.SplitLocalID('@', username); err == nil { username, domain = u, l } @@ -1117,5 +1147,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/userutil/userutil_test.go b/clientapi/userutil/userutil_test.go index ccd6647b2..ee6bf8a01 100644 --- a/clientapi/userutil/userutil_test.go +++ b/clientapi/userutil/userutil_test.go @@ -30,7 +30,9 @@ var ( // TestGoodUserID checks that correct localpart is returned for a valid user ID. func TestGoodUserID(t *testing.T) { cfg := &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, } lp, _, err := ParseUsernameParam(goodUserID, cfg) @@ -47,7 +49,9 @@ func TestGoodUserID(t *testing.T) { // TestWithLocalpartOnly checks that localpart is returned when usernameParam contains only localpart. func TestWithLocalpartOnly(t *testing.T) { cfg := &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, } lp, _, err := ParseUsernameParam(localpart, cfg) @@ -64,7 +68,9 @@ func TestWithLocalpartOnly(t *testing.T) { // TestIncorrectDomain checks for error when there's server name mismatch. func TestIncorrectDomain(t *testing.T) { cfg := &config.Global{ - ServerName: invalidServerName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: invalidServerName, + }, } _, _, err := ParseUsernameParam(goodUserID, cfg) @@ -77,7 +83,9 @@ func TestIncorrectDomain(t *testing.T) { // TestBadUserID checks that ParseUsernameParam fails for invalid user ID func TestBadUserID(t *testing.T) { cfg := &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, } _, _, err := ParseUsernameParam(badUserID, cfg) diff --git a/federationapi/routing/keys.go b/federationapi/routing/keys.go index ee25ffbb2..dc262cfde 100644 --- a/federationapi/routing/keys.go +++ b/federationapi/routing/keys.go @@ -144,24 +144,9 @@ func LocalKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerNam func localKeys(cfg *config.FederationAPI, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { var keys gomatrixserverlib.ServerKeys - var virtualHost *config.VirtualHost -loop: - for _, v := range cfg.Matrix.VirtualHosts { - if v.ServerName == serverName { - virtualHost = v - break loop - } - for _, httpHost := range v.MatchHTTPHosts { - if httpHost == serverName { - virtualHost = v - break loop - } - } - } - var identity *gomatrixserverlib.SigningIdentity var err error - if virtualHost == nil { + if virtualHost := cfg.Matrix.VirtualHostForHTTPHost(serverName); virtualHost == nil { if identity, err = cfg.Matrix.SigningIdentityFor(cfg.Matrix.ServerName); err != nil { return nil, err } diff --git a/go.mod b/go.mod index adf319127..7961a4b08 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 @@ -109,8 +109,6 @@ require ( github.com/nats-io/jwt/v2 v2.3.0 // indirect github.com/nats-io/nkeys v0.3.0 // indirect github.com/nats-io/nuid v1.0.1 // indirect - github.com/nxadm/tail v1.4.8 // indirect - github.com/onsi/ginkgo v1.16.5 // indirect github.com/onsi/ginkgo/v2 v2.3.0 // indirect github.com/onsi/gomega v1.22.1 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect diff --git a/go.sum b/go.sum index 73741136d..c5184ff06 100644 --- a/go.sum +++ b/go.sum @@ -335,12 +335,8 @@ github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgx github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/lucas-clemente/quic-go v0.28.1/go.mod h1:oGz5DKK41cJt5+773+BSO9BXDsREY4HLf7+0odGAPO0= -github.com/lucas-clemente/quic-go v0.29.2 h1:O8Mt0O6LpvEW+wfC40vZdcw0DngwYzoxq5xULZNzSI8= -github.com/lucas-clemente/quic-go v0.29.2/go.mod h1:g6/h9YMmLuU54tL1gW25uIi3VlBp3uv+sBihplIuskE= github.com/lucas-clemente/quic-go v0.30.0 h1:nwLW0h8ahVQ5EPTIM7uhl/stHqQDea15oRlYKZmw2O0= github.com/lucas-clemente/quic-go v0.30.0/go.mod h1:ssOrRsOmdxa768Wr78vnh2B8JozgLsMzG/g+0qEC7uk= -github.com/lunixbochs/vtclean v1.0.0/go.mod h1:pHhQNgMf3btfWnGBVipUOjRYhoOsdGqdm/+2c2E2WMI= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/marten-seemann/qtls-go1-18 v0.1.3 h1:R4H2Ks8P6pAtUagjFty2p7BVHn3XiwDAl7TTQf5h7TI= github.com/marten-seemann/qtls-go1-18 v0.1.3/go.mod h1:mJttiymBAByA49mhlNZZGrH5u1uXYZJ+RW28Py7f4m4= @@ -352,10 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39 h1:VapUpY3oSbEGhfSpnnTKh7bz6AA5R4tVB5lwdlcA6jE= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221115151040-900369eadf39/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6 h1:nAT5w41Q9uWTSnpKW55/hBwP91j2IFYPDRs0jJ8TyFI= -github.com/matrix-org/pinecone v0.0.0-20221026160848-639feeff74d6/go.mod h1:K0N1ixHQxXoCyqolDqVxPM3ArrDtcMs8yegOx2Lfv9k= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 h1:S2TNN7C00CZlE1Af31LzxkOsAEkFt0RYZ7/3VdR1D5U= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d h1:RRJg8TP3bYYtVyFIMv/ebJR+ZTWv7Xtci5zk1D3GD10= github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= @@ -408,12 +402,6 @@ github.com/ngrok/sqlmw v0.0.0-20220520173518-97c9c04efc79/go.mod h1:E26fwEtRNigB github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= -github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= -github.com/onsi/ginkgo v1.14.0/go.mod h1:iSB4RoI2tjJc9BBv4NKIKWKya62Rps+oPG/Lv9klQyY= -github.com/onsi/ginkgo v1.16.2/go.mod h1:CObGmKUOKaSC0RjmoAK7tKyn4Azo5P2IWuoMnvwxz1E= -github.com/onsi/ginkgo v1.16.4/go.mod h1:dX+/inL/fNMqNlz0e9LfyB9TswhZpCVdJM/Z6Vvnwo0= -github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= -github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU= github.com/onsi/ginkgo/v2 v2.3.0 h1:kUMoxMoQG3ogk/QWyKh3zibV7BKZ+xBpWil1cTylVqc= github.com/onsi/ginkgo/v2 v2.3.0/go.mod h1:Eew0uilEqZmIEZr8JrvYlvOM7Rr6xzTmMV8AyFNU9d0= github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY= diff --git a/setup/config/config.go b/setup/config/config.go index 918bcbe3b..7e7ed1aa1 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -235,7 +235,7 @@ func loadConfig( if v.KeyValidityPeriod == 0 { v.KeyValidityPeriod = c.Global.KeyValidityPeriod } - if v.PrivateKeyPath == "" { + if v.PrivateKeyPath == "" || v.PrivateKey == nil || v.KeyID == "" { v.KeyID = c.Global.KeyID v.PrivateKey = c.Global.PrivateKey continue diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 722230d9a..801c68450 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -12,8 +12,9 @@ import ( ) type Global struct { - // The name of the server. This is usually the domain name, e.g 'matrix.org', 'localhost'. - ServerName gomatrixserverlib.ServerName `yaml:"server_name"` + // Signing identity contains the server name, private key and key ID of + // the deployment. + gomatrixserverlib.SigningIdentity `yaml:",inline"` // The secondary server names, used for virtual hosting. VirtualHosts []*VirtualHost `yaml:"virtual_hosts"` @@ -21,13 +22,6 @@ type Global struct { // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` - // The private key which will be used to sign requests and events. - PrivateKey ed25519.PrivateKey `yaml:"-"` - - // An arbitrary string used to uniquely identify the PrivateKey. Must start with the - // prefix "ed25519:". - KeyID gomatrixserverlib.KeyID `yaml:"-"` - // Information about old private keys that used to be used to sign requests and // events on this domain. They will not be used but will be advertised to other // servers that ask for them to help verify old events. @@ -151,6 +145,29 @@ func (c *Global) SplitLocalID(sigil byte, id string) (string, gomatrixserverlib. return u, s, nil } +func (c *Global) VirtualHost(serverName gomatrixserverlib.ServerName) *VirtualHost { + for _, v := range c.VirtualHosts { + if v.ServerName == serverName { + return v + } + } + return nil +} + +func (c *Global) VirtualHostForHTTPHost(serverName gomatrixserverlib.ServerName) *VirtualHost { + for _, v := range c.VirtualHosts { + if v.ServerName == serverName { + return v + } + for _, h := range v.MatchHTTPHosts { + if h == serverName { + return v + } + } + } + return nil +} + func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.SigningIdentity, error) { for _, id := range c.SigningIdentities() { if id.ServerName == serverName { @@ -162,32 +179,22 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { identities := make([]*gomatrixserverlib.SigningIdentity, 0, len(c.VirtualHosts)+1) - identities = append(identities, &gomatrixserverlib.SigningIdentity{ - ServerName: c.ServerName, - KeyID: c.KeyID, - PrivateKey: c.PrivateKey, - }) + identities = append(identities, &c.SigningIdentity) for _, v := range c.VirtualHosts { - identities = append(identities, v.SigningIdentity()) + identities = append(identities, &v.SigningIdentity) } return identities } type VirtualHost struct { - // The server name of the virtual host. - ServerName gomatrixserverlib.ServerName `yaml:"server_name"` - - // The key ID of the private key. If not specified, the default global key ID - // will be used instead. - KeyID gomatrixserverlib.KeyID `yaml:"key_id"` + // Signing identity contains the server name, private key and key ID of + // the virtual host. + gomatrixserverlib.SigningIdentity `yaml:",inline"` // Path to the private key. If not specified, the default global private key // will be used instead. PrivateKeyPath Path `yaml:"private_key"` - // The private key itself. - PrivateKey ed25519.PrivateKey `yaml:"-"` - // How long a remote server can cache our server key for before requesting it again. // Increasing this number will reduce the number of requests made by remote servers // for our key, but increases the period a compromised key will be considered valid @@ -201,19 +208,24 @@ type VirtualHost struct { MatchHTTPHosts []gomatrixserverlib.ServerName `yaml:"match_http_hosts"` // Is registration enabled on this virtual host? - AllowRegistration bool `json:"allow_registration"` + AllowRegistration bool `yaml:"allow_registration"` + + // Is guest registration enabled on this virtual host? + AllowGuests bool `yaml:"allow_guests"` } func (v *VirtualHost) Verify(configErrs *ConfigErrors) { checkNotEmpty(configErrs, "virtual_host.*.server_name", string(v.ServerName)) } -func (v *VirtualHost) SigningIdentity() *gomatrixserverlib.SigningIdentity { - return &gomatrixserverlib.SigningIdentity{ - ServerName: v.ServerName, - KeyID: v.KeyID, - PrivateKey: v.PrivateKey, +// RegistrationAllowed returns two bools, the first states whether registration +// is allowed for this virtual host and the second states whether guests are +// allowed for this virtual host. +func (v *VirtualHost) RegistrationAllowed() (bool, bool) { + if v == nil { + return false, false } + return v.AllowRegistration, v.AllowGuests } type OldVerifyKeys struct { diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 23acfa6f3..25fa75ee2 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -61,7 +61,9 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap cfg := &config.UserAPI{ Matrix: &config.Global{ - ServerName: serverName, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: serverName, + }, }, } From 7ad87eace3b88b3aaa88dc8a0acec9b64ffcc52c Mon Sep 17 00:00:00 2001 From: devonh Date: Fri, 18 Nov 2022 19:37:13 +0000 Subject: [PATCH 34/42] Update pinecone version (#2884) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 7961a4b08..7b3804a6e 100644 --- a/go.mod +++ b/go.mod @@ -23,7 +23,7 @@ require ( github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 - github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d + github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 github.com/nats-io/nats-server/v2 v2.9.6 diff --git a/go.sum b/go.sum index c5184ff06..14a8b0e88 100644 --- a/go.sum +++ b/go.sum @@ -350,8 +350,8 @@ github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 h1:S2TNN7C00CZlE1Af31LzxkOsAEkFt0RYZ7/3VdR1D5U= github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d h1:RRJg8TP3bYYtVyFIMv/ebJR+ZTWv7Xtci5zk1D3GD10= -github.com/matrix-org/pinecone v0.0.0-20221117214503-218c39e0cd6d/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= +github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= +github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= From 31f56ac3f43b692a56718fd1acfcb5a46e485c57 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Nov 2022 21:38:27 +0000 Subject: [PATCH 35/42] Never filter out a user's own membership when using LL (#2887) --- syncapi/streams/stream_pdu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 65ca8e2a3..dd7845574 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -588,7 +588,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list stateKey := *event.StateKey() - if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental { + if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID { newStateEvents = append(newStateEvents, event) if !stateFilter.IncludeRedundantMembers { p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) From 5e4b461e0158daa76872956729b73ccdccefbf10 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 28 Nov 2022 11:26:03 +0100 Subject: [PATCH 36/42] Return empty JSON if we don't have any protocols to return (#2892) This should help with Element reporting `The homeserver may be too old to support third party networks.` --- clientapi/routing/thirdparty.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/clientapi/routing/thirdparty.go b/clientapi/routing/thirdparty.go index e757cd411..7a62da449 100644 --- a/clientapi/routing/thirdparty.go +++ b/clientapi/routing/thirdparty.go @@ -36,9 +36,15 @@ func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, dev return jsonerror.InternalServerError() } if !resp.Exists { + if protocol != "" { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The protocol is unknown."), + } + } return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The protocol is unknown."), + Code: http.StatusOK, + JSON: struct{}{}, } } if protocol != "" { From f6f1445cfaccc7c4540775d9cb7083522d7c27d2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 09:58:22 +0000 Subject: [PATCH 37/42] Tweak event auth logging and cases (update to matrix-org/gomatrixserverlib@8835f6d) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 7b3804a6e..fd662f6e5 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 14a8b0e88..f72aa739e 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 h1:S2TNN7C00CZlE1Af31LzxkOsAEkFt0RYZ7/3VdR1D5U= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From 1ed5fb5e98220c2a96f1743490d7fef821bd9114 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 10:37:57 +0000 Subject: [PATCH 38/42] Update NATS Server to 2.9.8 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index fd662f6e5..d3eb4890a 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 - github.com/nats-io/nats-server/v2 v2.9.6 + github.com/nats-io/nats-server/v2 v2.9.8 github.com/nats-io/nats.go v1.20.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index f72aa739e..ad9372c84 100644 --- a/go.sum +++ b/go.sum @@ -385,8 +385,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.6 h1:RTtK+rv/4CcliOuqGsy58g7MuWkBaWmF5TUNwuUo9Uw= -github.com/nats-io/nats-server/v2 v2.9.6/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= +github.com/nats-io/nats-server/v2 v2.9.8 h1:jgxZsv+A3Reb3MgwxaINcNq/za8xZInKhDg9Q0cGN1o= +github.com/nats-io/nats-server/v2 v2.9.8/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM= github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= From 1990c154e920a350654d0ae6e02071950e595a01 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 11:11:08 +0000 Subject: [PATCH 39/42] Update configuration --- setup/config/config_global.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 801c68450..511951fe6 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -17,7 +17,7 @@ type Global struct { gomatrixserverlib.SigningIdentity `yaml:",inline"` // The secondary server names, used for virtual hosting. - VirtualHosts []*VirtualHost `yaml:"virtual_hosts"` + VirtualHosts []*VirtualHost `yaml:"-"` // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` From f8d1dc521d401b78163840d8ff978cb2cd2718d7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 29 Nov 2022 15:46:28 +0100 Subject: [PATCH 40/42] Fix `m.receipt`s causing notifications (#2893) Fixes https://github.com/matrix-org/dendrite/issues/2353 --- syncapi/streams/stream_receipt.go | 3 +- syncapi/types/types.go | 7 +++ syncapi/types/types_test.go | 100 ++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 2 deletions(-) diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index 977815078..16a81e833 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -87,8 +87,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( } ev := gomatrixserverlib.ClientEvent{ - Type: gomatrixserverlib.MReceipt, - RoomID: roomID, + Type: gomatrixserverlib.MReceipt, } content := make(map[string]ReceiptMRead) for _, receipt := range receipts { diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 295187acc..9fbadc06c 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -480,6 +480,13 @@ func (jr JoinResponse) MarshalJSON() ([]byte, error) { if jr.Ephemeral != nil && len(jr.Ephemeral.Events) == 0 { a.Ephemeral = nil } + if jr.Ephemeral != nil { + // Remove the room_id from EDUs, as this seems to cause Element Web + // to trigger notifications - https://github.com/vector-im/element-web/issues/17263 + for i := range jr.Ephemeral.Events { + jr.Ephemeral.Events[i].RoomID = "" + } + } if jr.AccountData != nil && len(jr.AccountData.Events) == 0 { a.AccountData = nil } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 19fcfc150..74246d964 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -2,6 +2,7 @@ package types import ( "encoding/json" + "reflect" "testing" "github.com/matrix-org/gomatrixserverlib" @@ -63,3 +64,102 @@ func TestNewInviteResponse(t *testing.T) { t.Fatalf("Invite response didn't contain correct info") } } + +func TestJoinResponse_MarshalJSON(t *testing.T) { + type fields struct { + Summary *Summary + State *ClientEvents + Timeline *Timeline + Ephemeral *ClientEvents + AccountData *ClientEvents + UnreadNotifications *UnreadNotifications + } + tests := []struct { + name string + fields fields + want []byte + wantErr bool + }{ + { + name: "empty state is removed", + fields: fields{ + State: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty accountdata is removed", + fields: fields{ + AccountData: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty ephemeral is removed", + fields: fields{ + Ephemeral: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty timeline is removed", + fields: fields{ + Timeline: &Timeline{}, + }, + want: []byte("{}"), + }, + { + name: "empty summary is removed", + fields: fields{ + Summary: &Summary{}, + }, + want: []byte("{}"), + }, + { + name: "unread notifications are removed, if everything else is empty", + fields: fields{ + UnreadNotifications: &UnreadNotifications{}, + }, + want: []byte("{}"), + }, + { + name: "unread notifications are NOT removed, if state is set", + fields: fields{ + State: &ClientEvents{Events: []gomatrixserverlib.ClientEvent{{Content: []byte("{}")}}}, + UnreadNotifications: &UnreadNotifications{NotificationCount: 1}, + }, + want: []byte(`{"state":{"events":[{"content":{},"type":""}]},"unread_notifications":{"highlight_count":0,"notification_count":1}}`), + }, + { + name: "roomID is removed from EDUs", + fields: fields{ + Ephemeral: &ClientEvents{ + Events: []gomatrixserverlib.ClientEvent{ + {RoomID: "!someRandomRoomID:test", Content: []byte("{}")}, + }, + }, + }, + want: []byte(`{"ephemeral":{"events":[{"content":{},"type":""}]}}`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jr := JoinResponse{ + Summary: tt.fields.Summary, + State: tt.fields.State, + Timeline: tt.fields.Timeline, + Ephemeral: tt.fields.Ephemeral, + AccountData: tt.fields.AccountData, + UnreadNotifications: tt.fields.UnreadNotifications, + } + got, err := jr.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", string(got), string(tt.want)) + } + }) + } +} From ed497aa8b2d24b5d5ac00df773dab5087ddcf5ca Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 16:26:33 +0000 Subject: [PATCH 41/42] Version 0.10.8 --- .github/workflows/docker.yml | 48 ++++++++++++++++++------------------ CHANGES.md | 23 +++++++++++++++++ internal/version.go | 2 +- 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 846844173..2e17539d8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -68,18 +68,6 @@ jobs: ${{ env.DOCKER_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@master - with: - image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} - format: "sarif" - output: "trivy-results.sarif" - - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@v2 - with: - sarif_file: "trivy-results.sarif" - - name: Build release monolith image if: github.event_name == 'release' # Only for GitHub releases id: docker_build_monolith_release @@ -98,6 +86,18 @@ jobs: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ env.RELEASE_VERSION }} + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@master + with: + image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} + format: "sarif" + output: "trivy-results.sarif" + + - name: Upload Trivy scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: "trivy-results.sarif" + polylith: name: Polylith image runs-on: ubuntu-latest @@ -148,18 +148,6 @@ jobs: ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@master - with: - image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - format: "sarif" - output: "trivy-results.sarif" - - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@v2 - with: - sarif_file: "trivy-results.sarif" - - name: Build release polylith image if: github.event_name == 'release' # Only for GitHub releases id: docker_build_polylith_release @@ -178,6 +166,18 @@ jobs: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@master + with: + image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} + format: "sarif" + output: "trivy-results.sarif" + + - name: Upload Trivy scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: "trivy-results.sarif" + demo-pinecone: name: Pinecone demo image runs-on: ubuntu-latest diff --git a/CHANGES.md b/CHANGES.md index cdeb1dea3..f5a82cfe2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,28 @@ # Changelog +## Dendrite 0.10.8 (2022-11-29) + +### Features + +* The built-in NATS Server has been updated to version 2.9.8 +* A number of under-the-hood changes have been merged for future virtual hosting support in Dendrite (running multiple domain names on the same Dendrite deployment) + +### Fixes + +* Event auth handling of invites has been refactored, which should fix some edge cases being handled incorrectly +* Fix a bug when returning an empty protocol list, which could cause Element to display "The homeserver may be too old to support third party networks" when opening the public room directory +* The sync API will no longer filter out the user's own membership when using lazy-loading +* Dendrite will now correctly detect JetStream consumers being deleted, stopping the consumer goroutine as needed +* A panic in the federation API where the server list could go out of bounds has been fixed +* Blacklisted servers will now be excluded when querying joined servers, which improves CPU usage and performs less unnecessary outbound requests +* A database writer will now be used to assign state key NIDs when requesting NIDs that may not exist yet +* Dendrite will now correctly move local aliases for an upgraded room when the room is upgraded remotely +* Dendrite will now correctly move account data for an upgraded room when the room is upgraded remotely +* Missing state key NIDs will now be allocated on request rather than returning an error +* Guest access is now correctly denied on a number of endpoints +* Presence information will now be correctly sent for new private chats +* A number of unspecced fields have been removed from outbound `/send` transactions + ## Dendrite 0.10.7 (2022-11-04) ### Features diff --git a/internal/version.go b/internal/version.go index 85b19046e..685237b9e 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 7 + VersionPatch = 8 VersionTag = "" // example: "rc1" ) From ac514b406c2288e27565ed0a84d4e1dfd70c4848 Mon Sep 17 00:00:00 2001 From: danielaloni <1danielaloni@gmail.com> Date: Wed, 4 Jan 2023 13:19:01 +0200 Subject: [PATCH 42/42] =?UTF-8?q?=F0=9F=90=9B=20Migration=20to=20have=20th?= =?UTF-8?q?e=20correct=20composite=20primary=20key=20in=20roomserver=5Fpub?= =?UTF-8?q?lished.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../postgres/deltas/20221027084407_published_appservice.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/roomserver/storage/postgres/deltas/20221027084407_published_appservice.go b/roomserver/storage/postgres/deltas/20221027084407_published_appservice.go index 687ee9024..077234b6e 100644 --- a/roomserver/storage/postgres/deltas/20221027084407_published_appservice.go +++ b/roomserver/storage/postgres/deltas/20221027084407_published_appservice.go @@ -29,6 +29,13 @@ func UpPulishedAppservice(ctx context.Context, tx *sql.Tx) error { if err != nil { return fmt.Errorf("failed to execute upgrade: %w", err) } + _, err = tx.ExecContext(ctx, ` + ALTER TABLE roomserver_published DROP CONSTRAINT IF EXISTS roomserver_published_pkey; + ALTER TABLE roomserver_published ADD PRIMARY KEY (room_id, appservice_id, network_id); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } return nil }