From 31f56ac3f43b692a56718fd1acfcb5a46e485c57 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 22 Nov 2022 21:38:27 +0000 Subject: [PATCH 01/67] Never filter out a user's own membership when using LL (#2887) --- syncapi/streams/stream_pdu.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 65ca8e2a3..dd7845574 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -588,7 +588,7 @@ func (p *PDUStreamProvider) lazyLoadMembers( isGappedIncremental := limited && incremental // We want this users membership event, keep it in the list stateKey := *event.StateKey() - if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental { + if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID { newStateEvents = append(newStateEvents, event) if !stateFilter.IncludeRedundantMembers { p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID()) From 5e4b461e0158daa76872956729b73ccdccefbf10 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 28 Nov 2022 11:26:03 +0100 Subject: [PATCH 02/67] Return empty JSON if we don't have any protocols to return (#2892) This should help with Element reporting `The homeserver may be too old to support third party networks.` --- clientapi/routing/thirdparty.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/clientapi/routing/thirdparty.go b/clientapi/routing/thirdparty.go index e757cd411..7a62da449 100644 --- a/clientapi/routing/thirdparty.go +++ b/clientapi/routing/thirdparty.go @@ -36,9 +36,15 @@ func Protocols(req *http.Request, asAPI appserviceAPI.AppServiceInternalAPI, dev return jsonerror.InternalServerError() } if !resp.Exists { + if protocol != "" { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.NotFound("The protocol is unknown."), + } + } return util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("The protocol is unknown."), + Code: http.StatusOK, + JSON: struct{}{}, } } if protocol != "" { From f6f1445cfaccc7c4540775d9cb7083522d7c27d2 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 09:58:22 +0000 Subject: [PATCH 03/67] Tweak event auth logging and cases (update to matrix-org/gomatrixserverlib@8835f6d) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 7b3804a6e..fd662f6e5 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 + github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index 14a8b0e88..f72aa739e 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7 h1:S2TNN7C00CZlE1Af31LzxkOsAEkFt0RYZ7/3VdR1D5U= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221118122129-9b9340bf29d7/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= From 1ed5fb5e98220c2a96f1743490d7fef821bd9114 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 10:37:57 +0000 Subject: [PATCH 04/67] Update NATS Server to 2.9.8 --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index fd662f6e5..d3eb4890a 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 - github.com/nats-io/nats-server/v2 v2.9.6 + github.com/nats-io/nats-server/v2 v2.9.8 github.com/nats-io/nats.go v1.20.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index f72aa739e..ad9372c84 100644 --- a/go.sum +++ b/go.sum @@ -385,8 +385,8 @@ github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRW github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/nats-io/jwt/v2 v2.3.0 h1:z2mA1a7tIf5ShggOFlR1oBPgd6hGqcDYsISxZByUzdI= github.com/nats-io/jwt/v2 v2.3.0/go.mod h1:0tqz9Hlu6bCBFLWAASKhE5vUA4c24L9KPUUgvwumE/k= -github.com/nats-io/nats-server/v2 v2.9.6 h1:RTtK+rv/4CcliOuqGsy58g7MuWkBaWmF5TUNwuUo9Uw= -github.com/nats-io/nats-server/v2 v2.9.6/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= +github.com/nats-io/nats-server/v2 v2.9.8 h1:jgxZsv+A3Reb3MgwxaINcNq/za8xZInKhDg9Q0cGN1o= +github.com/nats-io/nats-server/v2 v2.9.8/go.mod h1:AB6hAnGZDlYfqb7CTAm66ZKMZy9DpfierY1/PbpvI2g= github.com/nats-io/nats.go v1.20.0 h1:T8JJnQfVSdh1CzGiwAOv5hEobYCBho/0EupGznYw0oM= github.com/nats-io/nats.go v1.20.0/go.mod h1:tLqubohF7t4z3du1QDPYJIQQyhb4wl6DhjxEajSI7UA= github.com/nats-io/nkeys v0.3.0 h1:cgM5tL53EvYRU+2YLXIK0G2mJtK12Ft9oeooSZMA2G8= From 1990c154e920a350654d0ae6e02071950e595a01 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 11:11:08 +0000 Subject: [PATCH 05/67] Update configuration --- setup/config/config_global.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 801c68450..511951fe6 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -17,7 +17,7 @@ type Global struct { gomatrixserverlib.SigningIdentity `yaml:",inline"` // The secondary server names, used for virtual hosting. - VirtualHosts []*VirtualHost `yaml:"virtual_hosts"` + VirtualHosts []*VirtualHost `yaml:"-"` // Path to the private key which will be used to sign requests and events. PrivateKeyPath Path `yaml:"private_key"` From f8d1dc521d401b78163840d8ff978cb2cd2718d7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 29 Nov 2022 15:46:28 +0100 Subject: [PATCH 06/67] Fix `m.receipt`s causing notifications (#2893) Fixes https://github.com/matrix-org/dendrite/issues/2353 --- syncapi/streams/stream_receipt.go | 3 +- syncapi/types/types.go | 7 +++ syncapi/types/types_test.go | 100 ++++++++++++++++++++++++++++++ 3 files changed, 108 insertions(+), 2 deletions(-) diff --git a/syncapi/streams/stream_receipt.go b/syncapi/streams/stream_receipt.go index 977815078..16a81e833 100644 --- a/syncapi/streams/stream_receipt.go +++ b/syncapi/streams/stream_receipt.go @@ -87,8 +87,7 @@ func (p *ReceiptStreamProvider) IncrementalSync( } ev := gomatrixserverlib.ClientEvent{ - Type: gomatrixserverlib.MReceipt, - RoomID: roomID, + Type: gomatrixserverlib.MReceipt, } content := make(map[string]ReceiptMRead) for _, receipt := range receipts { diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 295187acc..9fbadc06c 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -480,6 +480,13 @@ func (jr JoinResponse) MarshalJSON() ([]byte, error) { if jr.Ephemeral != nil && len(jr.Ephemeral.Events) == 0 { a.Ephemeral = nil } + if jr.Ephemeral != nil { + // Remove the room_id from EDUs, as this seems to cause Element Web + // to trigger notifications - https://github.com/vector-im/element-web/issues/17263 + for i := range jr.Ephemeral.Events { + jr.Ephemeral.Events[i].RoomID = "" + } + } if jr.AccountData != nil && len(jr.AccountData.Events) == 0 { a.AccountData = nil } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 19fcfc150..74246d964 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -2,6 +2,7 @@ package types import ( "encoding/json" + "reflect" "testing" "github.com/matrix-org/gomatrixserverlib" @@ -63,3 +64,102 @@ func TestNewInviteResponse(t *testing.T) { t.Fatalf("Invite response didn't contain correct info") } } + +func TestJoinResponse_MarshalJSON(t *testing.T) { + type fields struct { + Summary *Summary + State *ClientEvents + Timeline *Timeline + Ephemeral *ClientEvents + AccountData *ClientEvents + UnreadNotifications *UnreadNotifications + } + tests := []struct { + name string + fields fields + want []byte + wantErr bool + }{ + { + name: "empty state is removed", + fields: fields{ + State: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty accountdata is removed", + fields: fields{ + AccountData: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty ephemeral is removed", + fields: fields{ + Ephemeral: &ClientEvents{}, + }, + want: []byte("{}"), + }, + { + name: "empty timeline is removed", + fields: fields{ + Timeline: &Timeline{}, + }, + want: []byte("{}"), + }, + { + name: "empty summary is removed", + fields: fields{ + Summary: &Summary{}, + }, + want: []byte("{}"), + }, + { + name: "unread notifications are removed, if everything else is empty", + fields: fields{ + UnreadNotifications: &UnreadNotifications{}, + }, + want: []byte("{}"), + }, + { + name: "unread notifications are NOT removed, if state is set", + fields: fields{ + State: &ClientEvents{Events: []gomatrixserverlib.ClientEvent{{Content: []byte("{}")}}}, + UnreadNotifications: &UnreadNotifications{NotificationCount: 1}, + }, + want: []byte(`{"state":{"events":[{"content":{},"type":""}]},"unread_notifications":{"highlight_count":0,"notification_count":1}}`), + }, + { + name: "roomID is removed from EDUs", + fields: fields{ + Ephemeral: &ClientEvents{ + Events: []gomatrixserverlib.ClientEvent{ + {RoomID: "!someRandomRoomID:test", Content: []byte("{}")}, + }, + }, + }, + want: []byte(`{"ephemeral":{"events":[{"content":{},"type":""}]}}`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + jr := JoinResponse{ + Summary: tt.fields.Summary, + State: tt.fields.State, + Timeline: tt.fields.Timeline, + Ephemeral: tt.fields.Ephemeral, + AccountData: tt.fields.AccountData, + UnreadNotifications: tt.fields.UnreadNotifications, + } + got, err := jr.MarshalJSON() + if (err != nil) != tt.wantErr { + t.Errorf("MarshalJSON() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("MarshalJSON() got = %v, want %v", string(got), string(tt.want)) + } + }) + } +} From ed497aa8b2d24b5d5ac00df773dab5087ddcf5ca Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 29 Nov 2022 16:26:33 +0000 Subject: [PATCH 07/67] Version 0.10.8 --- .github/workflows/docker.yml | 48 ++++++++++++++++++------------------ CHANGES.md | 23 +++++++++++++++++ internal/version.go | 2 +- 3 files changed, 48 insertions(+), 25 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 846844173..2e17539d8 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -68,18 +68,6 @@ jobs: ${{ env.DOCKER_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@master - with: - image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} - format: "sarif" - output: "trivy-results.sarif" - - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@v2 - with: - sarif_file: "trivy-results.sarif" - - name: Build release monolith image if: github.event_name == 'release' # Only for GitHub releases id: docker_build_monolith_release @@ -98,6 +86,18 @@ jobs: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ env.RELEASE_VERSION }} + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@master + with: + image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-monolith:${{ github.ref_name }} + format: "sarif" + output: "trivy-results.sarif" + + - name: Upload Trivy scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: "trivy-results.sarif" + polylith: name: Polylith image runs-on: ubuntu-latest @@ -148,18 +148,6 @@ jobs: ${{ env.DOCKER_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - - name: Run Trivy vulnerability scanner - uses: aquasecurity/trivy-action@master - with: - image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} - format: "sarif" - output: "trivy-results.sarif" - - - name: Upload Trivy scan results to GitHub Security tab - uses: github/codeql-action/upload-sarif@v2 - with: - sarif_file: "trivy-results.sarif" - - name: Build release polylith image if: github.event_name == 'release' # Only for GitHub releases id: docker_build_polylith_release @@ -178,6 +166,18 @@ jobs: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:latest ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ env.RELEASE_VERSION }} + - name: Run Trivy vulnerability scanner + uses: aquasecurity/trivy-action@master + with: + image-ref: ghcr.io/${{ env.GHCR_NAMESPACE }}/dendrite-polylith:${{ github.ref_name }} + format: "sarif" + output: "trivy-results.sarif" + + - name: Upload Trivy scan results to GitHub Security tab + uses: github/codeql-action/upload-sarif@v2 + with: + sarif_file: "trivy-results.sarif" + demo-pinecone: name: Pinecone demo image runs-on: ubuntu-latest diff --git a/CHANGES.md b/CHANGES.md index cdeb1dea3..f5a82cfe2 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,28 @@ # Changelog +## Dendrite 0.10.8 (2022-11-29) + +### Features + +* The built-in NATS Server has been updated to version 2.9.8 +* A number of under-the-hood changes have been merged for future virtual hosting support in Dendrite (running multiple domain names on the same Dendrite deployment) + +### Fixes + +* Event auth handling of invites has been refactored, which should fix some edge cases being handled incorrectly +* Fix a bug when returning an empty protocol list, which could cause Element to display "The homeserver may be too old to support third party networks" when opening the public room directory +* The sync API will no longer filter out the user's own membership when using lazy-loading +* Dendrite will now correctly detect JetStream consumers being deleted, stopping the consumer goroutine as needed +* A panic in the federation API where the server list could go out of bounds has been fixed +* Blacklisted servers will now be excluded when querying joined servers, which improves CPU usage and performs less unnecessary outbound requests +* A database writer will now be used to assign state key NIDs when requesting NIDs that may not exist yet +* Dendrite will now correctly move local aliases for an upgraded room when the room is upgraded remotely +* Dendrite will now correctly move account data for an upgraded room when the room is upgraded remotely +* Missing state key NIDs will now be allocated on request rather than returning an error +* Guest access is now correctly denied on a number of endpoints +* Presence information will now be correctly sent for new private chats +* A number of unspecced fields have been removed from outbound `/send` transactions + ## Dendrite 0.10.7 (2022-11-04) ### Features diff --git a/internal/version.go b/internal/version.go index 85b19046e..685237b9e 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 7 + VersionPatch = 8 VersionTag = "" // example: "rc1" ) From ac5f3f025eaf157d2a2bb15adcd19a5732325608 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Wed, 30 Nov 2022 12:40:36 +0100 Subject: [PATCH 08/67] Calculate correct room member count for push rule evaluation (#2894) Fixes a bug where we would return only the local member count, which could result in wrongly calculated push rules. --- userapi/consumers/roomserver.go | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 5d8924dda..3ce5af621 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -385,7 +385,6 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s req := &rsapi.QueryMembershipsForRoomRequest{ RoomID: roomID, JoinedOnly: true, - LocalOnly: true, } var res rsapi.QueryMembershipsForRoomResponse @@ -396,8 +395,23 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s } var members []*localMembership - var ntotal int for _, event := range res.JoinEvents { + // Filter out invalid join events + if event.StateKey == nil { + continue + } + if *event.StateKey == "" { + continue + } + _, serverName, err := gomatrixserverlib.SplitID('@', *event.StateKey) + if err != nil { + log.WithError(err).Error("failed to get servername from statekey") + continue + } + // Only get memberships for our server + if serverName != s.serverName { + continue + } member, err := newLocalMembership(&event) if err != nil { log.WithError(err).Errorf("Parsing MemberContent") @@ -410,11 +424,10 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s continue } - ntotal++ members = append(members, member) } - return members, ntotal, nil + return members, len(res.JoinEvents), nil } // roomName returns the name in the event (if type==m.room.name), or @@ -641,7 +654,7 @@ func (s *OutputRoomEventConsumer) evaluatePushRules(ctx context.Context, event * if rule == nil { // SPEC: If no rules match an event, the homeserver MUST NOT // notify the Push Gateway for that event. - return nil, err + return nil, nil } log.WithFields(log.Fields{ From f009e541816cbae1519fede2b40b0af68020b3ab Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 30 Nov 2022 12:54:37 +0000 Subject: [PATCH 09/67] Push rule evaluation tweaks (#2897) This tweaks push rule evaluation: 1. to be more strict around pattern matching and to not match empty patterns 3. to bail if we come across a `dont_notify`, since cycles after that are wasted 4. refactors `ActionsToTweaks` to make a bit more sense --- internal/pushrules/evaluate.go | 13 ++++++++++++ internal/pushrules/evaluate_test.go | 13 ++++++------ internal/pushrules/util.go | 31 +++++++++++++++++------------ internal/pushrules/util_test.go | 1 + 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index df22cb042..4ff9939a6 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -145,6 +145,11 @@ func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec Evalua } func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, error) { + // It doesn't make sense for an empty pattern to match anything. + if pattern == "" { + return false, nil + } + re, err := globToRegexp(pattern) if err != nil { return false, err @@ -154,12 +159,20 @@ func patternMatches(key, pattern string, event *gomatrixserverlib.Event) (bool, if err = json.Unmarshal(event.JSON(), &eventMap); err != nil { return false, fmt.Errorf("parsing event: %w", err) } + // From the spec: + // "If the property specified by key is completely absent from + // the event, or does not have a string value, then the condition + // will not match, even if pattern is *." v, err := lookupMapPath(strings.Split(key, "."), eventMap) if err != nil { // An unknown path is a benign error that shouldn't stop rule // processing. It's just a non-match. return false, nil } + if _, ok := v.(string); !ok { + // A non-string never matches. + return false, nil + } return re.MatchString(fmt.Sprint(v)), nil } diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index eabd02415..c5d5abd2a 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -111,7 +111,10 @@ func TestConditionMatches(t *testing.T) { {"empty", Condition{}, `{}`, false}, {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, - {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, true}, + // Neither of these should match because `content` is not a full string match, + // and `content.body` is not a string value. + {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, false}, + {"eventBodyMatch", Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3"}, `{"content":{"body": 3}}`, false}, {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, @@ -137,7 +140,7 @@ func TestConditionMatches(t *testing.T) { t.Fatalf("conditionMatches failed: %v", err) } if got != tst.Want { - t.Errorf("conditionMatches: got %v, want %v", got, tst.Want) + t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.Want, tst.Name) } }) } @@ -161,9 +164,7 @@ func TestPatternMatches(t *testing.T) { }{ {"empty", "", "", `{}`, false}, - // Note that an empty pattern contains no wildcard characters, - // which implicitly means "*". - {"patternEmpty", "content", "", `{"content":{}}`, true}, + {"patternEmpty", "content", "", `{"content":{}}`, false}, {"literal", "content.creator", "acreator", `{"content":{"creator":"acreator"}}`, true}, {"substring", "content.creator", "reat", `{"content":{"creator":"acreator"}}`, true}, @@ -178,7 +179,7 @@ func TestPatternMatches(t *testing.T) { t.Fatalf("patternMatches failed: %v", err) } if got != tst.Want { - t.Errorf("patternMatches: got %v, want %v", got, tst.Want) + t.Errorf("patternMatches: got %v, want %v on %s", got, tst.Want, tst.Name) } }) } diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go index 8ab4eab94..fb9c05be2 100644 --- a/internal/pushrules/util.go +++ b/internal/pushrules/util.go @@ -11,22 +11,27 @@ import ( // kind and a tweaks map. Returns a nil map if it would have been // empty. func ActionsToTweaks(as []*Action) (ActionKind, map[string]interface{}, error) { - kind := UnknownAction - tweaks := map[string]interface{}{} + var kind ActionKind + var tweaks map[string]interface{} for _, a := range as { - if a.Kind == SetTweakAction { - tweaks[string(a.Tweak)] = a.Value - continue - } - if kind != UnknownAction { - return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) - } - kind = a.Kind - } + switch a.Kind { + case DontNotifyAction: + // Don't bother processing any further + return DontNotifyAction, nil, nil - if len(tweaks) == 0 { - tweaks = nil + case SetTweakAction: + if tweaks == nil { + tweaks = map[string]interface{}{} + } + tweaks[string(a.Tweak)] = a.Value + + default: + if kind != UnknownAction { + return UnknownAction, nil, fmt.Errorf("got multiple primary actions: already had %q, got %s", kind, a.Kind) + } + kind = a.Kind + } } return kind, tweaks, nil diff --git a/internal/pushrules/util_test.go b/internal/pushrules/util_test.go index a951c55a2..89f8243d9 100644 --- a/internal/pushrules/util_test.go +++ b/internal/pushrules/util_test.go @@ -17,6 +17,7 @@ func TestActionsToTweaks(t *testing.T) { {"empty", nil, UnknownAction, nil}, {"zero", []*Action{{}}, UnknownAction, nil}, {"onlyPrimary", []*Action{{Kind: NotifyAction}}, NotifyAction, nil}, + {"onlyPrimaryDontNotify", []*Action{{Kind: DontNotifyAction}}, DontNotifyAction, nil}, {"onlyTweak", []*Action{{Kind: SetTweakAction, Tweak: HighlightTweak}}, UnknownAction, map[string]interface{}{"highlight": nil}}, {"onlyTweakWithValue", []*Action{{Kind: SetTweakAction, Tweak: SoundTweak, Value: "default"}}, UnknownAction, map[string]interface{}{"sound": "default"}}, { From 6f000e980155d6651b01c9b4dbdb316e9f501a45 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 1 Dec 2022 10:14:26 +0000 Subject: [PATCH 10/67] Make `create-account` more verbose --- cmd/create-account/main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index c8e239f29..15b043ed5 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -177,7 +177,7 @@ func sharedSecretRegister(sharedSecret, serverURL, localpart, password string, a defer regResp.Body.Close() // nolint: errcheck if regResp.StatusCode < 200 || regResp.StatusCode >= 300 { body, _ = io.ReadAll(regResp.Body) - return "", fmt.Errorf(gjson.GetBytes(body, "error").Str) + return "", fmt.Errorf("got HTTP %d error from server: %s", regResp.StatusCode, string(body)) } r, err := io.ReadAll(regResp.Body) if err != nil { From 1be0afa1810ecbfb8348a81c8b49ef2a15114055 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 1 Dec 2022 10:24:17 +0000 Subject: [PATCH 11/67] Expose `/_dendrite` and `/_synapse` on the P2P demo HTTP muxes --- build/gobind-pinecone/monolith.go | 2 ++ build/gobind-yggdrasil/monolith.go | 2 ++ cmd/dendrite-demo-pinecone/main.go | 2 ++ cmd/dendrite-demo-yggdrasil/main.go | 2 ++ 4 files changed, 8 insertions(+) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index 9100ebf0f..b2fb70ce4 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -382,6 +382,8 @@ func (m *DendriteMonolith) Start() { httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) httpRouter.HandleFunc("/pinecone", m.PineconeRouter.ManholeHandler) pMux := mux.NewRouter().SkipClean(true).UseEncodedPath() diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 248b6c324..4cbe983ec 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -196,6 +196,8 @@ func (m *DendriteMonolith) Start() { httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) yggRouter := mux.NewRouter() yggRouter.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 421b17d56..2ceb7c17b 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -270,6 +270,8 @@ func main() { pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + pMux.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + pMux.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) pHTTP := pQUIC.Protocol("matrix").HTTP() pHTTP.Mux().Handle(users.PublicURL, pMux) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 1226496c3..e1a04abf5 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -198,6 +198,8 @@ func main() { httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) embed.Embed(httpRouter, *instancePort, "Yggdrasil Demo") yggRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() From 934056f21fb91b59c9a9c4413318a9387abf658e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 1 Dec 2022 10:45:15 +0000 Subject: [PATCH 12/67] Fix `dendrite-demo-pinecone`, `/_dendrite` namespace setup --- build/gobind-pinecone/monolith.go | 1 + build/gobind-yggdrasil/monolith.go | 1 + cmd/dendrite-demo-pinecone/main.go | 5 +++-- cmd/dendrite-demo-yggdrasil/main.go | 1 + setup/base/base.go | 34 ++++++++++++++++------------- 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index b2fb70ce4..e8ed8fe85 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -336,6 +336,7 @@ func (m *DendriteMonolith) Start() { } base := base.NewBaseDendrite(cfg, "Monolith") + base.ConfigureAdminEndpoints() defer base.Close() // nolint: errcheck federation := conn.CreateFederationClient(base, m.PineconeQUIC) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 4cbe983ec..9a3ac5d7b 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -150,6 +150,7 @@ func (m *DendriteMonolith) Start() { } base := base.NewBaseDendrite(cfg, "Monolith") + base.ConfigureAdminEndpoints() m.processContext = base.ProcessContext defer base.Close() // nolint: errcheck diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 2ceb7c17b..2f647a41b 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -155,6 +155,7 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) base := base.NewBaseDendrite(cfg, "Monolith") + base.ConfigureAdminEndpoints() defer base.Close() // nolint: errcheck pineconeEventChannel := make(chan pineconeEvents.Event) @@ -248,6 +249,8 @@ func main() { httpRouter.PathPrefix(httputil.InternalPathPrefix).Handler(base.InternalAPIMux) httpRouter.PathPrefix(httputil.PublicClientPathPrefix).Handler(base.PublicClientAPIMux) httpRouter.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) + httpRouter.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) + httpRouter.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) httpRouter.HandleFunc("/ws", func(w http.ResponseWriter, r *http.Request) { c, err := wsUpgrader.Upgrade(w, r, nil) if err != nil { @@ -270,8 +273,6 @@ func main() { pMux.PathPrefix(users.PublicURL).HandlerFunc(userProvider.FederatedUserProfiles) pMux.PathPrefix(httputil.PublicFederationPathPrefix).Handler(base.PublicFederationAPIMux) pMux.PathPrefix(httputil.PublicMediaPathPrefix).Handler(base.PublicMediaAPIMux) - pMux.PathPrefix(httputil.DendriteAdminPathPrefix).Handler(base.DendriteAdminMux) - pMux.PathPrefix(httputil.SynapseAdminPathPrefix).Handler(base.SynapseAdminMux) pHTTP := pQUIC.Protocol("matrix").HTTP() pHTTP.Mux().Handle(users.PublicURL, pMux) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index e1a04abf5..5dd61b1b7 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -144,6 +144,7 @@ func main() { cfg.Global.KeyID = gomatrixserverlib.KeyID(signing.KeyID) base := base.NewBaseDendrite(cfg, "Monolith") + base.ConfigureAdminEndpoints() defer base.Close() // nolint: errcheck ygg, err := yggconn.Setup(sk, *instanceName, ".", *instancePeer, *instanceListen) diff --git a/setup/base/base.go b/setup/base/base.go index 14edadd96..d3adbf53f 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -413,6 +413,24 @@ func (b *BaseDendrite) configureHTTPErrors() { b.PublicClientAPIMux.MethodNotAllowedHandler = http.HandlerFunc(clientNotFoundHandler) } +func (b *BaseDendrite) ConfigureAdminEndpoints() { + b.DendriteAdminMux.HandleFunc("/monitor/up", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + }) + b.DendriteAdminMux.HandleFunc("/monitor/health", func(w http.ResponseWriter, r *http.Request) { + if isDegraded, reasons := b.ProcessContext.IsDegraded(); isDegraded { + w.WriteHeader(503) + _ = json.NewEncoder(w).Encode(struct { + Warnings []string `json:"warnings"` + }{ + Warnings: reasons, + }) + return + } + w.WriteHeader(200) + }) +} + // SetupAndServeHTTP sets up the HTTP server to serve endpoints registered on // ApiMux under /api/ and adds a prometheus handler under /metrics. func (b *BaseDendrite) SetupAndServeHTTP( @@ -463,21 +481,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( internalRouter.Handle("/metrics", httputil.WrapHandlerInBasicAuth(promhttp.Handler(), b.Cfg.Global.Metrics.BasicAuth)) } - b.DendriteAdminMux.HandleFunc("/monitor/up", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) - }) - b.DendriteAdminMux.HandleFunc("/monitor/health", func(w http.ResponseWriter, r *http.Request) { - if isDegraded, reasons := b.ProcessContext.IsDegraded(); isDegraded { - w.WriteHeader(503) - _ = json.NewEncoder(w).Encode(struct { - Warnings []string `json:"warnings"` - }{ - Warnings: reasons, - }) - return - } - w.WriteHeader(200) - }) + b.ConfigureAdminEndpoints() var clientHandler http.Handler clientHandler = b.PublicClientAPIMux From 9a46d8d95c937bdb356f9e373424e1a37aa37f04 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 2 Dec 2022 11:44:20 +0100 Subject: [PATCH 13/67] Test and CI related changes (#2896) In an attempt to: - make on-boarding a bit easier (`go test ./...` should now not need additional postgres setup) - get code coverage faster, not only scheduled at night - test the `create-account` binary --- .github/workflows/dendrite.yml | 75 ++++++++++++++++++++++++-- .github/workflows/schedules.yaml | 82 ++++------------------------- cmd/dendrite-upgrade-tests/main.go | 43 +++++++++++++++ docs/CONTRIBUTING.md | 15 +++++- internal/pushgateway/client.go | 5 +- internal/pushgateway/client_test.go | 54 +++++++++++++++++++ 6 files changed, 197 insertions(+), 77 deletions(-) create mode 100644 internal/pushgateway/client_test.go diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index fa4282384..593012ef3 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -68,7 +68,7 @@ jobs: # run go test with different go versions test: - timeout-minutes: 5 + timeout-minutes: 10 name: Unit tests (Go ${{ matrix.go }}) runs-on: ubuntu-latest # Service containers to run with `container-job` @@ -94,14 +94,22 @@ jobs: strategy: fail-fast: false matrix: - go: ["1.18", "1.19"] + go: ["1.19"] steps: - uses: actions/checkout@v3 - name: Setup go uses: actions/setup-go@v3 with: go-version: ${{ matrix.go }} - cache: true + - uses: actions/cache@v3 + # manually set up caches, as they otherwise clash with different steps using setup-go with cache=true + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go${{ matrix.go }}-unit-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go${{ matrix.go }}-unit- - name: Set up gotestfmt uses: gotesttools/gotestfmt-action@v2 with: @@ -194,6 +202,66 @@ jobs: with: jobs: ${{ toJSON(needs) }} + # run go test with different go versions + integration: + timeout-minutes: 20 + needs: initial-tests-done + name: Integration tests (Go ${{ matrix.go }}) + runs-on: ubuntu-latest + # Service containers to run with `container-job` + services: + # Label used to access the service container + postgres: + # Docker Hub image + image: postgres:13-alpine + # Provide the password for postgres + env: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dendrite + ports: + # Maps tcp port 5432 on service container to the host + - 5432:5432 + # Set health checks to wait until postgres has started + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + strategy: + fail-fast: false + matrix: + go: ["1.19"] + steps: + - uses: actions/checkout@v3 + - name: Setup go + uses: actions/setup-go@v3 + with: + go-version: ${{ matrix.go }} + - name: Set up gotestfmt + uses: gotesttools/gotestfmt-action@v2 + with: + # Optional: pass GITHUB_TOKEN to avoid rate limiting. + token: ${{ secrets.GITHUB_TOKEN }} + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go${{ matrix.go }}-test-race- + - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt + env: + POSTGRES_HOST: localhost + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: dendrite + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + flags: unittests + # run database upgrade tests upgrade_test: name: Upgrade tests @@ -404,6 +472,7 @@ jobs: upgrade_test_direct, sytest, complement, + integration ] runs-on: ubuntu-latest if: ${{ !cancelled() }} # Run this even if prior jobs were skipped diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index ff4d47187..d2a1f6e1f 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -10,79 +10,9 @@ concurrency: cancel-in-progress: true jobs: - # run go test with different go versions - test: - timeout-minutes: 20 - name: Unit tests (Go ${{ matrix.go }}) - runs-on: ubuntu-latest - # Service containers to run with `container-job` - services: - # Label used to access the service container - postgres: - # Docker Hub image - image: postgres:13-alpine - # Provide the password for postgres - env: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: dendrite - ports: - # Maps tcp port 5432 on service container to the host - - 5432:5432 - # Set health checks to wait until postgres has started - options: >- - --health-cmd pg_isready - --health-interval 10s - --health-timeout 5s - --health-retries 5 - strategy: - fail-fast: false - matrix: - go: ["1.18", "1.19"] - steps: - - uses: actions/checkout@v3 - - name: Setup go - uses: actions/setup-go@v3 - with: - go-version: ${{ matrix.go }} - - name: Set up gotestfmt - uses: gotesttools/gotestfmt-action@v2 - with: - # Optional: pass GITHUB_TOKEN to avoid rate limiting. - token: ${{ secrets.GITHUB_TOKEN }} - - uses: actions/cache@v3 - with: - path: | - ~/.cache/go-build - ~/go/pkg/mod - key: ${{ runner.os }}-go${{ matrix.go }}-test-race-${{ hashFiles('**/go.sum') }} - restore-keys: | - ${{ runner.os }}-go${{ matrix.go }}-test-race- - - run: go test -race -json -v -coverpkg=./... -coverprofile=cover.out $(go list ./... | grep -v /cmd/dendrite*) 2>&1 | gotestfmt - env: - POSTGRES_HOST: localhost - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: dendrite - - name: Upload coverage to Codecov - uses: codecov/codecov-action@v3 - - # Dummy step to gate other tests on without repeating the whole list - initial-tests-done: - name: Initial tests passed - needs: [test] - runs-on: ubuntu-latest - if: ${{ !cancelled() }} # Run this even if prior jobs were skipped - steps: - - name: Check initial tests passed - uses: re-actors/alls-green@release/v1 - with: - jobs: ${{ toJSON(needs) }} - # run Sytest in different variations sytest: timeout-minutes: 60 - needs: initial-tests-done name: "Sytest (${{ matrix.label }})" runs-on: ubuntu-latest strategy: @@ -104,13 +34,23 @@ jobs: image: matrixdotorg/sytest-dendrite:latest volumes: - ${{ github.workspace }}:/src + - /root/.cache/go-build:/github/home/.cache/go-build + - /root/.cache/go-mod:/gopath/pkg/mod env: POSTGRES: ${{ matrix.postgres && 1}} API: ${{ matrix.api && 1 }} SYTEST_BRANCH: ${{ github.head_ref }} RACE_DETECTION: 1 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + /gopath/pkg/mod + key: ${{ runner.os }}-go-sytest-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go-sytest- - name: Run Sytest run: /bootstrap.sh dendrite working-directory: /src diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 75446d18c..39b9320cb 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -7,6 +7,7 @@ import ( "flag" "fmt" "io" + "io/ioutil" "log" "net/http" "os" @@ -61,6 +62,7 @@ COPY . . RUN go build ./cmd/dendrite-monolith-server RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config +RUN go build ./cmd/create-account RUN ./generate-config --ci > dendrite.yaml RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key @@ -104,6 +106,7 @@ COPY . . RUN go build ./cmd/dendrite-monolith-server RUN go build ./cmd/generate-keys RUN go build ./cmd/generate-config +RUN go build ./cmd/create-account RUN ./generate-config --ci > dendrite.yaml RUN ./generate-keys --private-key matrix_key.pem --tls-cert server.crt --tls-key server.key @@ -458,6 +461,46 @@ func loadAndRunTests(dockerClient *client.Client, volumeName, v string, branchTo if err = runTests(csAPIURL, v); err != nil { return fmt.Errorf("failed to run tests on version %s: %s", v, err) } + + err = testCreateAccount(dockerClient, v, containerID) + if err != nil { + return err + } + return nil +} + +// test that create-account is working +func testCreateAccount(dockerClient *client.Client, v string, containerID string) error { + createUser := strings.ToLower("createaccountuser-" + v) + log.Printf("%s: Creating account %s with create-account\n", v, createUser) + + respID, err := dockerClient.ContainerExecCreate(context.Background(), containerID, types.ExecConfig{ + AttachStderr: true, + AttachStdout: true, + Cmd: []string{ + "/build/create-account", + "-username", createUser, + "-password", "someRandomPassword", + }, + }) + if err != nil { + return fmt.Errorf("failed to ContainerExecCreate: %w", err) + } + + response, err := dockerClient.ContainerExecAttach(context.Background(), respID.ID, types.ExecStartCheck{}) + if err != nil { + return fmt.Errorf("failed to attach to container: %w", err) + } + defer response.Close() + + data, err := ioutil.ReadAll(response.Reader) + if err != nil { + return err + } + + if !bytes.Contains(data, []byte("AccessToken")) { + return fmt.Errorf("failed to create-account: %s", string(data)) + } return nil } diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 6ba05f46f..262a93a7c 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -75,7 +75,20 @@ comment. Please avoid doing this if you can. We also have unit tests which we run via: ```bash -go test --race ./... +DENDRITE_TEST_SKIP_NODB=1 go test --race ./... +``` + +This only runs SQLite database tests. If you wish to execute Postgres tests as well, you'll either need to +have Postgres installed locally (`createdb` will be used) or have a remote/containerized Postgres instance +available. + +To configure the connection to a remote Postgres, you can use the following enviroment variables: + +```bash +POSTGRES_USER=postgres +POSTGERS_PASSWORD=yourPostgresPassword +POSTGRES_HOST=localhost +POSTGRES_DB=postgres # the superuser database to use ``` In general, we like submissions that come with tests. Anything that proves that the diff --git a/internal/pushgateway/client.go b/internal/pushgateway/client.go index 95f5afd90..259239b87 100644 --- a/internal/pushgateway/client.go +++ b/internal/pushgateway/client.go @@ -9,6 +9,8 @@ import ( "net/http" "time" + "github.com/matrix-org/dendrite/internal" + "github.com/opentracing/opentracing-go" ) @@ -50,8 +52,7 @@ func (h *httpClient) Notify(ctx context.Context, url string, req *NotifyRequest, return err } - //nolint:errcheck - defer hresp.Body.Close() + defer internal.CloseAndLogIfError(ctx, hresp.Body, "failed to close response body") if hresp.StatusCode == http.StatusOK { return json.NewDecoder(hresp.Body).Decode(resp) diff --git a/internal/pushgateway/client_test.go b/internal/pushgateway/client_test.go new file mode 100644 index 000000000..bd0dca470 --- /dev/null +++ b/internal/pushgateway/client_test.go @@ -0,0 +1,54 @@ +package pushgateway + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestNotify(t *testing.T) { + wantResponse := NotifyResponse{ + Rejected: []string{"testing"}, + } + + var i = 0 + + svr := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // /notify only accepts POST requests + if r.Method != http.MethodPost { + w.WriteHeader(http.StatusNotImplemented) + return + } + + if i != 0 { // error path + w.WriteHeader(http.StatusBadRequest) + return + } + + // happy path + json.NewEncoder(w).Encode(wantResponse) + })) + defer svr.Close() + + cl := NewHTTPClient(true) + gotResponse := NotifyResponse{} + + // Test happy path + err := cl.Notify(context.Background(), svr.URL, &NotifyRequest{}, &gotResponse) + if err != nil { + t.Errorf("failed to notify client") + } + if !reflect.DeepEqual(gotResponse, wantResponse) { + t.Errorf("expected response %+v, got %+v", wantResponse, gotResponse) + } + + // Test error path + i++ + err = cl.Notify(context.Background(), svr.URL, &NotifyRequest{}, &gotResponse) + if err == nil { + t.Errorf("expected notifying the pushgateway to fail, but it succeeded") + } +} From b65f89e61e95b295e46ac3ade3c860b56126fa90 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 2 Dec 2022 16:42:23 +0100 Subject: [PATCH 14/67] Add tests for the AS internal API (#2898) --- appservice/appservice_test.go | 224 ++++++++++++++++++++++++++++++++++ 1 file changed, 224 insertions(+) create mode 100644 appservice/appservice_test.go diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go new file mode 100644 index 000000000..5a3a9aef7 --- /dev/null +++ b/appservice/appservice_test.go @@ -0,0 +1,224 @@ +package appservice_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "reflect" + "regexp" + "strings" + "testing" + + "github.com/gorilla/mux" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/appservice/api" + "github.com/matrix-org/dendrite/appservice/inthttp" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/userapi" + + "github.com/matrix-org/dendrite/test/testrig" +) + +func TestAppserviceInternalAPI(t *testing.T) { + + // Set expected results + existingProtocol := "irc" + wantLocationResponse := []api.ASLocationResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantUserResponse := []api.ASUserResponse{{Protocol: existingProtocol, Fields: []byte("{}")}} + wantProtocolResponse := api.ASProtocolResponse{Instances: []api.ProtocolInstance{{Fields: []byte("{}")}}} + wantProtocolResult := map[string]api.ASProtocolResponse{ + existingProtocol: wantProtocolResponse, + } + + // create a dummy AS url, handling some cases + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "location"): + // Check if we've got an existing protocol, if so, return a proper response. + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantLocationResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.ASLocationResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "user"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantUserResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode([]api.UserResponse{}); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + case strings.Contains(r.URL.Path, "protocol"): + if r.URL.Path[len(r.URL.Path)-len(existingProtocol):] == existingProtocol { + if err := json.NewEncoder(w).Encode(wantProtocolResponse); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + } + if err := json.NewEncoder(w).Encode(nil); err != nil { + t.Fatalf("failed to encode response: %s", err) + } + return + default: + t.Logf("hit location: %s", r.URL.Path) + } + })) + + // TODO: use test.WithAllDatabases + // only one DBType, since appservice.AddInternalRoutes complains about multiple prometheus counters added + base, closeBase := testrig.CreateBaseDendrite(t, test.DBTypeSQLite) + defer closeBase() + + // Create a dummy application service + base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ + { + ID: "someID", + URL: srv.URL, + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, + "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, + }, + Protocols: []string{existingProtocol}, + }, + } + + // Create required internal APIs + rsAPI := roomserver.NewInternalAPI(base) + usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) + + // The test cases to run + runCases := func(t *testing.T, testAPI api.AppServiceInternalAPI) { + t.Run("UserIDExists", func(t *testing.T) { + testUserIDExists(t, testAPI, "@as-testing:test", true) + testUserIDExists(t, testAPI, "@as1-testing:test", false) + }) + + t.Run("AliasExists", func(t *testing.T) { + testAliasExists(t, testAPI, "@asroom-testing:test", true) + testAliasExists(t, testAPI, "@asroom1-testing:test", false) + }) + + t.Run("Locations", func(t *testing.T) { + testLocations(t, testAPI, existingProtocol, wantLocationResponse) + testLocations(t, testAPI, "abc", nil) + }) + + t.Run("User", func(t *testing.T) { + testUser(t, testAPI, existingProtocol, wantUserResponse) + testUser(t, testAPI, "abc", nil) + }) + + t.Run("Protocols", func(t *testing.T) { + testProtocol(t, testAPI, existingProtocol, wantProtocolResult) + testProtocol(t, testAPI, existingProtocol, wantProtocolResult) // tests the cache + testProtocol(t, testAPI, "", wantProtocolResult) // tests getting all protocols + testProtocol(t, testAPI, "abc", nil) + }) + } + + // Finally execute the tests + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + appservice.AddInternalRoutes(router, asAPI) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + asHTTPApi, err := inthttp.NewAppserviceClient(apiURL, &http.Client{}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + runCases(t, asHTTPApi) + }) + + t.Run("Monolith", func(t *testing.T) { + runCases(t, asAPI) + }) + +} + +func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) { + ctx := context.Background() + userResp := &api.UserIDExistsResponse{} + + if err := asAPI.UserIDExists(ctx, &api.UserIDExistsRequest{ + UserID: userID, + }, userResp); err != nil { + t.Errorf("failed to get userID: %s", err) + } + if userResp.UserIDExists != wantExists { + t.Errorf("unexpected result for UserIDExists(%s): %v, expected %v", userID, userResp.UserIDExists, wantExists) + } +} + +func testAliasExists(t *testing.T, asAPI api.AppServiceInternalAPI, alias string, wantExists bool) { + ctx := context.Background() + aliasResp := &api.RoomAliasExistsResponse{} + + if err := asAPI.RoomAliasExists(ctx, &api.RoomAliasExistsRequest{ + Alias: alias, + }, aliasResp); err != nil { + t.Errorf("failed to get alias: %s", err) + } + if aliasResp.AliasExists != wantExists { + t.Errorf("unexpected result for RoomAliasExists(%s): %v, expected %v", alias, aliasResp.AliasExists, wantExists) + } +} + +func testLocations(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASLocationResponse) { + ctx := context.Background() + locationResp := &api.LocationResponse{} + + if err := asAPI.Locations(ctx, &api.LocationRequest{ + Protocol: proto, + }, locationResp); err != nil { + t.Errorf("failed to get locations: %s", err) + } + if !reflect.DeepEqual(locationResp.Locations, wantResult) { + t.Errorf("unexpected result for Locations(%s): %+v, expected %+v", proto, locationResp.Locations, wantResult) + } +} + +func testUser(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult []api.ASUserResponse) { + ctx := context.Background() + userResp := &api.UserResponse{} + + if err := asAPI.User(ctx, &api.UserRequest{ + Protocol: proto, + }, userResp); err != nil { + t.Errorf("failed to get user: %s", err) + } + if !reflect.DeepEqual(userResp.Users, wantResult) { + t.Errorf("unexpected result for User(%s): %+v, expected %+v", proto, userResp.Users, wantResult) + } +} + +func testProtocol(t *testing.T, asAPI api.AppServiceInternalAPI, proto string, wantResult map[string]api.ASProtocolResponse) { + ctx := context.Background() + protoResp := &api.ProtocolResponse{} + + if err := asAPI.Protocols(ctx, &api.ProtocolRequest{ + Protocol: proto, + }, protoResp); err != nil { + t.Errorf("failed to get Protocols: %s", err) + } + if !reflect.DeepEqual(protoResp.Protocols, wantResult) { + t.Errorf("unexpected result for Protocols(%s): %+v, expected %+v", proto, protoResp.Protocols[proto], wantResult) + } +} From e245a26f6bcb4d134015f49f621b6d639a78707f Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 13:53:36 +0100 Subject: [PATCH 15/67] Enable/Disable internal metrics (#2899) Basically enables us to use `test.WithAllDatabases` when testing internal HTTP APIs, as this would otherwise result in Prometheus complaining about already registered metric names. --- appservice/appservice.go | 7 +- appservice/inthttp/server.go | 12 +-- cmd/dendrite-monolith-server/main.go | 13 +-- .../personalities/appservice.go | 2 +- .../personalities/federationapi.go | 2 +- .../personalities/keyserver.go | 2 +- .../personalities/roomserver.go | 2 +- .../personalities/userapi.go | 2 +- federationapi/federationapi.go | 4 +- federationapi/inthttp/server.go | 46 +++++------ internal/httputil/httpapi.go | 15 ++-- internal/httputil/internalapi.go | 10 +-- keyserver/inthttp/server.go | 24 +++--- keyserver/keyserver.go | 4 +- roomserver/inthttp/server.go | 80 +++++++++---------- roomserver/roomserver.go | 7 +- userapi/inthttp/server.go | 71 ++++++++-------- userapi/inthttp/server_logintoken.go | 9 ++- userapi/userapi.go | 4 +- userapi/userapi_test.go | 2 +- 20 files changed, 164 insertions(+), 154 deletions(-) diff --git a/appservice/appservice.go b/appservice/appservice.go index b3c28dbde..753850de7 100644 --- a/appservice/appservice.go +++ b/appservice/appservice.go @@ -24,6 +24,8 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/appservice/consumers" "github.com/matrix-org/dendrite/appservice/inthttp" @@ -32,12 +34,11 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // AddInternalRoutes registers HTTP handlers for internal API calls -func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceInternalAPI) { - inthttp.AddRoutes(queryAPI, router) +func AddInternalRoutes(router *mux.Router, queryAPI appserviceAPI.AppServiceInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(queryAPI, router, enableMetrics) } // NewInternalAPI returns a concerete implementation of the internal API. Callers diff --git a/appservice/inthttp/server.go b/appservice/inthttp/server.go index ccf5c83d8..b70fad673 100644 --- a/appservice/inthttp/server.go +++ b/appservice/inthttp/server.go @@ -8,29 +8,29 @@ import ( ) // AddRoutes adds the AppServiceQueryAPI handlers to the http.ServeMux. -func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router) { +func AddRoutes(a api.AppServiceInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { internalAPIMux.Handle( AppServiceRoomAliasExistsPath, - httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", a.RoomAliasExists), + httputil.MakeInternalRPCAPI("AppserviceRoomAliasExists", enableMetrics, a.RoomAliasExists), ) internalAPIMux.Handle( AppServiceUserIDExistsPath, - httputil.MakeInternalRPCAPI("AppserviceUserIDExists", a.UserIDExists), + httputil.MakeInternalRPCAPI("AppserviceUserIDExists", enableMetrics, a.UserIDExists), ) internalAPIMux.Handle( AppServiceProtocolsPath, - httputil.MakeInternalRPCAPI("AppserviceProtocols", a.Protocols), + httputil.MakeInternalRPCAPI("AppserviceProtocols", enableMetrics, a.Protocols), ) internalAPIMux.Handle( AppServiceLocationsPath, - httputil.MakeInternalRPCAPI("AppserviceLocations", a.Locations), + httputil.MakeInternalRPCAPI("AppserviceLocations", enableMetrics, a.Locations), ) internalAPIMux.Handle( AppServiceUserPath, - httputil.MakeInternalRPCAPI("AppserviceUser", a.User), + httputil.MakeInternalRPCAPI("AppserviceUser", enableMetrics, a.User), ) } diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index ff980dc1c..2d2f32b00 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -18,6 +18,8 @@ import ( "flag" "os" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/appservice" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/keyserver" @@ -29,7 +31,6 @@ import ( "github.com/matrix-org/dendrite/setup/mscs" "github.com/matrix-org/dendrite/userapi" uapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/sirupsen/logrus" ) var ( @@ -75,7 +76,7 @@ func main() { // call functions directly on the impl unless running in HTTP mode rsAPI := rsImpl if base.UseHTTPAPIs { - roomserver.AddInternalRoutes(base.InternalAPIMux, rsImpl) + roomserver.AddInternalRoutes(base.InternalAPIMux, rsImpl, base.EnableMetrics) rsAPI = base.RoomserverHTTPClient() } if traceInternal { @@ -89,7 +90,7 @@ func main() { ) fsImplAPI := fsAPI if base.UseHTTPAPIs { - federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI) + federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI, base.EnableMetrics) fsAPI = base.FederationAPIHTTPClient() } keyRing := fsAPI.KeyRing() @@ -97,7 +98,7 @@ func main() { keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) keyAPI := keyImpl if base.UseHTTPAPIs { - keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI) + keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics) keyAPI = base.KeyServerHTTPClient() } @@ -105,7 +106,7 @@ func main() { userImpl := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, pgClient) userAPI := userImpl if base.UseHTTPAPIs { - userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) + userapi.AddInternalRoutes(base.InternalAPIMux, userAPI, base.EnableMetrics) userAPI = base.UserAPIClient() } if traceInternal { @@ -119,7 +120,7 @@ func main() { // before the listeners are up. asAPI := appservice.NewInternalAPI(base, userImpl, rsAPI) if base.UseHTTPAPIs { - appservice.AddInternalRoutes(base.InternalAPIMux, asAPI) + appservice.AddInternalRoutes(base.InternalAPIMux, asAPI, base.EnableMetrics) asAPI = base.AppserviceHTTPClient() } diff --git a/cmd/dendrite-polylith-multi/personalities/appservice.go b/cmd/dendrite-polylith-multi/personalities/appservice.go index 4f74434a4..0547d57f0 100644 --- a/cmd/dendrite-polylith-multi/personalities/appservice.go +++ b/cmd/dendrite-polylith-multi/personalities/appservice.go @@ -26,7 +26,7 @@ func Appservice(base *base.BaseDendrite, cfg *config.Dendrite) { rsAPI := base.RoomserverHTTPClient() intAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) - appservice.AddInternalRoutes(base.InternalAPIMux, intAPI) + appservice.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.AppServiceAPI.InternalAPI.Listen, // internal listener diff --git a/cmd/dendrite-polylith-multi/personalities/federationapi.go b/cmd/dendrite-polylith-multi/personalities/federationapi.go index 6377ce9e3..48da42fbf 100644 --- a/cmd/dendrite-polylith-multi/personalities/federationapi.go +++ b/cmd/dendrite-polylith-multi/personalities/federationapi.go @@ -34,7 +34,7 @@ func FederationAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { rsAPI, fsAPI, keyAPI, nil, ) - federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI) + federationapi.AddInternalRoutes(base.InternalAPIMux, fsAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.FederationAPI.InternalAPI.Listen, diff --git a/cmd/dendrite-polylith-multi/personalities/keyserver.go b/cmd/dendrite-polylith-multi/personalities/keyserver.go index f8aa57b86..d2924b892 100644 --- a/cmd/dendrite-polylith-multi/personalities/keyserver.go +++ b/cmd/dendrite-polylith-multi/personalities/keyserver.go @@ -25,7 +25,7 @@ func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) intAPI.SetUserAPI(base.UserAPIClient()) - keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI) + keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.KeyServer.InternalAPI.Listen, // internal listener diff --git a/cmd/dendrite-polylith-multi/personalities/roomserver.go b/cmd/dendrite-polylith-multi/personalities/roomserver.go index 1deb51ce0..974559bd2 100644 --- a/cmd/dendrite-polylith-multi/personalities/roomserver.go +++ b/cmd/dendrite-polylith-multi/personalities/roomserver.go @@ -26,7 +26,7 @@ func RoomServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { rsAPI := roomserver.NewInternalAPI(base) rsAPI.SetFederationAPI(fsAPI, fsAPI.KeyRing()) rsAPI.SetAppserviceAPI(asAPI) - roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI) + roomserver.AddInternalRoutes(base.InternalAPIMux, rsAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.RoomServer.InternalAPI.Listen, // internal listener diff --git a/cmd/dendrite-polylith-multi/personalities/userapi.go b/cmd/dendrite-polylith-multi/personalities/userapi.go index 3fe5a43d7..1bc88cb5f 100644 --- a/cmd/dendrite-polylith-multi/personalities/userapi.go +++ b/cmd/dendrite-polylith-multi/personalities/userapi.go @@ -27,7 +27,7 @@ func UserAPI(base *basepkg.BaseDendrite, cfg *config.Dendrite) { base.PushGatewayHTTPClient(), ) - userapi.AddInternalRoutes(base.InternalAPIMux, userAPI) + userapi.AddInternalRoutes(base.InternalAPIMux, userAPI, base.EnableMetrics) base.SetupAndServeHTTP( base.Cfg.UserAPI.InternalAPI.Listen, // internal listener diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 854251220..87eb751f5 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -43,8 +43,8 @@ import ( // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.FederationInternalAPI) { - inthttp.AddRoutes(intAPI, router) +func AddInternalRoutes(router *mux.Router, intAPI api.FederationInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(intAPI, router, enableMetrics) } // AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. diff --git a/federationapi/inthttp/server.go b/federationapi/inthttp/server.go index 21a070392..9068dc400 100644 --- a/federationapi/inthttp/server.go +++ b/federationapi/inthttp/server.go @@ -17,41 +17,41 @@ import ( // AddRoutes adds the FederationInternalAPI handlers to the http.ServeMux. // nolint:gocyclo -func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { +func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { internalAPIMux.Handle( FederationAPIQueryJoinedHostServerNamesInRoomPath, - httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", intAPI.QueryJoinedHostServerNamesInRoom), + httputil.MakeInternalRPCAPI("FederationAPIQueryJoinedHostServerNamesInRoom", enableMetrics, intAPI.QueryJoinedHostServerNamesInRoom), ) internalAPIMux.Handle( FederationAPIPerformInviteRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", intAPI.PerformInvite), + httputil.MakeInternalRPCAPI("FederationAPIPerformInvite", enableMetrics, intAPI.PerformInvite), ) internalAPIMux.Handle( FederationAPIPerformLeaveRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", intAPI.PerformLeave), + httputil.MakeInternalRPCAPI("FederationAPIPerformLeave", enableMetrics, intAPI.PerformLeave), ) internalAPIMux.Handle( FederationAPIPerformDirectoryLookupRequestPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", intAPI.PerformDirectoryLookup), + httputil.MakeInternalRPCAPI("FederationAPIPerformDirectoryLookupRequest", enableMetrics, intAPI.PerformDirectoryLookup), ) internalAPIMux.Handle( FederationAPIPerformBroadcastEDUPath, - httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", intAPI.PerformBroadcastEDU), + httputil.MakeInternalRPCAPI("FederationAPIPerformBroadcastEDU", enableMetrics, intAPI.PerformBroadcastEDU), ) internalAPIMux.Handle( FederationAPIPerformWakeupServers, - httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", intAPI.PerformWakeupServers), + httputil.MakeInternalRPCAPI("FederationAPIPerformWakeupServers", enableMetrics, intAPI.PerformWakeupServers), ) internalAPIMux.Handle( FederationAPIPerformJoinRequestPath, httputil.MakeInternalRPCAPI( - "FederationAPIPerformJoinRequest", + "FederationAPIPerformJoinRequest", enableMetrics, func(ctx context.Context, req *api.PerformJoinRequest, res *api.PerformJoinResponse) error { intAPI.PerformJoin(ctx, req, res) return nil @@ -62,7 +62,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIGetUserDevicesPath, httputil.MakeInternalProxyAPI( - "FederationAPIGetUserDevices", + "FederationAPIGetUserDevices", enableMetrics, func(ctx context.Context, req *getUserDevices) (*gomatrixserverlib.RespUserDevices, error) { res, err := intAPI.GetUserDevices(ctx, req.Origin, req.S, req.UserID) return &res, federationClientError(err) @@ -73,7 +73,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIClaimKeysPath, httputil.MakeInternalProxyAPI( - "FederationAPIClaimKeys", + "FederationAPIClaimKeys", enableMetrics, func(ctx context.Context, req *claimKeys) (*gomatrixserverlib.RespClaimKeys, error) { res, err := intAPI.ClaimKeys(ctx, req.Origin, req.S, req.OneTimeKeys) return &res, federationClientError(err) @@ -84,7 +84,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIQueryKeysPath, httputil.MakeInternalProxyAPI( - "FederationAPIQueryKeys", + "FederationAPIQueryKeys", enableMetrics, func(ctx context.Context, req *queryKeys) (*gomatrixserverlib.RespQueryKeys, error) { res, err := intAPI.QueryKeys(ctx, req.Origin, req.S, req.Keys) return &res, federationClientError(err) @@ -95,7 +95,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIBackfillPath, httputil.MakeInternalProxyAPI( - "FederationAPIBackfill", + "FederationAPIBackfill", enableMetrics, func(ctx context.Context, req *backfill) (*gomatrixserverlib.Transaction, error) { res, err := intAPI.Backfill(ctx, req.Origin, req.S, req.RoomID, req.Limit, req.EventIDs) return &res, federationClientError(err) @@ -106,7 +106,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPILookupStatePath, httputil.MakeInternalProxyAPI( - "FederationAPILookupState", + "FederationAPILookupState", enableMetrics, func(ctx context.Context, req *lookupState) (*gomatrixserverlib.RespState, error) { res, err := intAPI.LookupState(ctx, req.Origin, req.S, req.RoomID, req.EventID, req.RoomVersion) return &res, federationClientError(err) @@ -117,7 +117,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPILookupStateIDsPath, httputil.MakeInternalProxyAPI( - "FederationAPILookupStateIDs", + "FederationAPILookupStateIDs", enableMetrics, func(ctx context.Context, req *lookupStateIDs) (*gomatrixserverlib.RespStateIDs, error) { res, err := intAPI.LookupStateIDs(ctx, req.Origin, req.S, req.RoomID, req.EventID) return &res, federationClientError(err) @@ -128,7 +128,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPILookupMissingEventsPath, httputil.MakeInternalProxyAPI( - "FederationAPILookupMissingEvents", + "FederationAPILookupMissingEvents", enableMetrics, func(ctx context.Context, req *lookupMissingEvents) (*gomatrixserverlib.RespMissingEvents, error) { res, err := intAPI.LookupMissingEvents(ctx, req.Origin, req.S, req.RoomID, req.Missing, req.RoomVersion) return &res, federationClientError(err) @@ -139,7 +139,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIGetEventPath, httputil.MakeInternalProxyAPI( - "FederationAPIGetEvent", + "FederationAPIGetEvent", enableMetrics, func(ctx context.Context, req *getEvent) (*gomatrixserverlib.Transaction, error) { res, err := intAPI.GetEvent(ctx, req.Origin, req.S, req.EventID) return &res, federationClientError(err) @@ -150,7 +150,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIGetEventAuthPath, httputil.MakeInternalProxyAPI( - "FederationAPIGetEventAuth", + "FederationAPIGetEventAuth", enableMetrics, func(ctx context.Context, req *getEventAuth) (*gomatrixserverlib.RespEventAuth, error) { res, err := intAPI.GetEventAuth(ctx, req.Origin, req.S, req.RoomVersion, req.RoomID, req.EventID) return &res, federationClientError(err) @@ -160,13 +160,13 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIQueryServerKeysPath, - httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", intAPI.QueryServerKeys), + httputil.MakeInternalRPCAPI("FederationAPIQueryServerKeys", enableMetrics, intAPI.QueryServerKeys), ) internalAPIMux.Handle( FederationAPILookupServerKeysPath, httputil.MakeInternalProxyAPI( - "FederationAPILookupServerKeys", + "FederationAPILookupServerKeys", enableMetrics, func(ctx context.Context, req *lookupServerKeys) (*[]gomatrixserverlib.ServerKeys, error) { res, err := intAPI.LookupServerKeys(ctx, req.S, req.KeyRequests) return &res, federationClientError(err) @@ -177,7 +177,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPIEventRelationshipsPath, httputil.MakeInternalProxyAPI( - "FederationAPIMSC2836EventRelationships", + "FederationAPIMSC2836EventRelationships", enableMetrics, func(ctx context.Context, req *eventRelationships) (*gomatrixserverlib.MSC2836EventRelationshipsResponse, error) { res, err := intAPI.MSC2836EventRelationships(ctx, req.Origin, req.S, req.Req, req.RoomVer) return &res, federationClientError(err) @@ -188,7 +188,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { internalAPIMux.Handle( FederationAPISpacesSummaryPath, httputil.MakeInternalProxyAPI( - "FederationAPIMSC2946SpacesSummary", + "FederationAPIMSC2946SpacesSummary", enableMetrics, func(ctx context.Context, req *spacesReq) (*gomatrixserverlib.MSC2946SpacesResponse, error) { res, err := intAPI.MSC2946Spaces(ctx, req.Origin, req.S, req.RoomID, req.SuggestedOnly) return &res, federationClientError(err) @@ -198,7 +198,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { // TODO: Look at this shape internalAPIMux.Handle(FederationAPIQueryPublicKeyPath, - httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("FederationAPIQueryPublicKeys", enableMetrics, func(req *http.Request) util.JSONResponse { request := api.QueryPublicKeysRequest{} response := api.QueryPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -215,7 +215,7 @@ func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) { // TODO: Look at this shape internalAPIMux.Handle(FederationAPIInputPublicKeyPath, - httputil.MakeInternalAPI("FederationAPIInputPublicKeys", func(req *http.Request) util.JSONResponse { + httputil.MakeInternalAPI("FederationAPIInputPublicKeys", enableMetrics, func(req *http.Request) util.JSONResponse { request := api.InputPublicKeysRequest{} response := api.InputPublicKeysResponse{} if err := json.NewDecoder(req.Body).Decode(&request); err != nil { diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 4f33a3f79..127d1fac7 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -24,16 +24,17 @@ import ( "strings" "github.com/getsentry/sentry-go" - "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/jsonerror" - userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" - opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go/ext" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" + + "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + userapi "github.com/matrix-org/dendrite/userapi/api" ) // BasicAuth is used for authorization on /metrics handlers @@ -227,7 +228,7 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) // This is used for APIs that are internal to dendrite. // If we are passed a tracing context in the request headers then we use that // as the parent of any tracing spans we create. -func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler { +func MakeInternalAPI(metricsName string, enableMetrics bool, f func(*http.Request) util.JSONResponse) http.Handler { h := util.MakeJSONAPI(util.NewJSONRequestHandler(f)) withSpan := func(w http.ResponseWriter, req *http.Request) { carrier := opentracing.HTTPHeadersCarrier(req.Header) @@ -246,6 +247,10 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse h.ServeHTTP(w, req) } + if !enableMetrics { + return http.HandlerFunc(withSpan) + } + return promhttp.InstrumentHandlerCounter( promauto.NewCounterVec( prometheus.CounterOpts{ diff --git a/internal/httputil/internalapi.go b/internal/httputil/internalapi.go index 385092d9c..22f436e38 100644 --- a/internal/httputil/internalapi.go +++ b/internal/httputil/internalapi.go @@ -22,7 +22,7 @@ import ( "reflect" "github.com/matrix-org/util" - opentracing "github.com/opentracing/opentracing-go" + "github.com/opentracing/opentracing-go" ) type InternalAPIError struct { @@ -34,8 +34,8 @@ func (e InternalAPIError) Error() string { return fmt.Sprintf("internal API returned %q error: %s", e.Type, e.Message) } -func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype, *restype) error) http.Handler { - return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { +func MakeInternalRPCAPI[reqtype, restype any](metricsName string, enableMetrics bool, f func(context.Context, *reqtype, *restype) error) http.Handler { + return MakeInternalAPI(metricsName, enableMetrics, func(req *http.Request) util.JSONResponse { var request reqtype var response restype if err := json.NewDecoder(req.Body).Decode(&request); err != nil { @@ -57,8 +57,8 @@ func MakeInternalRPCAPI[reqtype, restype any](metricsName string, f func(context }) } -func MakeInternalProxyAPI[reqtype, restype any](metricsName string, f func(context.Context, *reqtype) (*restype, error)) http.Handler { - return MakeInternalAPI(metricsName, func(req *http.Request) util.JSONResponse { +func MakeInternalProxyAPI[reqtype, restype any](metricsName string, enableMetrics bool, f func(context.Context, *reqtype) (*restype, error)) http.Handler { + return MakeInternalAPI(metricsName, enableMetrics, func(req *http.Request) util.JSONResponse { var request reqtype if err := json.NewDecoder(req.Body).Decode(&request); err != nil { return util.MessageResponse(http.StatusBadRequest, err.Error()) diff --git a/keyserver/inthttp/server.go b/keyserver/inthttp/server.go index 7af0ff6e5..443269f73 100644 --- a/keyserver/inthttp/server.go +++ b/keyserver/inthttp/server.go @@ -21,59 +21,59 @@ import ( "github.com/matrix-org/dendrite/keyserver/api" ) -func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) { +func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI, enableMetrics bool) { internalAPIMux.Handle( PerformClaimKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", s.PerformClaimKeys), + httputil.MakeInternalRPCAPI("KeyserverPerformClaimKeys", enableMetrics, s.PerformClaimKeys), ) internalAPIMux.Handle( PerformDeleteKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", s.PerformDeleteKeys), + httputil.MakeInternalRPCAPI("KeyserverPerformDeleteKeys", enableMetrics, s.PerformDeleteKeys), ) internalAPIMux.Handle( PerformUploadKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", s.PerformUploadKeys), + httputil.MakeInternalRPCAPI("KeyserverPerformUploadKeys", enableMetrics, s.PerformUploadKeys), ) internalAPIMux.Handle( PerformUploadDeviceKeysPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", s.PerformUploadDeviceKeys), + httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceKeys", enableMetrics, s.PerformUploadDeviceKeys), ) internalAPIMux.Handle( PerformUploadDeviceSignaturesPath, - httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", s.PerformUploadDeviceSignatures), + httputil.MakeInternalRPCAPI("KeyserverPerformUploadDeviceSignatures", enableMetrics, s.PerformUploadDeviceSignatures), ) internalAPIMux.Handle( QueryKeysPath, - httputil.MakeInternalRPCAPI("KeyserverQueryKeys", s.QueryKeys), + httputil.MakeInternalRPCAPI("KeyserverQueryKeys", enableMetrics, s.QueryKeys), ) internalAPIMux.Handle( QueryOneTimeKeysPath, - httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", s.QueryOneTimeKeys), + httputil.MakeInternalRPCAPI("KeyserverQueryOneTimeKeys", enableMetrics, s.QueryOneTimeKeys), ) internalAPIMux.Handle( QueryDeviceMessagesPath, - httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", s.QueryDeviceMessages), + httputil.MakeInternalRPCAPI("KeyserverQueryDeviceMessages", enableMetrics, s.QueryDeviceMessages), ) internalAPIMux.Handle( QueryKeyChangesPath, - httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", s.QueryKeyChanges), + httputil.MakeInternalRPCAPI("KeyserverQueryKeyChanges", enableMetrics, s.QueryKeyChanges), ) internalAPIMux.Handle( QuerySignaturesPath, - httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", s.QuerySignatures), + httputil.MakeInternalRPCAPI("KeyserverQuerySignatures", enableMetrics, s.QuerySignatures), ) internalAPIMux.Handle( PerformMarkAsStalePath, - httputil.MakeInternalRPCAPI("KeyserverMarkAsStale", s.PerformMarkAsStaleIfNeeded), + httputil.MakeInternalRPCAPI("KeyserverMarkAsStale", enableMetrics, s.PerformMarkAsStaleIfNeeded), ) } diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index a86c2da4e..5360c06fd 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -32,8 +32,8 @@ import ( // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { - inthttp.AddRoutes(router, intAPI) +func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(router, intAPI, enableMetrics) } // NewInternalAPI returns a concerete implementation of the internal API. Callers diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 4d37e90b5..6e7c2d985 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -9,198 +9,198 @@ import ( // AddRoutes adds the RoomserverInternalAPI handlers to the http.ServeMux. // nolint: gocyclo -func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { +func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMetrics bool) { internalAPIMux.Handle( RoomserverInputRoomEventsPath, - httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", r.InputRoomEvents), + httputil.MakeInternalRPCAPI("RoomserverInputRoomEvents", enableMetrics, r.InputRoomEvents), ) internalAPIMux.Handle( RoomserverPerformInvitePath, - httputil.MakeInternalRPCAPI("RoomserverPerformInvite", r.PerformInvite), + httputil.MakeInternalRPCAPI("RoomserverPerformInvite", enableMetrics, r.PerformInvite), ) internalAPIMux.Handle( RoomserverPerformJoinPath, - httputil.MakeInternalRPCAPI("RoomserverPerformJoin", r.PerformJoin), + httputil.MakeInternalRPCAPI("RoomserverPerformJoin", enableMetrics, r.PerformJoin), ) internalAPIMux.Handle( RoomserverPerformLeavePath, - httputil.MakeInternalRPCAPI("RoomserverPerformLeave", r.PerformLeave), + httputil.MakeInternalRPCAPI("RoomserverPerformLeave", enableMetrics, r.PerformLeave), ) internalAPIMux.Handle( RoomserverPerformPeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformPeek", r.PerformPeek), + httputil.MakeInternalRPCAPI("RoomserverPerformPeek", enableMetrics, r.PerformPeek), ) internalAPIMux.Handle( RoomserverPerformInboundPeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", r.PerformInboundPeek), + httputil.MakeInternalRPCAPI("RoomserverPerformInboundPeek", enableMetrics, r.PerformInboundPeek), ) internalAPIMux.Handle( RoomserverPerformUnpeekPath, - httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", r.PerformUnpeek), + httputil.MakeInternalRPCAPI("RoomserverPerformUnpeek", enableMetrics, r.PerformUnpeek), ) internalAPIMux.Handle( RoomserverPerformRoomUpgradePath, - httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", r.PerformRoomUpgrade), + httputil.MakeInternalRPCAPI("RoomserverPerformRoomUpgrade", enableMetrics, r.PerformRoomUpgrade), ) internalAPIMux.Handle( RoomserverPerformPublishPath, - httputil.MakeInternalRPCAPI("RoomserverPerformPublish", r.PerformPublish), + httputil.MakeInternalRPCAPI("RoomserverPerformPublish", enableMetrics, r.PerformPublish), ) internalAPIMux.Handle( RoomserverPerformAdminEvacuateRoomPath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", r.PerformAdminEvacuateRoom), + httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateRoom", enableMetrics, r.PerformAdminEvacuateRoom), ) internalAPIMux.Handle( RoomserverPerformAdminEvacuateUserPath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", r.PerformAdminEvacuateUser), + httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", enableMetrics, r.PerformAdminEvacuateUser), ) internalAPIMux.Handle( RoomserverPerformAdminDownloadStatePath, - httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", r.PerformAdminDownloadState), + httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", enableMetrics, r.PerformAdminDownloadState), ) internalAPIMux.Handle( RoomserverQueryPublishedRoomsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", r.QueryPublishedRooms), + httputil.MakeInternalRPCAPI("RoomserverQueryPublishedRooms", enableMetrics, r.QueryPublishedRooms), ) internalAPIMux.Handle( RoomserverQueryLatestEventsAndStatePath, - httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", r.QueryLatestEventsAndState), + httputil.MakeInternalRPCAPI("RoomserverQueryLatestEventsAndState", enableMetrics, r.QueryLatestEventsAndState), ) internalAPIMux.Handle( RoomserverQueryStateAfterEventsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", r.QueryStateAfterEvents), + httputil.MakeInternalRPCAPI("RoomserverQueryStateAfterEvents", enableMetrics, r.QueryStateAfterEvents), ) internalAPIMux.Handle( RoomserverQueryEventsByIDPath, - httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", r.QueryEventsByID), + httputil.MakeInternalRPCAPI("RoomserverQueryEventsByID", enableMetrics, r.QueryEventsByID), ) internalAPIMux.Handle( RoomserverQueryMembershipForUserPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", r.QueryMembershipForUser), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipForUser", enableMetrics, r.QueryMembershipForUser), ) internalAPIMux.Handle( RoomserverQueryMembershipsForRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", r.QueryMembershipsForRoom), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipsForRoom", enableMetrics, r.QueryMembershipsForRoom), ) internalAPIMux.Handle( RoomserverQueryServerJoinedToRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", r.QueryServerJoinedToRoom), + httputil.MakeInternalRPCAPI("RoomserverQueryServerJoinedToRoom", enableMetrics, r.QueryServerJoinedToRoom), ) internalAPIMux.Handle( RoomserverQueryServerAllowedToSeeEventPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", r.QueryServerAllowedToSeeEvent), + httputil.MakeInternalRPCAPI("RoomserverQueryServerAllowedToSeeEvent", enableMetrics, r.QueryServerAllowedToSeeEvent), ) internalAPIMux.Handle( RoomserverQueryMissingEventsPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", r.QueryMissingEvents), + httputil.MakeInternalRPCAPI("RoomserverQueryMissingEvents", enableMetrics, r.QueryMissingEvents), ) internalAPIMux.Handle( RoomserverQueryStateAndAuthChainPath, - httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", r.QueryStateAndAuthChain), + httputil.MakeInternalRPCAPI("RoomserverQueryStateAndAuthChain", enableMetrics, r.QueryStateAndAuthChain), ) internalAPIMux.Handle( RoomserverPerformBackfillPath, - httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", r.PerformBackfill), + httputil.MakeInternalRPCAPI("RoomserverPerformBackfill", enableMetrics, r.PerformBackfill), ) internalAPIMux.Handle( RoomserverPerformForgetPath, - httputil.MakeInternalRPCAPI("RoomserverPerformForget", r.PerformForget), + httputil.MakeInternalRPCAPI("RoomserverPerformForget", enableMetrics, r.PerformForget), ) internalAPIMux.Handle( RoomserverQueryRoomVersionCapabilitiesPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", r.QueryRoomVersionCapabilities), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionCapabilities", enableMetrics, r.QueryRoomVersionCapabilities), ) internalAPIMux.Handle( RoomserverQueryRoomVersionForRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", r.QueryRoomVersionForRoom), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomVersionForRoom", enableMetrics, r.QueryRoomVersionForRoom), ) internalAPIMux.Handle( RoomserverSetRoomAliasPath, - httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", r.SetRoomAlias), + httputil.MakeInternalRPCAPI("RoomserverSetRoomAlias", enableMetrics, r.SetRoomAlias), ) internalAPIMux.Handle( RoomserverGetRoomIDForAliasPath, - httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", r.GetRoomIDForAlias), + httputil.MakeInternalRPCAPI("RoomserverGetRoomIDForAlias", enableMetrics, r.GetRoomIDForAlias), ) internalAPIMux.Handle( RoomserverGetAliasesForRoomIDPath, - httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", r.GetAliasesForRoomID), + httputil.MakeInternalRPCAPI("RoomserverGetAliasesForRoomID", enableMetrics, r.GetAliasesForRoomID), ) internalAPIMux.Handle( RoomserverRemoveRoomAliasPath, - httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", r.RemoveRoomAlias), + httputil.MakeInternalRPCAPI("RoomserverRemoveRoomAlias", enableMetrics, r.RemoveRoomAlias), ) internalAPIMux.Handle( RoomserverQueryCurrentStatePath, - httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", r.QueryCurrentState), + httputil.MakeInternalRPCAPI("RoomserverQueryCurrentState", enableMetrics, r.QueryCurrentState), ) internalAPIMux.Handle( RoomserverQueryRoomsForUserPath, - httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", r.QueryRoomsForUser), + httputil.MakeInternalRPCAPI("RoomserverQueryRoomsForUser", enableMetrics, r.QueryRoomsForUser), ) internalAPIMux.Handle( RoomserverQueryBulkStateContentPath, - httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", r.QueryBulkStateContent), + httputil.MakeInternalRPCAPI("RoomserverQueryBulkStateContent", enableMetrics, r.QueryBulkStateContent), ) internalAPIMux.Handle( RoomserverQuerySharedUsersPath, - httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", r.QuerySharedUsers), + httputil.MakeInternalRPCAPI("RoomserverQuerySharedUsers", enableMetrics, r.QuerySharedUsers), ) internalAPIMux.Handle( RoomserverQueryKnownUsersPath, - httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", r.QueryKnownUsers), + httputil.MakeInternalRPCAPI("RoomserverQueryKnownUsers", enableMetrics, r.QueryKnownUsers), ) internalAPIMux.Handle( RoomserverQueryServerBannedFromRoomPath, - httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", r.QueryServerBannedFromRoom), + httputil.MakeInternalRPCAPI("RoomserverQueryServerBannedFromRoom", enableMetrics, r.QueryServerBannedFromRoom), ) internalAPIMux.Handle( RoomserverQueryAuthChainPath, - httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", r.QueryAuthChain), + httputil.MakeInternalRPCAPI("RoomserverQueryAuthChain", enableMetrics, r.QueryAuthChain), ) internalAPIMux.Handle( RoomserverQueryRestrictedJoinAllowed, - httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", r.QueryRestrictedJoinAllowed), + httputil.MakeInternalRPCAPI("RoomserverQueryRestrictedJoinAllowed", enableMetrics, r.QueryRestrictedJoinAllowed), ) internalAPIMux.Handle( RoomserverQueryMembershipAtEventPath, - httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", r.QueryMembershipAtEvent), + httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent), ) } diff --git a/roomserver/roomserver.go b/roomserver/roomserver.go index 1f707735b..0f6b48bf9 100644 --- a/roomserver/roomserver.go +++ b/roomserver/roomserver.go @@ -16,18 +16,19 @@ package roomserver import ( "github.com/gorilla/mux" + "github.com/sirupsen/logrus" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal" "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" - "github.com/sirupsen/logrus" ) // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI) { - inthttp.AddRoutes(intAPI, router) +func AddInternalRoutes(router *mux.Router, intAPI api.RoomserverInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(intAPI, router, enableMetrics) } // NewInternalAPI returns a concerete implementation of the internal API. Callers diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 661fecfae..f0579079f 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -16,176 +16,177 @@ package inthttp import ( "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" ) // nolint: gocyclo -func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { - addRoutesLoginToken(internalAPIMux, s) +func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics bool) { + addRoutesLoginToken(internalAPIMux, s, enableMetrics) internalAPIMux.Handle( PerformAccountCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", s.PerformAccountCreation), + httputil.MakeInternalRPCAPI("UserAPIPerformAccountCreation", enableMetrics, s.PerformAccountCreation), ) internalAPIMux.Handle( PerformPasswordUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", s.PerformPasswordUpdate), + httputil.MakeInternalRPCAPI("UserAPIPerformPasswordUpdate", enableMetrics, s.PerformPasswordUpdate), ) internalAPIMux.Handle( PerformDeviceCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", s.PerformDeviceCreation), + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceCreation", enableMetrics, s.PerformDeviceCreation), ) internalAPIMux.Handle( PerformLastSeenUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", s.PerformLastSeenUpdate), + httputil.MakeInternalRPCAPI("UserAPIPerformLastSeenUpdate", enableMetrics, s.PerformLastSeenUpdate), ) internalAPIMux.Handle( PerformDeviceUpdatePath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", s.PerformDeviceUpdate), + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceUpdate", enableMetrics, s.PerformDeviceUpdate), ) internalAPIMux.Handle( PerformDeviceDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", s.PerformDeviceDeletion), + httputil.MakeInternalRPCAPI("UserAPIPerformDeviceDeletion", enableMetrics, s.PerformDeviceDeletion), ) internalAPIMux.Handle( PerformAccountDeactivationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", s.PerformAccountDeactivation), + httputil.MakeInternalRPCAPI("UserAPIPerformAccountDeactivation", enableMetrics, s.PerformAccountDeactivation), ) internalAPIMux.Handle( PerformOpenIDTokenCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", s.PerformOpenIDTokenCreation), + httputil.MakeInternalRPCAPI("UserAPIPerformOpenIDTokenCreation", enableMetrics, s.PerformOpenIDTokenCreation), ) internalAPIMux.Handle( QueryProfilePath, - httputil.MakeInternalRPCAPI("UserAPIQueryProfile", s.QueryProfile), + httputil.MakeInternalRPCAPI("UserAPIQueryProfile", enableMetrics, s.QueryProfile), ) internalAPIMux.Handle( QueryAccessTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", s.QueryAccessToken), + httputil.MakeInternalRPCAPI("UserAPIQueryAccessToken", enableMetrics, s.QueryAccessToken), ) internalAPIMux.Handle( QueryDevicesPath, - httputil.MakeInternalRPCAPI("UserAPIQueryDevices", s.QueryDevices), + httputil.MakeInternalRPCAPI("UserAPIQueryDevices", enableMetrics, s.QueryDevices), ) internalAPIMux.Handle( QueryAccountDataPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", s.QueryAccountData), + httputil.MakeInternalRPCAPI("UserAPIQueryAccountData", enableMetrics, s.QueryAccountData), ) internalAPIMux.Handle( QueryDeviceInfosPath, - httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", s.QueryDeviceInfos), + httputil.MakeInternalRPCAPI("UserAPIQueryDeviceInfos", enableMetrics, s.QueryDeviceInfos), ) internalAPIMux.Handle( QuerySearchProfilesPath, - httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", s.QuerySearchProfiles), + httputil.MakeInternalRPCAPI("UserAPIQuerySearchProfiles", enableMetrics, s.QuerySearchProfiles), ) internalAPIMux.Handle( QueryOpenIDTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", s.QueryOpenIDToken), + httputil.MakeInternalRPCAPI("UserAPIQueryOpenIDToken", enableMetrics, s.QueryOpenIDToken), ) internalAPIMux.Handle( InputAccountDataPath, - httputil.MakeInternalRPCAPI("UserAPIInputAccountData", s.InputAccountData), + httputil.MakeInternalRPCAPI("UserAPIInputAccountData", enableMetrics, s.InputAccountData), ) internalAPIMux.Handle( QueryKeyBackupPath, - httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", s.QueryKeyBackup), + httputil.MakeInternalRPCAPI("UserAPIQueryKeyBackup", enableMetrics, s.QueryKeyBackup), ) internalAPIMux.Handle( PerformKeyBackupPath, - httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", s.PerformKeyBackup), + httputil.MakeInternalRPCAPI("UserAPIPerformKeyBackup", enableMetrics, s.PerformKeyBackup), ) internalAPIMux.Handle( QueryNotificationsPath, - httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", s.QueryNotifications), + httputil.MakeInternalRPCAPI("UserAPIQueryNotifications", enableMetrics, s.QueryNotifications), ) internalAPIMux.Handle( PerformPusherSetPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", s.PerformPusherSet), + httputil.MakeInternalRPCAPI("UserAPIPerformPusherSet", enableMetrics, s.PerformPusherSet), ) internalAPIMux.Handle( PerformPusherDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", s.PerformPusherDeletion), + httputil.MakeInternalRPCAPI("UserAPIPerformPusherDeletion", enableMetrics, s.PerformPusherDeletion), ) internalAPIMux.Handle( QueryPushersPath, - httputil.MakeInternalRPCAPI("UserAPIQueryPushers", s.QueryPushers), + httputil.MakeInternalRPCAPI("UserAPIQueryPushers", enableMetrics, s.QueryPushers), ) internalAPIMux.Handle( PerformPushRulesPutPath, - httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", s.PerformPushRulesPut), + httputil.MakeInternalRPCAPI("UserAPIPerformPushRulesPut", enableMetrics, s.PerformPushRulesPut), ) internalAPIMux.Handle( QueryPushRulesPath, - httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", s.QueryPushRules), + httputil.MakeInternalRPCAPI("UserAPIQueryPushRules", enableMetrics, s.QueryPushRules), ) internalAPIMux.Handle( PerformSetAvatarURLPath, - httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", s.SetAvatarURL), + httputil.MakeInternalRPCAPI("UserAPIPerformSetAvatarURL", enableMetrics, s.SetAvatarURL), ) internalAPIMux.Handle( QueryNumericLocalpartPath, - httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", s.QueryNumericLocalpart), + httputil.MakeInternalRPCAPI("UserAPIQueryNumericLocalpart", enableMetrics, s.QueryNumericLocalpart), ) internalAPIMux.Handle( QueryAccountAvailabilityPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", s.QueryAccountAvailability), + httputil.MakeInternalRPCAPI("UserAPIQueryAccountAvailability", enableMetrics, s.QueryAccountAvailability), ) internalAPIMux.Handle( QueryAccountByPasswordPath, - httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", s.QueryAccountByPassword), + httputil.MakeInternalRPCAPI("UserAPIQueryAccountByPassword", enableMetrics, s.QueryAccountByPassword), ) internalAPIMux.Handle( PerformSetDisplayNamePath, - httputil.MakeInternalRPCAPI("UserAPISetDisplayName", s.SetDisplayName), + httputil.MakeInternalRPCAPI("UserAPISetDisplayName", enableMetrics, s.SetDisplayName), ) internalAPIMux.Handle( QueryLocalpartForThreePIDPath, - httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", s.QueryLocalpartForThreePID), + httputil.MakeInternalRPCAPI("UserAPIQueryLocalpartForThreePID", enableMetrics, s.QueryLocalpartForThreePID), ) internalAPIMux.Handle( QueryThreePIDsForLocalpartPath, - httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", s.QueryThreePIDsForLocalpart), + httputil.MakeInternalRPCAPI("UserAPIQueryThreePIDsForLocalpart", enableMetrics, s.QueryThreePIDsForLocalpart), ) internalAPIMux.Handle( PerformForgetThreePIDPath, - httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", s.PerformForgetThreePID), + httputil.MakeInternalRPCAPI("UserAPIPerformForgetThreePID", enableMetrics, s.PerformForgetThreePID), ) internalAPIMux.Handle( PerformSaveThreePIDAssociationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", s.PerformSaveThreePIDAssociation), + httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), ) } diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go index b57348413..dc116428b 100644 --- a/userapi/inthttp/server_logintoken.go +++ b/userapi/inthttp/server_logintoken.go @@ -16,24 +16,25 @@ package inthttp import ( "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/userapi/api" ) // addRoutesLoginToken adds routes for all login token API calls. -func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { +func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics bool) { internalAPIMux.Handle( PerformLoginTokenCreationPath, - httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", s.PerformLoginTokenCreation), + httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenCreation", enableMetrics, s.PerformLoginTokenCreation), ) internalAPIMux.Handle( PerformLoginTokenDeletionPath, - httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", s.PerformLoginTokenDeletion), + httputil.MakeInternalRPCAPI("UserAPIPerformLoginTokenDeletion", enableMetrics, s.PerformLoginTokenDeletion), ) internalAPIMux.Handle( QueryLoginTokenPath, - httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", s.QueryLoginToken), + httputil.MakeInternalRPCAPI("UserAPIQueryLoginToken", enableMetrics, s.QueryLoginToken), ) } diff --git a/userapi/userapi.go b/userapi/userapi.go index e46a8e76e..183ca3123 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -37,8 +37,8 @@ import ( // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. -func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { - inthttp.AddRoutes(router, intAPI) +func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI, enableMetrics bool) { + inthttp.AddRoutes(router, intAPI, enableMetrics) } // NewInternalAPI returns a concerete implementation of the internal API. Callers diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 25fa75ee2..60dd730fd 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -144,7 +144,7 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI) + userapi.AddInternalRoutes(router, userAPI, false) apiURL, cancel := test.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) From 07e8ed13f61fb04f6cac5a6d42263a9ddba49d32 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 15:09:59 +0100 Subject: [PATCH 16/67] Fix CI and test.WithAllDatabases --- appservice/appservice_test.go | 79 +++++++++++++++++------------------ 1 file changed, 39 insertions(+), 40 deletions(-) diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 5a3a9aef7..72910d8d1 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -77,32 +77,6 @@ func TestAppserviceInternalAPI(t *testing.T) { } })) - // TODO: use test.WithAllDatabases - // only one DBType, since appservice.AddInternalRoutes complains about multiple prometheus counters added - base, closeBase := testrig.CreateBaseDendrite(t, test.DBTypeSQLite) - defer closeBase() - - // Create a dummy application service - base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ - { - ID: "someID", - URL: srv.URL, - ASToken: "", - HSToken: "", - SenderLocalpart: "senderLocalPart", - NamespaceMap: map[string][]config.ApplicationServiceNamespace{ - "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, - "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, - }, - Protocols: []string{existingProtocol}, - }, - } - - // Create required internal APIs - rsAPI := roomserver.NewInternalAPI(base) - usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) - asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) - // The test cases to run runCases := func(t *testing.T, testAPI api.AppServiceInternalAPI) { t.Run("UserIDExists", func(t *testing.T) { @@ -133,24 +107,49 @@ func TestAppserviceInternalAPI(t *testing.T) { }) } - // Finally execute the tests - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - appservice.AddInternalRoutes(router, asAPI) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, closeBase := testrig.CreateBaseDendrite(t, test.DBTypeSQLite) + defer closeBase() - asHTTPApi, err := inthttp.NewAppserviceClient(apiURL, &http.Client{}) - if err != nil { - t.Fatalf("failed to create HTTP client: %s", err) + // Create a dummy application service + base.Cfg.AppServiceAPI.Derived.ApplicationServices = []config.ApplicationService{ + { + ID: "someID", + URL: srv.URL, + ASToken: "", + HSToken: "", + SenderLocalpart: "senderLocalPart", + NamespaceMap: map[string][]config.ApplicationServiceNamespace{ + "users": {{RegexpObject: regexp.MustCompile("as-.*")}}, + "aliases": {{RegexpObject: regexp.MustCompile("asroom-.*")}}, + }, + Protocols: []string{existingProtocol}, + }, } - runCases(t, asHTTPApi) - }) - t.Run("Monolith", func(t *testing.T) { - runCases(t, asAPI) - }) + // Create required internal APIs + rsAPI := roomserver.NewInternalAPI(base) + usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, usrAPI, rsAPI) + // Finally execute the tests + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + appservice.AddInternalRoutes(router, asAPI, base.EnableMetrics) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + asHTTPApi, err := inthttp.NewAppserviceClient(apiURL, &http.Client{}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + runCases(t, asHTTPApi) + }) + + t.Run("Monolith", func(t *testing.T) { + runCases(t, asAPI) + }) + }) } func testUserIDExists(t *testing.T, asAPI api.AppServiceInternalAPI, userID string, wantExists bool) { From 0e6d94757b81b8e52479706ee5f857fc34023a5b Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 15:24:36 +0100 Subject: [PATCH 17/67] Enforce coverage --- .github/codecov.yaml | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 .github/codecov.yaml diff --git a/.github/codecov.yaml b/.github/codecov.yaml new file mode 100644 index 000000000..e6a38a8be --- /dev/null +++ b/.github/codecov.yaml @@ -0,0 +1,13 @@ +flag_management: + default_rules: + carryforward: true + +coverage: + status: + project: + default: + target: 75% + threshold: 0% + base: auto + flags: + - unittests \ No newline at end of file From 3dc06bea81d58d26a2d52ffb04ceb04a8c70e928 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 15:49:11 +0100 Subject: [PATCH 18/67] Differentiate between project and patch --- .github/codecov.yaml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/codecov.yaml b/.github/codecov.yaml index e6a38a8be..78122c990 100644 --- a/.github/codecov.yaml +++ b/.github/codecov.yaml @@ -5,6 +5,13 @@ flag_management: coverage: status: project: + default: + target: auto + threshold: 0% + base: auto + flags: + - unittests + patch: default: target: 75% threshold: 0% From b99349b18c28a1c27b5bd5df30853a3b7c689d02 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Mon, 5 Dec 2022 16:00:02 +0100 Subject: [PATCH 19/67] Use test.WithAllDatabases --- appservice/appservice_test.go | 2 +- userapi/userapi_test.go | 55 ++++++++++++++++++----------------- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/appservice/appservice_test.go b/appservice/appservice_test.go index 72910d8d1..83c551fea 100644 --- a/appservice/appservice_test.go +++ b/appservice/appservice_test.go @@ -108,7 +108,7 @@ func TestAppserviceInternalAPI(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, closeBase := testrig.CreateBaseDendrite(t, test.DBTypeSQLite) + base, closeBase := testrig.CreateBaseDendrite(t, dbType) defer closeBase() // Create a dummy application service diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 60dd730fd..8a19af195 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -27,14 +27,13 @@ import ( "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" - "github.com/matrix-org/dendrite/userapi/inthttp" - - "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/internal" + "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage" ) @@ -79,19 +78,6 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - // only one DBType, since userapi.AddInternalRoutes complains about multiple prometheus counters added - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, test.DBTypeSQLite) - defer close() - _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) - if err != nil { - t.Fatalf("failed to make account: %s", err) - } - if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { - t.Fatalf("failed to set avatar url: %s", err) - } - if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { - t.Fatalf("failed to set display name: %s", err) - } testCases := []struct { req api.QueryProfileRequest @@ -142,19 +128,34 @@ func TestQueryProfile(t *testing.T) { } } - t.Run("HTTP API", func(t *testing.T) { - router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI, false) - apiURL, cancel := test.ListenAndServe(t, router, false) - defer cancel() - httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) if err != nil { - t.Fatalf("failed to create HTTP client") + t.Fatalf("failed to make account: %s", err) } - runCases(httpAPI, true) - }) - t.Run("Monolith", func(t *testing.T) { - runCases(userAPI, false) + if _, _, err := accountDB.SetAvatarURL(context.TODO(), "alice", serverName, aliceAvatarURL); err != nil { + t.Fatalf("failed to set avatar url: %s", err) + } + if _, _, err := accountDB.SetDisplayName(context.TODO(), "alice", serverName, aliceDisplayName); err != nil { + t.Fatalf("failed to set display name: %s", err) + } + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, userAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) + if err != nil { + t.Fatalf("failed to create HTTP client") + } + runCases(httpAPI, true) + }) + t.Run("Monolith", func(t *testing.T) { + runCases(userAPI, false) + }) }) } From 75834783055b0c70f8b411d9c3741e57461832f0 Mon Sep 17 00:00:00 2001 From: kegsay Date: Mon, 5 Dec 2022 16:54:01 +0000 Subject: [PATCH 20/67] Update contributing guidelines (#2904) --- docs/CONTRIBUTING.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/docs/CONTRIBUTING.md b/docs/CONTRIBUTING.md index 262a93a7c..21b0e7ab1 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/CONTRIBUTING.md @@ -9,6 +9,28 @@ permalink: /development/contributing Everyone is welcome to contribute to Dendrite! We aim to make it as easy as possible to get started. + ## Contribution types + +We are a small team maintaining a large project. As a result, we cannot merge every feature, even if it +is bug-free and useful, because we then commit to maintaining it indefinitely. We will always accept: + - bug fixes + - security fixes (please responsibly disclose via security@matrix.org *before* creating pull requests) + +We will accept the following with caveats: + - documentation fixes, provided they do not add additional instructions which can end up going out-of-date, + e.g example configs, shell commands. + - performance fixes, provided they do not add significantly more maintenance burden. + - additional functionality on existing features, provided the functionality is small and maintainable. + - additional functionality that, in its absence, would impact the ecosystem e.g spam and abuse mitigations + - test-only changes, provided they help improve coverage or test tricky code. + +The following items are at risk of not being accepted: + - Configuration or CLI changes, particularly ones which increase the overall configuration surface. + +The following items are unlikely to be accepted into a main Dendrite release for now: + - New MSC implementations. + - New features which are not in the specification. + ## Sign off We require that everyone who contributes to the project signs off their contributions From ded43e0f2d07adc399da5962454d9873021b59ac Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Tue, 6 Dec 2022 13:27:33 +0100 Subject: [PATCH 21/67] Fix issue with sending presence events to invalid servers --- federationapi/consumers/roomserver.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index d16af6626..0c1080afa 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -232,7 +232,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } func (s *OutputRoomEventConsumer) sendPresence(roomID string, addedJoined []types.JoinedHost) { - joined := make([]gomatrixserverlib.ServerName, len(addedJoined)) + joined := make([]gomatrixserverlib.ServerName, 0, len(addedJoined)) for _, added := range addedJoined { joined = append(joined, added.ServerName) } From ba2ffb7da9b86b64dc8091d90645ab80cf2831db Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 6 Dec 2022 18:16:17 +0000 Subject: [PATCH 22/67] Repeatable reads for `/sync` (#2783) This puts repeatable reads into all sync streams. Co-authored-by: kegsay --- syncapi/storage/shared/storage_consumer.go | 42 +++++++++------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index 23f53d11f..f2064fb89 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -57,31 +57,23 @@ type Database struct { } func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) { - return d.NewDatabaseTransaction(ctx) - - /* - TODO: Repeatable read is probably the right thing to do here, - but it seems to cause some problems with the invite tests, so - need to investigate that further. - - txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ - // Set the isolation level so that we see a snapshot of the database. - // In PostgreSQL repeatable read transactions will see a snapshot taken - // at the first query, and since the transaction is read-only it can't - // run into any serialisation errors. - // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ - Isolation: sql.LevelRepeatableRead, - ReadOnly: true, - }) - if err != nil { - return nil, err - } - return &DatabaseTransaction{ - Database: d, - ctx: ctx, - txn: txn, - }, nil - */ + txn, err := d.DB.BeginTx(ctx, &sql.TxOptions{ + // Set the isolation level so that we see a snapshot of the database. + // In PostgreSQL repeatable read transactions will see a snapshot taken + // at the first query, and since the transaction is read-only it can't + // run into any serialisation errors. + // https://www.postgresql.org/docs/9.5/static/transaction-iso.html#XACT-REPEATABLE-READ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + }) + if err != nil { + return nil, err + } + return &DatabaseTransaction{ + Database: d, + ctx: ctx, + txn: txn, + }, nil } func (d *Database) NewDatabaseTransaction(ctx context.Context) (*DatabaseTransaction, error) { From 27a1dea5225c828ad787098aadafa83ed73c4940 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 8 Dec 2022 08:24:06 +0100 Subject: [PATCH 23/67] Fix issue with multiple/duplicate log entries during tests (#2906) --- internal/log.go | 5 +++++ internal/log_unix.go | 13 +++++++------ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/internal/log.go b/internal/log.go index a171555ab..d7e852c81 100644 --- a/internal/log.go +++ b/internal/log.go @@ -33,6 +33,11 @@ import ( "github.com/matrix-org/dendrite/setup/config" ) +// logrus is using a global variable when we're using `logrus.AddHook` +// this unfortunately results in us adding the same hook multiple times. +// This map ensures we only ever add one level hook. +var stdLevelLogAdded = make(map[logrus.Level]bool) + type utcFormatter struct { logrus.Formatter } diff --git a/internal/log_unix.go b/internal/log_unix.go index 75332af73..b38e7c2e8 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -22,16 +22,16 @@ import ( "log/syslog" "github.com/MFAshby/stdemuxerhook" - "github.com/matrix-org/dendrite/setup/config" "github.com/sirupsen/logrus" lSyslog "github.com/sirupsen/logrus/hooks/syslog" + + "github.com/matrix-org/dendrite/setup/config" ) // SetupHookLogging configures the logging hooks defined in the configuration. // If something fails here it means that the logging was improperly configured, // so we just exit with the error func SetupHookLogging(hooks []config.LogrusHook, componentName string) { - stdLogAdded := false for _, hook := range hooks { // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) @@ -54,14 +54,11 @@ func SetupHookLogging(hooks []config.LogrusHook, componentName string) { setupSyslogHook(hook, level, componentName) case "std": setupStdLogHook(level) - stdLogAdded = true default: logrus.Fatalf("Unrecognised logging hook type: %s", hook.Type) } } - if !stdLogAdded { - setupStdLogHook(logrus.InfoLevel) - } + setupStdLogHook(logrus.InfoLevel) // Hooks are now configured for stdout/err, so throw away the default logger output logrus.SetOutput(io.Discard) } @@ -88,7 +85,11 @@ func checkSyslogHookParams(params map[string]interface{}) { } func setupStdLogHook(level logrus.Level) { + if stdLevelLogAdded[level] { + return + } logrus.AddHook(&logLevelHook{level, stdemuxerhook.New(logrus.StandardLogger())}) + stdLevelLogAdded[level] = true } func setupSyslogHook(hook config.LogrusHook, level logrus.Level, componentName string) { From 0351618ff4e7d569e14a165be59a1a7e9e979684 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 8 Dec 2022 08:24:24 +0100 Subject: [PATCH 24/67] Add UserAPI util tests (#2907) This adds some `userapi/util` tests. --- test/testrig/base.go | 2 +- userapi/util/notify_test.go | 119 ++++++++++++++++++++++++++++ userapi/util/phonehomestats.go | 10 +-- userapi/util/phonehomestats_test.go | 84 ++++++++++++++++++++ 4 files changed, 208 insertions(+), 7 deletions(-) create mode 100644 userapi/util/notify_test.go create mode 100644 userapi/util/phonehomestats_test.go diff --git a/test/testrig/base.go b/test/testrig/base.go index 15fb5c370..7bc26a5c5 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -108,7 +108,7 @@ func Base(cfg *config.Dendrite) (*base.BaseDendrite, nats.JetStreamContext, *nat cfg.Global.JetStream.InMemory = true cfg.SyncAPI.Fulltext.InMemory = true cfg.FederationAPI.KeyPerspectives = nil - base := base.NewBaseDendrite(cfg, "Tests") + base := base.NewBaseDendrite(cfg, "Tests", base.DisableMetrics) js, jc := base.NATS.Prepare(base.ProcessContext, &cfg.Global.JetStream) return base, js, jc } diff --git a/userapi/util/notify_test.go b/userapi/util/notify_test.go new file mode 100644 index 000000000..f1d20259c --- /dev/null +++ b/userapi/util/notify_test.go @@ -0,0 +1,119 @@ +package util_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/internal/pushgateway" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/userapi/storage" + userUtil "github.com/matrix-org/dendrite/userapi/util" +) + +func TestNotifyUserCountsAsync(t *testing.T) { + alice := test.NewUser(t) + aliceLocalpart, serverName, err := gomatrixserverlib.SplitID('@', alice.ID) + if err != nil { + t.Error(err) + } + ctx := context.Background() + + // Create a test room, just used to provide events + room := test.NewRoom(t, alice) + dummyEvent := room.Events()[len(room.Events())-1] + + appID := util.RandomString(8) + pushKey := util.RandomString(8) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + receivedRequest := make(chan bool, 1) + // create a test server which responds to our /notify call + srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data pushgateway.NotifyRequest + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Error(err) + } + notification := data.Notification + // Validate the request + if notification.Counts == nil { + t.Fatal("no unread notification counts in request") + } + if unread := notification.Counts.Unread; unread != 1 { + t.Errorf("expected one unread notification, got %d", unread) + } + + if len(notification.Devices) == 0 { + t.Fatal("expected devices in request") + } + + // We only created one push device, so access it directly + device := notification.Devices[0] + if device.AppID != appID { + t.Errorf("unexpected app_id: %s, want %s", device.AppID, appID) + } + if device.PushKey != pushKey { + t.Errorf("unexpected push_key: %s, want %s", device.PushKey, pushKey) + } + + // Return empty result, otherwise the call is handled as failed + if _, err := w.Write([]byte("{}")); err != nil { + t.Error(err) + } + close(receivedRequest) + })) + defer srv.Close() + + // Create DB and Dendrite base + connStr, close := test.PrepareDBConnectionString(t, dbType) + defer close() + base, _, _ := testrig.Base(nil) + defer base.Close() + db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "test", bcrypt.MinCost, 0, 0, "") + if err != nil { + t.Error(err) + } + + // Prepare pusher with our test server URL + if err := db.UpsertPusher(ctx, api.Pusher{ + Kind: api.HTTPKind, + AppID: appID, + PushKey: pushKey, + Data: map[string]interface{}{ + "url": srv.URL, + }, + }, aliceLocalpart, serverName); err != nil { + t.Error(err) + } + + // Insert a dummy event + if err := db.InsertNotification(ctx, aliceLocalpart, serverName, dummyEvent.EventID(), 0, nil, &api.Notification{ + Event: gomatrixserverlib.HeaderedToClientEvent(dummyEvent, gomatrixserverlib.FormatAll), + }); err != nil { + t.Error(err) + } + + // Notify the user about a new notification + if err := userUtil.NotifyUserCountsAsync(ctx, pushgateway.NewHTTPClient(true), aliceLocalpart, serverName, db); err != nil { + t.Error(err) + } + select { + case <-time.After(time.Second * 5): + t.Error("timed out waiting for response") + case <-receivedRequest: + } + }) + +} diff --git a/userapi/util/phonehomestats.go b/userapi/util/phonehomestats.go index 6f36568c9..42c8f5d7c 100644 --- a/userapi/util/phonehomestats.go +++ b/userapi/util/phonehomestats.go @@ -97,12 +97,10 @@ func (p *phoneHomeStats) collect() { // configuration information p.stats["federation_disabled"] = p.cfg.Global.DisableFederation - p.stats["nats_embedded"] = true - p.stats["nats_in_memory"] = p.cfg.Global.JetStream.InMemory - if len(p.cfg.Global.JetStream.Addresses) > 0 { - p.stats["nats_embedded"] = false - p.stats["nats_in_memory"] = false // probably - } + natsEmbedded := len(p.cfg.Global.JetStream.Addresses) == 0 + p.stats["nats_embedded"] = natsEmbedded + p.stats["nats_in_memory"] = p.cfg.Global.JetStream.InMemory && natsEmbedded + if len(p.cfg.Logging) > 0 { p.stats["log_level"] = p.cfg.Logging[0].Level } else { diff --git a/userapi/util/phonehomestats_test.go b/userapi/util/phonehomestats_test.go new file mode 100644 index 000000000..6e62210e8 --- /dev/null +++ b/userapi/util/phonehomestats_test.go @@ -0,0 +1,84 @@ +package util + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/crypto/bcrypt" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi/storage" +) + +func TestCollect(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + b, _, _ := testrig.Base(nil) + connStr, closeDB := test.PrepareDBConnectionString(t, dbType) + defer closeDB() + db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, "localhost", bcrypt.MinCost, 1000, 1000, "") + if err != nil { + t.Error(err) + } + + receivedRequest := make(chan struct{}, 1) + // create a test server which responds to our call + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var data map[string]interface{} + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + t.Error(err) + } + defer r.Body.Close() + if _, err := w.Write([]byte("{}")); err != nil { + t.Error(err) + } + + // verify the received data matches our expectations + dbEngine, ok := data["database_engine"] + if !ok { + t.Errorf("missing database_engine in JSON request: %+v", data) + } + version, ok := data["version"] + if !ok { + t.Errorf("missing version in JSON request: %+v", data) + } + if version != internal.VersionString() { + t.Errorf("unexpected version: %q, expected %q", version, internal.VersionString()) + } + switch { + case dbType == test.DBTypeSQLite && dbEngine != "SQLite": + t.Errorf("unexpected database_engine: %s", dbEngine) + case dbType == test.DBTypePostgres && dbEngine != "Postgres": + t.Errorf("unexpected database_engine: %s", dbEngine) + } + close(receivedRequest) + })) + defer srv.Close() + + b.Cfg.Global.ReportStats.Endpoint = srv.URL + stats := phoneHomeStats{ + prevData: timestampToRUUsage{}, + serverName: "localhost", + startTime: time.Now(), + cfg: b.Cfg, + db: db, + isMonolith: false, + client: &http.Client{Timeout: time.Second}, + } + + stats.collect() + + select { + case <-time.After(time.Second * 5): + t.Error("timed out waiting for response") + case <-receivedRequest: + } + }) +} From c136a450d5196cf22a91419f493bb73c29481122 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 8 Dec 2022 08:25:03 +0100 Subject: [PATCH 25/67] Fix newly joined users presence (#2854) Fixes #2803 Also refactors the presence stream to not hit the database for every user, instead queries all users at once now. --- syncapi/consumers/presence.go | 12 +- syncapi/storage/interface.go | 4 +- syncapi/storage/postgres/presence_table.go | 38 +++-- syncapi/storage/shared/storage_consumer.go | 4 +- syncapi/storage/shared/storage_sync.go | 4 +- syncapi/storage/sqlite3/presence_table.go | 48 +++++-- syncapi/storage/tables/interface.go | 2 +- syncapi/storage/tables/presence_table_test.go | 136 ++++++++++++++++++ syncapi/streams/stream_presence.go | 80 +++++++---- syncapi/sync/requestpool.go | 6 +- syncapi/sync/requestpool_test.go | 4 +- 11 files changed, 263 insertions(+), 75 deletions(-) create mode 100644 syncapi/storage/tables/presence_table_test.go diff --git a/syncapi/consumers/presence.go b/syncapi/consumers/presence.go index 145059c2d..6e3150c29 100644 --- a/syncapi/consumers/presence.go +++ b/syncapi/consumers/presence.go @@ -78,7 +78,7 @@ func (s *PresenceConsumer) Start() error { // Normal NATS subscription, used by Request/Reply _, err := s.nats.Subscribe(s.requestTopic, func(msg *nats.Msg) { userID := msg.Header.Get(jetstream.UserID) - presence, err := s.db.GetPresence(context.Background(), userID) + presences, err := s.db.GetPresences(context.Background(), []string{userID}) m := &nats.Msg{ Header: nats.Header{}, } @@ -89,10 +89,12 @@ func (s *PresenceConsumer) Start() error { } return } - if presence == nil { - presence = &types.PresenceInternal{ - UserID: userID, - } + + presence := &types.PresenceInternal{ + UserID: userID, + } + if len(presences) > 0 { + presence = presences[0] } deviceRes := api.QueryDevicesResponse{} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 97c2ced49..75afbce15 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -106,7 +106,7 @@ type DatabaseTransaction interface { SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) // getUserUnreadNotificationCountsForRooms returns the unread notifications for the given rooms GetUserUnreadNotificationCountsForRooms(ctx context.Context, userID string, roomIDs map[string]string) (map[string]*eventutil.NotificationData, error) - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (events []types.StreamEvent, prevBatch, nextBatch string, err error) } @@ -186,7 +186,7 @@ type Database interface { } type Presence interface { - GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) + GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) UpdatePresence(ctx context.Context, userID string, presence types.Presence, statusMsg *string, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (types.StreamPosition, error) } diff --git a/syncapi/storage/postgres/presence_table.go b/syncapi/storage/postgres/presence_table.go index 7194afea6..a3f7c5213 100644 --- a/syncapi/storage/postgres/presence_table.go +++ b/syncapi/storage/postgres/presence_table.go @@ -19,10 +19,12 @@ import ( "database/sql" "time" + "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -63,9 +65,9 @@ const upsertPresenceFromSyncSQL = "" + " RETURNING id" const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id = ANY($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -119,20 +121,28 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, - } + userIDs []string, +) ([]*types.PresenceInternal, error) { + result := make([]*types.PresenceInternal, 0, len(userIDs)) stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + rows, err := stmt.QueryContext(ctx, pq.Array(userIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index f2064fb89..df2338cf8 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -564,8 +564,8 @@ func (d *Database) UpdatePresence(ctx context.Context, userID string, presence t return pos, err } -func (d *Database) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, nil, userID) +func (d *Database) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, nil, userIDs) } func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) { diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index c3763521c..77afa0290 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -596,8 +596,8 @@ func (d *DatabaseTransaction) GetUserUnreadNotificationCountsForRooms(ctx contex return d.NotificationData.SelectUserUnreadCountsForRooms(ctx, d.txn, userID, roomIDs) } -func (d *DatabaseTransaction) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return d.Presence.GetPresenceForUser(ctx, d.txn, userID) +func (d *DatabaseTransaction) GetPresences(ctx context.Context, userIDs []string) ([]*types.PresenceInternal, error) { + return d.Presence.GetPresenceForUsers(ctx, d.txn, userIDs) } func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { diff --git a/syncapi/storage/sqlite3/presence_table.go b/syncapi/storage/sqlite3/presence_table.go index b61a825df..7641de92f 100644 --- a/syncapi/storage/sqlite3/presence_table.go +++ b/syncapi/storage/sqlite3/presence_table.go @@ -17,12 +17,14 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const presenceSchema = ` @@ -62,9 +64,9 @@ const upsertPresenceFromSyncSQL = "" + " RETURNING id" const selectPresenceForUserSQL = "" + - "SELECT presence, status_msg, last_active_ts" + + "SELECT user_id, presence, status_msg, last_active_ts" + " FROM syncapi_presence" + - " WHERE user_id = $1 LIMIT 1" + " WHERE user_id IN ($1)" const selectMaxPresenceSQL = "" + "SELECT COALESCE(MAX(id), 0) FROM syncapi_presence" @@ -134,20 +136,38 @@ func (p *presenceStatements) UpsertPresence( return } -// GetPresenceForUser returns the current presence of a user. -func (p *presenceStatements) GetPresenceForUser( +// GetPresenceForUsers returns the current presence for a list of users. +// If the user doesn't have a presence status yet, it is omitted from the response. +func (p *presenceStatements) GetPresenceForUsers( ctx context.Context, txn *sql.Tx, - userID string, -) (*types.PresenceInternal, error) { - result := &types.PresenceInternal{ - UserID: userID, + userIDs []string, +) ([]*types.PresenceInternal, error) { + qry := strings.Replace(selectPresenceForUserSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + prepStmt, err := p.db.Prepare(qry) + if err != nil { + return nil, err } - stmt := sqlutil.TxStmt(txn, p.selectPresenceForUsersStmt) - err := stmt.QueryRowContext(ctx, userID).Scan(&result.Presence, &result.ClientFields.StatusMsg, &result.LastActiveTS) - if err == sql.ErrNoRows { - return nil, nil + defer internal.CloseAndLogIfError(ctx, prepStmt, "GetPresenceForUsers: stmt.close() failed") + + params := make([]interface{}, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + rows, err := sqlutil.TxStmt(txn, prepStmt).QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "GetPresenceForUsers: rows.close() failed") + result := make([]*types.PresenceInternal, 0, len(userIDs)) + for rows.Next() { + presence := &types.PresenceInternal{} + if err = rows.Scan(&presence.UserID, &presence.Presence, &presence.ClientFields.StatusMsg, &presence.LastActiveTS); err != nil { + return nil, err + } + presence.ClientFields.Presence = presence.Presence.String() + result = append(result, presence) } - result.ClientFields.Presence = result.Presence.String() return result, err } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 2c4f04ec2..a0574b257 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -207,7 +207,7 @@ type Ignores interface { type Presence interface { UpsertPresence(ctx context.Context, txn *sql.Tx, userID string, statusMsg *string, presence types.Presence, lastActiveTS gomatrixserverlib.Timestamp, fromSync bool) (pos types.StreamPosition, err error) - GetPresenceForUser(ctx context.Context, txn *sql.Tx, userID string) (presence *types.PresenceInternal, err error) + GetPresenceForUsers(ctx context.Context, txn *sql.Tx, userIDs []string) (presence []*types.PresenceInternal, err error) GetMaxPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) GetPresenceAfter(ctx context.Context, txn *sql.Tx, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (presences map[string]*types.PresenceInternal, err error) } diff --git a/syncapi/storage/tables/presence_table_test.go b/syncapi/storage/tables/presence_table_test.go new file mode 100644 index 000000000..dce0c695a --- /dev/null +++ b/syncapi/storage/tables/presence_table_test.go @@ -0,0 +1,136 @@ +package tables_test + +import ( + "context" + "database/sql" + "reflect" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi/storage/postgres" + "github.com/matrix-org/dendrite/syncapi/storage/sqlite3" + "github.com/matrix-org/dendrite/syncapi/storage/tables" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" +) + +func mustPresenceTable(t *testing.T, dbType test.DBType) (tables.Presence, func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open db: %s", err) + } + + var tab tables.Presence + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresPresenceTable(db) + case test.DBTypeSQLite: + var stream sqlite3.StreamIDStatements + if err = stream.Prepare(db); err != nil { + t.Fatalf("failed to prepare stream stmts: %s", err) + } + tab, err = sqlite3.NewSqlitePresenceTable(db, &stream) + } + if err != nil { + t.Fatalf("failed to make new table: %s", err) + } + return tab, close +} + +func TestPresence(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + ctx := context.Background() + + statusMsg := "Hello World!" + timestamp := gomatrixserverlib.AsTimestamp(time.Now()) + + var txn *sql.Tx + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustPresenceTable(t, dbType) + defer closeDB() + + // Insert some presences + pos, err := tab.UpsertPresence(ctx, txn, alice.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos := types.StreamPosition(1) + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, false) + if err != nil { + t.Error(err) + } + wantPos = 2 + if pos != wantPos { + t.Errorf("expected pos to be %d, got %d", wantPos, pos) + } + + // verify the expected max presence ID + maxPos, err := tab.GetMaxPresenceID(ctx, txn) + if err != nil { + t.Error(err) + } + if maxPos != wantPos { + t.Errorf("expected max pos to be %d, got %d", wantPos, maxPos) + } + + // This should increment the position + pos, err = tab.UpsertPresence(ctx, txn, bob.ID, &statusMsg, types.PresenceOnline, timestamp, true) + if err != nil { + t.Error(err) + } + wantPos = pos + if wantPos <= maxPos { + t.Errorf("expected pos to be %d incremented, got %d", wantPos, pos) + } + + // This should return only Bobs status + presences, err := tab.GetPresenceAfter(ctx, txn, maxPos, gomatrixserverlib.EventFilter{Limit: 10}) + if err != nil { + t.Error(err) + } + + if c := len(presences); c > 1 { + t.Errorf("expected only one presence, got %d", c) + } + + // Validate the response + wantPresence := &types.PresenceInternal{ + UserID: bob.ID, + Presence: types.PresenceOnline, + StreamPos: wantPos, + LastActiveTS: timestamp, + ClientFields: types.PresenceClientResponse{ + LastActiveAgo: 0, + Presence: types.PresenceOnline.String(), + StatusMsg: &statusMsg, + }, + } + if !reflect.DeepEqual(wantPresence, presences[bob.ID]) { + t.Errorf("unexpected presence result:\n%+v, want\n%+v", presences[bob.ID], wantPresence) + } + + // Try getting presences for existing and non-existing users + getUsers := []string{alice.ID, bob.ID, "@doesntexist:test"} + presencesForUsers, err := tab.GetPresenceForUsers(ctx, nil, getUsers) + if err != nil { + t.Error(err) + } + + if len(presencesForUsers) >= len(getUsers) { + t.Errorf("expected less presences, but they are the same/more as requested: %d >= %d", len(presencesForUsers), len(getUsers)) + } + }) + +} diff --git a/syncapi/streams/stream_presence.go b/syncapi/streams/stream_presence.go index 030b7c5d5..445e46b3a 100644 --- a/syncapi/streams/stream_presence.go +++ b/syncapi/streams/stream_presence.go @@ -17,6 +17,7 @@ package streams import ( "context" "encoding/json" + "fmt" "sync" "github.com/matrix-org/gomatrixserverlib" @@ -70,39 +71,25 @@ func (p *PresenceStreamProvider) IncrementalSync( return from } - if len(presences) == 0 { + getPresenceForUsers, err := p.getNeededUsersFromRequest(ctx, req, presences) + if err != nil { + req.Log.WithError(err).Error("getNeededUsersFromRequest failed") + return from + } + + // Got no presence between range and no presence to get from the database + if len(getPresenceForUsers) == 0 && len(presences) == 0 { return to } - // add newly joined rooms user presences - newlyJoined := joinedRooms(req.Response, req.Device.UserID) - if len(newlyJoined) > 0 { - // TODO: Check if this is working better than before. - if err = p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { - req.Log.WithError(err).Error("unable to refresh notifier lists") - return from - } - NewlyJoinedLoop: - for _, roomID := range newlyJoined { - roomUsers := p.notifier.JoinedUsers(roomID) - for i := range roomUsers { - // we already got a presence from this user - if _, ok := presences[roomUsers[i]]; ok { - continue - } - // Bear in mind that this might return nil, but at least populating - // a nil means that there's a map entry so we won't repeat this call. - presences[roomUsers[i]], err = snapshot.GetPresence(ctx, roomUsers[i]) - if err != nil { - req.Log.WithError(err).Error("unable to query presence for user") - _ = snapshot.Rollback() - return from - } - if len(presences) > req.Filter.Presence.Limit { - break NewlyJoinedLoop - } - } - } + dbPresences, err := snapshot.GetPresences(ctx, getPresenceForUsers) + if err != nil { + req.Log.WithError(err).Error("unable to query presence for user") + _ = snapshot.Rollback() + return from + } + for _, presence := range dbPresences { + presences[presence.UserID] = presence } lastPos := from @@ -164,6 +151,39 @@ func (p *PresenceStreamProvider) IncrementalSync( return lastPos } +func (p *PresenceStreamProvider) getNeededUsersFromRequest(ctx context.Context, req *types.SyncRequest, presences map[string]*types.PresenceInternal) ([]string, error) { + getPresenceForUsers := []string{} + // Add presence for users which newly joined a room + for userID := range req.MembershipChanges { + if _, ok := presences[userID]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, userID) + } + + // add newly joined rooms user presences + newlyJoined := joinedRooms(req.Response, req.Device.UserID) + if len(newlyJoined) == 0 { + return getPresenceForUsers, nil + } + + // TODO: Check if this is working better than before. + if err := p.notifier.LoadRooms(ctx, p.DB, newlyJoined); err != nil { + return getPresenceForUsers, fmt.Errorf("unable to refresh notifier lists: %w", err) + } + for _, roomID := range newlyJoined { + roomUsers := p.notifier.JoinedUsers(roomID) + for i := range roomUsers { + // we already got a presence from this user + if _, ok := presences[roomUsers[i]]; ok { + continue + } + getPresenceForUsers = append(getPresenceForUsers, roomUsers[i]) + } + } + return getPresenceForUsers, nil +} + func joinedRooms(res *types.Response, userID string) []string { var roomIDs []string for roomID, join := range res.Rooms.Join { diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index 29d92b293..b086567b8 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -145,12 +145,12 @@ func (rp *RequestPool) updatePresence(db storage.Presence, presence string, user } // ensure we also send the current status_msg to federated servers and not nil - dbPresence, err := db.GetPresence(context.Background(), userID) + dbPresence, err := db.GetPresences(context.Background(), []string{userID}) if err != nil && err != sql.ErrNoRows { return } - if dbPresence != nil { - newPresence.ClientFields = dbPresence.ClientFields + if len(dbPresence) > 0 && dbPresence[0] != nil { + newPresence.ClientFields = dbPresence[0].ClientFields } newPresence.ClientFields.Presence = presenceID.String() diff --git a/syncapi/sync/requestpool_test.go b/syncapi/sync/requestpool_test.go index 3e5769d8c..faa0b49c6 100644 --- a/syncapi/sync/requestpool_test.go +++ b/syncapi/sync/requestpool_test.go @@ -29,8 +29,8 @@ func (d dummyDB) UpdatePresence(ctx context.Context, userID string, presence typ return 0, nil } -func (d dummyDB) GetPresence(ctx context.Context, userID string) (*types.PresenceInternal, error) { - return &types.PresenceInternal{}, nil +func (d dummyDB) GetPresences(ctx context.Context, userID []string) ([]*types.PresenceInternal, error) { + return []*types.PresenceInternal{}, nil } func (d dummyDB) PresenceAfter(ctx context.Context, after types.StreamPosition, filter gomatrixserverlib.EventFilter) (map[string]*types.PresenceInternal, error) { From 8846de7312d447cd2866236493b6ae172fbebfa9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 8 Dec 2022 10:19:55 +0000 Subject: [PATCH 26/67] Bump nokogiri from 1.13.9 to 1.13.10 in /docs (#2909) Bumps [nokogiri](https://github.com/sparklemotion/nokogiri) from 1.13.9 to 1.13.10.
Release notes

Sourced from nokogiri's releases.

1.13.10 / 2022-12-07

Security

  • [CRuby] Address CVE-2022-23476, unchecked return value from xmlTextReaderExpand. See GHSA-qv4q-mr5r-qprj for more information.

Improvements

  • [CRuby] XML::Reader#attribute_hash now returns nil on parse errors. This restores the behavior of #attributes from v1.13.7 and earlier. [#2715]

sha256 checksums:

777ce2e80f64772e91459b943e531dfef387e768f2255f9bc7a1655f254bbaa1
nokogiri-1.13.10-aarch64-linux.gem
b432ff47c51386e07f7e275374fe031c1349e37eaef2216759063bc5fa5624aa
nokogiri-1.13.10-arm64-darwin.gem
73ac581ddcb680a912e92da928ffdbac7b36afd3368418f2cee861b96e8c830b
nokogiri-1.13.10-java.gem
916aa17e624611dddbf2976ecce1b4a80633c6378f8465cff0efab022ebc2900
nokogiri-1.13.10-x64-mingw-ucrt.gem
0f85a1ad8c2b02c166a6637237133505b71a05f1bb41b91447005449769bced0
nokogiri-1.13.10-x64-mingw32.gem
91fa3a8724a1ce20fccbd718dafd9acbde099258183ac486992a61b00bb17020
nokogiri-1.13.10-x86-linux.gem
d6663f5900ccd8f72d43660d7f082565b7ffcaade0b9a59a74b3ef8791034168
nokogiri-1.13.10-x86-mingw32.gem
81755fc4b8130ef9678c76a2e5af3db7a0a6664b3cba7d9fe8ef75e7d979e91b
nokogiri-1.13.10-x86_64-darwin.gem
51d5246705dedad0a09b374d09cc193e7383a5dd32136a690a3cd56e95adf0a3
nokogiri-1.13.10-x86_64-linux.gem
d3ee00f26c151763da1691c7fc6871ddd03e532f74f85101f5acedc2d099e958
nokogiri-1.13.10.gem
Changelog

Sourced from nokogiri's changelog.

1.13.10 / 2022-12-07

Security

  • [CRuby] Address CVE-2022-23476, unchecked return value from xmlTextReaderExpand. See GHSA-qv4q-mr5r-qprj for more information.

Improvements

  • [CRuby] XML::Reader#attribute_hash now returns nil on parse errors. This restores the behavior of #attributes from v1.13.7 and earlier. [#2715]
Commits
  • 4c80121 version bump to v1.13.10
  • 85410e3 Merge pull request #2715 from sparklemotion/flavorjones-fix-reader-error-hand...
  • 9fe0761 fix(cruby): XML::Reader#attribute_hash returns nil on error
  • 3b9c736 Merge pull request #2717 from sparklemotion/flavorjones-lock-psych-to-fix-bui...
  • 2efa87b test: skip large cdata test on system libxml2
  • 3187d67 dep(dev): pin psych to v4 until v5 builds in CI
  • a16b4bf style(rubocop): disable Minitest/EmptyLineBeforeAssertionMethods
  • See full diff in compare view

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=nokogiri&package-manager=bundler&previous-version=1.13.9&new-version=1.13.10)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/Gemfile.lock | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index c7ba43711..509a8cbcf 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -231,9 +231,9 @@ GEM jekyll-seo-tag (~> 2.1) minitest (5.15.0) multipart-post (2.1.1) - nokogiri (1.13.9-arm64-darwin) + nokogiri (1.13.10-arm64-darwin) racc (~> 1.4) - nokogiri (1.13.9-x86_64-linux) + nokogiri (1.13.10-x86_64-linux) racc (~> 1.4) octokit (4.22.0) faraday (>= 0.9) @@ -241,7 +241,7 @@ GEM pathutil (0.16.2) forwardable-extended (~> 2.6) public_suffix (4.0.7) - racc (1.6.0) + racc (1.6.1) rb-fsevent (0.11.1) rb-inotify (0.10.1) ffi (~> 1.0) From aaf4e5c8654463cc5431d57db9163ad9ed558f53 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 9 Dec 2022 18:45:42 +0100 Subject: [PATCH 27/67] Use older sytest-dendrite image --- .github/workflows/dendrite.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 593012ef3..2c04005d2 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -331,7 +331,8 @@ jobs: postgres: postgres api: full-http container: - image: matrixdotorg/sytest-dendrite:latest + # Temporary for debugging to see if this image is working better. + image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73 volumes: - ${{ github.workspace }}:/src - /root/.cache/go-build:/github/home/.cache/go-build From 7d2344049d0780e92b071939addc41219ce94f5a Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 12 Dec 2022 08:20:59 +0100 Subject: [PATCH 28/67] Cleanup stale device lists for users we don't share a room with anymore (#2857) The stale device lists table might contain entries for users we don't share a room with anymore. This now asks the roomserver about left users and removes those entries from the table. Co-authored-by: Neil Alexander --- build/dendritejs-pinecone/main.go | 6 +- build/gobind-pinecone/monolith.go | 2 +- build/gobind-yggdrasil/monolith.go | 2 +- cmd/dendrite-demo-pinecone/main.go | 2 +- cmd/dendrite-demo-yggdrasil/main.go | 5 +- cmd/dendrite-monolith-server/main.go | 2 +- .../personalities/keyserver.go | 3 +- keyserver/internal/device_list_update.go | 29 +++++- keyserver/internal/device_list_update_test.go | 86 ++++++++++++++++- keyserver/keyserver.go | 12 ++- keyserver/keyserver_test.go | 29 ++++++ keyserver/storage/interface.go | 5 + .../storage/postgres/stale_device_lists.go | 33 +++++-- keyserver/storage/shared/storage.go | 10 ++ .../storage/sqlite3/stale_device_lists.go | 44 +++++++-- keyserver/storage/tables/interface.go | 1 + .../storage/tables/stale_device_lists_test.go | 94 ++++++++++++++++++ roomserver/api/api.go | 5 + roomserver/api/api_trace.go | 6 ++ roomserver/api/query.go | 12 +++ roomserver/internal/query/query.go | 6 ++ roomserver/inthttp/client.go | 8 ++ roomserver/inthttp/server.go | 5 + roomserver/roomserver_test.go | 77 ++++++++++++++- roomserver/storage/interface.go | 1 + .../storage/postgres/membership_table.go | 34 ++++++- roomserver/storage/shared/storage.go | 37 +++++++ roomserver/storage/shared/storage_test.go | 96 +++++++++++++++++++ .../storage/sqlite3/membership_table.go | 47 ++++++++- roomserver/storage/tables/interface.go | 1 + .../storage/tables/membership_table_test.go | 6 ++ 31 files changed, 666 insertions(+), 40 deletions(-) create mode 100644 keyserver/keyserver_test.go create mode 100644 keyserver/storage/tables/stale_device_lists_test.go create mode 100644 roomserver/storage/shared/storage_test.go diff --git a/build/dendritejs-pinecone/main.go b/build/dendritejs-pinecone/main.go index e070173aa..f44a77488 100644 --- a/build/dendritejs-pinecone/main.go +++ b/build/dendritejs-pinecone/main.go @@ -180,14 +180,14 @@ func startup() { base := base.NewBaseDendrite(cfg, "Monolith") defer base.Close() // nolint: errcheck + rsAPI := roomserver.NewInternalAPI(base) + federation := conn.CreateFederationClient(base, pSessions) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - rsAPI := roomserver.NewInternalAPI(base) - userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index e8ed8fe85..b8f8111d2 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -350,7 +350,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) diff --git a/build/gobind-yggdrasil/monolith.go b/build/gobind-yggdrasil/monolith.go index 9a3ac5d7b..8c2d0a006 100644 --- a/build/gobind-yggdrasil/monolith.go +++ b/build/gobind-yggdrasil/monolith.go @@ -165,7 +165,7 @@ func (m *DendriteMonolith) Start() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI) userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 2f647a41b..3f627b41d 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -213,7 +213,7 @@ func main() { base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsComponent) userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(userAPI) diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 5dd61b1b7..3ea4a08b0 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -157,11 +157,12 @@ func main() { serverKeyAPI := &signing.YggdrasilKeys{} keyRing := serverKeyAPI.KeyRing() - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation) - rsComponent := roomserver.NewInternalAPI( base, ) + + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsComponent) + rsAPI := rsComponent userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient()) diff --git a/cmd/dendrite-monolith-server/main.go b/cmd/dendrite-monolith-server/main.go index 2d2f32b00..6836b6426 100644 --- a/cmd/dendrite-monolith-server/main.go +++ b/cmd/dendrite-monolith-server/main.go @@ -95,7 +95,7 @@ func main() { } keyRing := fsAPI.KeyRing() - keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) keyAPI := keyImpl if base.UseHTTPAPIs { keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics) diff --git a/cmd/dendrite-polylith-multi/personalities/keyserver.go b/cmd/dendrite-polylith-multi/personalities/keyserver.go index d2924b892..ad0bd0e54 100644 --- a/cmd/dendrite-polylith-multi/personalities/keyserver.go +++ b/cmd/dendrite-polylith-multi/personalities/keyserver.go @@ -22,7 +22,8 @@ import ( func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) { fsAPI := base.FederationAPIHTTPClient() - intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI) + rsAPI := base.RoomserverHTTPClient() + intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) intAPI.SetUserAPI(base.UserAPIClient()) keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics) diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 8ff9dfc31..c7bf8da53 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -24,6 +24,8 @@ import ( "sync" "time" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrix" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -102,6 +104,7 @@ type DeviceListUpdater struct { // block on or timeout via a select. userIDToChan map[string]chan bool userIDToChanMu *sync.Mutex + rsAPI rsapi.KeyserverRoomserverAPI } // DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater. @@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface { // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced. DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error + + DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error } type DeviceListUpdaterAPI interface { @@ -140,7 +145,7 @@ func NewDeviceListUpdater( process *process.ProcessContext, db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer, fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int, - thisServer gomatrixserverlib.ServerName, + rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName, ) *DeviceListUpdater { return &DeviceListUpdater{ process: process, @@ -154,6 +159,7 @@ func NewDeviceListUpdater( workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers), userIDToChan: make(map[string]chan bool), userIDToChanMu: &sync.Mutex{}, + rsAPI: rsAPI, } } @@ -168,7 +174,7 @@ func (u *DeviceListUpdater) Start() error { go u.worker(ch) } - staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{}) + staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) if err != nil { return err } @@ -186,6 +192,25 @@ func (u *DeviceListUpdater) Start() error { return nil } +// CleanUp removes stale device entries for users we don't share a room with anymore +func (u *DeviceListUpdater) CleanUp() error { + staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{}) + if err != nil { + return err + } + + res := rsapi.QueryLeftUsersResponse{} + if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil { + return err + } + + if len(res.LeftUsers) == 0 { + return nil + } + logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers)) + return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers) +} + func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex { u.mu.Lock() defer u.mu.Unlock() diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index a374c9516..60a2c2f30 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -30,7 +30,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/keyserver/storage" + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" ) var ( @@ -53,6 +58,10 @@ type mockDeviceListUpdaterDatabase struct { mu sync.Mutex // protect staleUsers } +func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error { + return nil +} + // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) { @@ -153,7 +162,7 @@ func TestUpdateHavePrevID(t *testing.T) { } ap := &mockDeviceListUpdaterAPI{} producer := &mockKeyChangeProducer{} - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost") event := gomatrixserverlib.DeviceListUpdateEvent{ DeviceDisplayName: "Foo Bar", Deleted: false, @@ -225,7 +234,7 @@ func TestUpdateNoPrevID(t *testing.T) { `)), }, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -239,6 +248,7 @@ func TestUpdateNoPrevID(t *testing.T) { UserID: remoteUserID, } err := updater.Update(ctx, event) + if err != nil { t.Fatalf("Update returned an error: %s", err) } @@ -294,7 +304,7 @@ func TestDebounce(t *testing.T) { close(incomingFedReq) return <-fedCh, nil }) - updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost") + updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost") if err := updater.Start(); err != nil { t.Fatalf("failed to start updater: %s", err) } @@ -349,3 +359,73 @@ func TestDebounce(t *testing.T) { t.Errorf("user %s is marked as stale", userID) } } + +func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) { + t.Helper() + + base, _, _ := testrig.Base(nil) + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}) + if err != nil { + t.Fatal(err) + } + + return db, clearDB +} + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +func TestDeviceListUpdater_CleanUp(t *testing.T) { + processCtx := process.NewProcessContext() + + alice := test.NewUser(t) + bob := test.NewUser(t) + + // Bob is not joined to any of our rooms + rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}} + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, clearDB := mustCreateKeyserverDB(t, dbType) + defer clearDB() + + // This should not get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil { + t.Error(err) + } + + // this one should get deleted + if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil { + t.Error(err) + } + + updater := NewDeviceListUpdater(processCtx, db, nil, + nil, nil, + 0, rsAPI, "test") + if err := updater.CleanUp(); err != nil { + t.Error(err) + } + + // check that we still have Alice in our stale list + staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Error(err) + } + + // There should only be Alice + wantCount := 1 + if count := len(staleUsers); count != wantCount { + t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count) + } + + if staleUsers[0] != alice.ID { + t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID) + } + }) +} diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 5360c06fd..275576773 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -18,6 +18,8 @@ import ( "github.com/gorilla/mux" "github.com/sirupsen/logrus" + rsapi "github.com/matrix-org/dendrite/roomserver/api" + fedsenderapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/consumers" @@ -40,6 +42,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetr // can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes. func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI, + rsAPI rsapi.KeyserverRoomserverAPI, ) api.KeyInternalAPI { js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) @@ -47,6 +50,7 @@ func NewInternalAPI( if err != nil { logrus.WithError(err).Panicf("failed to connect to key server database") } + keyChangeProducer := &producers.KeyChange{ Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)), JetStream: js, @@ -58,8 +62,14 @@ func NewInternalAPI( FedClient: fedClient, Producer: keyChangeProducer, } - updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable + updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable ap.Updater = updater + + // Remove users which we don't share a room with anymore + if err := updater.CleanUp(); err != nil { + logrus.WithError(err).Error("failed to cleanup stale device lists") + } + go func() { if err := updater.Start(); err != nil { logrus.WithError(err).Panicf("failed to start device list updater") diff --git a/keyserver/keyserver_test.go b/keyserver/keyserver_test.go new file mode 100644 index 000000000..159b280f5 --- /dev/null +++ b/keyserver/keyserver_test.go @@ -0,0 +1,29 @@ +package keyserver + +import ( + "context" + "testing" + + roomserver "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +type mockKeyserverRoomserverAPI struct { + leftUsers []string +} + +func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error { + res.LeftUsers = m.leftUsers + return nil +} + +// Merely tests that we can create an internal keyserver API +func Test_NewInternalAPI(t *testing.T) { + rsAPI := &mockKeyserverRoomserverAPI{} + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, closeBase := testrig.CreateBaseDendrite(t, dbType) + defer closeBase() + _ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + }) +} diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 242e16a06..c6a8f44cd 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -85,4 +85,9 @@ type Database interface { StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error + + DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, + ) error } diff --git a/keyserver/storage/postgres/stale_device_lists.go b/keyserver/storage/postgres/stale_device_lists.go index d0fe50d00..248ddfb45 100644 --- a/keyserver/storage/postgres/stale_device_lists.go +++ b/keyserver/storage/postgres/stale_device_lists.go @@ -19,6 +19,10 @@ import ( "database/sql" "time" + "github.com/lib/pq" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" @@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)" + type staleDeviceListsStatements struct { upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + deleteStaleDeviceListsStmt *sql.Stmt } func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + {&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt) + _, err := stmt.ExecContext(ctx, pq.Array(userIDs)) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5beeed0f1..54dd6ddc9 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget( return nil }) } + +// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore. +func (d *Database) DeleteStaleDeviceLists( + ctx context.Context, + userIDs []string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs) + }) +} diff --git a/keyserver/storage/sqlite3/stale_device_lists.go b/keyserver/storage/sqlite3/stale_device_lists.go index 1e08b266c..fd76a6e3b 100644 --- a/keyserver/storage/sqlite3/stale_device_lists.go +++ b/keyserver/storage/sqlite3/stale_device_lists.go @@ -17,8 +17,11 @@ package sqlite3 import ( "context" "database/sql" + "strings" "time" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/gomatrixserverlib" @@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" + const selectStaleDeviceListsSQL = "" + "SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC" +const deleteStaleDevicesSQL = "" + + "DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)" + type staleDeviceListsStatements struct { db *sql.DB upsertStaleDeviceListStmt *sql.Stmt selectStaleDeviceListsWithDomainsStmt *sql.Stmt selectStaleDeviceListsStmt *sql.Stmt + // deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime } func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) { @@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) if err != nil { return nil, err } - if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil { - return nil, err - } - if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL}, + {&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL}, + {&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL}, + // { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime + }.Prepare(db) } func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error { @@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte return result, nil } +// DeleteStaleDeviceLists removes users from stale device lists +func (s *staleDeviceListsStatements) DeleteStaleDeviceLists( + ctx context.Context, txn *sql.Tx, userIDs []string, +) error { + qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1) + stmt, err := s.db.Prepare(qry) + if err != nil { + return err + } + defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed") + stmt = sqlutil.TxStmt(txn, stmt) + + params := make([]any, len(userIDs)) + for i := range userIDs { + params[i] = userIDs[i] + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) { defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed") for rows.Next() { diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 37a010a7c..24da1125e 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -56,6 +56,7 @@ type KeyChanges interface { type StaleDeviceLists interface { InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) + DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error } type CrossSigningKeys interface { diff --git a/keyserver/storage/tables/stale_device_lists_test.go b/keyserver/storage/tables/stale_device_lists_test.go new file mode 100644 index 000000000..76d3baddd --- /dev/null +++ b/keyserver/storage/tables/stale_device_lists_test.go @@ -0,0 +1,94 @@ +package tables_test + +import ( + "context" + "testing" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/config" + + "github.com/matrix-org/dendrite/keyserver/storage/postgres" + "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/dendrite/test" +) + +func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, nil) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresStaleDeviceListsTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db) + } + if err != nil { + t.Fatalf("failed to create new table: %s", err) + } + return tab, close +} + +func TestStaleDeviceLists(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := "@charlie:localhost" + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateTable(t, dbType) + defer closeDB() + + if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil { + t.Fatalf("failed to insert stale device: %s", err) + } + + // Query one server + wantStaleUsers := []string{alice.ID, bob.ID} + gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Query all servers + wantStaleUsers = []string{alice.ID, bob.ID, charlie} + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) { + t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers) + } + + // Delete stale devices + deleteUsers := []string{alice.ID, bob.ID} + if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil { + t.Fatalf("failed to delete stale device lists: %s", err) + } + + // Verify we don't get anything back after deleting + gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"}) + if err != nil { + t.Fatalf("failed to query stale device lists: %s", err) + } + + if gotCount := len(gotStaleUsers); gotCount > 0 { + t.Fatalf("expected no stale users, got %d", gotCount) + } + }) +} diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 01e87ec8a..420ef278a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -17,6 +17,7 @@ type RoomserverInternalAPI interface { ClientRoomserverAPI UserRoomserverAPI FederationRoomserverAPI + KeyserverRoomserverAPI // needed to avoid chicken and egg scenario when setting up the // interdependencies between the roomserver and other input APIs @@ -199,3 +200,7 @@ type FederationRoomserverAPI interface { // Query a given amount (or less) of events prior to a given set of events. PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error } + +type KeyserverRoomserverAPI interface { + QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error +} diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 342a3904c..b23263d17 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -19,6 +19,12 @@ type RoomserverInternalAPITrace struct { Impl RoomserverInternalAPI } +func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error { + err := t.Impl.QueryLeftUsers(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("QueryLeftUsers req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) { t.Impl.SetFederationAPI(fsAPI, keyRing) } diff --git a/roomserver/api/query.go b/roomserver/api/query.go index b62907f3c..76f8298ca 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -447,3 +447,15 @@ type QueryMembershipAtEventResponse struct { // do not have known state will return an empty array here. Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"` } + +// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a +// a room with anymore. This is used to cleanup stale device list entries, where we would +// otherwise keep on trying to get device lists. +type QueryLeftUsersRequest struct { + StaleDeviceListUsers []string `json:"user_ids"` +} + +// QueryLeftUsersResponse is the response to QueryLeftUsersRequest. +type QueryLeftUsersResponse struct { + LeftUsers []string `json:"user_ids"` +} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index d8456fb43..69d841dda 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS return nil } +func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error { + var err error + res.LeftUsers, err = r.DB.GetLeftUsers(ctx, req.StaleDeviceListUsers) + return err +} + func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error { roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join") if err != nil { diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 1bd1b3fb7..8a2e0a03c 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -63,6 +63,7 @@ const ( RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain" RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed" RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent" + RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers" ) type httpRoomserverInternalAPI struct { @@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context, h.httpClient, ctx, request, response, ) } + +func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error { + return httputil.CallInternalRPCAPI( + "RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath, + h.httpClient, ctx, request, response, + ) +} diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 6e7c2d985..4d21909b7 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe RoomserverQueryMembershipAtEventPath, httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent), ) + + internalAPIMux.Handle( + RoomserverQueryLeftMembersPath, + httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", enableMetrics, r.QueryLeftUsers), + ) } diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 24b5515e5..518bb3722 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -2,20 +2,27 @@ package roomserver_test import ( "context" + "net/http" "testing" + "time" + "github.com/gorilla/mux" + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" - "github.com/matrix-org/gomatrixserverlib" ) func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) { + t.Helper() base, close := testrig.CreateBaseDendrite(t, dbType) - db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches) + db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches) if err != nil { t.Fatalf("failed to create Database: %v", err) } @@ -67,3 +74,69 @@ func Test_SharedUsers(t *testing.T) { } }) } + +func Test_QueryLeftUsers(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite and join Bob + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, _, close := mustCreateDatabase(t, dbType) + defer close() + + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // Query the left users, there should only be "@idontexist:test", + // as Alice and Bob are still joined. + res := &api.QueryLeftUsersResponse{} + leftUserID := "@idontexist:test" + getLeftUsersList := []string{alice.ID, bob.ID, leftUserID} + + testCase := func(rsAPI api.RoomserverInternalAPI) { + if err := rsAPI.QueryLeftUsers(ctx, &api.QueryLeftUsersRequest{StaleDeviceListUsers: getLeftUsersList}, res); err != nil { + t.Fatalf("unable to query left users: %v", err) + } + wantCount := 1 + if count := len(res.LeftUsers); count > wantCount { + t.Fatalf("unexpected left users count: want %d, got %d", wantCount, count) + } + if res.LeftUsers[0] != leftUserID { + t.Fatalf("unexpected left users : want %s, got %s", leftUserID, res.LeftUsers[0]) + } + } + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + roomserver.AddInternalRoutes(router, rsAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + httpAPI, err := inthttp.NewRoomserverClient(apiURL, &http.Client{Timeout: time.Second * 5}, nil) + if err != nil { + t.Fatalf("failed to create HTTP client") + } + testCase(httpAPI) + }) + t.Run("Monolith", func(t *testing.T) { + testCase(rsAPI) + // also test tracing + traceAPI := &api.RoomserverInternalAPITrace{Impl: rsAPI} + testCase(traceAPI) + }) + + }) +} diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 06db4b2d8..92bc2e66f 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -172,5 +172,6 @@ type Database interface { ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) + GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/postgres/membership_table.go b/roomserver/storage/postgres/membership_table.go index 0150534e1..d774b7892 100644 --- a/roomserver/storage/postgres/membership_table.go +++ b/roomserver/storage/postgres/membership_table.go @@ -21,12 +21,13 @@ import ( "fmt" "github.com/lib/pq" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -157,6 +158,12 @@ const selectServerInRoomSQL = "" + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid = ANY($2) +` + type membershipStatements struct { insertMembershipStmt *sql.Stmt selectMembershipForUpdateStmt *sql.Stmt @@ -174,6 +181,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt } func CreateMembershipTable(db *sql.DB) error { @@ -209,9 +217,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) { {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.deleteMembershipStmt, deleteMembershipSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, }.Prepare(db) } +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt) + rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed") + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} + func (s *membershipStatements) InsertMembership( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 16898bcb1..725cc5bc7 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -1365,6 +1365,43 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [ return result, nil } +// GetLeftUsers calculates users we (the server) don't share a room with anymore. +func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) { + // Get the userNID for all users with a stale device list + stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs) + if err != nil { + return nil, err + } + + userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap)) + userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap)) + // Create a map from userNID -> userID + for userID, nid := range stateKeyNIDMap { + userNIDs = append(userNIDs, nid) + userNIDtoUserID[nid] = userID + } + + // Get all users whose membership is still join, knock or invite. + stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs) + if err != nil { + return nil, err + } + + // Remove joined users from the "user with stale devices" list, which contains left AND joined users + for _, joinedUser := range stillJoinedUsersNIDs { + delete(userNIDtoUserID, joinedUser) + } + + // The users still in our userNIDtoUserID map are the users we don't share a room with anymore, + // and the return value we are looking for. + leftUsers := make([]string, 0, len(userNIDtoUserID)) + for _, userID := range userNIDtoUserID { + leftUsers = append(leftUsers, userID) + } + + return leftUsers, nil +} + // GetLocalServerInRoom returns true if we think we're in a given room or false otherwise. func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) { return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID) diff --git a/roomserver/storage/shared/storage_test.go b/roomserver/storage/shared/storage_test.go new file mode 100644 index 000000000..58724340c --- /dev/null +++ b/roomserver/storage/shared/storage_test.go @@ -0,0 +1,96 @@ +package shared_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/postgres" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/sqlite3" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" +) + +func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Database, func()) { + t.Helper() + + connStr, clearDB := test.PrepareDBConnectionString(t, dbType) + base, _, _ := testrig.Base(nil) + dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)} + + db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + + var membershipTable tables.Membership + var stateKeyTable tables.EventStateKeys + switch dbType { + case test.DBTypePostgres: + err = postgres.CreateEventStateKeysTable(db) + assert.NoError(t, err) + err = postgres.CreateMembershipTable(db) + assert.NoError(t, err) + membershipTable, err = postgres.PrepareMembershipTable(db) + assert.NoError(t, err) + stateKeyTable, err = postgres.PrepareEventStateKeysTable(db) + case test.DBTypeSQLite: + err = sqlite3.CreateEventStateKeysTable(db) + assert.NoError(t, err) + err = sqlite3.CreateMembershipTable(db) + assert.NoError(t, err) + membershipTable, err = sqlite3.PrepareMembershipTable(db) + assert.NoError(t, err) + stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db) + } + assert.NoError(t, err) + + return &shared.Database{ + DB: db, + EventStateKeysTable: stateKeyTable, + MembershipTable: membershipTable, + Writer: sqlutil.NewExclusiveWriter(), + }, func() { + err := base.Close() + assert.NoError(t, err) + clearDB() + err = db.Close() + assert.NoError(t, err) + } +} + +func Test_GetLeftUsers(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRoomserverDatabase(t, dbType) + defer close() + + // Create dummy entries + for _, user := range []*test.User{alice, bob, charlie} { + nid, err := db.EventStateKeysTable.InsertEventStateKeyNID(ctx, nil, user.ID) + assert.NoError(t, err) + err = db.MembershipTable.InsertMembership(ctx, nil, 1, nid, true) + assert.NoError(t, err) + // We must update the membership with a non-zero event NID or it will get filtered out in later queries + membershipNID := tables.MembershipStateLeaveOrBan + if user == alice { + membershipNID = tables.MembershipStateJoin + } + _, err = db.MembershipTable.UpdateMembership(ctx, nil, 1, nid, nid, membershipNID, 1, false) + assert.NoError(t, err) + } + + // Now try to get the left users, this should be Bob and Charlie, since they have a "leave" membership + expectedUserIDs := []string{bob.ID, charlie.ID} + leftUsers, err := db.GetLeftUsers(context.Background(), []string{alice.ID, bob.ID, charlie.ID}) + assert.NoError(t, err) + assert.ElementsMatch(t, expectedUserIDs, leftUsers) + }) +} diff --git a/roomserver/storage/sqlite3/membership_table.go b/roomserver/storage/sqlite3/membership_table.go index cd149f0ed..8a60b359f 100644 --- a/roomserver/storage/sqlite3/membership_table.go +++ b/roomserver/storage/sqlite3/membership_table.go @@ -21,12 +21,13 @@ import ( "fmt" "strings" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/gomatrixserverlib" ) const membershipSchema = ` @@ -133,6 +134,12 @@ const selectServerInRoomSQL = "" + const deleteMembershipSQL = "" + "DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2" +const selectJoinedUsersSQL = ` +SELECT DISTINCT target_nid +FROM roomserver_membership m +WHERE membership_nid > $1 AND target_nid IN ($2) +` + type membershipStatements struct { db *sql.DB insertMembershipStmt *sql.Stmt @@ -149,6 +156,7 @@ type membershipStatements struct { selectLocalServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt deleteMembershipStmt *sql.Stmt + // selectJoinedUsersStmt *sql.Stmt // Prepared at runtime } func CreateMembershipTable(db *sql.DB) error { @@ -412,3 +420,40 @@ func (s *membershipStatements) DeleteMembership( ) return err } + +func (s *membershipStatements) SelectJoinedUsers( + ctx context.Context, txn *sql.Tx, + targetUserNIDs []types.EventStateKeyNID, +) ([]types.EventStateKeyNID, error) { + result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs)) + + qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1) + + stmt, err := s.db.Prepare(qry) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed") + + params := make([]any, len(targetUserNIDs)+1) + params[0] = tables.MembershipStateLeaveOrBan + for i := range targetUserNIDs { + params[i+1] = targetUserNIDs[i] + } + + stmt = sqlutil.TxStmt(txn, stmt) + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed") + var targetNID types.EventStateKeyNID + for rows.Next() { + if err = rows.Scan(&targetNID); err != nil { + return nil, err + } + result = append(result, targetNID) + } + + return result, rows.Err() +} diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 50d27c756..80fcf72dd 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -144,6 +144,7 @@ type Membership interface { SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error + SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error) } type Published interface { diff --git a/roomserver/storage/tables/membership_table_test.go b/roomserver/storage/tables/membership_table_test.go index c9541d9d2..c4524ee44 100644 --- a/roomserver/storage/tables/membership_table_test.go +++ b/roomserver/storage/tables/membership_table_test.go @@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) { knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2) assert.NoError(t, err) assert.Equal(t, 1, len(knownUsers)) + + // get users we share a room with, given their userNID + joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs) + assert.NoError(t, err) + // Only userNIDs[0] is actually joined, so we only expect this userNID + assert.Equal(t, userNIDs[:1], joinedUsers) }) } From 76db8e90defdfb9e61f6caea8a312c5d60bcc005 Mon Sep 17 00:00:00 2001 From: Kento Okamoto Date: Mon, 12 Dec 2022 08:46:37 -0800 Subject: [PATCH 29/67] Dendrite Documentation Fix (#2913) ### Pull Request Checklist I was reading through the Dendrite documentation on https://matrix-org.github.io/dendrite/development/contributing and noticed the installation link leads to a 404 error. This link works fine if it is viewed directly from [docs/CONTRIBUTING.md](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md) but this might not be very obvious to new contributors who are reading through the [contribution page](https://matrix-org.github.io/dendrite/development/contributing) directly. This PR is mainly a small re-organization of the online documentation mainly in the [Development](https://matrix-org.github.io/dendrite/development) tab along with any links throughout the doc that may be impacted by the change. This does not contain any Go unit tests as this does not actually touch core dendrite functionality. * [ ] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Kento Okamoto ` --- docs/FAQ.md | 4 ++-- docs/{ => development}/CONTRIBUTING.md | 6 +++--- docs/{ => development}/PROFILING.md | 0 docs/{ => development}/coverage.md | 0 docs/{ => development}/sytest.md | 0 docs/{ => development}/tracing/opentracing.md | 0 docs/{ => development}/tracing/setup.md | 0 7 files changed, 5 insertions(+), 5 deletions(-) rename docs/{ => development}/CONTRIBUTING.md (97%) rename docs/{ => development}/PROFILING.md (100%) rename docs/{ => development}/coverage.md (100%) rename docs/{ => development}/sytest.md (100%) rename docs/{ => development}/tracing/opentracing.md (100%) rename docs/{ => development}/tracing/setup.md (100%) diff --git a/docs/FAQ.md b/docs/FAQ.md index ca72b151d..816130515 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -91,7 +91,7 @@ Please use PostgreSQL wherever possible, especially if you are planning to run a ## Dendrite is using a lot of CPU Generally speaking, you should expect to see some CPU spikes, particularly if you are joining or participating in large rooms. However, constant/sustained high CPU usage is not expected - if you are experiencing that, please join `#dendrite-dev:matrix.org` and let us know what you were doing when the -CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](PROFILING.md) then that would +CPU usage shot up, or file a GitHub issue. If you can take a [CPU profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the CPU time is going. ## Dendrite is using a lot of RAM @@ -99,7 +99,7 @@ be a huge help too, as that will help us to understand where the CPU time is goi As above with CPU usage, some memory spikes are expected if Dendrite is doing particularly heavy work at a given instant. However, if it is using more RAM than you expect for a long time, that's probably not expected. Join `#dendrite-dev:matrix.org` and let us know what you were doing when the memory usage -ballooned, or file a GitHub issue if you can. If you can take a [memory profile](PROFILING.md) then that +ballooned, or file a GitHub issue if you can. If you can take a [memory profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the memory usage is happening. ## Dendrite is running out of PostgreSQL database connections diff --git a/docs/CONTRIBUTING.md b/docs/development/CONTRIBUTING.md similarity index 97% rename from docs/CONTRIBUTING.md rename to docs/development/CONTRIBUTING.md index 21b0e7ab1..2aec4c363 100644 --- a/docs/CONTRIBUTING.md +++ b/docs/development/CONTRIBUTING.md @@ -9,7 +9,7 @@ permalink: /development/contributing Everyone is welcome to contribute to Dendrite! We aim to make it as easy as possible to get started. - ## Contribution types +## Contribution types We are a small team maintaining a large project. As a result, we cannot merge every feature, even if it is bug-free and useful, because we then commit to maintaining it indefinitely. We will always accept: @@ -57,7 +57,7 @@ to do so for future contributions. ## Getting up and running -See the [Installation](installation) section for information on how to build an +See the [Installation](../installation) section for information on how to build an instance of Dendrite. You will likely need this in order to test your changes. ## Code style @@ -151,7 +151,7 @@ significant amount of CPU and RAM. Once the code builds, run [Sytest](https://github.com/matrix-org/sytest) according to the guide in -[docs/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/sytest.md#using-a-sytest-docker-image) +[docs/development/sytest.md](https://github.com/matrix-org/dendrite/blob/main/docs/development/sytest.md#using-a-sytest-docker-image) so you can see whether something is being broken and whether there are newly passing tests. diff --git a/docs/PROFILING.md b/docs/development/PROFILING.md similarity index 100% rename from docs/PROFILING.md rename to docs/development/PROFILING.md diff --git a/docs/coverage.md b/docs/development/coverage.md similarity index 100% rename from docs/coverage.md rename to docs/development/coverage.md diff --git a/docs/sytest.md b/docs/development/sytest.md similarity index 100% rename from docs/sytest.md rename to docs/development/sytest.md diff --git a/docs/tracing/opentracing.md b/docs/development/tracing/opentracing.md similarity index 100% rename from docs/tracing/opentracing.md rename to docs/development/tracing/opentracing.md diff --git a/docs/tracing/setup.md b/docs/development/tracing/setup.md similarity index 100% rename from docs/tracing/setup.md rename to docs/development/tracing/setup.md From d3db542fbf5b35377586567b21bc5c28872167a1 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 10:56:20 +0100 Subject: [PATCH 30/67] Add federation peeking table tests (#2920) As the title says, adds tests for inbound/outbound peeking federation table tests. Also removes some unused code --- federationapi/queue/queue_test.go | 22 --- federationapi/storage/interface.go | 3 - .../storage/postgres/inbound_peeks_table.go | 32 ++-- .../storage/postgres/outbound_peeks_table.go | 31 ++-- .../storage/postgres/queue_edus_table.go | 21 --- .../storage/postgres/queue_pdus_table.go | 23 --- federationapi/storage/shared/storage_edus.go | 9 - federationapi/storage/shared/storage_pdus.go | 9 - .../storage/sqlite3/inbound_peeks_table.go | 32 ++-- .../storage/sqlite3/outbound_peeks_table.go | 31 ++-- .../storage/sqlite3/queue_edus_table.go | 21 --- .../storage/sqlite3/queue_pdus_table.go | 23 --- federationapi/storage/storage_test.go | 166 ++++++++++++++++++ .../tables/inbound_peeks_table_test.go | 148 ++++++++++++++++ federationapi/storage/tables/interface.go | 2 - .../tables/outbound_peeks_table_test.go | 147 ++++++++++++++++ 16 files changed, 503 insertions(+), 217 deletions(-) create mode 100644 federationapi/storage/tables/inbound_peeks_table_test.go create mode 100644 federationapi/storage/tables/outbound_peeks_table_test.go diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index b2ec4b836..c317edc21 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -221,28 +221,6 @@ func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverl return nil } -func (d *fakeDatabase) GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if pdus, ok := d.associatedPDUs[serverName]; ok { - count = int64(len(pdus)) - } - return count, nil -} - -func (d *fakeDatabase) GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var count int64 - if edus, ok := d.associatedEDUs[serverName]; ok { - count = int64(len(edus)) - } - return count, nil -} - func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { d.dbMutex.Lock() defer d.dbMutex.Unlock() diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index b15b8bfae..276cd9a50 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -45,9 +45,6 @@ type Database interface { CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) - GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) diff --git a/federationapi/storage/postgres/inbound_peeks_table.go b/federationapi/storage/postgres/inbound_peeks_table.go index df5c60761..ad2afcb15 100644 --- a/federationapi/storage/postgres/inbound_peeks_table.go +++ b/federationapi/storage/postgres/inbound_peeks_table.go @@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectInboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER by creation_ts" const renewInboundPeekSQL = "" + "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteInboundPeekSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteInboundPeeksSQL = "" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" @@ -74,25 +74,15 @@ func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err er return } - if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { - return - } - if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { - return - } - if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { - return - } - if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertInboundPeekStmt, insertInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeeksStmt, selectInboundPeeksSQL}, + {&s.renewInboundPeekStmt, renewInboundPeekSQL}, + {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL}, + {&s.deleteInboundPeekStmt, deleteInboundPeekSQL}, + }.Prepare(db) } func (s *inboundPeeksStatements) InsertInboundPeek( diff --git a/federationapi/storage/postgres/outbound_peeks_table.go b/federationapi/storage/postgres/outbound_peeks_table.go index c22d893f7..5df684318 100644 --- a/federationapi/storage/postgres/outbound_peeks_table.go +++ b/federationapi/storage/postgres/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/postgres/queue_edus_table.go b/federationapi/storage/postgres/queue_edus_table.go index d6507e13b..8870dc88d 100644 --- a/federationapi/storage/postgres/queue_edus_table.go +++ b/federationapi/storage/postgres/queue_edus_table.go @@ -62,10 +62,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_edus" + " WHERE json_nid = $1" -const selectQueueEDUCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE server_name = $1" - const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" @@ -81,7 +77,6 @@ type queueEDUsStatements struct { deleteQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt - selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt @@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error { {&s.deleteQueueEDUStmt, deleteQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, - {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, @@ -186,21 +180,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( return count, err } -func (s *queueEDUsStatements) SelectQueueEDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationapi/storage/postgres/queue_pdus_table.go b/federationapi/storage/postgres/queue_pdus_table.go index 38ac5a6eb..3b0bef9af 100644 --- a/federationapi/storage/postgres/queue_pdus_table.go +++ b/federationapi/storage/postgres/queue_pdus_table.go @@ -58,10 +58,6 @@ const selectQueuePDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" -const selectQueuePDUsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE server_name = $1" - const selectQueuePDUServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" @@ -71,7 +67,6 @@ type queuePDUsStatements struct { deleteQueuePDUsStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt selectQueuePDUReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt selectQueuePDUServerNamesStmt *sql.Stmt } @@ -95,9 +90,6 @@ func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil { - return - } if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil { return } @@ -146,21 +138,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) SelectQueuePDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index c796d2f8f..be8355f31 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -162,15 +162,6 @@ func (d *Database) CleanEDUs( }) } -// GetPendingEDUCount returns the number of EDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingEDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName) -} - // GetPendingServerNames returns the server names that have EDUs // waiting to be sent. func (d *Database) GetPendingEDUServerNames( diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index dc37d7507..da4cb979d 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -141,15 +141,6 @@ func (d *Database) CleanPDUs( }) } -// GetPendingPDUCount returns the number of PDUs waiting to be -// sent for a given servername. -func (d *Database) GetPendingPDUCount( - ctx context.Context, - serverName gomatrixserverlib.ServerName, -) (int64, error) { - return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName) -} - // GetPendingServerNames returns the server names that have PDUs // waiting to be sent. func (d *Database) GetPendingPDUServerNames( diff --git a/federationapi/storage/sqlite3/inbound_peeks_table.go b/federationapi/storage/sqlite3/inbound_peeks_table.go index ad3c4a6dd..8c3567934 100644 --- a/federationapi/storage/sqlite3/inbound_peeks_table.go +++ b/federationapi/storage/sqlite3/inbound_peeks_table.go @@ -44,13 +44,13 @@ const selectInboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectInboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewInboundPeekSQL = "" + "UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteInboundPeekSQL = "" + - "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteInboundPeeksSQL = "" + "DELETE FROM federationsender_inbound_peeks WHERE room_id = $1" @@ -74,25 +74,15 @@ func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err erro return } - if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil { - return - } - if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil { - return - } - if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil { - return - } - if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil { - return - } - if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertInboundPeekStmt, insertInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeekStmt, selectInboundPeekSQL}, + {&s.selectInboundPeeksStmt, selectInboundPeeksSQL}, + {&s.renewInboundPeekStmt, renewInboundPeekSQL}, + {&s.deleteInboundPeeksStmt, deleteInboundPeeksSQL}, + {&s.deleteInboundPeekStmt, deleteInboundPeekSQL}, + }.Prepare(db) } func (s *inboundPeeksStatements) InsertInboundPeek( diff --git a/federationapi/storage/sqlite3/outbound_peeks_table.go b/federationapi/storage/sqlite3/outbound_peeks_table.go index e29026fab..33f452b68 100644 --- a/federationapi/storage/sqlite3/outbound_peeks_table.go +++ b/federationapi/storage/sqlite3/outbound_peeks_table.go @@ -44,13 +44,13 @@ const selectOutboundPeekSQL = "" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const selectOutboundPeeksSQL = "" + - "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1" + "SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 ORDER BY creation_ts" const renewOutboundPeekSQL = "" + "UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5" const deleteOutboundPeekSQL = "" + - "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3" const deleteOutboundPeeksSQL = "" + "DELETE FROM federationsender_outbound_peeks WHERE room_id = $1" @@ -74,25 +74,14 @@ func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err er return } - if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil { - return - } - if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil { - return - } - if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil { - return - } - if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil { - return - } - if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil { - return - } - return + return s, sqlutil.StatementList{ + {&s.insertOutboundPeekStmt, insertOutboundPeekSQL}, + {&s.selectOutboundPeekStmt, selectOutboundPeekSQL}, + {&s.selectOutboundPeeksStmt, selectOutboundPeeksSQL}, + {&s.renewOutboundPeekStmt, renewOutboundPeekSQL}, + {&s.deleteOutboundPeeksStmt, deleteOutboundPeeksSQL}, + {&s.deleteOutboundPeekStmt, deleteOutboundPeekSQL}, + }.Prepare(db) } func (s *outboundPeeksStatements) InsertOutboundPeek( diff --git a/federationapi/storage/sqlite3/queue_edus_table.go b/federationapi/storage/sqlite3/queue_edus_table.go index 8e7e7901f..0dc914328 100644 --- a/federationapi/storage/sqlite3/queue_edus_table.go +++ b/federationapi/storage/sqlite3/queue_edus_table.go @@ -63,10 +63,6 @@ const selectQueueEDUReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_edus" + " WHERE json_nid = $1" -const selectQueueEDUCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_edus" + - " WHERE server_name = $1" - const selectQueueServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_edus" @@ -82,7 +78,6 @@ type queueEDUsStatements struct { // deleteQueueEDUStmt *sql.Stmt - prepared at runtime due to variadic selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt - selectQueueEDUCountStmt *sql.Stmt selectQueueEDUServerNamesStmt *sql.Stmt selectExpiredEDUsStmt *sql.Stmt deleteExpiredEDUsStmt *sql.Stmt @@ -116,7 +111,6 @@ func (s *queueEDUsStatements) Prepare() error { {&s.insertQueueEDUStmt, insertQueueEDUSQL}, {&s.selectQueueEDUStmt, selectQueueEDUSQL}, {&s.selectQueueEDUReferenceJSONCountStmt, selectQueueEDUReferenceJSONCountSQL}, - {&s.selectQueueEDUCountStmt, selectQueueEDUCountSQL}, {&s.selectQueueEDUServerNamesStmt, selectQueueServerNamesSQL}, {&s.selectExpiredEDUsStmt, selectExpiredEDUsSQL}, {&s.deleteExpiredEDUsStmt, deleteExpiredEDUsSQL}, @@ -198,21 +192,6 @@ func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount( return count, err } -func (s *queueEDUsStatements) SelectQueueEDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queueEDUsStatements) SelectQueueEDUServerNames( ctx context.Context, txn *sql.Tx, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationapi/storage/sqlite3/queue_pdus_table.go b/federationapi/storage/sqlite3/queue_pdus_table.go index e818585a5..aee8b03d6 100644 --- a/federationapi/storage/sqlite3/queue_pdus_table.go +++ b/federationapi/storage/sqlite3/queue_pdus_table.go @@ -66,10 +66,6 @@ const selectQueuePDUsReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" -const selectQueuePDUsCountSQL = "" + - "SELECT COUNT(*) FROM federationsender_queue_pdus" + - " WHERE server_name = $1" - const selectQueuePDUsServerNamesSQL = "" + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" @@ -79,7 +75,6 @@ type queuePDUsStatements struct { selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt - selectQueuePDUsCountStmt *sql.Stmt selectQueueServerNamesStmt *sql.Stmt // deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic } @@ -107,9 +102,6 @@ func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil { return } - if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { - return - } if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil { return } @@ -179,21 +171,6 @@ func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount( return count, err } -func (s *queuePDUsStatements) SelectQueuePDUCount( - ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (int64, error) { - var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) - err := stmt.QueryRowContext(ctx, serverName).Scan(&count) - if err == sql.ErrNoRows { - // It's acceptable for there to be no rows referencing a given - // JSON NID but it's not an error condition. Just return as if - // there's a zero count. - return 0, nil - } - return count, err -} - func (s *queuePDUsStatements) SelectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index f7408fa9f..14efa2655 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -2,10 +2,12 @@ package storage_test import ( "context" + "reflect" "testing" "time" "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/federationapi/storage" @@ -80,3 +82,167 @@ func TestExpireEDUs(t *testing.T) { assert.Equal(t, 2, len(data)) }) } + +func TestOutboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add outbound peek + if err := db.AddOutboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err = db.RenewOutboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := db.GetOutboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = db.AddOutboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := db.GetOutboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + for i := range outboundPeeks { + if outboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i]) + } + } + }) +} + +func TestInboundPeeking(t *testing.T) { + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + + // Add inbound peek + if err := db.AddInboundPeek(ctx, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + inboundPeek1, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if inboundPeek1.PeekID != peekID { + t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) + } + if inboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) + } + if inboundPeek1.ServerName != serverName { + t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) + } + if inboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) + } + // Renew the peek + if err = db.RenewInboundPeek(ctx, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + inboundPeek2, err := db.GetInboundPeek(ctx, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(inboundPeek1, inboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if inboundPeek1.ServerName != inboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) + } + if inboundPeek1.RoomID != inboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) + } + + // insert some peeks + peekIDs := []string{peekID} + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = db.AddInboundPeek(ctx, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + inboundPeeks, err := db.GetInboundPeeks(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) + } + for i := range inboundPeeks { + if inboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i]) + } + } + }) +} diff --git a/federationapi/storage/tables/inbound_peeks_table_test.go b/federationapi/storage/tables/inbound_peeks_table_test.go new file mode 100644 index 000000000..3a76a8576 --- /dev/null +++ b/federationapi/storage/tables/inbound_peeks_table_test.go @@ -0,0 +1,148 @@ +package tables_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + var tab tables.FederationInboundPeeks + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresInboundPeeksTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteInboundPeeksTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + return tab, close +} + +func TestInboundPeeksTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateInboundpeeksTable(t, dbType) + defer closeDB() + + // Insert a peek + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + if err := tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + inboundPeek1, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if inboundPeek1.PeekID != peekID { + t.Fatalf("unexpected inbound peek ID: %s, want %s", inboundPeek1.PeekID, peekID) + } + if inboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected inbound peek room ID: %s, want %s", inboundPeek1.RoomID, peekID) + } + if inboundPeek1.ServerName != serverName { + t.Fatalf("unexpected inbound peek servername: %s, want %s", inboundPeek1.ServerName, serverName) + } + if inboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected inbound peek renewal interval: %d, want %d", inboundPeek1.RenewalInterval, renewalInterval) + } + + // Renew the peek + if err = tab.RenewInboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + inboundPeek2, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(inboundPeek1, inboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if inboundPeek1.ServerName != inboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", inboundPeek1.ServerName, inboundPeek2.ServerName) + } + if inboundPeek1.RoomID != inboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", inboundPeek1.RoomID, inboundPeek2.RoomID) + } + + // delete the peek + if err = tab.DeleteInboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil { + t.Fatal(err) + } + + // There should be no peek anymore + peek, err := tab.SelectInboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if peek != nil { + t.Fatalf("got a peek which should be deleted: %+v", peek) + } + + // insert some peeks + var peekIDs []string + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = tab.InsertInboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + inboundPeeks, err := tab.SelectInboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) + } + for i := range inboundPeeks { + if inboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("") + } + } + + // And delete them again + if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil { + t.Fatal(err) + } + + // they should be gone now + inboundPeeks, err = tab.SelectInboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(inboundPeeks) > 0 { + t.Fatal("got inbound peeks which should be deleted") + } + + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 9f4e86a6e..2b36edb46 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -28,7 +28,6 @@ type FederationQueuePDUs interface { InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) } @@ -38,7 +37,6 @@ type FederationQueueEDUs interface { DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error) - SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error) SelectExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) ([]int64, error) DeleteExpiredEDUs(ctx context.Context, txn *sql.Tx, expiredBefore gomatrixserverlib.Timestamp) error diff --git a/federationapi/storage/tables/outbound_peeks_table_test.go b/federationapi/storage/tables/outbound_peeks_table_test.go new file mode 100644 index 000000000..dad6b9825 --- /dev/null +++ b/federationapi/storage/tables/outbound_peeks_table_test.go @@ -0,0 +1,147 @@ +package tables_test + +import ( + "context" + "reflect" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) { + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + if err != nil { + t.Fatalf("failed to open database: %s", err) + } + var tab tables.FederationOutboundPeeks + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresOutboundPeeksTable(db) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteOutboundPeeksTable(db) + } + if err != nil { + t.Fatalf("failed to create table: %s", err) + } + return tab, close +} + +func TestOutboundPeeksTable(t *testing.T) { + ctx := context.Background() + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + _, serverName, _ := gomatrixserverlib.SplitID('@', alice.ID) + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + tab, closeDB := mustCreateOutboundpeeksTable(t, dbType) + defer closeDB() + + // Insert a peek + peekID := util.RandomString(8) + var renewalInterval int64 = 1000 + if err := tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, renewalInterval); err != nil { + t.Fatal(err) + } + + // select the newly inserted peek + outboundPeek1, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + + // Assert fields are set as expected + if outboundPeek1.PeekID != peekID { + t.Fatalf("unexpected outbound peek ID: %s, want %s", outboundPeek1.PeekID, peekID) + } + if outboundPeek1.RoomID != room.ID { + t.Fatalf("unexpected outbound peek room ID: %s, want %s", outboundPeek1.RoomID, peekID) + } + if outboundPeek1.ServerName != serverName { + t.Fatalf("unexpected outbound peek servername: %s, want %s", outboundPeek1.ServerName, serverName) + } + if outboundPeek1.RenewalInterval != renewalInterval { + t.Fatalf("unexpected outbound peek renewal interval: %d, want %d", outboundPeek1.RenewalInterval, renewalInterval) + } + + // Renew the peek + if err = tab.RenewOutboundPeek(ctx, nil, serverName, room.ID, peekID, 2000); err != nil { + t.Fatal(err) + } + + // verify the values changed + outboundPeek2, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if reflect.DeepEqual(outboundPeek1, outboundPeek2) { + t.Fatal("expected a change peek, but they are the same") + } + if outboundPeek1.ServerName != outboundPeek2.ServerName { + t.Fatalf("unexpected servername change: %s -> %s", outboundPeek1.ServerName, outboundPeek2.ServerName) + } + if outboundPeek1.RoomID != outboundPeek2.RoomID { + t.Fatalf("unexpected roomID change: %s -> %s", outboundPeek1.RoomID, outboundPeek2.RoomID) + } + + // delete the peek + if err = tab.DeleteOutboundPeek(ctx, nil, serverName, room.ID, peekID); err != nil { + t.Fatal(err) + } + + // There should be no peek anymore + peek, err := tab.SelectOutboundPeek(ctx, nil, serverName, room.ID, peekID) + if err != nil { + t.Fatal(err) + } + if peek != nil { + t.Fatalf("got a peek which should be deleted: %+v", peek) + } + + // insert some peeks + var peekIDs []string + for i := 0; i < 5; i++ { + peekID = util.RandomString(8) + if err = tab.InsertOutboundPeek(ctx, nil, serverName, room.ID, peekID, 1000); err != nil { + t.Fatal(err) + } + peekIDs = append(peekIDs, peekID) + } + + // Now select them + outboundPeeks, err := tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) != len(peekIDs) { + t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) + } + for i := range outboundPeeks { + if outboundPeeks[i].PeekID != peekIDs[i] { + t.Fatalf("") + } + } + + // And delete them again + if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil { + t.Fatal(err) + } + + // they should be gone now + outboundPeeks, err = tab.SelectOutboundPeeks(ctx, nil, room.ID) + if err != nil { + t.Fatal(err) + } + if len(outboundPeeks) > 0 { + t.Fatal("got outbound peeks which should be deleted") + } + }) +} From beea2432e6144a98370138f8d3f6334c19a044bb Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 11:31:54 +0100 Subject: [PATCH 31/67] Fix flakey test --- federationapi/storage/tables/inbound_peeks_table_test.go | 9 +++++---- .../storage/tables/outbound_peeks_table_test.go | 9 +++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/federationapi/storage/tables/inbound_peeks_table_test.go b/federationapi/storage/tables/inbound_peeks_table_test.go index 3a76a8576..e5d898b3a 100644 --- a/federationapi/storage/tables/inbound_peeks_table_test.go +++ b/federationapi/storage/tables/inbound_peeks_table_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateInboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationInboundPeeks, func()) { @@ -124,11 +125,11 @@ func TestInboundPeeksTable(t *testing.T) { if len(inboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) } - for i := range inboundPeeks { - if inboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("") - } + gotPeekIDs := make([]string, 0, len(inboundPeeks)) + for _, p := range inboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) // And delete them again if err = tab.DeleteInboundPeeks(ctx, nil, room.ID); err != nil { diff --git a/federationapi/storage/tables/outbound_peeks_table_test.go b/federationapi/storage/tables/outbound_peeks_table_test.go index dad6b9825..a460af09d 100644 --- a/federationapi/storage/tables/outbound_peeks_table_test.go +++ b/federationapi/storage/tables/outbound_peeks_table_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateOutboundpeeksTable(t *testing.T, dbType test.DBType) (tables.FederationOutboundPeeks, func()) { @@ -124,11 +125,11 @@ func TestOutboundPeeksTable(t *testing.T) { if len(outboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) } - for i := range outboundPeeks { - if outboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("") - } + gotPeekIDs := make([]string, 0, len(outboundPeeks)) + for _, p := range outboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) // And delete them again if err = tab.DeleteOutboundPeeks(ctx, nil, room.ID); err != nil { From d1d2d16738a248846ea4367fe2b33485d56db6cd Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 11:54:03 +0100 Subject: [PATCH 32/67] Fix reset password endpoint (#2921) Fixes the admin password reset endpoint. It was using a wrong variable, so could not detect the user. Adds some more checks to validate we can actually change the password. --- clientapi/admin_test.go | 134 ++++++++++++++++++++++++++++++ clientapi/clientapi.go | 5 +- clientapi/routing/admin.go | 38 +++++++-- clientapi/routing/password.go | 3 +- clientapi/routing/register.go | 24 +----- clientapi/routing/routing.go | 14 +++- docs/administration/4_adminapi.md | 7 +- internal/httputil/httpapi.go | 6 +- internal/validate.go | 44 ++++++++++ test/user.go | 4 +- 10 files changed, 237 insertions(+), 42 deletions(-) create mode 100644 clientapi/admin_test.go create mode 100644 internal/validate.go diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go new file mode 100644 index 000000000..0d973f350 --- /dev/null +++ b/clientapi/admin_test.go @@ -0,0 +1,134 @@ +package clientapi + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/tidwall/gjson" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestAdminResetPassword(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t, test.WithAccountType(uapi.AccountTypeUser)) + vhUser := &test.User{ID: "@vhuser:vh1"} + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + // add a vhost + base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + }) + + rsAPI := roomserver.NewInternalAPI(base) + // Needed for changing the password/login + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil) + + // Create the users in the userapi and login + accessTokens := map[*test.User]string{ + aliceAdmin: "", + bob: "", + vhUser: "", + } + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + } + + testCases := []struct { + name string + requestingUser *test.User + userID string + requestOpt test.HTTPRequestOpt + wantOK bool + withHeader bool + }{ + {name: "Missing auth", requestingUser: bob, wantOK: false, userID: bob.ID}, + {name: "Bob is denied access", requestingUser: bob, wantOK: false, withHeader: true, userID: bob.ID}, + {name: "Alice is allowed access", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(8), + })}, + {name: "missing userID does not call function", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: ""}, // this 404s + {name: "rejects empty password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": "", + })}, + {name: "rejects unknown server name", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:localhost", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "rejects unknown user", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "@doesnotexist:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "allows changing password for different vhost", requestingUser: aliceAdmin, wantOK: true, withHeader: true, userID: vhUser.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(8), + })}, + {name: "rejects existing user, missing body", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID}, + {name: "rejects invalid userID", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: "!notauserid:test", requestOpt: test.WithJSONBody(t, map[string]interface{}{})}, + {name: "rejects invalid json", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, `{invalidJSON}`)}, + {name: "rejects too weak password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(6), + })}, + {name: "rejects too long password", requestingUser: aliceAdmin, wantOK: false, withHeader: true, userID: bob.ID, requestOpt: test.WithJSONBody(t, map[string]interface{}{ + "password": util.RandomString(513), + })}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID) + if tc.requestOpt != nil { + req = test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID, tc.requestOpt) + } + + if tc.withHeader { + req.Header.Set("Authorization", "Bearer "+accessTokens[tc.requestingUser]) + } + + rec := httptest.NewRecorder() + base.DendriteAdminMux.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + }) +} diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 080d4d9fa..62ffa6155 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -57,10 +57,7 @@ func AddPublicRoutes( } routing.Setup( - base.PublicClientAPIMux, - base.PublicWellKnownAPIMux, - base.SynapseAdminMux, - base.DendriteAdminMux, + base, cfg, rsAPI, asAPI, userAPI, userDirectoryProvider, federation, syncProducer, transactionsCache, fsAPI, keyAPI, diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index be8073c33..8419622df 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -7,6 +7,7 @@ import ( "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -98,20 +99,40 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { + if req.Body == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.Unknown("Missing request body"), + } + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) } - serverName := cfg.Matrix.ServerName - localpart, ok := vars["localpart"] - if !ok { + var localpart string + userID := vars["userID"] + localpart, serverName, err := cfg.Matrix.SplitLocalID('@', userID) + if err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.MissingArgument("Expecting user localpart."), + JSON: jsonerror.InvalidArgumentValue(err.Error()), } } - if l, s, err := cfg.Matrix.SplitLocalID('@', localpart); err == nil { - localpart, serverName = l, s + accAvailableResp := &userapi.QueryAccountAvailabilityResponse{} + if err = userAPI.QueryAccountAvailability(req.Context(), &userapi.QueryAccountAvailabilityRequest{ + Localpart: localpart, + ServerName: serverName, + }, accAvailableResp); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.InternalAPIError(req.Context(), err), + } + } + if accAvailableResp.Available { + return util.JSONResponse{ + Code: http.StatusNotFound, + JSON: jsonerror.Unknown("User does not exist"), + } } request := struct { Password string `json:"password"` @@ -128,6 +149,11 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap JSON: jsonerror.MissingArgument("Expecting non-empty password."), } } + + if resErr := internal.ValidatePassword(request.Password); resErr != nil { + return *resErr + } + updateReq := &userapi.PerformPasswordUpdateRequest{ Localpart: localpart, ServerName: serverName, diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index 9772f669a..cd88b025a 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -7,6 +7,7 @@ import ( "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -81,7 +82,7 @@ func Password( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. - if resErr = validatePassword(r.NewPassword); resErr != nil { + if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil { return *resErr } diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 801000f61..4abbcdf9e 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -30,6 +30,7 @@ import ( "sync" "time" + "github.com/matrix-org/dendrite/internal" "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" @@ -60,8 +61,6 @@ var ( ) const ( - minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based - maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain sessionIDLength = 24 ) @@ -315,23 +314,6 @@ func validateApplicationServiceUsername(localpart string, domain gomatrixserverl return nil } -// validatePassword returns an error response if the password is invalid -func validatePassword(password string) *util.JSONResponse { - // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if len(password) > maxPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("'password' >%d characters", maxPasswordLength)), - } - } else if len(password) > 0 && len(password) < minPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), - } - } - return nil -} - // validateRecaptcha returns an error response if the captcha response is invalid func validateRecaptcha( cfg *config.ClientAPI, @@ -636,7 +618,7 @@ func Register( return *resErr } } - if resErr := validatePassword(r.Password); resErr != nil { + if resErr := internal.ValidatePassword(r.Password); resErr != nil { return *resErr } @@ -1138,7 +1120,7 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { return *resErr } - if resErr := validatePassword(ssrr.Password); resErr != nil { + if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil { return *resErr } deviceID := "shared_secret_registration" diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index a510761eb..69b46214c 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/nats-io/nats.go" @@ -49,7 +50,7 @@ import ( // applied: // nolint: gocyclo func Setup( - publicAPIMux, wkMux, synapseAdminRouter, dendriteAdminRouter *mux.Router, + base *base.BaseDendrite, cfg *config.ClientAPI, rsAPI roomserverAPI.ClientRoomserverAPI, asAPI appserviceAPI.AppServiceInternalAPI, @@ -63,7 +64,14 @@ func Setup( extRoomsProvider api.ExtraPublicRoomsProvider, mscCfg *config.MSCs, natsClient *nats.Conn, ) { - prometheus.MustRegister(amtRegUsers, sendEventDuration) + publicAPIMux := base.PublicClientAPIMux + wkMux := base.PublicWellKnownAPIMux + synapseAdminRouter := base.SynapseAdminMux + dendriteAdminRouter := base.DendriteAdminMux + + if base.EnableMetrics { + prometheus.MustRegister(amtRegUsers, sendEventDuration) + } rateLimits := httputil.NewRateLimits(&cfg.RateLimiting) userInteractiveAuth := auth.NewUserInteractive(userAPI, cfg) @@ -631,7 +639,7 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) return AuthFallback(w, req, vars["authType"], cfg) }), diff --git a/docs/administration/4_adminapi.md b/docs/administration/4_adminapi.md index 56e19a8b4..c521cbc90 100644 --- a/docs/administration/4_adminapi.md +++ b/docs/administration/4_adminapi.md @@ -44,7 +44,9 @@ This endpoint will instruct Dendrite to part the given local `userID` in the URL all rooms which they are currently joined. A JSON body will be returned containing the room IDs of all affected rooms. -## POST `/_dendrite/admin/resetPassword/{localpart}` +## POST `/_dendrite/admin/resetPassword/{userID}` + +Reset the password of a local user. Request body format: @@ -54,9 +56,6 @@ Request body format: } ``` -Reset the password of a local user. The `localpart` is the username only, i.e. if -the full user ID is `@alice:domain.com` then the local part is `alice`. - ## GET `/_dendrite/admin/fulltext/reindex` This endpoint instructs Dendrite to reindex all searchable events (`m.room.message`, `m.room.topic` and `m.room.name`). An empty JSON body will be returned immediately. diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 127d1fac7..383913c60 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -198,7 +198,7 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTMLAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { +func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { span := opentracing.StartSpan(metricsName) defer span.Finish() @@ -211,6 +211,10 @@ func MakeHTMLAPI(metricsName string, f func(http.ResponseWriter, *http.Request) } } + if !enableMetrics { + return http.HandlerFunc(withSpan) + } + return promhttp.InstrumentHandlerCounter( promauto.NewCounterVec( prometheus.CounterOpts{ diff --git a/internal/validate.go b/internal/validate.go new file mode 100644 index 000000000..fc685ad50 --- /dev/null +++ b/internal/validate.go @@ -0,0 +1,44 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "fmt" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/util" +) + +const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based + +const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + +// ValidatePassword returns an error response if the password is invalid +func ValidatePassword(password string) *util.JSONResponse { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if len(password) > maxPasswordLength { + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)), + } + } else if len(password) > 0 && len(password) < minPasswordLength { + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + } + } + return nil +} diff --git a/test/user.go b/test/user.go index 692eae351..95a8f83e6 100644 --- a/test/user.go +++ b/test/user.go @@ -47,7 +47,7 @@ var ( type User struct { ID string - accountType api.AccountType + AccountType api.AccountType // key ID and private key of the server who has this user, if known. keyID gomatrixserverlib.KeyID privKey ed25519.PrivateKey @@ -66,7 +66,7 @@ func WithSigningServer(srvName gomatrixserverlib.ServerName, keyID gomatrixserve func WithAccountType(accountType api.AccountType) UserOpt { return func(u *User) { - u.accountType = accountType + u.AccountType = accountType } } From 09dff951d6be1fee1cc7c6872e98eb27e81fc778 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 13:04:32 +0100 Subject: [PATCH 33/67] More flakey tests --- federationapi/storage/storage_test.go | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 14efa2655..5b57d40d4 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -157,11 +157,11 @@ func TestOutboundPeeking(t *testing.T) { if len(outboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(outboundPeeks)) } - for i := range outboundPeeks { - if outboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("unexpected peek ID: %s, want %s", outboundPeeks[i].PeekID, peekIDs[i]) - } + gotPeekIDs := make([]string, 0, len(outboundPeeks)) + for _, p := range outboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } @@ -239,10 +239,10 @@ func TestInboundPeeking(t *testing.T) { if len(inboundPeeks) != len(peekIDs) { t.Fatalf("inserted %d peeks, selected %d", len(peekIDs), len(inboundPeeks)) } - for i := range inboundPeeks { - if inboundPeeks[i].PeekID != peekIDs[i] { - t.Fatalf("unexpected peek ID: %s, want %s", inboundPeeks[i].PeekID, peekIDs[i]) - } + gotPeekIDs := make([]string, 0, len(inboundPeeks)) + for _, p := range inboundPeeks { + gotPeekIDs = append(gotPeekIDs, p.PeekID) } + assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } From 5eed31fea330f5f0500384c98272b9a75a44fba4 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 22 Dec 2022 13:05:59 +0100 Subject: [PATCH 34/67] Handle guest access [1/2?] (#2872) Needs https://github.com/matrix-org/sytest/pull/1315, as otherwise the membership events aren't persisted yet when hitting `/state` after kicking guest users. Makes the following tests pass: ``` Guest users denied access over federation if guest access prohibited Guest users are kicked from guest_access rooms on revocation of guest_access Guest users are kicked from guest_access rooms on revocation of guest_access over federation ``` Todo (in a follow up PR): - Restrict access to CS API Endpoints as per https://spec.matrix.org/v1.4/client-server-api/#client-behaviour-14 Co-authored-by: kegsay --- clientapi/clientapi.go | 3 +- clientapi/routing/joinroom.go | 10 +- clientapi/routing/joinroom_test.go | 158 ++++++++++++++++++++ roomserver/api/perform.go | 1 + roomserver/internal/api.go | 20 ++- roomserver/internal/input/input.go | 4 + roomserver/internal/input/input_events.go | 105 +++++++++++++ roomserver/internal/perform/perform_join.go | 23 +++ roomserver/roomserver_test.go | 141 +++++++++++++---- setup/config/config_global.go | 2 +- setup/config/config_test.go | 54 +++++++ sytest-blacklist | 5 +- sytest-whitelist | 5 +- test/room.go | 22 ++- userapi/api/api.go | 10 ++ userapi/api/api_trace.go | 6 + userapi/internal/api.go | 5 + userapi/inthttp/client.go | 12 ++ userapi/inthttp/server.go | 5 + userapi/userapi_test.go | 61 ++++++++ 20 files changed, 607 insertions(+), 45 deletions(-) create mode 100644 clientapi/routing/joinroom_test.go diff --git a/clientapi/clientapi.go b/clientapi/clientapi.go index 62ffa6155..2d17e0928 100644 --- a/clientapi/clientapi.go +++ b/clientapi/clientapi.go @@ -15,6 +15,8 @@ package clientapi import ( + "github.com/matrix-org/gomatrixserverlib" + appserviceAPI "github.com/matrix-org/dendrite/appservice/api" "github.com/matrix-org/dendrite/clientapi/api" "github.com/matrix-org/dendrite/clientapi/producers" @@ -26,7 +28,6 @@ import ( "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/jetstream" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" ) // AddPublicRoutes sets up and registers HTTP handlers for the ClientAPI component. diff --git a/clientapi/routing/joinroom.go b/clientapi/routing/joinroom.go index c50e552bd..e371d9214 100644 --- a/clientapi/routing/joinroom.go +++ b/clientapi/routing/joinroom.go @@ -37,6 +37,7 @@ func JoinRoomByIDOrAlias( joinReq := roomserverAPI.PerformJoinRequest{ RoomIDOrAlias: roomIDOrAlias, UserID: device.UserID, + IsGuest: device.AccountType == api.AccountTypeGuest, Content: map[string]interface{}{}, } joinRes := roomserverAPI.PerformJoinResponse{} @@ -84,7 +85,14 @@ func JoinRoomByIDOrAlias( if err := rsAPI.PerformJoin(req.Context(), &joinReq, &joinRes); err != nil { done <- jsonerror.InternalAPIError(req.Context(), err) } else if joinRes.Error != nil { - done <- joinRes.Error.JSONResponse() + if joinRes.Error.Code == roomserverAPI.PerformErrorNotAllowed && device.AccountType == api.AccountTypeGuest { + done <- util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.GuestAccessForbidden(joinRes.Error.Msg), + } + } else { + done <- joinRes.Error.JSONResponse() + } } else { done <- util.JSONResponse{ Code: http.StatusOK, diff --git a/clientapi/routing/joinroom_test.go b/clientapi/routing/joinroom_test.go new file mode 100644 index 000000000..9e8208e6d --- /dev/null +++ b/clientapi/routing/joinroom_test.go @@ -0,0 +1,158 @@ +package routing + +import ( + "bytes" + "context" + "net/http" + "testing" + "time" + + "github.com/matrix-org/gomatrixserverlib" + + "github.com/matrix-org/dendrite/appservice" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestJoinRoomByIDOrAlias(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t, test.WithAccountType(uapi.AccountTypeGuest)) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + asAPI := appservice.NewInternalAPI(base, userAPI, rsAPI) + rsAPI.SetFederationAPI(nil, nil) // creates the rs.Inputer etc + + // Create the users in the userapi + for _, u := range []*test.User{alice, bob, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + } + + aliceDev := &uapi.Device{UserID: alice.ID} + bobDev := &uapi.Device{UserID: bob.ID} + charlieDev := &uapi.Device{UserID: charlie.ID, AccountType: uapi.AccountTypeGuest} + + // create a room with disabled guest access and invite Bob + resp := createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + RoomAliasName: "alias", + Invite: []string{bob.ID}, + GuestCanJoin: false, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crResp, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // create a room with guest access enabled and invite Charlie + resp = createRoom(ctx, createRoomRequest{ + Name: "testing", + IsDirect: true, + Topic: "testing", + Visibility: "public", + Preset: presetPublicChat, + Invite: []string{charlie.ID}, + GuestCanJoin: true, + }, aliceDev, &base.Cfg.ClientAPI, userAPI, rsAPI, asAPI, time.Now()) + crRespWithGuestAccess, ok := resp.JSON.(createRoomResponse) + if !ok { + t.Fatalf("response is not a createRoomResponse: %+v", resp) + } + + // Dummy request + body := &bytes.Buffer{} + req, err := http.NewRequest(http.MethodPost, "/?server_name=test", body) + if err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + device *uapi.Device + roomID string + wantHTTP200 bool + }{ + { + name: "User can join successfully by alias", + device: bobDev, + roomID: crResp.RoomAlias, + wantHTTP200: true, + }, + { + name: "User can join successfully by roomID", + device: bobDev, + roomID: crResp.RoomID, + wantHTTP200: true, + }, + { + name: "join is forbidden if user is guest", + device: charlieDev, + roomID: crResp.RoomID, + }, + { + name: "room does not exist", + device: aliceDev, + roomID: "!doesnotexist:test", + }, + { + name: "user from different server", + device: &uapi.Device{UserID: "@wrong:server"}, + roomID: crResp.RoomAlias, + }, + { + name: "user doesn't exist locally", + device: &uapi.Device{UserID: "@doesnotexist:test"}, + roomID: crResp.RoomAlias, + }, + { + name: "invalid room ID", + device: aliceDev, + roomID: "invalidRoomID", + }, + { + name: "roomAlias does not exist", + device: aliceDev, + roomID: "#doesnotexist:test", + }, + { + name: "room with guest_access event", + device: charlieDev, + roomID: crRespWithGuestAccess.RoomID, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + joinResp := JoinRoomByIDOrAlias(req, tc.device, rsAPI, userAPI, tc.roomID) + if tc.wantHTTP200 && !joinResp.Is2xx() { + t.Fatalf("expected join room to succeed, but didn't: %+v", joinResp) + } + }) + } + }) +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e70e5ea9c..e789b9568 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -78,6 +78,7 @@ const ( type PerformJoinRequest struct { RoomIDOrAlias string `json:"room_id_or_alias"` UserID string `json:"user_id"` + IsGuest bool `json:"is_guest"` Content map[string]interface{} `json:"content"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"` Unsigned map[string]interface{} `json:"unsigned"` diff --git a/roomserver/internal/api.go b/roomserver/internal/api.go index 1a3626609..451b37696 100644 --- a/roomserver/internal/api.go +++ b/roomserver/internal/api.go @@ -4,6 +4,10 @@ import ( "context" "github.com/getsentry/sentry-go" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" + asAPI "github.com/matrix-org/dendrite/appservice/api" fsAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal/caching" @@ -19,9 +23,6 @@ import ( "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/setup/process" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" - "github.com/nats-io/nats.go" - "github.com/sirupsen/logrus" ) // RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI @@ -104,6 +105,11 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio r.fsAPI = fsAPI r.KeyRing = keyRing + identity, err := r.Cfg.Matrix.SigningIdentityFor(r.ServerName) + if err != nil { + logrus.Panic(err) + } + r.Inputer = &input.Inputer{ Cfg: &r.Base.Cfg.RoomServer, Base: r.Base, @@ -114,7 +120,8 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio JetStream: r.JetStream, NATSClient: r.NATSClient, Durable: nats.Durable(r.Durable), - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, + SigningIdentity: identity, FSAPI: fsAPI, KeyRing: keyRing, ACLs: r.ServerACLs, @@ -135,7 +142,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Queryer: r.Queryer, } r.Peeker = &perform.Peeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -146,7 +153,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio Inputer: r.Inputer, } r.Unpeeker = &perform.Unpeeker{ - ServerName: r.Cfg.Matrix.ServerName, + ServerName: r.ServerName, Cfg: r.Cfg, DB: r.DB, FSAPI: r.fsAPI, @@ -193,6 +200,7 @@ func (r *RoomserverInternalAPI) SetFederationAPI(fsAPI fsAPI.RoomserverFederatio func (r *RoomserverInternalAPI) SetUserAPI(userAPI userapi.RoomserverUserAPI) { r.Leaver.UserAPI = userAPI + r.Inputer.UserAPI = userAPI } func (r *RoomserverInternalAPI) SetAppserviceAPI(asAPI asAPI.AppServiceInternalAPI) { diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index e965691c9..941311030 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -23,6 +23,8 @@ import ( "sync" "time" + userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/Arceliar/phony" "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" @@ -79,6 +81,7 @@ type Inputer struct { JetStream nats.JetStreamContext Durable nats.SubOpt ServerName gomatrixserverlib.ServerName + SigningIdentity *gomatrixserverlib.SigningIdentity FSAPI fedapi.RoomserverFederationAPI KeyRing gomatrixserverlib.JSONVerifier ACLs *acls.ServerACLs @@ -87,6 +90,7 @@ type Inputer struct { workers sync.Map // room ID -> *worker Queryer *query.Queryer + UserAPI userapi.RoomserverUserAPI } // If a room consumer is inactive for a while then we will allow NATS diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 10b8ee27f..4179fc1ef 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -19,6 +19,7 @@ package input import ( "context" "database/sql" + "encoding/json" "errors" "fmt" "time" @@ -31,6 +32,8 @@ import ( "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" + userAPI "github.com/matrix-org/dendrite/userapi/api" + fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/eventutil" @@ -440,6 +443,13 @@ func (r *Inputer) processRoomEvent( } } + // If guest_access changed and is not can_join, kick all guest users. + if event.Type() == gomatrixserverlib.MRoomGuestAccess && gjson.GetBytes(event.Content(), "guest_access").Str != "can_join" { + if err = r.kickGuests(ctx, event, roomInfo); err != nil { + logrus.WithError(err).Error("failed to kick guest users on m.room.guest_access revocation") + } + } + // Everything was OK — the latest events updater didn't error and // we've sent output events. Finally, generate a hook call. hooks.Run(hooks.KindNewEventPersisted, headered) @@ -729,3 +739,98 @@ func (r *Inputer) calculateAndSetState( succeeded = true return nil } + +// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited. +func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error { + membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true) + if err != nil { + return err + } + + memberEvents, err := r.DB.Events(ctx, membershipNIDs) + if err != nil { + return err + } + + inputEvents := make([]api.InputRoomEvent, 0, len(memberEvents)) + latestReq := &api.QueryLatestEventsAndStateRequest{ + RoomID: event.RoomID(), + } + latestRes := &api.QueryLatestEventsAndStateResponse{} + if err = r.Queryer.QueryLatestEventsAndState(ctx, latestReq, latestRes); err != nil { + return err + } + + prevEvents := latestRes.LatestEvents + for _, memberEvent := range memberEvents { + if memberEvent.StateKey() == nil { + continue + } + + localpart, senderDomain, err := gomatrixserverlib.SplitID('@', *memberEvent.StateKey()) + if err != nil { + continue + } + + accountRes := &userAPI.QueryAccountByLocalpartResponse{} + if err = r.UserAPI.QueryAccountByLocalpart(ctx, &userAPI.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: senderDomain, + }, accountRes); err != nil { + return err + } + if accountRes.Account == nil { + continue + } + + if accountRes.Account.AccountType != userAPI.AccountTypeGuest { + continue + } + + var memberContent gomatrixserverlib.MemberContent + if err = json.Unmarshal(memberEvent.Content(), &memberContent); err != nil { + return err + } + memberContent.Membership = gomatrixserverlib.Leave + + stateKey := *memberEvent.StateKey() + fledglingEvent := &gomatrixserverlib.EventBuilder{ + RoomID: event.RoomID(), + Type: gomatrixserverlib.MRoomMember, + StateKey: &stateKey, + Sender: stateKey, + PrevEvents: prevEvents, + } + + if fledglingEvent.Content, err = json.Marshal(memberContent); err != nil { + return err + } + + eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(fledglingEvent) + if err != nil { + return err + } + + event, err := eventutil.BuildEvent(ctx, fledglingEvent, r.Cfg.Matrix, r.SigningIdentity, time.Now(), &eventsNeeded, latestRes) + if err != nil { + return err + } + + inputEvents = append(inputEvents, api.InputRoomEvent{ + Kind: api.KindNew, + Event: event, + Origin: senderDomain, + SendAsServer: string(senderDomain), + }) + prevEvents = []gomatrixserverlib.EventReference{ + event.EventReference(), + } + } + + inputReq := &api.InputRoomEventsRequest{ + InputRoomEvents: inputEvents, + Asynchronous: true, // Needs to be async, as we otherwise create a deadlock + } + inputRes := &api.InputRoomEventsResponse{} + return r.InputRoomEvents(ctx, inputReq, inputRes) +} diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 4de008c66..fc7ba940c 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -16,6 +16,7 @@ package perform import ( "context" + "database/sql" "errors" "fmt" "strings" @@ -270,6 +271,28 @@ func (r *Joiner) performJoinRoomByID( } } + // If a guest is trying to join a room, check that the room has a m.room.guest_access event + if req.IsGuest { + var guestAccessEvent *gomatrixserverlib.HeaderedEvent + guestAccess := "forbidden" + guestAccessEvent, err = r.DB.GetStateEvent(ctx, req.RoomIDOrAlias, gomatrixserverlib.MRoomGuestAccess, "") + if (err != nil && !errors.Is(err, sql.ErrNoRows)) || guestAccessEvent == nil { + logrus.WithError(err).Warn("unable to get m.room.guest_access event, defaulting to 'forbidden'") + } + if guestAccessEvent != nil { + guestAccess = gjson.GetBytes(guestAccessEvent.Content(), "guest_access").String() + } + + // Servers MUST only allow guest users to join rooms if the m.room.guest_access state event + // is present on the room and has the guest_access value can_join. + if guestAccess != "can_join" { + return "", "", &rsAPI.PerformError{ + Code: rsAPI.PerformErrorNotAllowed, + Msg: "Guest access is forbidden", + } + } + } + // If we should do a forced federated join then do that. var joinedVia gomatrixserverlib.ServerName if forceFederatedJoin { diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 518bb3722..595ceb526 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -3,18 +3,23 @@ package roomserver_test import ( "context" "net/http" + "reflect" "testing" "time" "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/userapi" + + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/inthttp" "github.com/matrix-org/dendrite/roomserver/storage" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" ) @@ -29,7 +34,28 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, s return base, db, close } -func Test_SharedUsers(t *testing.T) { +func TestUsers(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + rsAPI := roomserver.NewInternalAPI(base) + // SetFederationAPI starts the room event input consumer + rsAPI.SetFederationAPI(nil, nil) + + t.Run("shared users", func(t *testing.T) { + testSharedUsers(t, rsAPI) + }) + + t.Run("kick users", func(t *testing.T) { + usrAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, nil, rsAPI, nil) + rsAPI.SetUserAPI(usrAPI) + testKickUsers(t, rsAPI, usrAPI) + }) + }) + +} + +func testSharedUsers(t *testing.T, rsAPI api.RoomserverInternalAPI) { alice := test.NewUser(t) bob := test.NewUser(t) room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) @@ -43,36 +69,93 @@ func Test_SharedUsers(t *testing.T) { }, test.WithStateKey(bob.ID)) ctx := context.Background() - test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - base, _, close := mustCreateDatabase(t, dbType) - defer close() - rsAPI := roomserver.NewInternalAPI(base) - // SetFederationAPI starts the room event input consumer - rsAPI.SetFederationAPI(nil, nil) - // Create the room - if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { - t.Fatalf("failed to send events: %v", err) + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Query the shared users for Alice, there should only be Bob. + // This is used by the SyncAPI keychange consumer. + res := &api.QuerySharedUsersResponse{} + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } + // Also verify that we get the expected result when specifying OtherUserIDs. + // This is used by the SyncAPI when getting device list changes. + if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { + t.Errorf("unable to query known users: %v", err) + } + if _, ok := res.UserIDsToCount[bob.ID]; !ok { + t.Errorf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) + } +} + +func testKickUsers(t *testing.T, rsAPI api.RoomserverInternalAPI, usrAPI userAPI.UserInternalAPI) { + // Create users and room; Bob is going to be the guest and kicked on revocation of guest access + alice := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeUser)) + bob := test.NewUser(t, test.WithAccountType(userAPI.AccountTypeGuest)) + + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat), test.GuestsCanJoin(true)) + + // Join with the guest user + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + // Create the users in the userapi, so the RSAPI can query the account type later + for _, u := range []*test.User{alice, bob} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &userAPI.PerformAccountCreationResponse{} + if err := usrAPI.PerformAccountCreation(ctx, &userAPI.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: "someRandomPassword", + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + } + + // Create the room in the database + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // Get the membership events BEFORE revoking guest access + membershipRes := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes); err != nil { + t.Errorf("failed to query membership for room: %s", err) + } + + // revoke guest access + revokeEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomGuestAccess, map[string]string{"guest_access": "forbidden"}, test.WithStateKey("")) + if err := api.SendEvents(ctx, rsAPI, api.KindNew, []*gomatrixserverlib.HeaderedEvent{revokeEvent}, "test", "test", "test", nil, false); err != nil { + t.Errorf("failed to send events: %v", err) + } + + // TODO: Even though we are sending the events sync, the "kickUsers" function is sending the events async, so we need + // to loop and wait for the events to be processed by the roomserver. + for i := 0; i <= 20; i++ { + // Get the membership events AFTER revoking guest access + membershipRes2 := &api.QueryMembershipsForRoomResponse{} + if err := rsAPI.QueryMembershipsForRoom(ctx, &api.QueryMembershipsForRoomRequest{LocalOnly: true, JoinedOnly: true, RoomID: room.ID}, membershipRes2); err != nil { + t.Errorf("failed to query membership for room: %s", err) } - // Query the shared users for Alice, there should only be Bob. - // This is used by the SyncAPI keychange consumer. - res := &api.QuerySharedUsersResponse{} - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) + // The membership events should NOT match, as Bob (guest user) should now be kicked from the room + if !reflect.DeepEqual(membershipRes, membershipRes2) { + return } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - // Also verify that we get the expected result when specifying OtherUserIDs. - // This is used by the SyncAPI when getting device list changes. - if err := rsAPI.QuerySharedUsers(ctx, &api.QuerySharedUsersRequest{UserID: alice.ID, OtherUserIDs: []string{bob.ID}}, res); err != nil { - t.Fatalf("unable to query known users: %v", err) - } - if _, ok := res.UserIDsToCount[bob.ID]; !ok { - t.Fatalf("expected to find %s in shared users, but didn't: %+v", bob.ID, res.UserIDsToCount) - } - }) + time.Sleep(time.Millisecond * 10) + } + + t.Errorf("memberships didn't change in time") } func Test_QueryLeftUsers(t *testing.T) { diff --git a/setup/config/config_global.go b/setup/config/config_global.go index 511951fe6..804eb1a2d 100644 --- a/setup/config/config_global.go +++ b/setup/config/config_global.go @@ -174,7 +174,7 @@ func (c *Global) SigningIdentityFor(serverName gomatrixserverlib.ServerName) (*g return id, nil } } - return nil, fmt.Errorf("no signing identity %q", serverName) + return nil, fmt.Errorf("no signing identity for %q", serverName) } func (c *Global) SigningIdentities() []*gomatrixserverlib.SigningIdentity { diff --git a/setup/config/config_test.go b/setup/config/config_test.go index ee7e7389c..3408bf46d 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -16,8 +16,10 @@ package config import ( "fmt" + "reflect" "testing" + "github.com/matrix-org/gomatrixserverlib" "gopkg.in/yaml.v2" ) @@ -290,3 +292,55 @@ func TestUnmarshalDataUnit(t *testing.T) { } } } + +func Test_SigningIdentityFor(t *testing.T) { + tests := []struct { + name string + virtualHosts []*VirtualHost + serverName gomatrixserverlib.ServerName + want *gomatrixserverlib.SigningIdentity + wantErr bool + }{ + { + name: "no virtual hosts defined", + wantErr: true, + }, + { + name: "no identity found", + serverName: gomatrixserverlib.ServerName("doesnotexist"), + wantErr: true, + }, + { + name: "found identity", + serverName: gomatrixserverlib.ServerName("main"), + want: &gomatrixserverlib.SigningIdentity{ServerName: "main"}, + }, + { + name: "identity found on virtual hosts", + serverName: gomatrixserverlib.ServerName("vh2"), + virtualHosts: []*VirtualHost{ + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}}, + {SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh2"}}, + }, + want: &gomatrixserverlib.SigningIdentity{ServerName: "vh2"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &Global{ + VirtualHosts: tt.virtualHosts, + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "main", + }, + } + got, err := c.SigningIdentityFor(tt.serverName) + if (err != nil) != tt.wantErr { + t.Errorf("SigningIdentityFor() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SigningIdentityFor() got = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/sytest-blacklist b/sytest-blacklist index c35b03bd7..99cfbabc8 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -48,4 +48,7 @@ If a device list update goes missing, the server resyncs on the next one Leaves are present in non-gapped incremental syncs # Below test was passing for the wrong reason, failing correctly since #2858 -New federated private chats get full presence information (SYN-115) \ No newline at end of file +New federated private chats get full presence information (SYN-115) + +# We don't have any state to calculate m.room.guest_access when accepting invites +Guest users can accept invites to private rooms over federation \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 49ffb8fe8..215889a49 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -763,4 +763,7 @@ AS and main public room lists are separate local user has tags copied to the new room remote user has tags copied to the new room /upgrade moves remote aliases to the new room -Local and remote users' homeservers remove a room from their public directory on upgrade \ No newline at end of file +Local and remote users' homeservers remove a room from their public directory on upgrade +Guest users denied access over federation if guest access prohibited +Guest users are kicked from guest_access rooms on revocation of guest_access +Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file diff --git a/test/room.go b/test/room.go index 4328bf84f..685876cb0 100644 --- a/test/room.go +++ b/test/room.go @@ -38,11 +38,12 @@ var ( ) type Room struct { - ID string - Version gomatrixserverlib.RoomVersion - preset Preset - visibility gomatrixserverlib.HistoryVisibility - creator *User + ID string + Version gomatrixserverlib.RoomVersion + preset Preset + guestCanJoin bool + visibility gomatrixserverlib.HistoryVisibility + creator *User authEvents gomatrixserverlib.AuthEvents currentState map[string]*gomatrixserverlib.HeaderedEvent @@ -120,6 +121,11 @@ func (r *Room) insertCreateEvents(t *testing.T) { r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomPowerLevels, plContent, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomJoinRules, joinRule, WithStateKey("")) r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomHistoryVisibility, hisVis, WithStateKey("")) + if r.guestCanJoin { + r.CreateAndInsert(t, r.creator, gomatrixserverlib.MRoomGuestAccess, map[string]string{ + "guest_access": "can_join", + }, WithStateKey("")) + } } // Create an event in this room but do not insert it. Does not modify the room in any way (depth, fwd extremities, etc) so is thread-safe. @@ -268,3 +274,9 @@ func RoomVersion(ver gomatrixserverlib.RoomVersion) roomModifier { r.Version = ver } } + +func GuestsCanJoin(canJoin bool) roomModifier { + return func(t *testing.T, r *Room) { + r.guestCanJoin = canJoin + } +} diff --git a/userapi/api/api.go b/userapi/api/api.go index d3f5aefc8..4ea2e91c3 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -50,6 +50,7 @@ type KeyserverUserAPI interface { type RoomserverUserAPI interface { QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error + QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error) } // api functions required by the media api @@ -671,3 +672,12 @@ type PerformSaveThreePIDAssociationRequest struct { ServerName gomatrixserverlib.ServerName Medium string } + +type QueryAccountByLocalpartRequest struct { + Localpart string + ServerName gomatrixserverlib.ServerName +} + +type QueryAccountByLocalpartResponse struct { + Account *Account +} diff --git a/userapi/api/api_trace.go b/userapi/api/api_trace.go index ce661770f..d10b5767b 100644 --- a/userapi/api/api_trace.go +++ b/userapi/api/api_trace.go @@ -204,6 +204,12 @@ func (t *UserInternalAPITrace) PerformSaveThreePIDAssociation(ctx context.Contex return err } +func (t *UserInternalAPITrace) QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) error { + err := t.Impl.QueryAccountByLocalpart(ctx, req, res) + util.GetLogger(ctx).Infof("QueryAccountByLocalpart req=%+v res=%+v", js(req), js(res)) + return err +} + func js(thing interface{}) string { b, err := json.Marshal(thing) if err != nil { diff --git a/userapi/internal/api.go b/userapi/internal/api.go index 3f256457e..0bb480da6 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -548,6 +548,11 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc return nil } +func (a *UserInternalAPI) QueryAccountByLocalpart(ctx context.Context, req *api.QueryAccountByLocalpartRequest, res *api.QueryAccountByLocalpartResponse) (err error) { + res.Account, err = a.DB.GetAccountByLocalpart(ctx, req.Localpart, req.ServerName) + return +} + // Return the appservice 'device' or nil if the token is not an appservice. Returns an error if there was a problem // creating a 'device'. func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appServiceUserID string) (*api.Device, error) { diff --git a/userapi/inthttp/client.go b/userapi/inthttp/client.go index 87ae058c2..51b0fe3ef 100644 --- a/userapi/inthttp/client.go +++ b/userapi/inthttp/client.go @@ -60,6 +60,7 @@ const ( QueryAccountByPasswordPath = "/userapi/queryAccountByPassword" QueryLocalpartForThreePIDPath = "/userapi/queryLocalpartForThreePID" QueryThreePIDsForLocalpartPath = "/userapi/queryThreePIDsForLocalpart" + QueryAccountByLocalpartPath = "/userapi/queryAccountType" ) // NewUserAPIClient creates a UserInternalAPI implemented by talking to a HTTP POST API. @@ -440,3 +441,14 @@ func (h *httpUserInternalAPI) PerformSaveThreePIDAssociation( h.httpClient, ctx, request, response, ) } + +func (h *httpUserInternalAPI) QueryAccountByLocalpart( + ctx context.Context, + req *api.QueryAccountByLocalpartRequest, + res *api.QueryAccountByLocalpartResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryAccountByLocalpart", h.apiURL+QueryAccountByLocalpartPath, + h.httpClient, ctx, req, res, + ) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index f0579079f..b40b507c2 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -189,4 +189,9 @@ func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI, enableMetrics PerformSaveThreePIDAssociationPath, httputil.MakeInternalRPCAPI("UserAPIPerformSaveThreePIDAssociation", enableMetrics, s.PerformSaveThreePIDAssociation), ) + + internalAPIMux.Handle( + QueryAccountByLocalpartPath, + httputil.MakeInternalRPCAPI("AccountByLocalpart", enableMetrics, s.QueryAccountByLocalpart), + ) } diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 8a19af195..dada56de4 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -307,3 +307,64 @@ func TestLoginToken(t *testing.T) { }) }) } + +func TestQueryAccountByLocalpart(t *testing.T) { + alice := test.NewUser(t) + + localpart, userServername, _ := gomatrixserverlib.SplitID('@', alice.ID) + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + defer close() + + createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType) + if err != nil { + t.Error(err) + } + + testCases := func(t *testing.T, internalAPI api.UserInternalAPI) { + // Query existing account + queryAccResp := &api.QueryAccountByLocalpartResponse{} + if err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: localpart, + ServerName: userServername, + }, queryAccResp); err != nil { + t.Error(err) + } + if !reflect.DeepEqual(createdAcc, queryAccResp.Account) { + t.Fatalf("created and queried accounts don't match:\n%+v vs.\n%+v", createdAcc, queryAccResp.Account) + } + + // Query non-existent account, this should result in an error + err = internalAPI.QueryAccountByLocalpart(ctx, &api.QueryAccountByLocalpartRequest{ + Localpart: "doesnotexist", + ServerName: userServername, + }, queryAccResp) + + if err == nil { + t.Fatalf("expected an error, but got none: %+v", queryAccResp) + } + } + + t.Run("Monolith", func(t *testing.T) { + testCases(t, intAPI) + // also test tracing + testCases(t, &api.UserInternalAPITrace{Impl: intAPI}) + }) + + t.Run("HTTP API", func(t *testing.T) { + router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() + userapi.AddInternalRoutes(router, intAPI, false) + apiURL, cancel := test.ListenAndServe(t, router, false) + defer cancel() + + userHTTPApi, err := inthttp.NewUserAPIClient(apiURL, &http.Client{Timeout: time.Second * 5}) + if err != nil { + t.Fatalf("failed to create HTTP client: %s", err) + } + testCases(t, userHTTPApi) + + }) + }) +} From f47515e38b0bbf734bf977daedd836bf85465272 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 12:52:47 +0100 Subject: [PATCH 35/67] Pushrule tweaks, make `pattern` non-optional on `EventMatchCondition` (#2918) This should fix https://github.com/matrix-org/dendrite/issues/2882 (Tested with FluffyChat 1.7.1) Also adds tests that the predefined push rules (as per the spec) is what we have in Dendrite. --- internal/pushrules/condition.go | 2 +- internal/pushrules/default_content.go | 9 +- internal/pushrules/default_override.go | 54 +++++---- internal/pushrules/default_pushrules_test.go | 111 +++++++++++++++++++ internal/pushrules/default_underride.go | 39 ++----- internal/pushrules/evaluate.go | 10 +- internal/pushrules/evaluate_test.go | 51 +++++---- internal/pushrules/pushrules.go | 10 +- internal/pushrules/util.go | 4 + internal/pushrules/validate.go | 5 +- internal/pushrules/validate_test.go | 19 ++-- userapi/consumers/roomserver_test.go | 6 - 12 files changed, 210 insertions(+), 110 deletions(-) create mode 100644 internal/pushrules/default_pushrules_test.go diff --git a/internal/pushrules/condition.go b/internal/pushrules/condition.go index 2d9773c0f..c7b30da8e 100644 --- a/internal/pushrules/condition.go +++ b/internal/pushrules/condition.go @@ -14,7 +14,7 @@ type Condition struct { // Pattern indicates the value pattern that must match. Required // for EventMatchCondition. - Pattern string `json:"pattern,omitempty"` + Pattern *string `json:"pattern,omitempty"` // Is indicates the condition that must be fulfilled. Required for // RoomMemberCountCondition. diff --git a/internal/pushrules/default_content.go b/internal/pushrules/default_content.go index 8982dd587..a055ba03c 100644 --- a/internal/pushrules/default_content.go +++ b/internal/pushrules/default_content.go @@ -15,13 +15,7 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule { RuleID: MRuleContainsUserName, Default: true, Enabled: true, - Pattern: localpart, - Conditions: []*Condition{ - { - Kind: EventMatchCondition, - Key: "content.body", - }, - }, + Pattern: &localpart, Actions: []*Action{ {Kind: NotifyAction}, { @@ -32,7 +26,6 @@ func mRuleContainsUserNameDefinition(localpart string) *Rule { { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } diff --git a/internal/pushrules/default_override.go b/internal/pushrules/default_override.go index a9788df2f..f97427b71 100644 --- a/internal/pushrules/default_override.go +++ b/internal/pushrules/default_override.go @@ -22,15 +22,15 @@ const ( MRuleTombstone = ".m.rule.tombstone" MRuleRoomNotif = ".m.rule.roomnotif" MRuleReaction = ".m.rule.reaction" + MRuleRoomACLs = ".m.rule.room.server_acl" ) var ( mRuleMasterDefinition = Rule{ - RuleID: MRuleMaster, - Default: true, - Enabled: false, - Conditions: []*Condition{}, - Actions: []*Action{{Kind: DontNotifyAction}}, + RuleID: MRuleMaster, + Default: true, + Enabled: false, + Actions: []*Action{{Kind: DontNotifyAction}}, } mRuleSuppressNoticesDefinition = Rule{ RuleID: MRuleSuppressNotices, @@ -40,7 +40,7 @@ var ( { Kind: EventMatchCondition, Key: "content.msgtype", - Pattern: "m.notice", + Pattern: pointer("m.notice"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -53,7 +53,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, }, Actions: []*Action{{Kind: DontNotifyAction}}, @@ -73,7 +73,6 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } @@ -85,12 +84,12 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.tombstone", + Pattern: pointer("m.room.tombstone"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: "", + Pattern: pointer(""), }, }, Actions: []*Action{ @@ -98,10 +97,27 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } + mRuleACLsDefinition = Rule{ + RuleID: MRuleRoomACLs, + Default: true, + Enabled: true, + Conditions: []*Condition{ + { + Kind: EventMatchCondition, + Key: "type", + Pattern: pointer("m.room.server_acl"), + }, + { + Kind: EventMatchCondition, + Key: "state_key", + Pattern: pointer(""), + }, + }, + Actions: []*Action{}, + } mRuleRoomNotifDefinition = Rule{ RuleID: MRuleRoomNotif, Default: true, @@ -110,7 +126,7 @@ var ( { Kind: EventMatchCondition, Key: "content.body", - Pattern: "@room", + Pattern: pointer("@room"), }, { Kind: SenderNotificationPermissionCondition, @@ -122,7 +138,6 @@ var ( { Kind: SetTweakAction, Tweak: HighlightTweak, - Value: true, }, }, } @@ -134,7 +149,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.reaction", + Pattern: pointer("m.reaction"), }, }, Actions: []*Action{ @@ -152,17 +167,17 @@ func mRuleInviteForMeDefinition(userID string) *Rule { { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.member", + Pattern: pointer("m.room.member"), }, { Kind: EventMatchCondition, Key: "content.membership", - Pattern: "invite", + Pattern: pointer("invite"), }, { Kind: EventMatchCondition, Key: "state_key", - Pattern: userID, + Pattern: pointer(userID), }, }, Actions: []*Action{ @@ -172,11 +187,6 @@ func mRuleInviteForMeDefinition(userID string) *Rule { Tweak: SoundTweak, Value: "default", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } } diff --git a/internal/pushrules/default_pushrules_test.go b/internal/pushrules/default_pushrules_test.go new file mode 100644 index 000000000..dea829842 --- /dev/null +++ b/internal/pushrules/default_pushrules_test.go @@ -0,0 +1,111 @@ +package pushrules + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +// Tests that the pre-defined rules as of +// https://spec.matrix.org/v1.4/client-server-api/#predefined-rules +// are correct +func TestDefaultRules(t *testing.T) { + type testCase struct { + name string + inputBytes []byte + want Rule + } + + testCases := []testCase{ + // Default override rules + { + name: ".m.rule.master", + inputBytes: []byte(`{"rule_id":".m.rule.master","default":true,"enabled":false,"actions":["dont_notify"]}`), + want: mRuleMasterDefinition, + }, + { + name: ".m.rule.suppress_notices", + inputBytes: []byte(`{"rule_id":".m.rule.suppress_notices","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.msgtype","pattern":"m.notice"}],"actions":["dont_notify"]}`), + want: mRuleSuppressNoticesDefinition, + }, + { + name: ".m.rule.invite_for_me", + inputBytes: []byte(`{"rule_id":".m.rule.invite_for_me","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"},{"kind":"event_match","key":"content.membership","pattern":"invite"},{"kind":"event_match","key":"state_key","pattern":"@test:localhost"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: *mRuleInviteForMeDefinition("@test:localhost"), + }, + { + name: ".m.rule.member_event", + inputBytes: []byte(`{"rule_id":".m.rule.member_event","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.member"}],"actions":["dont_notify"]}`), + want: mRuleMemberEventDefinition, + }, + { + name: ".m.rule.contains_display_name", + inputBytes: []byte(`{"rule_id":".m.rule.contains_display_name","default":true,"enabled":true,"conditions":[{"kind":"contains_display_name"}],"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}]}`), + want: mRuleContainsDisplayNameDefinition, + }, + { + name: ".m.rule.tombstone", + inputBytes: []byte(`{"rule_id":".m.rule.tombstone","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.tombstone"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":["notify",{"set_tweak":"highlight"}]}`), + want: mRuleTombstoneDefinition, + }, + { + name: ".m.rule.room.server_acl", + inputBytes: []byte(`{"rule_id":".m.rule.room.server_acl","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.server_acl"},{"kind":"event_match","key":"state_key","pattern":""}],"actions":[]}`), + want: mRuleACLsDefinition, + }, + { + name: ".m.rule.roomnotif", + inputBytes: []byte(`{"rule_id":".m.rule.roomnotif","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"content.body","pattern":"@room"},{"kind":"sender_notification_permission","key":"room"}],"actions":["notify",{"set_tweak":"highlight"}]}`), + want: mRuleRoomNotifDefinition, + }, + // Default content rules + { + name: ".m.rule.contains_user_name", + inputBytes: []byte(`{"rule_id":".m.rule.contains_user_name","default":true,"enabled":true,"actions":["notify",{"set_tweak":"sound","value":"default"},{"set_tweak":"highlight"}],"pattern":"myLocalUser"}`), + want: *mRuleContainsUserNameDefinition("myLocalUser"), + }, + // default underride rules + { + name: ".m.rule.call", + inputBytes: []byte(`{"rule_id":".m.rule.call","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.call.invite"}],"actions":["notify",{"set_tweak":"sound","value":"ring"}]}`), + want: mRuleCallDefinition, + }, + { + name: ".m.rule.encrypted_room_one_to_one", + inputBytes: []byte(`{"rule_id":".m.rule.encrypted_room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: mRuleEncryptedRoomOneToOneDefinition, + }, + { + name: ".m.rule.room_one_to_one", + inputBytes: []byte(`{"rule_id":".m.rule.room_one_to_one","default":true,"enabled":true,"conditions":[{"kind":"room_member_count","is":"2"},{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify",{"set_tweak":"sound","value":"default"}]}`), + want: mRuleRoomOneToOneDefinition, + }, + { + name: ".m.rule.message", + inputBytes: []byte(`{"rule_id":".m.rule.message","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.message"}],"actions":["notify"]}`), + want: mRuleMessageDefinition, + }, + { + name: ".m.rule.encrypted", + inputBytes: []byte(`{"rule_id":".m.rule.encrypted","default":true,"enabled":true,"conditions":[{"kind":"event_match","key":"type","pattern":"m.room.encrypted"}],"actions":["notify"]}`), + want: mRuleEncryptedDefinition, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + r := Rule{} + // unmarshal predefined push rules + err := json.Unmarshal(tc.inputBytes, &r) + assert.NoError(t, err) + assert.Equal(t, tc.want, r) + + // and reverse it to check we get the expected result + got, err := json.Marshal(r) + assert.NoError(t, err) + assert.Equal(t, string(got), string(tc.inputBytes)) + }) + + } +} diff --git a/internal/pushrules/default_underride.go b/internal/pushrules/default_underride.go index 8da449a19..118bfae59 100644 --- a/internal/pushrules/default_underride.go +++ b/internal/pushrules/default_underride.go @@ -25,7 +25,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.call.invite", + Pattern: pointer("m.call.invite"), }, }, Actions: []*Action{ @@ -35,11 +35,6 @@ var ( Tweak: SoundTweak, Value: "ring", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleEncryptedRoomOneToOneDefinition = Rule{ @@ -54,7 +49,7 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ @@ -64,11 +59,6 @@ var ( Tweak: SoundTweak, Value: "default", }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleRoomOneToOneDefinition = Rule{ @@ -83,20 +73,15 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ {Kind: NotifyAction}, { Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, + Tweak: SoundTweak, + Value: "default", }, }, } @@ -108,16 +93,11 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.message", + Pattern: pointer("m.room.message"), }, }, Actions: []*Action{ {Kind: NotifyAction}, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } mRuleEncryptedDefinition = Rule{ @@ -128,16 +108,11 @@ var ( { Kind: EventMatchCondition, Key: "type", - Pattern: "m.room.encrypted", + Pattern: pointer("m.room.encrypted"), }, }, Actions: []*Action{ {Kind: NotifyAction}, - { - Kind: SetTweakAction, - Tweak: HighlightTweak, - Value: false, - }, }, } ) diff --git a/internal/pushrules/evaluate.go b/internal/pushrules/evaluate.go index 4ff9939a6..fc8e0f174 100644 --- a/internal/pushrules/evaluate.go +++ b/internal/pushrules/evaluate.go @@ -104,7 +104,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu case ContentKind: // TODO: "These configure behaviour for (unencrypted) messages // that match certain patterns." - Does that mean "content.body"? - return patternMatches("content.body", rule.Pattern, event) + if rule.Pattern == nil { + return false, nil + } + return patternMatches("content.body", *rule.Pattern, event) case RoomKind: return rule.RuleID == event.RoomID(), nil @@ -120,7 +123,10 @@ func ruleMatches(rule *Rule, kind Kind, event *gomatrixserverlib.Event, ec Evalu func conditionMatches(cond *Condition, event *gomatrixserverlib.Event, ec EvaluationContext) (bool, error) { switch cond.Kind { case EventMatchCondition: - return patternMatches(cond.Key, cond.Pattern, event) + if cond.Pattern == nil { + return false, fmt.Errorf("missing condition pattern") + } + return patternMatches(cond.Key, *cond.Pattern, event) case ContainsDisplayNameCondition: return patternMatches("content.body", ec.UserDisplayName(), event) diff --git a/internal/pushrules/evaluate_test.go b/internal/pushrules/evaluate_test.go index c5d5abd2a..ca8ae5519 100644 --- a/internal/pushrules/evaluate_test.go +++ b/internal/pushrules/evaluate_test.go @@ -79,8 +79,8 @@ func TestRuleMatches(t *testing.T) { {"underrideConditionMatch", UnderrideKind, Rule{Enabled: true}, `{}`, true}, {"underrideConditionNoMatch", UnderrideKind, Rule{Enabled: true, Conditions: []*Condition{{}}}, `{}`, false}, - {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: "b"}, `{"content":{"body":"abc"}}`, true}, - {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: "d"}, `{"content":{"body":"abc"}}`, false}, + {"contentMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("b")}, `{"content":{"body":"abc"}}`, true}, + {"contentNoMatch", ContentKind, Rule{Enabled: true, Pattern: pointer("d")}, `{"content":{"body":"abc"}}`, false}, {"roomMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!room@example.com"}`, true}, {"roomNoMatch", RoomKind, Rule{Enabled: true, RuleID: "!room@example.com"}, `{"room_id":"!otherroom@example.com"}`, false}, @@ -106,41 +106,44 @@ func TestConditionMatches(t *testing.T) { Name string Cond Condition EventJSON string - Want bool + WantMatch bool + WantErr bool }{ - {"empty", Condition{}, `{}`, false}, - {"empty", Condition{Kind: "unknownstring"}, `{}`, false}, + {Name: "empty", Cond: Condition{}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "empty", Cond: Condition{Kind: "unknownstring"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, // Neither of these should match because `content` is not a full string match, // and `content.body` is not a string value. - {"eventMatch", Condition{Kind: EventMatchCondition, Key: "content"}, `{"content":{}}`, false}, - {"eventBodyMatch", Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3"}, `{"content":{"body": 3}}`, false}, + {Name: "eventMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content", Pattern: pointer("")}, EventJSON: `{"content":{}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Is: "3", Pattern: pointer("")}, EventJSON: `{"content":{"body": "3"}}`, WantMatch: false, WantErr: false}, + {Name: "eventBodyMatch matches", Cond: Condition{Kind: EventMatchCondition, Key: "content.body", Pattern: pointer("world")}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: true, WantErr: false}, + {Name: "EventMatch missing pattern", Cond: Condition{Kind: EventMatchCondition, Key: "content.body"}, EventJSON: `{"content":{"body": "hello world!"}}`, WantMatch: false, WantErr: true}, - {"displayNameNoMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"something without displayname"}}`, false}, - {"displayNameMatch", Condition{Kind: ContainsDisplayNameCondition}, `{"content":{"body":"hello Dear User, how are you?"}}`, true}, + {Name: "displayNameNoMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"something without displayname"}}`, WantMatch: false, WantErr: false}, + {Name: "displayNameMatch", Cond: Condition{Kind: ContainsDisplayNameCondition}, EventJSON: `{"content":{"body":"hello Dear User, how are you?"}}`, WantMatch: true, WantErr: false}, - {"roomMemberCountLessNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<2"}, `{}`, false}, - {"roomMemberCountLessMatch", Condition{Kind: RoomMemberCountCondition, Is: "<3"}, `{}`, true}, - {"roomMemberCountLessEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, `{}`, false}, - {"roomMemberCountLessEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, `{}`, true}, - {"roomMemberCountEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: "==1"}, `{}`, false}, - {"roomMemberCountEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: "==2"}, `{}`, true}, - {"roomMemberCountGreaterEqualNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, `{}`, false}, - {"roomMemberCountGreaterEqualMatch", Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, `{}`, true}, - {"roomMemberCountGreaterNoMatch", Condition{Kind: RoomMemberCountCondition, Is: ">2"}, `{}`, false}, - {"roomMemberCountGreaterMatch", Condition{Kind: RoomMemberCountCondition, Is: ">1"}, `{}`, true}, + {Name: "roomMemberCountLessNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<3"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountLessEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountLessEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "<=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==1"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: "==2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterEqualNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=3"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterEqualMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">=2"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, + {Name: "roomMemberCountGreaterNoMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">2"}, EventJSON: `{}`, WantMatch: false, WantErr: false}, + {Name: "roomMemberCountGreaterMatch", Cond: Condition{Kind: RoomMemberCountCondition, Is: ">1"}, EventJSON: `{}`, WantMatch: true, WantErr: false}, - {"senderNotificationPermissionMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@poweruser:example.com"}`, true}, - {"senderNotificationPermissionNoMatch", Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, `{"sender":"@nobody:example.com"}`, false}, + {Name: "senderNotificationPermissionMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@poweruser:example.com"}`, WantMatch: true, WantErr: false}, + {Name: "senderNotificationPermissionNoMatch", Cond: Condition{Kind: SenderNotificationPermissionCondition, Key: "powerlevel"}, EventJSON: `{"sender":"@nobody:example.com"}`, WantMatch: false, WantErr: false}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { got, err := conditionMatches(&tst.Cond, mustEventFromJSON(t, tst.EventJSON), &fakeEvaluationContext{2}) - if err != nil { + if err != nil && !tst.WantErr { t.Fatalf("conditionMatches failed: %v", err) } - if got != tst.Want { - t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.Want, tst.Name) + if got != tst.WantMatch { + t.Errorf("conditionMatches: got %v, want %v on %s", got, tst.WantMatch, tst.Name) } }) } diff --git a/internal/pushrules/pushrules.go b/internal/pushrules/pushrules.go index bbed1f95f..98deaf132 100644 --- a/internal/pushrules/pushrules.go +++ b/internal/pushrules/pushrules.go @@ -36,18 +36,18 @@ type Rule struct { // around. Required. Enabled bool `json:"enabled"` + // Conditions provide the rule's conditions for OverrideKind and + // UnderrideKind. Not allowed for other kinds. + Conditions []*Condition `json:"conditions,omitempty"` + // Actions describe the desired outcome, should the rule // match. Required. Actions []*Action `json:"actions"` - // Conditions provide the rule's conditions for OverrideKind and - // UnderrideKind. Not allowed for other kinds. - Conditions []*Condition `json:"conditions"` - // Pattern is the body pattern to match for ContentKind. Required // for that kind. The interpretation is the same as that of // Condition.Pattern. - Pattern string `json:"pattern"` + Pattern *string `json:"pattern,omitempty"` } // Scope only has one valid value. See also AccountRuleSets. diff --git a/internal/pushrules/util.go b/internal/pushrules/util.go index fb9c05be2..de8fe5cd0 100644 --- a/internal/pushrules/util.go +++ b/internal/pushrules/util.go @@ -128,3 +128,7 @@ func parseRoomMemberCountCondition(s string) (func(int) bool, error) { b = int(v) return cmp, nil } + +func pointer[t any](s t) *t { + return &s +} diff --git a/internal/pushrules/validate.go b/internal/pushrules/validate.go index 5d260f0b9..f50c51bd7 100644 --- a/internal/pushrules/validate.go +++ b/internal/pushrules/validate.go @@ -34,7 +34,10 @@ func ValidateRule(kind Kind, rule *Rule) []error { } case ContentKind: - if rule.Pattern == "" { + if rule.Pattern == nil { + errs = append(errs, fmt.Errorf("missing content rule pattern")) + } + if rule.Pattern != nil && *rule.Pattern == "" { errs = append(errs, fmt.Errorf("missing content rule pattern")) } diff --git a/internal/pushrules/validate_test.go b/internal/pushrules/validate_test.go index b276eb551..966e46259 100644 --- a/internal/pushrules/validate_test.go +++ b/internal/pushrules/validate_test.go @@ -12,15 +12,16 @@ func TestValidateRuleNegatives(t *testing.T) { Rule Rule WantErrString string }{ - {"emptyRuleID", OverrideKind, Rule{}, "invalid rule ID"}, - {"invalidKind", Kind("something else"), Rule{}, "invalid rule kind"}, - {"ruleIDBackslash", OverrideKind, Rule{RuleID: "#foo\\:example.com"}, "invalid rule ID"}, - {"noActions", OverrideKind, Rule{}, "missing actions"}, - {"invalidAction", OverrideKind, Rule{Actions: []*Action{{}}}, "invalid rule action kind"}, - {"invalidCondition", OverrideKind, Rule{Conditions: []*Condition{{}}}, "invalid rule condition kind"}, - {"overrideNoCondition", OverrideKind, Rule{}, "missing rule conditions"}, - {"underrideNoCondition", UnderrideKind, Rule{}, "missing rule conditions"}, - {"contentNoPattern", ContentKind, Rule{}, "missing content rule pattern"}, + {Name: "emptyRuleID", Kind: OverrideKind, Rule: Rule{}, WantErrString: "invalid rule ID"}, + {Name: "invalidKind", Kind: Kind("something else"), Rule: Rule{}, WantErrString: "invalid rule kind"}, + {Name: "ruleIDBackslash", Kind: OverrideKind, Rule: Rule{RuleID: "#foo\\:example.com"}, WantErrString: "invalid rule ID"}, + {Name: "noActions", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing actions"}, + {Name: "invalidAction", Kind: OverrideKind, Rule: Rule{Actions: []*Action{{}}}, WantErrString: "invalid rule action kind"}, + {Name: "invalidCondition", Kind: OverrideKind, Rule: Rule{Conditions: []*Condition{{}}}, WantErrString: "invalid rule condition kind"}, + {Name: "overrideNoCondition", Kind: OverrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"}, + {Name: "underrideNoCondition", Kind: UnderrideKind, Rule: Rule{}, WantErrString: "missing rule conditions"}, + {Name: "contentNoPattern", Kind: ContentKind, Rule: Rule{}, WantErrString: "missing content rule pattern"}, + {Name: "contentEmptyPattern", Kind: ContentKind, Rule: Rule{Pattern: pointer("")}, WantErrString: "missing content rule pattern"}, } for _, tst := range tsts { t.Run(tst.Name, func(t *testing.T) { diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 265e3a3aa..39f4aab4a 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -81,11 +81,6 @@ func Test_evaluatePushRules(t *testing.T) { wantAction: pushrules.NotifyAction, wantActions: []*pushrules.Action{ {Kind: pushrules.NotifyAction}, - { - Kind: pushrules.SetTweakAction, - Tweak: pushrules.HighlightTweak, - Value: false, - }, }, }, { @@ -103,7 +98,6 @@ func Test_evaluatePushRules(t *testing.T) { { Kind: pushrules.SetTweakAction, Tweak: pushrules.HighlightTweak, - Value: true, }, }, }, From f762ce1050f2add409a83b1eeb6da5940177cfa7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 14:11:11 +0100 Subject: [PATCH 36/67] Add clientapi tests (#2916) This PR - adds several tests for the clientapi, mostly around `/register` and auth fallback. - removes the now deprecated `homeserver` field from responses to `/register` and `/login` - slightly refactors auth fallback handling --- .github/workflows/dendrite.yml | 3 +- clientapi/routing/admin.go | 6 +- clientapi/routing/auth_fallback.go | 115 ++++----- clientapi/routing/auth_fallback_test.go | 149 ++++++++++++ clientapi/routing/login.go | 9 +- clientapi/routing/password.go | 4 +- clientapi/routing/register.go | 148 ++++-------- clientapi/routing/register_test.go | 306 ++++++++++++++++++++++++ clientapi/routing/routing.go | 4 +- cmd/create-account/main.go | 32 +-- internal/httputil/httpapi.go | 9 +- internal/validate.go | 84 ++++++- internal/validate_test.go | 170 +++++++++++++ setup/config/config.go | 12 +- setup/config/config_clientapi.go | 7 +- 15 files changed, 838 insertions(+), 220 deletions(-) create mode 100644 clientapi/routing/auth_fallback_test.go create mode 100644 internal/validate_test.go diff --git a/.github/workflows/dendrite.yml b/.github/workflows/dendrite.yml index 2c04005d2..1de39850d 100644 --- a/.github/workflows/dendrite.yml +++ b/.github/workflows/dendrite.yml @@ -331,8 +331,7 @@ jobs: postgres: postgres api: full-http container: - # Temporary for debugging to see if this image is working better. - image: matrixdotorg/sytest-dendrite@sha256:434ad464a9f4ed3f8c3cc47200275b6ccb5c5031a8063daf4acea62be5a23c73 + image: matrixdotorg/sytest-dendrite volumes: - ${{ github.workspace }}:/src - /root/.cache/go-build:/github/home/.cache/go-build diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index 8419622df..dbd913376 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -137,7 +137,7 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap request := struct { Password string `json:"password"` }{} - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + if err = json.NewDecoder(req.Body).Decode(&request); err != nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown("Failed to decode request body: " + err.Error()), @@ -150,8 +150,8 @@ func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userap } } - if resErr := internal.ValidatePassword(request.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(request.Password); err != nil { + return *internal.PasswordResponse(err) } updateReq := &userapi.PerformPasswordUpdateRequest{ diff --git a/clientapi/routing/auth_fallback.go b/clientapi/routing/auth_fallback.go index ad870993e..f8d3684fe 100644 --- a/clientapi/routing/auth_fallback.go +++ b/clientapi/routing/auth_fallback.go @@ -15,11 +15,11 @@ package routing import ( + "fmt" "html/template" "net/http" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/util" ) @@ -101,14 +101,28 @@ func serveTemplate(w http.ResponseWriter, templateHTML string, data map[string]s func AuthFallback( w http.ResponseWriter, req *http.Request, authType string, cfg *config.ClientAPI, -) *util.JSONResponse { - sessionID := req.URL.Query().Get("session") +) { + // We currently only support "m.login.recaptcha", so fail early if that's not requested + if authType == authtypes.LoginTypeRecaptcha { + if !cfg.RecaptchaEnabled { + writeHTTPMessage(w, req, + "Recaptcha login is disabled on this Homeserver", + http.StatusBadRequest, + ) + return + } + } else { + writeHTTPMessage(w, req, fmt.Sprintf("Unknown authtype %q", authType), http.StatusNotImplemented) + return + } + sessionID := req.URL.Query().Get("session") if sessionID == "" { - return writeHTTPMessage(w, req, + writeHTTPMessage(w, req, "Session ID not provided", http.StatusBadRequest, ) + return } serveRecaptcha := func() { @@ -130,70 +144,44 @@ func AuthFallback( if req.Method == http.MethodGet { // Handle Recaptcha - if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(cfg, w, req); err != nil { - return err - } - - serveRecaptcha() - return nil - } - return &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown auth stage type"), - } + serveRecaptcha() + return } else if req.Method == http.MethodPost { // Handle Recaptcha - if authType == authtypes.LoginTypeRecaptcha { - if err := checkRecaptchaEnabled(cfg, w, req); err != nil { - return err - } - - clientIP := req.RemoteAddr - err := req.ParseForm() - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") - res := jsonerror.InternalServerError() - return &res - } - - response := req.Form.Get(cfg.RecaptchaFormField) - if err := validateRecaptcha(cfg, response, clientIP); err != nil { - util.GetLogger(req.Context()).Error(err) - return err - } - - // Success. Add recaptcha as a completed login flow - sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) - - serveSuccess() - return nil + clientIP := req.RemoteAddr + err := req.ParseForm() + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("req.ParseForm failed") + w.WriteHeader(http.StatusBadRequest) + serveRecaptcha() + return } - return &util.JSONResponse{ - Code: http.StatusNotFound, - JSON: jsonerror.NotFound("Unknown auth stage type"), + response := req.Form.Get(cfg.RecaptchaFormField) + err = validateRecaptcha(cfg, response, clientIP) + switch err { + case ErrMissingResponse: + w.WriteHeader(http.StatusBadRequest) + serveRecaptcha() // serve the initial page again, instead of nothing + return + case ErrInvalidCaptcha: + w.WriteHeader(http.StatusUnauthorized) + serveRecaptcha() + return + case nil: + default: // something else failed + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + serveRecaptcha() + return } - } - return &util.JSONResponse{ - Code: http.StatusMethodNotAllowed, - JSON: jsonerror.NotFound("Bad method"), - } -} -// checkRecaptchaEnabled creates an error response if recaptcha is not usable on homeserver. -func checkRecaptchaEnabled( - cfg *config.ClientAPI, - w http.ResponseWriter, - req *http.Request, -) *util.JSONResponse { - if !cfg.RecaptchaEnabled { - return writeHTTPMessage(w, req, - "Recaptcha login is disabled on this Homeserver", - http.StatusBadRequest, - ) + // Success. Add recaptcha as a completed login flow + sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypeRecaptcha) + + serveSuccess() + return } - return nil + writeHTTPMessage(w, req, "Bad method", http.StatusMethodNotAllowed) } // writeHTTPMessage writes the given header and message to the HTTP response writer. @@ -201,13 +189,10 @@ func checkRecaptchaEnabled( func writeHTTPMessage( w http.ResponseWriter, req *http.Request, message string, header int, -) *util.JSONResponse { +) { w.WriteHeader(header) _, err := w.Write([]byte(message)) if err != nil { util.GetLogger(req.Context()).WithError(err).Error("w.Write failed") - res := jsonerror.InternalServerError() - return &res } - return nil } diff --git a/clientapi/routing/auth_fallback_test.go b/clientapi/routing/auth_fallback_test.go new file mode 100644 index 000000000..0d77f9a01 --- /dev/null +++ b/clientapi/routing/auth_fallback_test.go @@ -0,0 +1,149 @@ +package routing + +import ( + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test/testrig" +) + +func Test_AuthFallback(t *testing.T) { + base, _, _ := testrig.Base(nil) + defer base.Close() + + for _, useHCaptcha := range []bool{false, true} { + for _, recaptchaEnabled := range []bool{false, true} { + for _, wantErr := range []bool{false, true} { + t.Run(fmt.Sprintf("useHCaptcha(%v) - recaptchaEnabled(%v) - wantErr(%v)", useHCaptcha, recaptchaEnabled, wantErr), func(t *testing.T) { + // Set the defaults for each test + base.Cfg.ClientAPI.Defaults(config.DefaultOpts{Generate: true, Monolithic: true}) + base.Cfg.ClientAPI.RecaptchaEnabled = recaptchaEnabled + base.Cfg.ClientAPI.RecaptchaPublicKey = "pub" + base.Cfg.ClientAPI.RecaptchaPrivateKey = "priv" + if useHCaptcha { + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = "https://hcaptcha.com/siteverify" + base.Cfg.ClientAPI.RecaptchaApiJsUrl = "https://js.hcaptcha.com/1/api.js" + base.Cfg.ClientAPI.RecaptchaFormField = "h-captcha-response" + base.Cfg.ClientAPI.RecaptchaSitekeyClass = "h-captcha" + } + cfgErrs := &config.ConfigErrors{} + base.Cfg.ClientAPI.Verify(cfgErrs, true) + if len(*cfgErrs) > 0 { + t.Fatalf("(hCaptcha=%v) unexpected config errors: %s", useHCaptcha, cfgErrs.Error()) + } + + req := httptest.NewRequest(http.MethodGet, "/?session=1337", nil) + rec := httptest.NewRecorder() + + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if !recaptchaEnabled { + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest) + } + if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" { + t.Fatalf("unexpected response body: %s", rec.Body.String()) + } + } else { + if !strings.Contains(rec.Body.String(), base.Cfg.ClientAPI.RecaptchaSitekeyClass) { + t.Fatalf("body does not contain %s: %s", base.Cfg.ClientAPI.RecaptchaSitekeyClass, rec.Body.String()) + } + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if wantErr { + _, _ = w.Write([]byte(`{"success":false}`)) + return + } + _, _ = w.Write([]byte(`{"success":true}`)) + })) + defer srv.Close() // nolint: errcheck + + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + + // check the result after sending the captcha + req = httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + req.Form = url.Values{} + req.Form.Add(base.Cfg.ClientAPI.RecaptchaFormField, "someRandomValue") + rec = httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if recaptchaEnabled { + if !wantErr { + if rec.Code != http.StatusOK { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusOK) + } + if rec.Body.String() != successTemplate { + t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), successTemplate) + } + } else { + if rec.Code != http.StatusUnauthorized { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusUnauthorized) + } + wantString := "Authentication" + if !strings.Contains(rec.Body.String(), wantString) { + t.Fatalf("expected response to contain '%s', but didn't: %s", wantString, rec.Body.String()) + } + } + } else { + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected response code: %d, want %d", rec.Code, http.StatusBadRequest) + } + if rec.Body.String() != "Recaptcha login is disabled on this Homeserver" { + t.Fatalf("unexpected response: %s, want %s", rec.Body.String(), "successTemplate") + } + } + }) + } + } + } + + t.Run("unknown fallbacks are handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, "DoesNotExist", &base.Cfg.ClientAPI) + if rec.Code != http.StatusNotImplemented { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusNotImplemented) + } + }) + + t.Run("unknown methods are handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodDelete, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusMethodNotAllowed { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusMethodNotAllowed) + } + }) + + t.Run("missing session parameter is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing session parameter is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) + + t.Run("missing 'response' is handled correctly", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/?session=1337", nil) + rec := httptest.NewRecorder() + AuthFallback(rec, req, authtypes.LoginTypeRecaptcha, &base.Cfg.ClientAPI) + if rec.Code != http.StatusBadRequest { + t.Fatalf("unexpected http status: %d, want %d", rec.Code, http.StatusBadRequest) + } + }) +} diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 0de324da1..778c8c0c3 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -23,15 +23,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" - "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) type loginResponse struct { - UserID string `json:"user_id"` - AccessToken string `json:"access_token"` - HomeServer gomatrixserverlib.ServerName `json:"home_server"` - DeviceID string `json:"device_id"` + UserID string `json:"user_id"` + AccessToken string `json:"access_token"` + DeviceID string `json:"device_id"` } type flows struct { @@ -116,7 +114,6 @@ func completeAuth( JSON: loginResponse{ UserID: performRes.Device.UserID, AccessToken: performRes.Device.AccessToken, - HomeServer: serverName, DeviceID: performRes.Device.ID, }, } diff --git a/clientapi/routing/password.go b/clientapi/routing/password.go index cd88b025a..f7f9da622 100644 --- a/clientapi/routing/password.go +++ b/clientapi/routing/password.go @@ -82,8 +82,8 @@ func Password( sessions.addCompletedSessionStage(sessionID, authtypes.LoginTypePassword) // Check the new password strength. - if resErr = internal.ValidatePassword(r.NewPassword); resErr != nil { - return *resErr + if err := internal.ValidatePassword(r.NewPassword); err != nil { + return *internal.PasswordResponse(err) } // Get the local part. diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 4abbcdf9e..6087bda0c 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -18,12 +18,12 @@ package routing import ( "context" "encoding/json" + "errors" "fmt" "io" "net" "net/http" "net/url" - "regexp" "sort" "strconv" "strings" @@ -60,10 +60,7 @@ var ( ) ) -const ( - maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain - sessionIDLength = 24 -) +const sessionIDLength = 24 // sessionsDict keeps track of completed auth stages for each session. // It shouldn't be passed by value because it contains a mutex. @@ -198,8 +195,7 @@ func (d *sessionsDict) getDeviceToDelete(sessionID string) (string, bool) { } var ( - sessions = newSessionsDict() - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) + sessions = newSessionsDict() ) // registerRequest represents the submitted registration request. @@ -262,10 +258,9 @@ func newUserInteractiveResponse( // http://matrix.org/speculator/spec/HEAD/client_server/unstable.html#post-matrix-client-unstable-register type registerResponse struct { - UserID string `json:"user_id"` - AccessToken string `json:"access_token,omitempty"` - HomeServer gomatrixserverlib.ServerName `json:"home_server"` - DeviceID string `json:"device_id,omitempty"` + UserID string `json:"user_id"` + AccessToken string `json:"access_token,omitempty"` + DeviceID string `json:"device_id,omitempty"` } // recaptchaResponse represents the HTTP response from a Google Recaptcha server @@ -276,66 +271,28 @@ type recaptchaResponse struct { ErrorCodes []int `json:"error-codes"` } -// validateUsername returns an error response if the username is invalid -func validateUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } else if localpart[0] == '_' { // Regex checks its not a zero length string - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username cannot start with a '_'"), - } - } - return nil -} - -// validateApplicationServiceUsername returns an error response if the username is invalid for an application service -func validateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) *util.JSONResponse { - if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("%q exceeds the maximum length of %d characters", id, maxUsernameLength)), - } - } else if !validUsernameRegex.MatchString(localpart) { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.InvalidUsername("Username can only contain characters a-z, 0-9, or '_-./='"), - } - } - return nil -} +var ( + ErrInvalidCaptcha = errors.New("invalid captcha response") + ErrMissingResponse = errors.New("captcha response is required") + ErrCaptchaDisabled = errors.New("captcha registration is disabled") +) // validateRecaptcha returns an error response if the captcha response is invalid func validateRecaptcha( cfg *config.ClientAPI, response string, clientip string, -) *util.JSONResponse { +) error { ip, _, _ := net.SplitHostPort(clientip) if !cfg.RecaptchaEnabled { - return &util.JSONResponse{ - Code: http.StatusConflict, - JSON: jsonerror.Unknown("Captcha registration is disabled"), - } + return ErrCaptchaDisabled } if response == "" { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("Captcha response is required"), - } + return ErrMissingResponse } - // Make a POST request to Google's API to check the captcha response + // Make a POST request to the captcha provider API to check the captcha response resp, err := http.PostForm(cfg.RecaptchaSiteVerifyAPI, url.Values{ "secret": {cfg.RecaptchaPrivateKey}, @@ -345,10 +302,7 @@ func validateRecaptcha( ) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in requesting validation of captcha response"), - } + return err } // Close the request once we're finishing reading from it @@ -358,25 +312,16 @@ func validateRecaptcha( var r recaptchaResponse body, err := io.ReadAll(resp.Body) if err != nil { - return &util.JSONResponse{ - Code: http.StatusGatewayTimeout, - JSON: jsonerror.Unknown("Error in contacting captcha server" + err.Error()), - } + return err } err = json.Unmarshal(body, &r) if err != nil { - return &util.JSONResponse{ - Code: http.StatusInternalServerError, - JSON: jsonerror.BadJSON("Error in unmarshaling captcha server's response: " + err.Error()), - } + return err } // Check that we received a "success" if !r.Success { - return &util.JSONResponse{ - Code: http.StatusUnauthorized, - JSON: jsonerror.BadJSON("Invalid captcha response. Please try again."), - } + return ErrInvalidCaptcha } return nil } @@ -508,8 +453,8 @@ func validateApplicationService( } // Check username application service is trying to register is valid - if err := validateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { - return "", err + if err := internal.ValidateApplicationServiceUsername(username, cfg.Matrix.ServerName); err != nil { + return "", internal.UsernameResponse(err) } // No errors, registration valid @@ -564,15 +509,12 @@ func Register( if resErr := httputil.UnmarshalJSON(reqBody, &r); resErr != nil { return *resErr } - if l, d, err := cfg.Matrix.SplitLocalID('@', r.Username); err == nil { - r.Username, r.ServerName = l, d - } if req.URL.Query().Get("kind") == "guest" { return handleGuestRegistration(req, r, cfg, userAPI) } // Don't allow numeric usernames less than MAX_INT64. - if _, err := strconv.ParseInt(r.Username, 10, 64); err == nil { + if _, err = strconv.ParseInt(r.Username, 10, 64); err == nil { return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), @@ -584,7 +526,7 @@ func Register( ServerName: r.ServerName, } nres := &userapi.QueryNumericLocalpartResponse{} - if err := userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { + if err = userAPI.QueryNumericLocalpart(req.Context(), nreq, nres); err != nil { util.GetLogger(req.Context()).WithError(err).Error("userAPI.QueryNumericLocalpart failed") return jsonerror.InternalServerError() } @@ -601,8 +543,8 @@ func Register( case r.Type == authtypes.LoginTypeApplicationService && accessTokenErr == nil: // Spec-compliant case (the access_token is specified and the login type // is correctly set, so it's an appservice registration) - if resErr := validateApplicationServiceUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateApplicationServiceUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } case accessTokenErr == nil: // Non-spec-compliant case (the access_token is specified but the login @@ -614,12 +556,12 @@ func Register( default: // Spec-compliant case (neither the access_token nor the login type are // specified, so it's a normal user registration) - if resErr := validateUsername(r.Username, r.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(r.Username, r.ServerName); err != nil { + return *internal.UsernameResponse(err) } } - if resErr := internal.ValidatePassword(r.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(r.Password); err != nil { + return *internal.PasswordResponse(err) } logger := util.GetLogger(req.Context()) @@ -697,7 +639,6 @@ func handleGuestRegistration( JSON: registerResponse{ UserID: devRes.Device.UserID, AccessToken: devRes.Device.AccessToken, - HomeServer: res.Account.ServerName, DeviceID: devRes.Device.ID, }, } @@ -761,9 +702,18 @@ func handleRegistrationFlow( switch r.Auth.Type { case authtypes.LoginTypeRecaptcha: // Check given captcha response - resErr := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) - if resErr != nil { - return *resErr + err := validateRecaptcha(cfg, r.Auth.Response, req.RemoteAddr) + switch err { + case ErrCaptchaDisabled: + return util.JSONResponse{Code: http.StatusForbidden, JSON: jsonerror.Unknown(err.Error())} + case ErrMissingResponse: + return util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(err.Error())} + case ErrInvalidCaptcha: + return util.JSONResponse{Code: http.StatusUnauthorized, JSON: jsonerror.BadJSON(err.Error())} + case nil: + default: + util.GetLogger(req.Context()).WithError(err).Error("failed to validate recaptcha") + return util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()} } // Add Recaptcha to the list of completed registration stages @@ -924,8 +874,7 @@ func completeRegistration( return util.JSONResponse{ Code: http.StatusOK, JSON: registerResponse{ - UserID: userutil.MakeUserID(username, accRes.Account.ServerName), - HomeServer: accRes.Account.ServerName, + UserID: userutil.MakeUserID(username, accRes.Account.ServerName), }, } } @@ -958,7 +907,6 @@ func completeRegistration( result := registerResponse{ UserID: devRes.Device.UserID, AccessToken: devRes.Device.AccessToken, - HomeServer: accRes.Account.ServerName, DeviceID: devRes.Device.ID, } sessions.addCompletedRegistration(sessionID, result) @@ -1054,8 +1002,8 @@ func RegisterAvailable( } } - if err := validateUsername(username, domain); err != nil { - return *err + if err := internal.ValidateUsername(username, domain); err != nil { + return *internal.UsernameResponse(err) } // Check if this username is reserved by an application service @@ -1117,11 +1065,11 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien // downcase capitals ssrr.User = strings.ToLower(ssrr.User) - if resErr := validateUsername(ssrr.User, cfg.Matrix.ServerName); resErr != nil { - return *resErr + if err = internal.ValidateUsername(ssrr.User, cfg.Matrix.ServerName); err != nil { + return *internal.UsernameResponse(err) } - if resErr := internal.ValidatePassword(ssrr.Password); resErr != nil { - return *resErr + if err = internal.ValidatePassword(ssrr.Password); err != nil { + return *internal.PasswordResponse(err) } deviceID := "shared_secret_registration" diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index 85846c7d6..b8fd19e90 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -15,12 +15,27 @@ package routing import ( + "bytes" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "reflect" "regexp" + "strings" "testing" "time" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/util" ) var ( @@ -264,3 +279,294 @@ func TestSessionCleanUp(t *testing.T) { } }) } + +func Test_register(t *testing.T) { + testCases := []struct { + name string + kind string + password string + username string + loginType string + forceEmpty bool + registrationDisabled bool + guestsDisabled bool + enableRecaptcha bool + captchaBody string + wantResponse util.JSONResponse + }{ + { + name: "disallow guests", + kind: "guest", + guestsDisabled: true, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(`Guest registration is disabled on "test"`), + }, + }, + { + name: "allow guests", + kind: "guest", + }, + { + name: "unknown login type", + loginType: "im.not.known", + wantResponse: util.JSONResponse{ + Code: http.StatusNotImplemented, + JSON: jsonerror.Unknown("unknown/unimplemented auth type"), + }, + }, + { + name: "disabled registration", + registrationDisabled: true, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden(`Registration is disabled on "test"`), + }, + }, + { + name: "successful registration, numeric ID", + username: "", + password: "someRandomPassword", + forceEmpty: true, + }, + { + name: "successful registration", + username: "success", + }, + { + name: "failing registration - user already exists", + username: "success", + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.UserInUse("Desired user ID is already taken."), + }, + }, + { + name: "successful registration uppercase username", + username: "LOWERCASED", // this is going to be lower-cased + }, + { + name: "invalid username", + username: "#totalyNotValid", + wantResponse: *internal.UsernameResponse(internal.ErrUsernameInvalid), + }, + { + name: "numeric username is forbidden", + username: "1337", + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Numeric user IDs are reserved"), + }, + }, + { + name: "disabled recaptcha login", + loginType: authtypes.LoginTypeRecaptcha, + wantResponse: util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Unknown(ErrCaptchaDisabled.Error()), + }, + }, + { + name: "enabled recaptcha, no response defined", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + wantResponse: util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrMissingResponse.Error()), + }, + }, + { + name: "invalid captcha response", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `notvalid`, + wantResponse: util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.BadJSON(ErrInvalidCaptcha.Error()), + }, + }, + { + name: "valid captcha response", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `success`, + }, + { + name: "captcha invalid from remote", + enableRecaptcha: true, + loginType: authtypes.LoginTypeRecaptcha, + captchaBody: `i should fail for other reasons`, + wantResponse: util.JSONResponse{Code: http.StatusInternalServerError, JSON: jsonerror.InternalServerError()}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if tc.enableRecaptcha { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Fatal(err) + } + response := r.Form.Get("response") + + // Respond with valid JSON or no JSON at all to test happy/error cases + switch response { + case "success": + json.NewEncoder(w).Encode(recaptchaResponse{Success: true}) + case "notvalid": + json.NewEncoder(w).Encode(recaptchaResponse{Success: false}) + default: + + } + })) + defer srv.Close() + base.Cfg.ClientAPI.RecaptchaSiteVerifyAPI = srv.URL + } + + if err := base.Cfg.Derive(); err != nil { + t.Fatalf("failed to derive config: %s", err) + } + + base.Cfg.ClientAPI.RecaptchaEnabled = tc.enableRecaptcha + base.Cfg.ClientAPI.RegistrationDisabled = tc.registrationDisabled + base.Cfg.ClientAPI.GuestsDisabled = tc.guestsDisabled + + if tc.kind == "" { + tc.kind = "user" + } + if tc.password == "" && !tc.forceEmpty { + tc.password = "someRandomPassword" + } + if tc.username == "" && !tc.forceEmpty { + tc.username = "valid" + } + if tc.loginType == "" { + tc.loginType = "m.login.dummy" + } + + reg := registerRequest{ + Password: tc.password, + Username: tc.username, + } + + body := &bytes.Buffer{} + err := json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/?kind=%s", tc.kind), body) + + resp := Register(req, userAPI, &base.Cfg.ClientAPI) + t.Logf("Resp: %+v", resp) + + // The first request should return a userInteractiveResponse + switch r := resp.JSON.(type) { + case userInteractiveResponse: + // Check that the flows are the ones we configured + if !reflect.DeepEqual(r.Flows, base.Cfg.Derived.Registration.Flows) { + t.Fatalf("unexpected registration flows: %+v, want %+v", r.Flows, base.Cfg.Derived.Registration.Flows) + } + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("(%s), unexpected response: %+v, want: %+v", tc.name, resp, tc.wantResponse) + } + return + case registerResponse: + // this should only be possible on guest user registration, never for normal users + if tc.kind != "guest" { + t.Fatalf("got register response on first request: %+v", r) + } + // assert we've got a UserID, AccessToken and DeviceID + if r.UserID == "" { + t.Fatalf("missing userID in response") + } + if r.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + if r.DeviceID == "" { + t.Fatalf("missing deviceID in response") + } + return + default: + t.Logf("Got response: %T", resp.JSON) + } + + // If we reached this, we should have received a UIA response + uia, ok := resp.JSON.(userInteractiveResponse) + if !ok { + t.Fatalf("did not receive a userInteractiveResponse: %T", resp.JSON) + } + t.Logf("%+v", uia) + + // Register the user + reg.Auth = authDict{ + Type: authtypes.LoginType(tc.loginType), + Session: uia.Session, + } + + if tc.captchaBody != "" { + reg.Auth.Response = tc.captchaBody + } + + dummy := "dummy" + reg.DeviceID = &dummy + reg.InitialDisplayName = &dummy + reg.Type = authtypes.LoginType(tc.loginType) + + err = json.NewEncoder(body).Encode(reg) + if err != nil { + t.Fatal(err) + } + + req = httptest.NewRequest(http.MethodPost, "/", body) + + resp = Register(req, userAPI, &base.Cfg.ClientAPI) + + switch resp.JSON.(type) { + case *jsonerror.MatrixError: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + case util.JSONResponse: + if !reflect.DeepEqual(tc.wantResponse, resp) { + t.Fatalf("unexpected response: %+v, want: %+v", resp, tc.wantResponse) + } + return + } + + rr, ok := resp.JSON.(registerResponse) + if !ok { + t.Fatalf("expected a registerresponse, got %T", resp.JSON) + } + + // validate the response + if tc.forceEmpty { + // when not supplying a username, one will be generated. Given this _SHOULD_ be + // the second user, set the username accordingly + reg.Username = "2" + } + wantUserID := strings.ToLower(fmt.Sprintf("@%s:%s", reg.Username, "test")) + if wantUserID != rr.UserID { + t.Fatalf("unexpected userID: %s, want %s", rr.UserID, wantUserID) + } + if rr.DeviceID != *reg.DeviceID { + t.Fatalf("unexpected deviceID: %s, want %s", rr.DeviceID, *reg.DeviceID) + } + if rr.AccessToken == "" { + t.Fatalf("missing accessToken in response") + } + }) + } + }) +} diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 69b46214c..09c2cd02f 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -639,9 +639,9 @@ func Setup( ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) v3mux.Handle("/auth/{authType}/fallback/web", - httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { + httputil.MakeHTMLAPI("auth_fallback", base.EnableMetrics, func(w http.ResponseWriter, req *http.Request) { vars := mux.Vars(req) - return AuthFallback(w, req, vars["authType"], cfg) + AuthFallback(w, req, vars["authType"], cfg) }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) diff --git a/cmd/create-account/main.go b/cmd/create-account/main.go index 15b043ed5..772778680 100644 --- a/cmd/create-account/main.go +++ b/cmd/create-account/main.go @@ -25,10 +25,10 @@ import ( "io" "net/http" "os" - "regexp" "strings" "time" + "github.com/matrix-org/dendrite/internal" "github.com/tidwall/gjson" "github.com/sirupsen/logrus" @@ -58,15 +58,14 @@ Arguments: ` var ( - username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") - password = flag.String("password", "", "The password to associate with the account") - pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") - pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") - isAdmin = flag.Bool("admin", false, "Create an admin account") - resetPassword = flag.Bool("reset-password", false, "Deprecated") - serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") - validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) - timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") + username = flag.String("username", "", "The username of the account to register (specify the localpart only, e.g. 'alice' for '@alice:domain.com')") + password = flag.String("password", "", "The password to associate with the account") + pwdFile = flag.String("passwordfile", "", "The file to use for the password (e.g. for automated account creation)") + pwdStdin = flag.Bool("passwordstdin", false, "Reads the password from stdin") + isAdmin = flag.Bool("admin", false, "Create an admin account") + resetPassword = flag.Bool("reset-password", false, "Deprecated") + serverURL = flag.String("url", "http://localhost:8008", "The URL to connect to.") + timeout = flag.Duration("timeout", time.Second*30, "Timeout for the http client when connecting to the server") ) var cl = http.Client{ @@ -95,20 +94,21 @@ func main() { os.Exit(1) } - if !validUsernameRegex.MatchString(*username) { - logrus.Warn("Username can only contain characters a-z, 0-9, or '_-./='") + if err := internal.ValidateUsername(*username, cfg.Global.ServerName); err != nil { + logrus.WithError(err).Error("Specified username is invalid") os.Exit(1) } - if len(fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) > 255 { - logrus.Fatalf("Username can not be longer than 255 characters: %s", fmt.Sprintf("@%s:%s", *username, cfg.Global.ServerName)) - } - pass, err := getPassword(*password, *pwdFile, *pwdStdin, os.Stdin) if err != nil { logrus.Fatalln(err) } + if err = internal.ValidatePassword(pass); err != nil { + logrus.WithError(err).Error("Specified password is invalid") + os.Exit(1) + } + cl.Timeout = *timeout accessToken, err := sharedSecretRegister(cfg.ClientAPI.RegistrationSharedSecret, *serverURL, *username, pass, *isAdmin) diff --git a/internal/httputil/httpapi.go b/internal/httputil/httpapi.go index 383913c60..37d144f4e 100644 --- a/internal/httputil/httpapi.go +++ b/internal/httputil/httpapi.go @@ -198,17 +198,12 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse // MakeHTMLAPI adds Span metrics to the HTML Handler function // This is used to serve HTML alongside JSON error messages -func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request) *util.JSONResponse) http.Handler { +func MakeHTMLAPI(metricsName string, enableMetrics bool, f func(http.ResponseWriter, *http.Request)) http.Handler { withSpan := func(w http.ResponseWriter, req *http.Request) { span := opentracing.StartSpan(metricsName) defer span.Finish() req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span)) - if err := f(w, req); err != nil { - h := util.MakeJSONAPI(util.NewJSONRequestHandler(func(req *http.Request) util.JSONResponse { - return *err - })) - h.ServeHTTP(w, req) - } + f(w, req) } if !enableMetrics { diff --git a/internal/validate.go b/internal/validate.go index fc685ad50..0461b897e 100644 --- a/internal/validate.go +++ b/internal/validate.go @@ -15,30 +15,96 @@ package internal import ( + "errors" "fmt" "net/http" + "regexp" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) -const minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based +const ( + maxUsernameLength = 254 // http://matrix.org/speculator/spec/HEAD/intro.html#user-identifiers TODO account for domain -const maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + minPasswordLength = 8 // http://matrix.org/docs/spec/client_server/r0.2.0.html#password-based + maxPasswordLength = 512 // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 +) -// ValidatePassword returns an error response if the password is invalid -func ValidatePassword(password string) *util.JSONResponse { +var ( + ErrPasswordTooLong = fmt.Errorf("password too long: max %d characters", maxPasswordLength) + ErrPasswordWeak = fmt.Errorf("password too weak: min %d characters", minPasswordLength) + ErrUsernameTooLong = fmt.Errorf("username exceeds the maximum length of %d characters", maxUsernameLength) + ErrUsernameInvalid = errors.New("username can only contain characters a-z, 0-9, or '_-./='") + ErrUsernameUnderscore = errors.New("username cannot start with a '_'") + validUsernameRegex = regexp.MustCompile(`^[0-9a-z_\-=./]+$`) +) + +// ValidatePassword returns an error if the password is invalid +func ValidatePassword(password string) error { // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 if len(password) > maxPasswordLength { - return &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON(fmt.Sprintf("password too long: max %d characters", maxPasswordLength)), - } + return ErrPasswordTooLong } else if len(password) > 0 && len(password) < minPasswordLength { + return ErrPasswordWeak + } + return nil +} + +// PasswordResponse returns a util.JSONResponse for a given error, if any. +func PasswordResponse(err error) *util.JSONResponse { + switch err { + case ErrPasswordWeak: return &util.JSONResponse{ Code: http.StatusBadRequest, - JSON: jsonerror.WeakPassword(fmt.Sprintf("password too weak: min %d chars", minPasswordLength)), + JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error()), + } + case ErrPasswordTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error()), } } return nil } + +// ValidateUsername returns an error if the username is invalid +func ValidateUsername(localpart string, domain gomatrixserverlib.ServerName) error { + // https://github.com/matrix-org/synapse/blob/v0.20.0/synapse/rest/client/v2_alpha/register.py#L161 + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } else if localpart[0] == '_' { // Regex checks its not a zero length string + return ErrUsernameUnderscore + } + return nil +} + +// UsernameResponse returns a util.JSONResponse for the given error, if any. +func UsernameResponse(err error) *util.JSONResponse { + switch err { + case ErrUsernameTooLong: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(err.Error()), + } + case ErrUsernameInvalid, ErrUsernameUnderscore: + return &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(err.Error()), + } + } + return nil +} + +// ValidateApplicationServiceUsername returns an error if the username is invalid for an application service +func ValidateApplicationServiceUsername(localpart string, domain gomatrixserverlib.ServerName) error { + if id := fmt.Sprintf("@%s:%s", localpart, domain); len(id) > maxUsernameLength { + return ErrUsernameTooLong + } else if !validUsernameRegex.MatchString(localpart) { + return ErrUsernameInvalid + } + return nil +} diff --git a/internal/validate_test.go b/internal/validate_test.go new file mode 100644 index 000000000..d0ad04707 --- /dev/null +++ b/internal/validate_test.go @@ -0,0 +1,170 @@ +package internal + +import ( + "net/http" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +func Test_validatePassword(t *testing.T) { + tests := []struct { + name string + password string + wantError error + wantJSON *util.JSONResponse + }{ + { + name: "password too short", + password: "shortpw", + wantError: ErrPasswordWeak, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.WeakPassword(ErrPasswordWeak.Error())}, + }, + { + name: "password too long", + password: strings.Repeat("a", maxPasswordLength+1), + wantError: ErrPasswordTooLong, + wantJSON: &util.JSONResponse{Code: http.StatusBadRequest, JSON: jsonerror.BadJSON(ErrPasswordTooLong.Error())}, + }, + { + name: "password OK", + password: util.RandomString(10), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidatePassword(tt.password) + if !reflect.DeepEqual(gotErr, tt.wantError) { + t.Errorf("validatePassword() = %v, wantJSON %v", gotErr, tt.wantError) + } + + if got := PasswordResponse(gotErr); !reflect.DeepEqual(got, tt.wantJSON) { + t.Errorf("validatePassword() = %v, wantJSON %v", got, tt.wantJSON) + } + }) + } +} + +func Test_validateUsername(t *testing.T) { + tooLongUsername := strings.Repeat("a", maxUsernameLength) + tests := []struct { + name string + localpart string + domain gomatrixserverlib.ServerName + wantErr error + wantJSON *util.JSONResponse + }{ + { + name: "empty username", + localpart: "", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "invalid username", + localpart: "INVALIDUSERNAME", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username too long", + localpart: tooLongUsername, + domain: "localhost", + wantErr: ErrUsernameTooLong, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON(ErrUsernameTooLong.Error()), + }, + }, + { + name: "localpart starting with an underscore", + localpart: "_notvalid", + domain: "localhost", + wantErr: ErrUsernameUnderscore, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameUnderscore.Error()), + }, + }, + { + name: "valid username", + localpart: "valid", + domain: "localhost", + }, + { + name: "complex username", + localpart: "f00_bar-baz.=40/", + domain: "localhost", + }, + { + name: "rejects emoji username 💥", + localpart: "💥", + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "special characters are allowed", + localpart: "/dev/null", + domain: "localhost", + }, + { + name: "special characters are allowed 2", + localpart: "i_am_allowed=1", + domain: "localhost", + }, + { + name: "not all special characters are allowed", + localpart: "notallowed#", // contains # + domain: "localhost", + wantErr: ErrUsernameInvalid, + wantJSON: &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername(ErrUsernameInvalid.Error()), + }, + }, + { + name: "username containing numbers", + localpart: "hello1337", + domain: "localhost", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotErr := ValidateUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + + // Application services are allowed usernames starting with an underscore + if tt.wantErr == ErrUsernameUnderscore { + return + } + gotErr = ValidateApplicationServiceUsername(tt.localpart, tt.domain) + if !reflect.DeepEqual(gotErr, tt.wantErr) { + t.Errorf("ValidateUsername() = %v, wantErr %v", gotErr, tt.wantErr) + } + if gotJSON := UsernameResponse(gotErr); !reflect.DeepEqual(gotJSON, tt.wantJSON) { + t.Errorf("UsernameResponse() = %v, wantJSON %v", gotJSON, tt.wantJSON) + } + }) + } +} diff --git a/setup/config/config.go b/setup/config/config.go index 7e7ed1aa1..6523a2452 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -29,7 +29,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" "golang.org/x/crypto/ed25519" - yaml "gopkg.in/yaml.v2" + "gopkg.in/yaml.v2" jaegerconfig "github.com/uber/jaeger-client-go/config" jaegermetrics "github.com/uber/jaeger-lib/metrics" @@ -314,11 +314,13 @@ func (config *Dendrite) Derive() error { if config.ClientAPI.RecaptchaEnabled { config.Derived.Registration.Params[authtypes.LoginTypeRecaptcha] = map[string]string{"public_key": config.ClientAPI.RecaptchaPublicKey} - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}) + config.Derived.Registration.Flows = []authtypes.Flow{ + {Stages: []authtypes.LoginType{authtypes.LoginTypeRecaptcha}}, + } } else { - config.Derived.Registration.Flows = append(config.Derived.Registration.Flows, - authtypes.Flow{Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}) + config.Derived.Registration.Flows = []authtypes.Flow{ + {Stages: []authtypes.LoginType{authtypes.LoginTypeDummy}}, + } } // Load application service configuration files diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 0a871da18..11628b1b0 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -78,9 +78,6 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) if c.RecaptchaEnabled { - checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) - checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) - checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) if c.RecaptchaSiteVerifyAPI == "" { c.RecaptchaSiteVerifyAPI = "https://www.google.com/recaptcha/api/siteverify" } @@ -93,6 +90,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { if c.RecaptchaSitekeyClass == "" { c.RecaptchaSitekeyClass = "g-recaptcha-response" } + checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) + checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) + checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", c.RecaptchaSiteVerifyAPI) + checkNotEmpty(configErrs, "client_api.recaptcha_sitekey_class", c.RecaptchaSitekeyClass) } // Ensure there is any spam counter measure when enabling registration if !c.RegistrationDisabled && !c.OpenRegistrationWithoutVerificationEnabled { From e449d174ccf7569b2536289f3c8145298e80bc90 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 23 Dec 2022 14:28:15 +0100 Subject: [PATCH 37/67] Add possibility to run complement with coverage enabled (#2901) This adds the possibility to run Complement with coverage enabled. In combination with https://github.com/matrix-org/complement/pull/566 we should then be able to extract the coverage logs, combine them with https://github.com/wadey/gocovmerge (or similar) and upload them to Codecov (with different flags, depending on SQLite, HTTP etc.) --- Dockerfile | 27 --------------------- build/scripts/Complement.Dockerfile | 7 ++++-- build/scripts/ComplementLocal.Dockerfile | 5 +++- build/scripts/ComplementPostgres.Dockerfile | 7 ++++-- build/scripts/complement-cmd.sh | 22 +++++++++++++++++ 5 files changed, 36 insertions(+), 32 deletions(-) create mode 100755 build/scripts/complement-cmd.sh diff --git a/Dockerfile b/Dockerfile index a9bbce925..ede33e635 100644 --- a/Dockerfile +++ b/Dockerfile @@ -63,30 +63,3 @@ WORKDIR /etc/dendrite ENTRYPOINT ["/usr/bin/dendrite-monolith-server"] EXPOSE 8008 8448 -# -# Builds the Complement image, used for integration tests -# -FROM base AS complement -LABEL org.opencontainers.image.title="Dendrite (Complement)" -RUN apk add --no-cache sqlite openssl ca-certificates - -COPY --from=build /out/generate-config /usr/bin/generate-config -COPY --from=build /out/generate-keys /usr/bin/generate-keys -COPY --from=build /out/dendrite-monolith-server /usr/bin/dendrite-monolith-server - -WORKDIR /dendrite -RUN /usr/bin/generate-keys --private-key matrix_key.pem && \ - mkdir /ca && \ - openssl genrsa -out /ca/ca.key 2048 && \ - openssl req -new -x509 -key /ca/ca.key -days 3650 -subj "/C=GB/ST=London/O=matrix.org/CN=Complement CA" -out /ca/ca.crt - -ENV SERVER_NAME=localhost -ENV API=0 -EXPOSE 8008 8448 - -# At runtime, generate TLS cert based on the CA now mounted at /ca -# At runtime, replace the SERVER_NAME with what we are told -CMD /usr/bin/generate-keys --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /ca/ca.crt --tls-authority-key /ca/ca.key && \ - /usr/bin/generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ - cp /ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - /usr/bin/dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} diff --git a/build/scripts/Complement.Dockerfile b/build/scripts/Complement.Dockerfile index 79422e645..3a00fbdf0 100644 --- a/build/scripts/Complement.Dockerfile +++ b/build/scripts/Complement.Dockerfile @@ -16,13 +16,16 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \ + cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost ENV API=0 +ENV COVER=0 EXPOSE 8008 8448 # At runtime, generate TLS cert based on the CA now mounted at /ca @@ -30,4 +33,4 @@ EXPOSE 8008 8448 CMD ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key && \ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} + exec /complement-cmd.sh diff --git a/build/scripts/ComplementLocal.Dockerfile b/build/scripts/ComplementLocal.Dockerfile index 3a019fc20..e3fbe1aa8 100644 --- a/build/scripts/ComplementLocal.Dockerfile +++ b/build/scripts/ComplementLocal.Dockerfile @@ -12,18 +12,20 @@ FROM golang:1.18-stretch RUN apt-get update && apt-get install -y sqlite3 ENV SERVER_NAME=localhost +ENV COVER=0 EXPOSE 8008 8448 WORKDIR /runtime # This script compiles Dendrite for us. RUN echo '\ #!/bin/bash -eux \n\ - if test -f "/runtime/dendrite-monolith-server"; then \n\ + if test -f "/runtime/dendrite-monolith-server" && test -f "/runtime/dendrite-monolith-server-cover"; then \n\ echo "Skipping compilation; binaries exist" \n\ exit 0 \n\ fi \n\ cd /dendrite \n\ go build -v -o /runtime /dendrite/cmd/dendrite-monolith-server \n\ + go test -c -cover -covermode=atomic -o /runtime/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." /dendrite/cmd/dendrite-monolith-server \n\ ' > compile.sh && chmod +x compile.sh # This script runs Dendrite for us. Must be run in the /runtime directory. @@ -33,6 +35,7 @@ RUN echo '\ ./generate-keys -keysize 1024 --server $SERVER_NAME --tls-cert server.crt --tls-key server.key --tls-authority-cert /complement/ca/ca.crt --tls-authority-key /complement/ca/ca.key \n\ ./generate-config -server $SERVER_NAME --ci > dendrite.yaml \n\ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates \n\ + [ ${COVER} -eq 1 ] && exec ./dendrite-monolith-server-cover --test.coverprofile=integrationcover.log --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml \n\ ' > run.sh && chmod +x run.sh diff --git a/build/scripts/ComplementPostgres.Dockerfile b/build/scripts/ComplementPostgres.Dockerfile index 3faf43cc7..444cb947d 100644 --- a/build/scripts/ComplementPostgres.Dockerfile +++ b/build/scripts/ComplementPostgres.Dockerfile @@ -34,13 +34,16 @@ RUN --mount=target=. \ --mount=type=cache,target=/root/.cache/go-build \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-config && \ CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/generate-keys && \ - CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server + CGO_ENABLED=${CGO} go build -o /dendrite ./cmd/dendrite-monolith-server && \ + CGO_ENABLED=${CGO} go test -c -cover -covermode=atomic -o /dendrite/dendrite-monolith-server-cover -coverpkg "github.com/matrix-org/..." ./cmd/dendrite-monolith-server && \ + cp build/scripts/complement-cmd.sh /complement-cmd.sh WORKDIR /dendrite RUN ./generate-keys --private-key matrix_key.pem ENV SERVER_NAME=localhost ENV API=0 +ENV COVER=0 EXPOSE 8008 8448 @@ -51,4 +54,4 @@ CMD /build/run_postgres.sh && ./generate-keys --keysize 1024 --server $SERVER_NA # Bump max_open_conns up here in the global database config sed -i 's/max_open_conns:.*$/max_open_conns: 1990/g' dendrite.yaml && \ cp /complement/ca/ca.crt /usr/local/share/ca-certificates/ && update-ca-certificates && \ - exec ./dendrite-monolith-server --really-enable-open-registration --tls-cert server.crt --tls-key server.key --config dendrite.yaml -api=${API:-0} \ No newline at end of file + exec /complement-cmd.sh \ No newline at end of file diff --git a/build/scripts/complement-cmd.sh b/build/scripts/complement-cmd.sh new file mode 100755 index 000000000..061bd18eb --- /dev/null +++ b/build/scripts/complement-cmd.sh @@ -0,0 +1,22 @@ +#!/bin/bash -e + +# This script is intended to be used inside a docker container for Complement + +if [[ "${COVER}" -eq 1 ]]; then + echo "Running with coverage" + exec /dendrite/dendrite-monolith-server-cover \ + --really-enable-open-registration \ + --tls-cert server.crt \ + --tls-key server.key \ + --config dendrite.yaml \ + -api=${API:-0} \ + --test.coverprofile=integrationcover.log +else + echo "Not running with coverage" + exec /dendrite/dendrite-monolith-server \ + --really-enable-open-registration \ + --tls-cert server.crt \ + --tls-key server.key \ + --config dendrite.yaml \ + -api=${API:-0} +fi From 2e1fe589375b650f9b2d9a09e1fcffb3ab6fe5b6 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 5 Jan 2023 09:24:00 +0100 Subject: [PATCH 38/67] Fix backfilling (#2926) This should fix https://github.com/matrix-org/dendrite/issues/2923 --- go.mod | 2 +- go.sum | 4 ++-- roomserver/internal/perform/perform_backfill.go | 9 +++++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/go.mod b/go.mod index d3eb4890a..2d7174150 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index ad9372c84..b12f65eab 100644 --- a/go.sum +++ b/go.sum @@ -348,8 +348,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8 h1:jVvlCGs6OosCdvw9MkfiVnTVnIt7vKMHg/F6th9BtSo= -github.com/matrix-org/gomatrixserverlib v0.0.0-20221129095800-8835f6db16b8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index 069f017a9..d9214fdc6 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -122,11 +122,14 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform ctx, req.VirtualHost, requester, r.KeyRing, req.RoomID, info.RoomVersion, req.PrevEventIDs(), 100, ) - if err != nil { + // Only return an error if we really couldn't get any events. + if err != nil && len(events) == 0 { logrus.WithError(err).Errorf("gomatrixserverlib.RequestBackfill failed") return err } - logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) + // If we got an error but still got events, that's fine, because a server might have returned a 404 (or something) + // but other servers could provide the missing event. + logrus.WithError(err).WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events)) // persist these new events - auth checks have already been done roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events) @@ -319,6 +322,7 @@ FederationHit: FedClient: b.fsAPI, RememberAuthEvents: false, Server: srv, + Origin: b.virtualHost, } res, err := c.StateIDsBeforeEvent(ctx, targetEvent) if err != nil { @@ -394,6 +398,7 @@ func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatr FedClient: b.fsAPI, RememberAuthEvents: false, Server: srv, + Origin: b.virtualHost, } result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs) if err != nil { From d579ddb8e7c1a7e118797bcef08113379535e6fb Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:44:10 +0100 Subject: [PATCH 39/67] Add simplified helm chart (#2905) As discussed yesterday, a simplified version of [my helm](https://github.com/S7evinK/dendrite-helm) which deploys a monolith with internal NATS and an optionally enabled PostgreSQL server. If the PostgreSQL dependency is not enabled, a user specified connection string is constructed. Co-authored-by: kegsay --- .github/workflows/gh-pages.yml | 52 +++ .github/workflows/helm.yml | 39 ++ .github/workflows/k8s.yml | 90 +++++ helm/cr.yaml | 1 + helm/ct.yaml | 7 + helm/dendrite/.helm-docs/about.gotmpl | 5 + helm/dendrite/.helm-docs/appservices.gotmpl | 5 + helm/dendrite/.helm-docs/database.gotmpl | 18 + helm/dendrite/.helm-docs/state.gotmpl | 3 + helm/dendrite/Chart.yaml | 19 + helm/dendrite/README.md | 147 ++++++++ helm/dendrite/README.md.gotmpl | 13 + helm/dendrite/ci/ct-ingress-values.yaml | 13 + .../ci/ct-postgres-sharedsecret-values.yaml | 16 + helm/dendrite/templates/_helpers.tpl | 72 ++++ helm/dendrite/templates/_overrides.yaml | 16 + helm/dendrite/templates/deployment.yaml | 103 ++++++ helm/dendrite/templates/ingress.yaml | 55 +++ helm/dendrite/templates/jobs.yaml | 99 +++++ helm/dendrite/templates/pvc.yaml | 48 +++ helm/dendrite/templates/secrets.yaml | 33 ++ helm/dendrite/templates/service.yaml | 17 + .../templates/tests/test-version.yaml | 17 + helm/dendrite/values.yaml | 348 ++++++++++++++++++ setup/config/config.go | 4 +- 25 files changed, 1238 insertions(+), 2 deletions(-) create mode 100644 .github/workflows/gh-pages.yml create mode 100644 .github/workflows/helm.yml create mode 100644 .github/workflows/k8s.yml create mode 100644 helm/cr.yaml create mode 100644 helm/ct.yaml create mode 100644 helm/dendrite/.helm-docs/about.gotmpl create mode 100644 helm/dendrite/.helm-docs/appservices.gotmpl create mode 100644 helm/dendrite/.helm-docs/database.gotmpl create mode 100644 helm/dendrite/.helm-docs/state.gotmpl create mode 100644 helm/dendrite/Chart.yaml create mode 100644 helm/dendrite/README.md create mode 100644 helm/dendrite/README.md.gotmpl create mode 100644 helm/dendrite/ci/ct-ingress-values.yaml create mode 100644 helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml create mode 100644 helm/dendrite/templates/_helpers.tpl create mode 100644 helm/dendrite/templates/_overrides.yaml create mode 100644 helm/dendrite/templates/deployment.yaml create mode 100644 helm/dendrite/templates/ingress.yaml create mode 100644 helm/dendrite/templates/jobs.yaml create mode 100644 helm/dendrite/templates/pvc.yaml create mode 100644 helm/dendrite/templates/secrets.yaml create mode 100644 helm/dendrite/templates/service.yaml create mode 100644 helm/dendrite/templates/tests/test-version.yaml create mode 100644 helm/dendrite/values.yaml diff --git a/.github/workflows/gh-pages.yml b/.github/workflows/gh-pages.yml new file mode 100644 index 000000000..b5a8f0bbd --- /dev/null +++ b/.github/workflows/gh-pages.yml @@ -0,0 +1,52 @@ +# Sample workflow for building and deploying a Jekyll site to GitHub Pages +name: Deploy GitHub Pages dependencies preinstalled + +on: + # Runs on pushes targeting the default branch + push: + branches: ["main"] + paths: + - 'docs/**' # only execute if we have docs changes + + # Allows you to run this workflow manually from the Actions tab + workflow_dispatch: + +# Sets permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages +permissions: + contents: read + pages: write + id-token: write + +# Allow one concurrent deployment +concurrency: + group: "pages" + cancel-in-progress: true + +jobs: + # Build job + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + - name: Setup Pages + uses: actions/configure-pages@v2 + - name: Build with Jekyll + uses: actions/jekyll-build-pages@v1 + with: + source: ./docs + destination: ./_site + - name: Upload artifact + uses: actions/upload-pages-artifact@v1 + + # Deployment job + deploy: + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v1 diff --git a/.github/workflows/helm.yml b/.github/workflows/helm.yml new file mode 100644 index 000000000..7cdc369ba --- /dev/null +++ b/.github/workflows/helm.yml @@ -0,0 +1,39 @@ +name: Release Charts + +on: + push: + branches: + - main + paths: + - 'helm/**' # only execute if we have helm chart changes + +jobs: + release: + # depending on default permission settings for your org (contents being read-only or read-write for workloads), you will have to add permissions + # see: https://docs.github.com/en/actions/security-guides/automatic-token-authentication#modifying-the-permissions-for-the-github_token + permissions: + contents: write + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Configure Git + run: | + git config user.name "$GITHUB_ACTOR" + git config user.email "$GITHUB_ACTOR@users.noreply.github.com" + + - name: Install Helm + uses: azure/setup-helm@v3 + with: + version: v3.10.0 + + - name: Run chart-releaser + uses: helm/chart-releaser-action@v1.4.1 + env: + CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}" + with: + config: helm/cr.yaml + charts_dir: helm/ diff --git a/.github/workflows/k8s.yml b/.github/workflows/k8s.yml new file mode 100644 index 000000000..fc5e8c906 --- /dev/null +++ b/.github/workflows/k8s.yml @@ -0,0 +1,90 @@ +name: k8s + +on: + push: + branches: ["main"] + paths: + - 'helm/**' # only execute if we have helm chart changes + pull_request: + branches: ["main"] + paths: + - 'helm/**' + +jobs: + lint: + name: Lint Helm chart + runs-on: ubuntu-latest + outputs: + changed: ${{ steps.list-changed.outputs.changed }} + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + - uses: azure/setup-helm@v3 + with: + version: v3.10.0 + - uses: actions/setup-python@v4 + with: + python-version: 3.11 + check-latest: true + - uses: helm/chart-testing-action@v2.3.1 + - name: Get changed status + id: list-changed + run: | + changed=$(ct list-changed --config helm/ct.yaml --target-branch ${{ github.event.repository.default_branch }}) + if [[ -n "$changed" ]]; then + echo "::set-output name=changed::true" + fi + + - name: Run lint + run: ct lint --config helm/ct.yaml + + # only bother to run if lint step reports a change to the helm chart + install: + needs: + - lint + if: ${{ needs.lint.outputs.changed == 'true' }} + name: Install Helm charts + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v3 + with: + fetch-depth: 0 + ref: ${{ inputs.checkoutCommit }} + - name: Install Kubernetes tools + uses: yokawasa/action-setup-kube-tools@v0.8.2 + with: + setup-tools: | + helmv3 + helm: "3.10.3" + - uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Set up chart-testing + uses: helm/chart-testing-action@v2.3.1 + - name: Create k3d cluster + uses: nolar/setup-k3d-k3s@v1 + with: + version: v1.21 + - name: Remove node taints + run: | + kubectl taint --all=true nodes node.cloudprovider.kubernetes.io/uninitialized- || true + - name: Run chart-testing (install) + run: ct install --config helm/ct.yaml + + # Install the chart using helm directly and test with create-account + - name: Install chart + run: | + helm install --values helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml dendrite helm/dendrite + - name: Wait for Postgres and Dendrite to be up + run: | + kubectl wait --for=condition=ready --timeout=90s pod -l app.kubernetes.io/name=postgresql || kubectl get pods -A + kubectl wait --for=condition=ready --timeout=90s pod -l app.kubernetes.io/name=dendrite || kubectl get pods -A + kubectl get pods -A + kubectl get services + kubectl get ingress + - name: Run create account + run: | + podName=$(kubectl get pods -l app.kubernetes.io/name=dendrite -o name) + kubectl exec "${podName}" -- /usr/bin/create-account -username alice -password somerandompassword \ No newline at end of file diff --git a/helm/cr.yaml b/helm/cr.yaml new file mode 100644 index 000000000..f895ab8d6 --- /dev/null +++ b/helm/cr.yaml @@ -0,0 +1 @@ +release-name-template: "helm-{{ .Name }}-{{ .Version }}" \ No newline at end of file diff --git a/helm/ct.yaml b/helm/ct.yaml new file mode 100644 index 000000000..af706fa3d --- /dev/null +++ b/helm/ct.yaml @@ -0,0 +1,7 @@ +remote: origin +target-branch: main +chart-repos: + - bitnami=https://charts.bitnami.com/bitnami +chart-dirs: + - helm +validate-maintainers: false \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/about.gotmpl b/helm/dendrite/.helm-docs/about.gotmpl new file mode 100644 index 000000000..a92c6be42 --- /dev/null +++ b/helm/dendrite/.helm-docs/about.gotmpl @@ -0,0 +1,5 @@ +{{ define "chart.about" }} +## About + +This chart creates a monolith deployment, including an optionally enabled PostgreSQL dependency to connect to. +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/appservices.gotmpl b/helm/dendrite/.helm-docs/appservices.gotmpl new file mode 100644 index 000000000..8a79a0780 --- /dev/null +++ b/helm/dendrite/.helm-docs/appservices.gotmpl @@ -0,0 +1,5 @@ +{{ define "chart.appservices" }} +## Usage with appservices + +Create a folder `appservices` and place your configurations in there. The configurations will be read and placed in a secret `dendrite-appservices-conf`. +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/database.gotmpl b/helm/dendrite/.helm-docs/database.gotmpl new file mode 100644 index 000000000..85ef01ecc --- /dev/null +++ b/helm/dendrite/.helm-docs/database.gotmpl @@ -0,0 +1,18 @@ +{{ define "chart.dbCreation" }} +## Manual database creation + +(You can skip this, if you're deploying the PostgreSQL dependency) + +You'll need to create the following database before starting Dendrite (see [installation](https://matrix-org.github.io/dendrite/installation/database#single-database-creation)): + +```postgres +create database dendrite +``` + +or + +```bash +sudo -u postgres createdb -O dendrite -E UTF-8 dendrite +``` + +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/.helm-docs/state.gotmpl b/helm/dendrite/.helm-docs/state.gotmpl new file mode 100644 index 000000000..2fe987ddd --- /dev/null +++ b/helm/dendrite/.helm-docs/state.gotmpl @@ -0,0 +1,3 @@ +{{ define "chart.state" }} +Status: **NOT PRODUCTION READY** +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml new file mode 100644 index 000000000..15d1e6d19 --- /dev/null +++ b/helm/dendrite/Chart.yaml @@ -0,0 +1,19 @@ +apiVersion: v2 +name: dendrite +version: "0.10.8" +appVersion: "0.10.8" +description: Dendrite Matrix Homeserver +type: application +keywords: + - matrix + - chat + - homeserver + - dendrite +home: https://github.com/matrix-org/dendrite +sources: + - https://github.com/matrix-org/dendrite +dependencies: +- name: postgresql + version: 12.1.7 + repository: https://charts.bitnami.com/bitnami + condition: postgresql.enabled diff --git a/helm/dendrite/README.md b/helm/dendrite/README.md new file mode 100644 index 000000000..cb850d655 --- /dev/null +++ b/helm/dendrite/README.md @@ -0,0 +1,147 @@ +# dendrite + +![Version: 0.10.8](https://img.shields.io/badge/Version-0.10.8-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.10.8](https://img.shields.io/badge/AppVersion-0.10.8-informational?style=flat-square) +Dendrite Matrix Homeserver + +Status: **NOT PRODUCTION READY** + +## About + +This chart creates a monolith deployment, including an optionally enabled PostgreSQL dependency to connect to. + +## Manual database creation + +(You can skip this, if you're deploying the PostgreSQL dependency) + +You'll need to create the following database before starting Dendrite (see [installation](https://matrix-org.github.io/dendrite/installation/database#single-database-creation)): + +```postgres +create database dendrite +``` + +or + +```bash +sudo -u postgres createdb -O dendrite -E UTF-8 dendrite +``` + +## Usage with appservices + +Create a folder `appservices` and place your configurations in there. The configurations will be read and placed in a secret `dendrite-appservices-conf`. + +## Source Code + +* +## Requirements + +| Repository | Name | Version | +|------------|------|---------| +| https://charts.bitnami.com/bitnami | postgresql | 12.1.7 | +## Values + +| Key | Type | Default | Description | +|-----|------|---------|-------------| +| image.name | string | `"ghcr.io/matrix-org/dendrite-monolith:v0.10.8"` | Docker repository/image to use | +| image.pullPolicy | string | `"IfNotPresent"` | Kubernetes pullPolicy | +| signing_key.create | bool | `true` | Create a new signing key, if not exists | +| signing_key.existingSecret | string | `""` | Use an existing secret | +| resources | object | sets some sane default values | Default resource requests/limits. | +| persistence.storageClass | string | `""` | The storage class to use for volume claims. Defaults to the cluster default storage class. | +| persistence.jetstream.existingClaim | string | `""` | Use an existing volume claim for jetstream | +| persistence.jetstream.capacity | string | `"1Gi"` | PVC Storage Request for the jetstream volume | +| persistence.media.existingClaim | string | `""` | Use an existing volume claim for media files | +| persistence.media.capacity | string | `"1Gi"` | PVC Storage Request for the media volume | +| persistence.search.existingClaim | string | `""` | Use an existing volume claim for the fulltext search index | +| persistence.search.capacity | string | `"1Gi"` | PVC Storage Request for the search volume | +| dendrite_config.version | int | `2` | | +| dendrite_config.global.server_name | string | `""` | **REQUIRED** Servername for this Dendrite deployment. | +| dendrite_config.global.private_key | string | `"/etc/dendrite/secrets/signing.key"` | The private key to use. (**NOTE**: This is overriden in Helm) | +| dendrite_config.global.well_known_server_name | string | `""` | The server name to delegate server-server communications to, with optional port e.g. localhost:443 | +| dendrite_config.global.well_known_client_name | string | `""` | The server name to delegate client-server communications to, with optional port e.g. localhost:443 | +| dendrite_config.global.trusted_third_party_id_servers | list | `["matrix.org","vector.im"]` | Lists of domains that the server will trust as identity servers to verify third party identifiers such as phone numbers and email addresses. | +| dendrite_config.global.old_private_keys | string | `nil` | The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) to old signing keys that were formerly in use on this domain name. These keys will not be used for federation request or event signing, but will be provided to any other homeserver that asks when trying to verify old events. | +| dendrite_config.global.disable_federation | bool | `false` | Disable federation. Dendrite will not be able to make any outbound HTTP requests to other servers and the federation API will not be exposed. | +| dendrite_config.global.key_validity_period | string | `"168h0m0s"` | | +| dendrite_config.global.database.connection_string | string | `""` | The connection string for connections to Postgres. This will be set automatically if using the Postgres dependency | +| dendrite_config.global.database.max_open_conns | int | `90` | Default database maximum open connections | +| dendrite_config.global.database.max_idle_conns | int | `5` | Default database maximum idle connections | +| dendrite_config.global.database.conn_max_lifetime | int | `-1` | Default database maximum lifetime | +| dendrite_config.global.jetstream.storage_path | string | `"/data/jetstream"` | Persistent directory to store JetStream streams in. | +| dendrite_config.global.jetstream.addresses | list | `[]` | NATS JetStream server addresses if not using internal NATS. | +| dendrite_config.global.jetstream.topic_prefix | string | `"Dendrite"` | The prefix for JetStream streams | +| dendrite_config.global.jetstream.in_memory | bool | `false` | Keep all data in memory. (**NOTE**: This is overriden in Helm to `false`) | +| dendrite_config.global.jetstream.disable_tls_validation | bool | `true` | Disables TLS validation. This should **NOT** be used in production. | +| dendrite_config.global.cache.max_size_estimated | string | `"1gb"` | The estimated maximum size for the global cache in bytes, or in terabytes, gigabytes, megabytes or kilobytes when the appropriate 'tb', 'gb', 'mb' or 'kb' suffix is specified. Note that this is not a hard limit, nor is it a memory limit for the entire process. A cache that is too small may ultimately provide little or no benefit. | +| dendrite_config.global.cache.max_age | string | `"1h"` | The maximum amount of time that a cache entry can live for in memory before it will be evicted and/or refreshed from the database. Lower values result in easier admission of new cache entries but may also increase database load in comparison to higher values, so adjust conservatively. Higher values may make it harder for new items to make it into the cache, e.g. if new rooms suddenly become popular. | +| dendrite_config.global.report_stats.enabled | bool | `false` | Configures phone-home statistics reporting. These statistics contain the server name, number of active users and some information on your deployment config. We use this information to understand how Dendrite is being used in the wild. | +| dendrite_config.global.report_stats.endpoint | string | `"https://matrix.org/report-usage-stats/push"` | Endpoint to report statistics to. | +| dendrite_config.global.presence.enable_inbound | bool | `false` | Controls whether we receive presence events from other servers | +| dendrite_config.global.presence.enable_outbound | bool | `false` | Controls whether we send presence events for our local users to other servers. (_May increase CPU/memory usage_) | +| dendrite_config.global.server_notices.enabled | bool | `false` | Server notices allows server admins to send messages to all users on the server. | +| dendrite_config.global.server_notices.local_part | string | `"_server"` | The local part for the user sending server notices. | +| dendrite_config.global.server_notices.display_name | string | `"Server Alerts"` | The display name for the user sending server notices. | +| dendrite_config.global.server_notices.avatar_url | string | `""` | The avatar URL (as a mxc:// URL) name for the user sending server notices. | +| dendrite_config.global.server_notices.room_name | string | `"Server Alerts"` | | +| dendrite_config.global.metrics.enabled | bool | `false` | Whether or not Prometheus metrics are enabled. | +| dendrite_config.global.metrics.basic_auth.user | string | `"metrics"` | HTTP basic authentication username | +| dendrite_config.global.metrics.basic_auth.password | string | `"metrics"` | HTTP basic authentication password | +| dendrite_config.global.dns_cache.enabled | bool | `false` | Whether or not the DNS cache is enabled. | +| dendrite_config.global.dns_cache.cache_size | int | `256` | Maximum number of entries to hold in the DNS cache | +| dendrite_config.global.dns_cache.cache_lifetime | string | `"10m"` | Duration for how long DNS cache items should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) | +| dendrite_config.global.profiling.enabled | bool | `false` | Enable pprof. You will need to manually create a port forwarding to the deployment to access PPROF, as it will only listen on localhost and the defined port. e.g. `kubectl port-forward deployments/dendrite 65432:65432` | +| dendrite_config.global.profiling.port | int | `65432` | pprof port, if enabled | +| dendrite_config.mscs | object | `{"mscs":["msc2946"]}` | Configuration for experimental MSC's. (Valid values are: msc2836 and msc2946) | +| dendrite_config.app_service_api.disable_tls_validation | bool | `false` | Disable the validation of TLS certificates of appservices. This is not recommended in production since it may allow appservice traffic to be sent to an insecure endpoint. | +| dendrite_config.app_service_api.config_files | list | `[]` | Appservice config files to load on startup. (**NOTE**: This is overriden by Helm, if a folder `./appservices/` exists) | +| dendrite_config.client_api.registration_disabled | bool | `true` | Prevents new users from being able to register on this homeserver, except when using the registration shared secret below. | +| dendrite_config.client_api.guests_disabled | bool | `true` | | +| dendrite_config.client_api.registration_shared_secret | string | `""` | If set, allows registration by anyone who knows the shared secret, regardless of whether registration is otherwise disabled. | +| dendrite_config.client_api.enable_registration_captcha | bool | `false` | enable reCAPTCHA registration | +| dendrite_config.client_api.recaptcha_public_key | string | `""` | reCAPTCHA public key | +| dendrite_config.client_api.recaptcha_private_key | string | `""` | reCAPTCHA private key | +| dendrite_config.client_api.recaptcha_bypass_secret | string | `""` | reCAPTCHA bypass secret | +| dendrite_config.client_api.recaptcha_siteverify_api | string | `""` | | +| dendrite_config.client_api.turn.turn_user_lifetime | string | `"24h"` | Duration for how long users should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) | +| dendrite_config.client_api.turn.turn_uris | list | `[]` | | +| dendrite_config.client_api.turn.turn_shared_secret | string | `""` | | +| dendrite_config.client_api.turn.turn_username | string | `""` | The TURN username | +| dendrite_config.client_api.turn.turn_password | string | `""` | The TURN password | +| dendrite_config.client_api.rate_limiting.enabled | bool | `true` | Enable rate limiting | +| dendrite_config.client_api.rate_limiting.threshold | int | `20` | After how many requests a rate limit should be activated | +| dendrite_config.client_api.rate_limiting.cooloff_ms | int | `500` | Cooloff time in milliseconds | +| dendrite_config.client_api.rate_limiting.exempt_user_ids | string | `nil` | Users which should be exempt from rate limiting | +| dendrite_config.federation_api.send_max_retries | int | `16` | Federation failure threshold. How many consecutive failures that we should tolerate when sending federation requests to a specific server. The backoff is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds, etc. The default value is 16 if not specified, which is circa 18 hours. | +| dendrite_config.federation_api.disable_tls_validation | bool | `false` | Disable TLS validation. This should **NOT** be used in production. | +| dendrite_config.federation_api.prefer_direct_fetch | bool | `false` | | +| dendrite_config.federation_api.disable_http_keepalives | bool | `false` | Prevents Dendrite from keeping HTTP connections open for reuse for future requests. Connections will be closed quicker but we may spend more time on TLS handshakes instead. | +| dendrite_config.federation_api.key_perspectives | list | See value.yaml | Perspective keyservers, to use as a backup when direct key fetch requests don't succeed. | +| dendrite_config.media_api.base_path | string | `"/data/media_store"` | The path to store media files (e.g. avatars) in | +| dendrite_config.media_api.max_file_size_bytes | int | `10485760` | The max file size for uploaded media files | +| dendrite_config.media_api.dynamic_thumbnails | bool | `false` | | +| dendrite_config.media_api.max_thumbnail_generators | int | `10` | The maximum number of simultaneous thumbnail generators to run. | +| dendrite_config.media_api.thumbnail_sizes | list | See value.yaml | A list of thumbnail sizes to be generated for media content. | +| dendrite_config.sync_api.real_ip_header | string | `"X-Real-IP"` | This option controls which HTTP header to inspect to find the real remote IP address of the client. This is likely required if Dendrite is running behind a reverse proxy server. | +| dendrite_config.sync_api.search | object | `{"enabled":true,"index_path":"/data/search","language":"en"}` | Configuration for the full-text search engine. | +| dendrite_config.sync_api.search.enabled | bool | `true` | Whether fulltext search is enabled. | +| dendrite_config.sync_api.search.index_path | string | `"/data/search"` | The path to store the search index in. | +| dendrite_config.sync_api.search.language | string | `"en"` | The language most likely to be used on the server - used when indexing, to ensure the returned results match expectations. A full list of possible languages can be found [here](https://github.com/matrix-org/dendrite/blob/76db8e90defdfb9e61f6caea8a312c5d60bcc005/internal/fulltext/bleve.go#L25-L46) | +| dendrite_config.user_api.bcrypt_cost | int | `10` | bcrypt cost to use when hashing passwords. (ranges from 4-31; 4 being least secure, 31 being most secure; _NOTE: Using a too high value can cause clients to timeout and uses more CPU._) | +| dendrite_config.user_api.openid_token_lifetime_ms | int | `3600000` | OpenID Token lifetime in milliseconds. | +| dendrite_config.user_api.push_gateway_disable_tls_validation | bool | `false` | | +| dendrite_config.user_api.auto_join_rooms | list | `[]` | Rooms to join users to after registration | +| dendrite_config.logging | list | `[{"level":"info","type":"std"}]` | Default logging configuration | +| postgresql.enabled | bool | See value.yaml | Enable and configure postgres as the database for dendrite. | +| postgresql.image.repository | string | `"bitnami/postgresql"` | | +| postgresql.image.tag | string | `"15.1.0"` | | +| postgresql.auth.username | string | `"dendrite"` | | +| postgresql.auth.password | string | `"changeme"` | | +| postgresql.auth.database | string | `"dendrite"` | | +| postgresql.persistence.enabled | bool | `false` | | +| ingress.enabled | bool | `false` | Create an ingress for a monolith deployment | +| ingress.hosts | list | `[]` | | +| ingress.className | string | `""` | | +| ingress.hostName | string | `""` | | +| ingress.annotations | object | `{}` | Extra, custom annotations | +| ingress.tls | list | `[]` | | +| service.type | string | `"ClusterIP"` | | +| service.port | int | `80` | | diff --git a/helm/dendrite/README.md.gotmpl b/helm/dendrite/README.md.gotmpl new file mode 100644 index 000000000..7c32f7b02 --- /dev/null +++ b/helm/dendrite/README.md.gotmpl @@ -0,0 +1,13 @@ +{{ template "chart.header" . }} +{{ template "chart.deprecationWarning" . }} +{{ template "chart.badgesSection" . }} +{{ template "chart.description" . }} +{{ template "chart.state" . }} +{{ template "chart.about" . }} +{{ template "chart.dbCreation" . }} +{{ template "chart.appservices" . }} +{{ template "chart.maintainersSection" . }} +{{ template "chart.sourcesSection" . }} +{{ template "chart.requirementsSection" . }} +{{ template "chart.valuesSection" . }} +{{ template "helm-docs.versionFooter" . }} \ No newline at end of file diff --git a/helm/dendrite/ci/ct-ingress-values.yaml b/helm/dendrite/ci/ct-ingress-values.yaml new file mode 100644 index 000000000..28311d33e --- /dev/null +++ b/helm/dendrite/ci/ct-ingress-values.yaml @@ -0,0 +1,13 @@ +--- +postgresql: + enabled: true + primary: + persistence: + size: 1Gi + +dendrite_config: + global: + server_name: "localhost" + +ingress: + enabled: true diff --git a/helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml b/helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml new file mode 100644 index 000000000..55e652c63 --- /dev/null +++ b/helm/dendrite/ci/ct-postgres-sharedsecret-values.yaml @@ -0,0 +1,16 @@ +--- +postgresql: + enabled: true + primary: + persistence: + size: 1Gi + +dendrite_config: + global: + server_name: "localhost" + + client_api: + registration_shared_secret: "d233f2fcb0470845a8e150a20ef594ddbe0b4cf7fe482fb9d5120c198557acbf" # echo "dendrite" | sha256sum + +ingress: + enabled: true diff --git a/helm/dendrite/templates/_helpers.tpl b/helm/dendrite/templates/_helpers.tpl new file mode 100644 index 000000000..291f351bc --- /dev/null +++ b/helm/dendrite/templates/_helpers.tpl @@ -0,0 +1,72 @@ +{{- define "validate.config" }} +{{- if not .Values.signing_key.create -}} +{{- fail "You must create a signing key for configuration.signing_key. (see https://github.com/matrix-org/dendrite/blob/master/docs/INSTALL.md#server-key-generation)" -}} +{{- end -}} +{{- if not (or .Values.dendrite_config.global.database.host .Values.postgresql.enabled) -}} +{{- fail "Database server must be set." -}} +{{- end -}} +{{- if not (or .Values.dendrite_config.global.database.user .Values.postgresql.enabled) -}} +{{- fail "Database user must be set." -}} +{{- end -}} +{{- if not (or .Values.dendrite_config.global.database.password .Values.postgresql.enabled) -}} +{{- fail "Database password must be set." -}} +{{- end -}} +{{- end -}} + + +{{- define "image.name" -}} +image: {{ .name }} +imagePullPolicy: {{ .pullPolicy }} +{{- end -}} + +{{/* +Expand the name of the chart. +*/}} +{{- define "dendrite.name" -}} +{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). +If release name contains chart name it will be used as a full name. +*/}} +{{- define "dendrite.fullname" -}} +{{- if .Values.fullnameOverride }} +{{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- $name := default .Chart.Name .Values.nameOverride }} +{{- if contains $name .Release.Name }} +{{- .Release.Name | trunc 63 | trimSuffix "-" }} +{{- else }} +{{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} +{{- end }} +{{- end }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "dendrite.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "dendrite.labels" -}} +helm.sh/chart: {{ include "dendrite.chart" . }} +{{ include "dendrite.selectorLabels" . }} +{{- if .Chart.AppVersion }} +app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} +{{- end }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{/* +Selector labels +*/}} +{{- define "dendrite.selectorLabels" -}} +app.kubernetes.io/name: {{ include "dendrite.name" . }} +app.kubernetes.io/instance: {{ .Release.Name }} +{{- end }} \ No newline at end of file diff --git a/helm/dendrite/templates/_overrides.yaml b/helm/dendrite/templates/_overrides.yaml new file mode 100644 index 000000000..edb8ba83a --- /dev/null +++ b/helm/dendrite/templates/_overrides.yaml @@ -0,0 +1,16 @@ +{{- define "override.config" }} +{{- if .Values.postgresql.enabled }} +{{- $_ := set .Values.dendrite_config.global.database "connection_string" (print "postgresql://" .Values.postgresql.auth.username ":" .Values.postgresql.auth.password "@" .Release.Name "-postgresql/dendrite?sslmode=disable") -}} +{{ end }} +global: + private_key: /etc/dendrite/secrets/signing.key + jetstream: + in_memory: false +{{ if (gt (len (.Files.Glob "appservices/*")) 0) }} +app_service_api: + config_files: + {{- range $x, $y := .Files.Glob "appservices/*" }} + - /etc/dendrite/appservices/{{ base $x }} + {{ end }} +{{ end }} +{{ end }} diff --git a/helm/dendrite/templates/deployment.yaml b/helm/dendrite/templates/deployment.yaml new file mode 100644 index 000000000..629ffe528 --- /dev/null +++ b/helm/dendrite/templates/deployment.yaml @@ -0,0 +1,103 @@ +{{ template "validate.config" . }} +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + namespace: {{ $.Release.Namespace }} + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} +spec: + selector: + matchLabels: + {{- include "dendrite.selectorLabels" . | nindent 6 }} + replicas: 1 + template: + metadata: + labels: + {{- include "dendrite.selectorLabels" . | nindent 8 }} + annotations: + confighash-global: secret-{{ .Values.global | toYaml | sha256sum | trunc 32 }} + confighash-clientapi: clientapi-{{ .Values.clientapi | toYaml | sha256sum | trunc 32 }} + confighash-federationapi: federationapi-{{ .Values.federationapi | toYaml | sha256sum | trunc 32 }} + confighash-mediaapi: mediaapi-{{ .Values.mediaapi | toYaml | sha256sum | trunc 32 }} + confighash-syncapi: syncapi-{{ .Values.syncapi | toYaml | sha256sum | trunc 32 }} + spec: + volumes: + - name: {{ include "dendrite.fullname" . }}-conf-vol + secret: + secretName: {{ include "dendrite.fullname" . }}-conf + - name: {{ include "dendrite.fullname" . }}-signing-key + secret: + secretName: {{ default (print ( include "dendrite.fullname" . ) "-signing-key") $.Values.signing_key.existingSecret | quote }} + {{- if (gt (len ($.Files.Glob "appservices/*")) 0) }} + - name: {{ include "dendrite.fullname" . }}-appservices + secret: + secretName: {{ include "dendrite.fullname" . }}-appservices-conf + {{- end }} + - name: {{ include "dendrite.fullname" . }}-jetstream + persistentVolumeClaim: + claimName: {{ default (print ( include "dendrite.fullname" . ) "-jetstream-pvc") $.Values.persistence.jetstream.existingClaim | quote }} + - name: {{ include "dendrite.fullname" . }}-media + persistentVolumeClaim: + claimName: {{ default (print ( include "dendrite.fullname" . ) "-media-pvc") $.Values.persistence.media.existingClaim | quote }} + - name: {{ include "dendrite.fullname" . }}-search + persistentVolumeClaim: + claimName: {{ default (print ( include "dendrite.fullname" . ) "-search-pvc") $.Values.persistence.search.existingClaim | quote }} + containers: + - name: {{ $.Chart.Name }} + {{- include "image.name" $.Values.image | nindent 8 }} + args: + - '--config' + - '/etc/dendrite/dendrite.yaml' + ports: + - name: http + containerPort: 8008 + protocol: TCP + {{- if $.Values.dendrite_config.global.profiling.enabled }} + env: + - name: PPROFLISTEN + value: "localhost:{{- $.Values.global.profiling.port -}}" + {{- end }} + resources: + {{- toYaml $.Values.resources | nindent 10 }} + volumeMounts: + - mountPath: /etc/dendrite/ + name: {{ include "dendrite.fullname" . }}-conf-vol + - mountPath: /etc/dendrite/secrets/ + name: {{ include "dendrite.fullname" . }}-signing-key + {{- if (gt (len ($.Files.Glob "appservices/*")) 0) }} + - mountPath: /etc/dendrite/appservices + name: {{ include "dendrite.fullname" . }}-appservices + readOnly: true + {{ end }} + - mountPath: {{ .Values.dendrite_config.media_api.base_path }} + name: {{ include "dendrite.fullname" . }}-media + - mountPath: {{ .Values.dendrite_config.global.jetstream.storage_path }} + name: {{ include "dendrite.fullname" . }}-jetstream + - mountPath: {{ .Values.dendrite_config.sync_api.search.index_path }} + name: {{ include "dendrite.fullname" . }}-search + livenessProbe: + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + httpGet: + path: /_dendrite/monitor/health + port: http + readinessProbe: + initialDelaySeconds: 10 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + httpGet: + path: /_dendrite/monitor/health + port: http + startupProbe: + initialDelaySeconds: 5 + periodSeconds: 10 + timeoutSeconds: 5 + failureThreshold: 10 + httpGet: + path: /_dendrite/monitor/up + port: http \ No newline at end of file diff --git a/helm/dendrite/templates/ingress.yaml b/helm/dendrite/templates/ingress.yaml new file mode 100644 index 000000000..8f86ad723 --- /dev/null +++ b/helm/dendrite/templates/ingress.yaml @@ -0,0 +1,55 @@ +{{- if .Values.ingress.enabled -}} + {{- $fullName := include "dendrite.fullname" . -}} + {{- $svcPort := .Values.service.port -}} + {{- if and .Values.ingress.className (not (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion)) }} + {{- if not (hasKey .Values.ingress.annotations "kubernetes.io/ingress.class") }} + {{- $_ := set .Values.ingress.annotations "kubernetes.io/ingress.class" .Values.ingress.className}} + {{- end }} + {{- end }} + {{- if semverCompare ">=1.19-0" .Capabilities.KubeVersion.GitVersion -}} +apiVersion: networking.k8s.io/v1 + {{- else if semverCompare ">=1.14-0" .Capabilities.KubeVersion.GitVersion -}} +apiVersion: networking.k8s.io/v1beta1 + {{- else -}} +apiVersion: extensions/v1beta1 + {{- end }} +kind: Ingress +metadata: + name: {{ $fullName }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} + annotations: + {{- with .Values.ingress.annotations }} + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + {{- if and .Values.ingress.className (semverCompare ">=1.18-0" .Capabilities.KubeVersion.GitVersion) }} + ingressClassName: {{ .Values.ingress.className }} + {{- end }} + {{- if .Values.ingress.tls }} + tls: + {{- range .Values.ingress.tls }} + - hosts: + {{- range .hosts }} + - {{ . | quote }} + {{- end }} + secretName: {{ .secretName }} + {{- end }} + {{- end }} + rules: + - host: {{ .Values.ingress.hostName | quote }} + http: + paths: + - path: / + pathType: ImplementationSpecific + backend: + {{- if semverCompare ">=1.19-0" $.Capabilities.KubeVersion.GitVersion }} + service: + name: {{ $fullName }} + port: + number: {{ $svcPort }} + {{- else }} + serviceName: {{ $fullName }} + servicePort: {{ $svcPort }} + {{- end }} + {{- end }} \ No newline at end of file diff --git a/helm/dendrite/templates/jobs.yaml b/helm/dendrite/templates/jobs.yaml new file mode 100644 index 000000000..76915694d --- /dev/null +++ b/helm/dendrite/templates/jobs.yaml @@ -0,0 +1,99 @@ +{{ if and .Values.signing_key.create (not .Values.signing_key.existingSecret ) }} +{{ $name := (print ( include "dendrite.fullname" . ) "-signing-key") }} +{{ $secretName := (print ( include "dendrite.fullname" . ) "-signing-key") }} +--- +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{ $name }} + labels: + app.kubernetes.io/component: signingkey-job +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: Role +metadata: + name: {{ $name }} + labels: + app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} +rules: + - apiGroups: + - "" + resources: + - secrets + resourceNames: + - {{ $secretName }} + verbs: + - get + - update + - patch +--- +apiVersion: rbac.authorization.k8s.io/v1 +kind: RoleBinding +metadata: + name: {{ $name }} + labels: + app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} +roleRef: + apiGroup: rbac.authorization.k8s.io + kind: Role + name: {{ $name }} +subjects: + - kind: ServiceAccount + name: {{ $name }} + namespace: {{ .Release.Namespace }} +--- +apiVersion: batch/v1 +kind: Job +metadata: + name: generate-signing-key + labels: + {{- include "dendrite.labels" . | nindent 4 }} +spec: + template: + spec: + restartPolicy: "Never" + serviceAccount: {{ $name }} + containers: + - name: upload-key + image: bitnami/kubectl + command: + - sh + - -c + - | + # check if key already exists + key=$(kubectl get secret {{ $secretName }} -o jsonpath="{.data['signing\.key']}" 2> /dev/null) + [ $? -ne 0 ] && echo "Failed to get existing secret" && exit 1 + [ -n "$key" ] && echo "Key already created, exiting." && exit 0 + # wait for signing key + while [ ! -f /etc/dendrite/signing-key.pem ]; do + echo "Waiting for signing key.." + sleep 5; + done + # update secret + kubectl patch secret {{ $secretName }} -p "{\"data\":{\"signing.key\":\"$(base64 /etc/dendrite/signing-key.pem | tr -d '\n')\"}}" + [ $? -ne 0 ] && echo "Failed to update secret." && exit 1 + echo "Signing key successfully created." + volumeMounts: + - mountPath: /etc/dendrite/ + name: signing-key + readOnly: true + - name: generate-key + {{- include "image.name" $.Values.image | nindent 8 }} + command: + - sh + - -c + - | + /usr/bin/generate-keys -private-key /etc/dendrite/signing-key.pem + chown 1001:1001 /etc/dendrite/signing-key.pem + volumeMounts: + - mountPath: /etc/dendrite/ + name: signing-key + volumes: + - name: signing-key + emptyDir: {} + parallelism: 1 + completions: 1 + backoffLimit: 1 +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/templates/pvc.yaml b/helm/dendrite/templates/pvc.yaml new file mode 100644 index 000000000..897957e60 --- /dev/null +++ b/helm/dendrite/templates/pvc.yaml @@ -0,0 +1,48 @@ +{{ if not .Values.persistence.media.existingClaim }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-media-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.persistence.media.capacity }} + storageClassName: {{ .Values.persistence.storageClass }} +{{ end }} +{{ if not .Values.persistence.jetstream.existingClaim }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-jetstream-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.persistence.jetstream.capacity }} + storageClassName: {{ .Values.persistence.storageClass }} +{{ end }} +{{ if not .Values.persistence.search.existingClaim }} +--- +apiVersion: v1 +kind: PersistentVolumeClaim +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-search-pvc +spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.persistence.search.capacity }} + storageClassName: {{ .Values.persistence.storageClass }} +{{ end }} \ No newline at end of file diff --git a/helm/dendrite/templates/secrets.yaml b/helm/dendrite/templates/secrets.yaml new file mode 100644 index 000000000..d4b8ecbf2 --- /dev/null +++ b/helm/dendrite/templates/secrets.yaml @@ -0,0 +1,33 @@ +{{ if (gt (len (.Files.Glob "appservices/*")) 0) }} +--- +apiVersion: v1 +kind: Secret +metadata: + name: {{ include "dendrite.fullname" . }}-appservices-conf + namespace: {{ .Release.Namespace }} +type: Opaque +data: +{{ (.Files.Glob "appservices/*").AsSecrets | indent 2 }} +{{ end }} +{{ if and .Values.signing_key.create (not .Values.signing_key.existingSecret) }} +--- +apiVersion: v1 +kind: Secret +metadata: + annotations: + helm.sh/resource-policy: keep + name: {{ include "dendrite.fullname" . }}-signing-key + namespace: {{ .Release.Namespace }} +type: Opaque +{{ end }} + +--- +apiVersion: v1 +kind: Secret +type: Opaque +metadata: + name: {{ include "dendrite.fullname" . }}-conf + namespace: {{ .Release.Namespace }} +stringData: + dendrite.yaml: | + {{ toYaml ( mustMergeOverwrite .Values.dendrite_config ( fromYaml (include "override.config" .) ) .Values.dendrite_config ) | nindent 4 }} \ No newline at end of file diff --git a/helm/dendrite/templates/service.yaml b/helm/dendrite/templates/service.yaml new file mode 100644 index 000000000..365a43f04 --- /dev/null +++ b/helm/dendrite/templates/service.yaml @@ -0,0 +1,17 @@ +{{ template "validate.config" . }} +--- +apiVersion: v1 +kind: Service +metadata: + namespace: {{ $.Release.Namespace }} + name: {{ include "dendrite.fullname" . }} + labels: + {{- include "dendrite.labels" . | nindent 4 }} +spec: + selector: + {{- include "dendrite.selectorLabels" . | nindent 4 }} + ports: + - name: http + protocol: TCP + port: 8008 + targetPort: 8008 \ No newline at end of file diff --git a/helm/dendrite/templates/tests/test-version.yaml b/helm/dendrite/templates/tests/test-version.yaml new file mode 100644 index 000000000..d88751325 --- /dev/null +++ b/helm/dendrite/templates/tests/test-version.yaml @@ -0,0 +1,17 @@ +--- +apiVersion: v1 +kind: Pod +metadata: + name: "{{ include "dendrite.fullname" . }}-test-version" + labels: + {{- include "dendrite.selectorLabels" . | nindent 4 }} + annotations: + "helm.sh/hook": test +spec: + containers: + - name: curl + image: curlimages/curl + imagePullPolicy: IfNotPresent + args: + - 'http://{{- include "dendrite.fullname" . -}}:8008/_matrix/client/versions' + restartPolicy: Never diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml new file mode 100644 index 000000000..2c6e80942 --- /dev/null +++ b/helm/dendrite/values.yaml @@ -0,0 +1,348 @@ +image: + # -- Docker repository/image to use + name: "ghcr.io/matrix-org/dendrite-monolith:v0.10.8" + # -- Kubernetes pullPolicy + pullPolicy: IfNotPresent + + +# signing key to use +signing_key: + # -- Create a new signing key, if not exists + create: true + # -- Use an existing secret + existingSecret: "" + +# -- Default resource requests/limits. +# @default -- sets some sane default values +resources: + requests: + memory: "512Mi" + + limits: + memory: "4096Mi" + +persistence: + # -- The storage class to use for volume claims. Defaults to the + # cluster default storage class. + storageClass: "" + jetstream: + # -- Use an existing volume claim for jetstream + existingClaim: "" + # -- PVC Storage Request for the jetstream volume + capacity: "1Gi" + media: + # -- Use an existing volume claim for media files + existingClaim: "" + # -- PVC Storage Request for the media volume + capacity: "1Gi" + search: + # -- Use an existing volume claim for the fulltext search index + existingClaim: "" + # -- PVC Storage Request for the search volume + capacity: "1Gi" + +dendrite_config: + version: 2 + global: + # -- **REQUIRED** Servername for this Dendrite deployment. + server_name: "" + + # -- The private key to use. (**NOTE**: This is overriden in Helm) + private_key: /etc/dendrite/secrets/signing.key + + # -- The server name to delegate server-server communications to, with optional port + # e.g. localhost:443 + well_known_server_name: "" + + # -- The server name to delegate client-server communications to, with optional port + # e.g. localhost:443 + well_known_client_name: "" + + # -- Lists of domains that the server will trust as identity servers to verify third + # party identifiers such as phone numbers and email addresses. + trusted_third_party_id_servers: + - matrix.org + - vector.im + + # -- The paths and expiry timestamps (as a UNIX timestamp in millisecond precision) + # to old signing keys that were formerly in use on this domain name. These + # keys will not be used for federation request or event signing, but will be + # provided to any other homeserver that asks when trying to verify old events. + old_private_keys: + # If the old private key file is available: + # - private_key: old_matrix_key.pem + # expired_at: 1601024554498 + # If only the public key (in base64 format) and key ID are known: + # - public_key: mn59Kxfdq9VziYHSBzI7+EDPDcBS2Xl7jeUdiiQcOnM= + # key_id: ed25519:mykeyid + # expired_at: 1601024554498 + + # -- Disable federation. Dendrite will not be able to make any outbound HTTP requests + # to other servers and the federation API will not be exposed. + disable_federation: false + + key_validity_period: 168h0m0s + + database: + # -- The connection string for connections to Postgres. + # This will be set automatically if using the Postgres dependency + connection_string: "" + + # -- Default database maximum open connections + max_open_conns: 90 + # -- Default database maximum idle connections + max_idle_conns: 5 + # -- Default database maximum lifetime + conn_max_lifetime: -1 + + jetstream: + # -- Persistent directory to store JetStream streams in. + storage_path: "/data/jetstream" + # -- NATS JetStream server addresses if not using internal NATS. + addresses: [] + # -- The prefix for JetStream streams + topic_prefix: "Dendrite" + # -- Keep all data in memory. (**NOTE**: This is overriden in Helm to `false`) + in_memory: false + # -- Disables TLS validation. This should **NOT** be used in production. + disable_tls_validation: true + + cache: + # -- The estimated maximum size for the global cache in bytes, or in terabytes, + # gigabytes, megabytes or kilobytes when the appropriate 'tb', 'gb', 'mb' or + # 'kb' suffix is specified. Note that this is not a hard limit, nor is it a + # memory limit for the entire process. A cache that is too small may ultimately + # provide little or no benefit. + max_size_estimated: 1gb + # -- The maximum amount of time that a cache entry can live for in memory before + # it will be evicted and/or refreshed from the database. Lower values result in + # easier admission of new cache entries but may also increase database load in + # comparison to higher values, so adjust conservatively. Higher values may make + # it harder for new items to make it into the cache, e.g. if new rooms suddenly + # become popular. + max_age: 1h + + report_stats: + # -- Configures phone-home statistics reporting. These statistics contain the server + # name, number of active users and some information on your deployment config. + # We use this information to understand how Dendrite is being used in the wild. + enabled: false + # -- Endpoint to report statistics to. + endpoint: https://matrix.org/report-usage-stats/push + + presence: + # -- Controls whether we receive presence events from other servers + enable_inbound: false + # -- Controls whether we send presence events for our local users to other servers. + # (_May increase CPU/memory usage_) + enable_outbound: false + + server_notices: + # -- Server notices allows server admins to send messages to all users on the server. + enabled: false + # -- The local part for the user sending server notices. + local_part: "_server" + # -- The display name for the user sending server notices. + display_name: "Server Alerts" + # -- The avatar URL (as a mxc:// URL) name for the user sending server notices. + avatar_url: "" + # The room name to be used when sending server notices. This room name will + # appear in user clients. + room_name: "Server Alerts" + + # prometheus metrics + metrics: + # -- Whether or not Prometheus metrics are enabled. + enabled: false + # HTTP basic authentication to protect access to monitoring. + basic_auth: + # -- HTTP basic authentication username + user: "metrics" + # -- HTTP basic authentication password + password: metrics + + dns_cache: + # -- Whether or not the DNS cache is enabled. + enabled: false + # -- Maximum number of entries to hold in the DNS cache + cache_size: 256 + # -- Duration for how long DNS cache items should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) + cache_lifetime: "10m" + + profiling: + # -- Enable pprof. You will need to manually create a port forwarding to the deployment to access PPROF, + # as it will only listen on localhost and the defined port. + # e.g. `kubectl port-forward deployments/dendrite 65432:65432` + enabled: false + # -- pprof port, if enabled + port: 65432 + + # -- Configuration for experimental MSC's. (Valid values are: msc2836 and msc2946) + mscs: + mscs: + - msc2946 + # A list of enabled MSC's + # Currently valid values are: + # - msc2836 (Threading, see https://github.com/matrix-org/matrix-doc/pull/2836) + # - msc2946 (Spaces Summary, see https://github.com/matrix-org/matrix-doc/pull/2946) + + + app_service_api: + # -- Disable the validation of TLS certificates of appservices. This is + # not recommended in production since it may allow appservice traffic + # to be sent to an insecure endpoint. + disable_tls_validation: false + # -- Appservice config files to load on startup. (**NOTE**: This is overriden by Helm, if a folder `./appservices/` exists) + config_files: [] + + client_api: + # -- Prevents new users from being able to register on this homeserver, except when + # using the registration shared secret below. + registration_disabled: true + + # Prevents new guest accounts from being created. Guest registration is also + # disabled implicitly by setting 'registration_disabled' above. + guests_disabled: true + + # -- If set, allows registration by anyone who knows the shared secret, regardless of + # whether registration is otherwise disabled. + registration_shared_secret: "" + + # -- enable reCAPTCHA registration + enable_registration_captcha: false + # -- reCAPTCHA public key + recaptcha_public_key: "" + # -- reCAPTCHA private key + recaptcha_private_key: "" + # -- reCAPTCHA bypass secret + recaptcha_bypass_secret: "" + recaptcha_siteverify_api: "" + + # TURN server information that this homeserver should send to clients. + turn: + # -- Duration for how long users should be considered valid ([see time.ParseDuration](https://pkg.go.dev/time#ParseDuration) for more) + turn_user_lifetime: "24h" + turn_uris: [] + turn_shared_secret: "" + # -- The TURN username + turn_username: "" + # -- The TURN password + turn_password: "" + + rate_limiting: + # -- Enable rate limiting + enabled: true + # -- After how many requests a rate limit should be activated + threshold: 20 + # -- Cooloff time in milliseconds + cooloff_ms: 500 + # -- Users which should be exempt from rate limiting + exempt_user_ids: + + federation_api: + # -- Federation failure threshold. How many consecutive failures that we should + # tolerate when sending federation requests to a specific server. The backoff + # is 2**x seconds, so 1 = 2 seconds, 2 = 4 seconds, 3 = 8 seconds, etc. + # The default value is 16 if not specified, which is circa 18 hours. + send_max_retries: 16 + # -- Disable TLS validation. This should **NOT** be used in production. + disable_tls_validation: false + prefer_direct_fetch: false + # -- Prevents Dendrite from keeping HTTP connections + # open for reuse for future requests. Connections will be closed quicker + # but we may spend more time on TLS handshakes instead. + disable_http_keepalives: false + # -- Perspective keyservers, to use as a backup when direct key fetch + # requests don't succeed. + # @default -- See value.yaml + key_perspectives: + - server_name: matrix.org + keys: + - key_id: ed25519:auto + public_key: Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw + - key_id: ed25519:a_RXGa + public_key: l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ + + media_api: + # -- The path to store media files (e.g. avatars) in + base_path: "/data/media_store" + # -- The max file size for uploaded media files + max_file_size_bytes: 10485760 + # Whether to dynamically generate thumbnails if needed. + dynamic_thumbnails: false + # -- The maximum number of simultaneous thumbnail generators to run. + max_thumbnail_generators: 10 + # -- A list of thumbnail sizes to be generated for media content. + # @default -- See value.yaml + thumbnail_sizes: + - width: 32 + height: 32 + method: crop + - width: 96 + height: 96 + method: crop + - width: 640 + height: 480 + method: scale + + sync_api: + # -- This option controls which HTTP header to inspect to find the real remote IP + # address of the client. This is likely required if Dendrite is running behind + # a reverse proxy server. + real_ip_header: X-Real-IP + # -- Configuration for the full-text search engine. + search: + # -- Whether fulltext search is enabled. + enabled: true + # -- The path to store the search index in. + index_path: "/data/search" + # -- The language most likely to be used on the server - used when indexing, to + # ensure the returned results match expectations. A full list of possible languages + # can be found [here](https://github.com/matrix-org/dendrite/blob/76db8e90defdfb9e61f6caea8a312c5d60bcc005/internal/fulltext/bleve.go#L25-L46) + language: "en" + + user_api: + # -- bcrypt cost to use when hashing passwords. + # (ranges from 4-31; 4 being least secure, 31 being most secure; _NOTE: Using a too high value can cause clients to timeout and uses more CPU._) + bcrypt_cost: 10 + # -- OpenID Token lifetime in milliseconds. + openid_token_lifetime_ms: 3600000 + # - Disable TLS validation when hitting push gateways. This should **NOT** be used in production. + push_gateway_disable_tls_validation: false + # -- Rooms to join users to after registration + auto_join_rooms: [] + + # -- Default logging configuration + logging: + - type: std + level: info + +postgresql: + # -- Enable and configure postgres as the database for dendrite. + # @default -- See value.yaml + enabled: false + image: + repository: bitnami/postgresql + tag: "15.1.0" + auth: + username: dendrite + password: changeme + database: dendrite + + persistence: + enabled: false + +ingress: + # -- Create an ingress for a monolith deployment + enabled: false + hosts: [] + className: "" + hostName: "" + # -- Extra, custom annotations + annotations: {} + + tls: [] + +service: + type: ClusterIP + port: 80 diff --git a/setup/config/config.go b/setup/config/config.go index 6523a2452..41d2b6674 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -228,7 +228,7 @@ func loadConfig( privateKeyPath := absPath(basePath, c.Global.PrivateKeyPath) if c.Global.KeyID, c.Global.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { - return nil, err + return nil, fmt.Errorf("failed to load private_key: %w", err) } for _, v := range c.Global.VirtualHosts { @@ -242,7 +242,7 @@ func loadConfig( } privateKeyPath := absPath(basePath, v.PrivateKeyPath) if v.KeyID, v.PrivateKey, err = LoadMatrixKey(privateKeyPath, readFile); err != nil { - return nil, err + return nil, fmt.Errorf("failed to load private_key for virtualhost %s: %w", v.ServerName, err) } } From 002310390f874f69e52be1a8d20b17a3cb11d126 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:51:07 +0100 Subject: [PATCH 40/67] Output to docs folder, hopefully --- helm/cr.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/helm/cr.yaml b/helm/cr.yaml index f895ab8d6..014803cf1 100644 --- a/helm/cr.yaml +++ b/helm/cr.yaml @@ -1 +1,2 @@ -release-name-template: "helm-{{ .Name }}-{{ .Version }}" \ No newline at end of file +release-name-template: "helm-{{ .Name }}-{{ .Version }}" +package-path: docs/ \ No newline at end of file From 3fd95e60cc5fc3ae0610d0aca2177d8436a65ee1 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Fri, 6 Jan 2023 15:54:04 +0100 Subject: [PATCH 41/67] Try that again --- helm/cr.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/helm/cr.yaml b/helm/cr.yaml index 014803cf1..884c2b46b 100644 --- a/helm/cr.yaml +++ b/helm/cr.yaml @@ -1,2 +1,2 @@ release-name-template: "helm-{{ .Name }}-{{ .Version }}" -package-path: docs/ \ No newline at end of file +pages-index-path: docs/index.yaml \ No newline at end of file From 54b47a98e57cf210b568fa99ae159fd000012eb2 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 6 Jan 2023 11:49:59 -0700 Subject: [PATCH 42/67] Add curl to dendrite docker containers --- Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile b/Dockerfile index ede33e635..6da555c04 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,6 +27,7 @@ RUN --mount=target=. \ # The dendrite base image # FROM alpine:latest AS dendrite-base +RUN apk --update --no-cache add curl LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" LABEL org.opencontainers.image.licenses="Apache-2.0" From 0995dc48224b90432e38fa92345cf5735bca6090 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Fri, 6 Jan 2023 12:02:43 -0700 Subject: [PATCH 43/67] Add curl to dendrite-demo-pinecone docker container --- build/docker/Dockerfile.demo-pinecone | 1 + 1 file changed, 1 insertion(+) diff --git a/build/docker/Dockerfile.demo-pinecone b/build/docker/Dockerfile.demo-pinecone index facd1e3af..90f515167 100644 --- a/build/docker/Dockerfile.demo-pinecone +++ b/build/docker/Dockerfile.demo-pinecone @@ -17,6 +17,7 @@ RUN go build -trimpath -o bin/ ./cmd/create-account RUN go build -trimpath -o bin/ ./cmd/generate-keys FROM alpine:latest +RUN apk --update --no-cache add curl LABEL org.opencontainers.image.title="Dendrite (Pinecone demo)" LABEL org.opencontainers.image.description="Next-generation Matrix homeserver written in Go" LABEL org.opencontainers.image.source="https://github.com/matrix-org/dendrite" From b0c5af6674465a3384a1b55c84325e7989ce1eb5 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 10 Jan 2023 17:02:38 +0100 Subject: [PATCH 44/67] Fix `/login` issue causing wrong device list updates (#2922) Fixes https://github.com/matrix-org/dendrite/issues/2914 and possibly https://github.com/matrix-org/dendrite/issues/2073? --- clientapi/auth/login_test.go | 5 +- clientapi/auth/password.go | 5 ++ clientapi/routing/login_test.go | 152 ++++++++++++++++++++++++++++++++ 3 files changed, 160 insertions(+), 2 deletions(-) create mode 100644 clientapi/routing/login_test.go diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go index b79c573aa..044062c42 100644 --- a/clientapi/auth/login_test.go +++ b/clientapi/auth/login_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -47,7 +48,7 @@ func TestLoginFromJSONReader(t *testing.T) { "password": "herpassword", "device_id": "adevice" }`, - WantUsername: "alice", + WantUsername: "@alice:example.com", WantDeviceID: "adevice", }, { @@ -174,7 +175,7 @@ func (ua *fakeUserInternalAPI) QueryAccountByPassword(ctx context.Context, req * return nil } res.Exists = true - res.Account = &uapi.Account{} + res.Account = &uapi.Account{UserID: userutil.MakeUserID(req.Localpart, req.ServerName)} return nil } diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index 4de2b443c..f2b0383ab 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -101,6 +101,8 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } + // If we couldn't find the user by the lower cased localpart, try the provided + // localpart as is. if !res.Exists { err = t.GetAccountByPassword(ctx, &api.QueryAccountByPasswordRequest{ Localpart: localpart, @@ -122,5 +124,8 @@ func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, } } } + // Set the user, so login.Username() can do the right thing + r.Identifier.User = res.Account.UserID + r.User = res.Account.UserID return &r.Login, nil } diff --git a/clientapi/routing/login_test.go b/clientapi/routing/login_test.go new file mode 100644 index 000000000..d429d7f8c --- /dev/null +++ b/clientapi/routing/login_test.go @@ -0,0 +1,152 @@ +package routing + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/dendrite/userapi" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestLogin(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bobUser := &test.User{ID: "@bob:test", AccountType: uapi.AccountTypeUser} + charlie := &test.User{ID: "@Charlie:test", AccountType: uapi.AccountTypeUser} + vhUser := &test.User{ID: "@vhuser:vh1"} + + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.ClientAPI.RateLimiting.Enabled = false + // add a vhost + base.Cfg.Global.VirtualHosts = append(base.Cfg.Global.VirtualHosts, &config.VirtualHost{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ServerName: "vh1"}, + }) + + rsAPI := roomserver.NewInternalAPI(base) + // Needed for /login + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + + // We mostly need the userAPI for this test, so nil for other APIs/caches etc. + Setup(base, &base.Cfg.ClientAPI, nil, nil, userAPI, nil, nil, nil, nil, nil, keyAPI, nil, &base.Cfg.MSCs, nil) + + // Create password + password := util.RandomString(8) + + // create the users + for _, u := range []*test.User{aliceAdmin, bobUser, vhUser, charlie} { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + if !userRes.AccountCreated { + t.Fatalf("account not created") + } + } + + testCases := []struct { + name string + userID string + wantOK bool + }{ + { + name: "aliceAdmin can login", + userID: aliceAdmin.ID, + wantOK: true, + }, + { + name: "bobUser can login", + userID: bobUser.ID, + wantOK: true, + }, + { + name: "vhuser can login", + userID: vhUser.ID, + wantOK: true, + }, + { + name: "bob with uppercase can login", + userID: "@Bob:test", + wantOK: true, + }, + { + name: "Charlie can login (existing uppercase)", + userID: charlie.ID, + wantOK: true, + }, + { + name: "Charlie can not login with lowercase userID", + userID: strings.ToLower(charlie.ID), + wantOK: false, + }, + } + + ctx := context.Background() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": tc.userID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + + t.Logf("Response: %s", rec.Body.String()) + // get the response + resp := loginResponse{} + if err := json.Unmarshal(rec.Body.Bytes(), &resp); err != nil { + t.Fatal(err) + } + // everything OK + if !tc.wantOK && resp.AccessToken == "" { + return + } + if tc.wantOK && resp.AccessToken == "" { + t.Fatalf("expected accessToken after successful login but got none: %+v", resp) + } + + devicesResp := &uapi.QueryDevicesResponse{} + if err := userAPI.QueryDevices(ctx, &uapi.QueryDevicesRequest{UserID: resp.UserID}, devicesResp); err != nil { + t.Fatal(err) + } + for _, dev := range devicesResp.Devices { + // We expect the userID on the device to be the same as resp.UserID + if dev.UserID != resp.UserID { + t.Fatalf("unexpected userID on device: %s", dev.UserID) + } + } + }) + } + }) +} From 7482cd2b47cf19f4da9121461950409ea8f07a12 Mon Sep 17 00:00:00 2001 From: devonh Date: Tue, 10 Jan 2023 18:09:25 +0000 Subject: [PATCH 45/67] Handle DisplayName field in admin user registration endpoint (#2935) `/_synapse/admin/v1/register` has a `displayname` field that we were previously ignoring. This handles that field and adds the displayname to the new user if one was provided. --- clientapi/routing/register.go | 28 +++++-- clientapi/routing/register_secret.go | 13 ++-- clientapi/routing/register_secret_test.go | 2 +- clientapi/routing/register_test.go | 93 +++++++++++++++++++++++ 4 files changed, 123 insertions(+), 13 deletions(-) diff --git a/clientapi/routing/register.go b/clientapi/routing/register.go index 6087bda0c..ff6a0900e 100644 --- a/clientapi/routing/register.go +++ b/clientapi/routing/register.go @@ -780,7 +780,7 @@ func handleApplicationServiceRegistration( // Don't need to worry about appending to registration stages as // application service registration is entirely separate. return completeRegistration( - req.Context(), userAPI, r.Username, r.ServerName, "", appserviceID, req.RemoteAddr, + req.Context(), userAPI, r.Username, r.ServerName, "", "", appserviceID, req.RemoteAddr, req.UserAgent(), r.Auth.Session, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeAppService, ) @@ -800,7 +800,7 @@ func checkAndCompleteFlow( if checkFlowCompleted(flow, cfg.Derived.Registration.Flows) { // This flow was completed, registration can continue return completeRegistration( - req.Context(), userAPI, r.Username, r.ServerName, r.Password, "", req.RemoteAddr, + req.Context(), userAPI, r.Username, r.ServerName, "", r.Password, "", req.RemoteAddr, req.UserAgent(), sessionID, r.InhibitLogin, r.InitialDisplayName, r.DeviceID, userapi.AccountTypeUser, ) @@ -824,10 +824,10 @@ func checkAndCompleteFlow( func completeRegistration( ctx context.Context, userAPI userapi.ClientUserAPI, - username string, serverName gomatrixserverlib.ServerName, + username string, serverName gomatrixserverlib.ServerName, displayName string, password, appserviceID, ipAddr, userAgent, sessionID string, inhibitLogin eventutil.WeakBoolean, - displayName, deviceID *string, + deviceDisplayName, deviceID *string, accType userapi.AccountType, ) util.JSONResponse { if username == "" { @@ -887,12 +887,28 @@ func completeRegistration( } } + if displayName != "" { + nameReq := userapi.PerformUpdateDisplayNameRequest{ + Localpart: username, + ServerName: serverName, + DisplayName: displayName, + } + var nameRes userapi.PerformUpdateDisplayNameResponse + err = userAPI.SetDisplayName(ctx, &nameReq, &nameRes) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.Unknown("failed to set display name: " + err.Error()), + } + } + } + var devRes userapi.PerformDeviceCreationResponse err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ Localpart: username, ServerName: serverName, AccessToken: token, - DeviceDisplayName: displayName, + DeviceDisplayName: deviceDisplayName, DeviceID: deviceID, IPAddr: ipAddr, UserAgent: userAgent, @@ -1077,5 +1093,5 @@ func handleSharedSecretRegistration(cfg *config.ClientAPI, userAPI userapi.Clien if ssrr.Admin { accType = userapi.AccountTypeAdmin } - return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) + return completeRegistration(req.Context(), userAPI, ssrr.User, cfg.Matrix.ServerName, ssrr.DisplayName, ssrr.Password, "", req.RemoteAddr, req.UserAgent(), "", false, &ssrr.User, &deviceID, accType) } diff --git a/clientapi/routing/register_secret.go b/clientapi/routing/register_secret.go index 1a974b77a..f384b604a 100644 --- a/clientapi/routing/register_secret.go +++ b/clientapi/routing/register_secret.go @@ -18,12 +18,13 @@ import ( ) type SharedSecretRegistrationRequest struct { - User string `json:"username"` - Password string `json:"password"` - Nonce string `json:"nonce"` - MacBytes []byte - MacStr string `json:"mac"` - Admin bool `json:"admin"` + User string `json:"username"` + Password string `json:"password"` + Nonce string `json:"nonce"` + MacBytes []byte + MacStr string `json:"mac"` + Admin bool `json:"admin"` + DisplayName string `json:"displayname,omitempty"` } func NewSharedSecretRegistrationRequest(reader io.ReadCloser) (*SharedSecretRegistrationRequest, error) { diff --git a/clientapi/routing/register_secret_test.go b/clientapi/routing/register_secret_test.go index a2ed35853..ca265d237 100644 --- a/clientapi/routing/register_secret_test.go +++ b/clientapi/routing/register_secret_test.go @@ -10,7 +10,7 @@ import ( func TestSharedSecretRegister(t *testing.T) { // these values have come from a local synapse instance to ensure compatibility - jsonStr := []byte(`{"admin":false,"mac":"f1ba8d37123866fd659b40de4bad9b0f8965c565","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice"}`) + jsonStr := []byte(`{"admin":false,"mac":"f1ba8d37123866fd659b40de4bad9b0f8965c565","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) sharedSecret := "dendritetest" req, err := NewSharedSecretRegistrationRequest(io.NopCloser(bytes.NewBuffer(jsonStr))) diff --git a/clientapi/routing/register_test.go b/clientapi/routing/register_test.go index b8fd19e90..bccc1b79b 100644 --- a/clientapi/routing/register_test.go +++ b/clientapi/routing/register_test.go @@ -18,6 +18,7 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "net/http/httptest" "reflect" @@ -35,7 +36,10 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/dendrite/userapi" + "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/util" + "github.com/patrickmn/go-cache" + "github.com/stretchr/testify/assert" ) var ( @@ -570,3 +574,92 @@ func Test_register(t *testing.T) { } }) } + +func TestRegisterUserWithDisplayName(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.Global.ServerName = "server" + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + deviceName, deviceID := "deviceName", "deviceID" + expectedDisplayName := "DisplayName" + response := completeRegistration( + base.Context(), + userAPI, + "user", + "server", + expectedDisplayName, + "password", + "", + "localhost", + "user agent", + "session", + false, + &deviceName, + &deviceID, + api.AccountTypeAdmin, + ) + + assert.Equal(t, http.StatusOK, response.Code) + + req := api.QueryProfileRequest{UserID: "@user:server"} + var res api.QueryProfileResponse + err := userAPI.QueryProfile(base.Context(), &req, &res) + assert.NoError(t, err) + assert.Equal(t, expectedDisplayName, res.DisplayName) + }) +} + +func TestRegisterAdminUsingSharedSecret(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + base.Cfg.Global.ServerName = "server" + sharedSecret := "dendritetest" + base.Cfg.ClientAPI.RegistrationSharedSecret = sharedSecret + + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + keyAPI.SetUserAPI(userAPI) + + expectedDisplayName := "rabbit" + jsonStr := []byte(`{"admin":true,"mac":"24dca3bba410e43fe64b9b5c28306693bf3baa9f","nonce":"759f047f312b99ff428b21d581256f8592b8976e58bc1b543972dc6147e529a79657605b52d7becd160ff5137f3de11975684319187e06901955f79e5a6c5a79","password":"wonderland","username":"alice","displayname":"rabbit"}`) + req, err := NewSharedSecretRegistrationRequest(io.NopCloser(bytes.NewBuffer(jsonStr))) + assert.NoError(t, err) + if err != nil { + t.Fatalf("failed to read request: %s", err) + } + + r := NewSharedSecretRegistration(sharedSecret) + + // force the nonce to be known + r.nonces.Set(req.Nonce, true, cache.DefaultExpiration) + + _, err = r.IsValidMacLogin(req.Nonce, req.User, req.Password, req.Admin, req.MacBytes) + assert.NoError(t, err) + + body := &bytes.Buffer{} + err = json.NewEncoder(body).Encode(req) + assert.NoError(t, err) + ssrr := httptest.NewRequest(http.MethodPost, "/", body) + + response := handleSharedSecretRegistration( + &base.Cfg.ClientAPI, + userAPI, + r, + ssrr, + ) + assert.Equal(t, http.StatusOK, response.Code) + + profilReq := api.QueryProfileRequest{UserID: "@alice:server"} + var profileRes api.QueryProfileResponse + err = userAPI.QueryProfile(base.Context(), &profilReq, &profileRes) + assert.NoError(t, err) + assert.Equal(t, expectedDisplayName, profileRes.DisplayName) + }) +} From 97ebd72b5a731decdf8f67742179e1adc0f9f30d Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Tue, 10 Jan 2023 16:26:41 -0700 Subject: [PATCH 46/67] Add FAQs based on commonly asked questions from the community --- docs/FAQ.md | 54 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/docs/FAQ.md b/docs/FAQ.md index 816130515..4047bfffc 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -6,6 +6,12 @@ permalink: /faq # FAQ +## Why does Dendrite exist? + +Dendrite aims to provide a matrix compatible server that has low resource usage compared to [Synapse](https://github.com/matrix-org/synapse). +It also aims to provide more flexibility when scaling either up or down. +Dendrite's code is also very easy to hack on which makes it suitable for experimenting with new matrix features such as peer-to-peer. + ## Is Dendrite stable? Mostly, although there are still bugs and missing features. If you are a confident power user and you are happy to spend some time debugging things when they go wrong, then please try out Dendrite. If you are a community, organisation or business that demands stability and uptime, then Dendrite is not for you yet - please install Synapse instead. @@ -34,6 +40,10 @@ No, Dendrite has a very different database schema to Synapse and the two are not Monolith deployments are always preferred where possible, and at this time, are far better tested than polylith deployments are. The only reason to consider a polylith deployment is if you wish to run different Dendrite components on separate physical machines, but this is an advanced configuration which we don't recommend. +## Can I configure which port Dendrite listens on? + +Yes, use the cli flag `-http-bind-address`. + ## I've installed Dendrite but federation isn't working Check the [Federation Tester](https://federationtester.matrix.org). You need at least: @@ -42,6 +52,10 @@ Check the [Federation Tester](https://federationtester.matrix.org). You need at * A valid TLS certificate for that DNS name * Either DNS SRV records or well-known files +## Whenever I try to connect from Element it says unable to connect to homeserver + +Check that your dendrite instance is running. Otherwise this is most likely due to a reverse proxy misconfiguration. + ## Does Dendrite work with my favourite client? It should do, although we are aware of some minor issues: @@ -49,6 +63,10 @@ It should do, although we are aware of some minor issues: * **Element Android**: registration does not work, but logging in with an existing account does * **Hydrogen**: occasionally sync can fail due to gaps in the `since` parameter, but clearing the cache fixes this +## Is there a public instance of Dendrite I can try out? + +Use [dendrite.matrix.org](https://dendrite.matrix.org) which we officially support. + ## Does Dendrite support Space Summaries? Yes, [Space Summaries](https://github.com/matrix-org/matrix-spec-proposals/pull/2946) were merged into the Matrix Spec as of 2022-01-17 however, they are still treated as an MSC (Matrix Specification Change) in Dendrite. In order to enable Space Summaries in Dendrite, you must add the MSC to the MSC configuration section in the configuration YAML. If the MSC is not enabled, a user will typically see a perpetual loading icon on the summary page. See below for a demonstration of how to add to the Dendrite configuration: @@ -84,10 +102,42 @@ Remember to add the config file(s) to the `app_service_api` section of the confi Yes, you can do this by disabling federation - set `disable_federation` to `true` in the `global` section of the Dendrite configuration file. +## How can I migrate a room in order to change the internal ID? + +This can be done by performing a room upgrade. Use the command `/upgraderoom ` in Element to do this. + +## How do I reset somebody's password on my server? + +Use the admin endpoint [resetpassword](https://matrix-org.github.io/dendrite/administration/adminapi#post-_dendriteadminresetpassworduserid) + ## Should I use PostgreSQL or SQLite for my databases? Please use PostgreSQL wherever possible, especially if you are planning to run a homeserver that caters to more than a couple of users. +## What data needs to be kept if transferring/backing up Dendrite? + +The list of files that need to be stored is: +- matrix-key.pem +- dendrite.yaml +- the postgres or sqlite DB +- the media store +- the search index (although this can be regenerated) + +Note that this list may change / be out of date. We don't officially maintain instructions for migrations like this. + +## How can I prepare enough storage for media caches? + +This might be what you want: [matrix-media-repo](https://github.com/turt2live/matrix-media-repo) +We don't officially support this or any other dedicated media storage solutions. + +## Is there an upgrade guide for Dendrite? + +Run a newer docker image. We don't officially support deployments other than Docker. +Most of the time you should be able to just +- stop +- replace binary +- start + ## Dendrite is using a lot of CPU Generally speaking, you should expect to see some CPU spikes, particularly if you are joining or participating in large rooms. However, constant/sustained high CPU usage is not expected - if you are experiencing that, please join `#dendrite-dev:matrix.org` and let us know what you were doing when the @@ -102,6 +152,10 @@ not expected. Join `#dendrite-dev:matrix.org` and let us know what you were doin ballooned, or file a GitHub issue if you can. If you can take a [memory profile](development/PROFILING.md) then that would be a huge help too, as that will help us to understand where the memory usage is happening. +## Do I need to generate the self-signed certificate if I'm going to use a reverse proxy? + +No, if you already have a proper certificate from some provider, like Let's Encrypt, and use that on your reverse proxy, and the reverse proxy does TLS termination, then you’re good and can use HTTP to the dendrite process. + ## Dendrite is running out of PostgreSQL database connections You may need to revisit the connection limit of your PostgreSQL server and/or make changes to the `max_connections` lines in your Dendrite configuration. Be aware that each Dendrite component opens its own database connections and has its own connection limit, even in monolith mode! From 11a07d855dd7f08fcd386cb778cbdd353ddd5aa4 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 09:52:58 -0700 Subject: [PATCH 47/67] Initial attempt at adding cypress tests to ci --- .github/workflows/schedules.yaml | 35 ++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index d2a1f6e1f..63a60a241 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -73,3 +73,38 @@ jobs: path: | /logs/results.tap /logs/**/*.log* + + element_web: + runs-on: ubuntu-latest + steps: + - uses: tecolicom/actions-use-apt-tools@v1 + with: + # Our test suite includes some screenshot tests with unusual diacritics, which are + # supposed to be covered by STIXGeneral. + tools: fonts-stix + - uses: actions/checkout@v2 + with: + repository: matrix-org/matrix-react-sdk + - uses: actions/setup-node@v3 + with: + cache: 'yarn' + - name: Fetch layered build + run: scripts/ci/layered.sh + - name: Copy config + run: cp element.io/develop/config.json config.json + working-directory: ./element-web + - name: Build + env: + CI_PACKAGE: true + run: yarn build + working-directory: ./element-web + - name: "Run cypress tests" + uses: cypress-io/github-action@v4.1.1 + with: + browser: chrome + start: npx serve -p 8080 ./element-web/webapp + wait-on: 'http://localhost:8080' + env: + PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true + TMPDIR: ${{ runner.temp }} + HOMESERVER: 'dendrite' From 8fef692741f2e228c04b5f986ab65036e89947d2 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 10:10:24 -0700 Subject: [PATCH 48/67] Edit cypress config before running tests --- .github/workflows/schedules.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index 63a60a241..ba05f2083 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -98,6 +98,9 @@ jobs: CI_PACKAGE: true run: yarn build working-directory: ./element-web + - name: Edit Test Config + run: | + sed -i '/HOMESERVER/c\ HOMESERVER: "dendrite",' cypress.config.ts - name: "Run cypress tests" uses: cypress-io/github-action@v4.1.1 with: @@ -107,4 +110,3 @@ jobs: env: PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true TMPDIR: ${{ runner.temp }} - HOMESERVER: 'dendrite' From b297ea7379d6d5b953a810fe2475b549a917cc9a Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 10:40:38 -0700 Subject: [PATCH 49/67] Add cypress cloud recording --- .github/workflows/schedules.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index ba05f2083..45098925f 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -75,6 +75,7 @@ jobs: /logs/**/*.log* element_web: + timeout-minutes: 120 runs-on: ubuntu-latest steps: - uses: tecolicom/actions-use-apt-tools@v1 @@ -107,6 +108,13 @@ jobs: browser: chrome start: npx serve -p 8080 ./element-web/webapp wait-on: 'http://localhost:8080' + record: + true env: + # pass the Dashboard record key as an environment variable + CYPRESS_RECORD_KEY: ${{ secrets.CYPRESS_RECORD_KEY }} + # pass GitHub token to allow accurately detecting a build vs a re-run build + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true TMPDIR: ${{ runner.temp }} From 6ae1dd565c739efb1f847558e4170cdc0cb4085a Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 10:46:52 -0700 Subject: [PATCH 50/67] Revert "Add cypress cloud recording" This reverts commit b297ea7379d6d5b953a810fe2475b549a917cc9a. --- .github/workflows/schedules.yaml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index 45098925f..ba05f2083 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -75,7 +75,6 @@ jobs: /logs/**/*.log* element_web: - timeout-minutes: 120 runs-on: ubuntu-latest steps: - uses: tecolicom/actions-use-apt-tools@v1 @@ -108,13 +107,6 @@ jobs: browser: chrome start: npx serve -p 8080 ./element-web/webapp wait-on: 'http://localhost:8080' - record: - true env: - # pass the Dashboard record key as an environment variable - CYPRESS_RECORD_KEY: ${{ secrets.CYPRESS_RECORD_KEY }} - # pass GitHub token to allow accurately detecting a build vs a re-run build - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PUPPETEER_SKIP_CHROMIUM_DOWNLOAD: true TMPDIR: ${{ runner.temp }} From 25dfbc6ec3991ba04f317cbae4a4dd51bab6013e Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Wed, 11 Jan 2023 10:47:37 -0700 Subject: [PATCH 51/67] Extend cypress test timeout in ci --- .github/workflows/schedules.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/schedules.yaml b/.github/workflows/schedules.yaml index ba05f2083..5636c4cf9 100644 --- a/.github/workflows/schedules.yaml +++ b/.github/workflows/schedules.yaml @@ -75,6 +75,7 @@ jobs: /logs/**/*.log* element_web: + timeout-minutes: 120 runs-on: ubuntu-latest steps: - uses: tecolicom/actions-use-apt-tools@v1 From 0491a8e3436bc17535a4c57d26376af83685a97c Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 12 Jan 2023 10:06:03 +0100 Subject: [PATCH 52/67] Fix room summary returning wrong heroes (#2930) This should fix #2910. Probably makes Sytest/Complement a bit upset, since this not using `sort.Strings` anymore. --- syncapi/storage/interface.go | 2 +- .../postgres/current_room_state_table.go | 92 +++++---- syncapi/storage/postgres/memberships_table.go | 27 --- syncapi/storage/shared/storage_sync.go | 58 +++++- .../sqlite3/current_room_state_table.go | 97 +++++++--- syncapi/storage/sqlite3/memberships_table.go | 39 ---- syncapi/storage/storage_test.go | 179 ++++++++++++++++++ syncapi/storage/tables/interface.go | 4 +- syncapi/storage/tables/memberships_test.go | 36 ---- syncapi/streams/stream_pdu.go | 54 +----- 10 files changed, 378 insertions(+), 210 deletions(-) diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 75afbce15..4e22f8a6f 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -45,7 +45,7 @@ type DatabaseTransaction interface { GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error) MembershipCount(ctx context.Context, roomID, membership string, pos types.StreamPosition) (int, error) - GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) + GetRoomSummary(ctx context.Context, roomID, userID string) (summary *types.Summary, err error) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) GetBackwardTopologyPos(ctx context.Context, events []*gomatrixserverlib.HeaderedEvent) (types.TopologyToken, error) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index 48ed20021..3caafa14b 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" @@ -110,6 +111,15 @@ const selectSharedUsersSQL = "" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + ") AND type = 'm.room.member' AND state_key = ANY($2) AND membership IN ('join', 'invite');" +const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2` + +const selectRoomHeroes = ` +SELECT state_key FROM syncapi_current_room_state +WHERE type = 'm.room.member' AND room_id = $1 AND membership = ANY($2) AND state_key != $3 +ORDER BY added_at, state_key +LIMIT 5 +` + type currentRoomStateStatements struct { upsertRoomStateStmt *sql.Stmt deleteRoomStateByEventIDStmt *sql.Stmt @@ -122,6 +132,8 @@ type currentRoomStateStatements struct { selectEventsWithEventIDsStmt *sql.Stmt selectStateEventStmt *sql.Stmt selectSharedUsersStmt *sql.Stmt + selectMembershipCountStmt *sql.Stmt + selectRoomHeroesStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -141,40 +153,21 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro return nil, err } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { - return nil, err - } - if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { - return nil, err - } - if s.selectEventsWithEventIDsStmt, err = db.Prepare(selectEventsWithEventIDsSQL); err != nil { - return nil, err - } - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - if s.selectSharedUsersStmt, err = db.Prepare(selectSharedUsersSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertRoomStateStmt, upsertRoomStateSQL}, + {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL}, + {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL}, + {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL}, + {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL}, + {&s.selectCurrentStateStmt, selectCurrentStateSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, + {&s.selectJoinedUsersInRoomStmt, selectJoinedUsersInRoomSQL}, + {&s.selectEventsWithEventIDsStmt, selectEventsWithEventIDsSQL}, + {&s.selectStateEventStmt, selectStateEventSQL}, + {&s.selectSharedUsersStmt, selectSharedUsersSQL}, + {&s.selectMembershipCountStmt, selectMembershipCount}, + {&s.selectRoomHeroesStmt, selectRoomHeroes}, + }.Prepare(db) } // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. @@ -447,3 +440,34 @@ func (s *currentRoomStateStatements) SelectSharedUsers( } return result, rows.Err() } + +func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomHeroesStmt) + rows, err := stmt.QueryContext(ctx, roomID, pq.StringArray(memberships), excludeUserID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroesStmt: rows.close() failed") + + var stateKey string + result := make([]string, 0, 5) + for rows.Next() { + if err = rows.Scan(&stateKey); err != nil { + return nil, err + } + result = append(result, stateKey) + } + return result, rows.Err() +} + +func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + return count, nil +} diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index b555e8456..ac44b235f 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -19,10 +19,8 @@ import ( "database/sql" "fmt" - "github.com/lib/pq" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" + " SELECT DISTINCT ON (room_id, user_id) room_id, user_id, membership FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC" + ") t WHERE t.membership = $3" -const selectHeroesSQL = "" + - "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership = ANY($3) LIMIT 5" - const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" @@ -81,7 +76,6 @@ WHERE ($3::text IS NULL OR t.membership = $3) type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt - selectHeroesStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt selectMembersStmt *sql.Stmt } @@ -95,7 +89,6 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { return s, sqlutil.StatementList{ {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, - {&s.selectHeroesStmt, selectHeroesSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembersStmt, selectMembersSQL}, }.Prepare(db) @@ -129,26 +122,6 @@ func (s *membershipsStatements) SelectMembershipCount( return } -func (s *membershipsStatements) SelectHeroes( - ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, -) (heroes []string, err error) { - stmt := sqlutil.TxStmt(txn, s.selectHeroesStmt) - var rows *sql.Rows - rows, err = stmt.QueryContext(ctx, roomID, userID, pq.StringArray(memberships)) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") - var hero string - for rows.Next() { - if err = rows.Scan(&hero); err != nil { - return - } - heroes = append(heroes, hero) - } - return heroes, rows.Err() -} - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // string as the membership. diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 77afa0290..c6933486c 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/matrix-org/gomatrixserverlib" + "github.com/tidwall/gjson" "github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/syncapi/types" @@ -92,8 +93,61 @@ func (d *DatabaseTransaction) MembershipCount(ctx context.Context, roomID, membe return d.Memberships.SelectMembershipCount(ctx, d.txn, roomID, membership, pos) } -func (d *DatabaseTransaction) GetRoomHeroes(ctx context.Context, roomID, userID string, memberships []string) ([]string, error) { - return d.Memberships.SelectHeroes(ctx, d.txn, roomID, userID, memberships) +func (d *DatabaseTransaction) GetRoomSummary(ctx context.Context, roomID, userID string) (*types.Summary, error) { + summary := &types.Summary{Heroes: []string{}} + + joinCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Join) + if err != nil { + return summary, err + } + inviteCount, err := d.CurrentRoomState.SelectMembershipCount(ctx, d.txn, roomID, gomatrixserverlib.Invite) + if err != nil { + return summary, err + } + summary.InvitedMemberCount = &inviteCount + summary.JoinedMemberCount = &joinCount + + // Get the room name and canonical alias, if any + filter := gomatrixserverlib.DefaultStateFilter() + filterTypes := []string{gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias} + filterRooms := []string{roomID} + + filter.Types = &filterTypes + filter.Rooms = &filterRooms + evs, err := d.CurrentRoomState.SelectCurrentState(ctx, d.txn, roomID, &filter, nil) + if err != nil { + return summary, err + } + + for _, ev := range evs { + switch ev.Type() { + case gomatrixserverlib.MRoomName: + if gjson.GetBytes(ev.Content(), "name").Str != "" { + return summary, nil + } + case gomatrixserverlib.MRoomCanonicalAlias: + if gjson.GetBytes(ev.Content(), "alias").Str != "" { + return summary, nil + } + } + } + + // If there's no room name or canonical alias, get the room heroes, excluding the user + heroes, err := d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Join, gomatrixserverlib.Invite}) + if err != nil { + return summary, err + } + + // "When no joined or invited members are available, this should consist of the banned and left users" + if len(heroes) == 0 { + heroes, err = d.CurrentRoomState.SelectRoomHeroes(ctx, d.txn, roomID, userID, []string{gomatrixserverlib.Leave, gomatrixserverlib.Ban}) + if err != nil { + return summary, err + } + } + summary.Heroes = heroes + + return summary, nil } func (d *DatabaseTransaction) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) { diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index 7a381f68b..6bc1b267a 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "errors" "fmt" "strings" @@ -95,6 +96,15 @@ const selectSharedUsersSQL = "" + " SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" + ") AND type = 'm.room.member' AND state_key IN ($2) AND membership IN ('join', 'invite');" +const selectMembershipCount = `SELECT count(*) FROM syncapi_current_room_state WHERE type = 'm.room.member' AND room_id = $1 AND membership = $2` + +const selectRoomHeroes = ` +SELECT state_key FROM syncapi_current_room_state +WHERE type = 'm.room.member' AND room_id = $1 AND state_key != $2 AND membership IN ($3) +ORDER BY added_at, state_key +LIMIT 5 +` + type currentRoomStateStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -107,6 +117,8 @@ type currentRoomStateStatements struct { //selectJoinedUsersInRoomStmt *sql.Stmt - prepared at runtime due to variadic selectStateEventStmt *sql.Stmt //selectSharedUsersSQL *sql.Stmt - prepared at runtime due to variadic + selectMembershipCountStmt *sql.Stmt + //selectRoomHeroes *sql.Stmt - prepared at runtime due to variadic } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) { @@ -129,31 +141,16 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t return nil, err } - if s.upsertRoomStateStmt, err = db.Prepare(upsertRoomStateSQL); err != nil { - return nil, err - } - if s.deleteRoomStateByEventIDStmt, err = db.Prepare(deleteRoomStateByEventIDSQL); err != nil { - return nil, err - } - if s.deleteRoomStateForRoomStmt, err = db.Prepare(deleteRoomStateForRoomSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { - return nil, err - } - if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { - return nil, err - } - if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { - return nil, err - } - //if s.selectJoinedUsersInRoomStmt, err = db.Prepare(selectJoinedUsersInRoomSQL); err != nil { - // return nil, err - //} - if s.selectStateEventStmt, err = db.Prepare(selectStateEventSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.upsertRoomStateStmt, upsertRoomStateSQL}, + {&s.deleteRoomStateByEventIDStmt, deleteRoomStateByEventIDSQL}, + {&s.deleteRoomStateForRoomStmt, deleteRoomStateForRoomSQL}, + {&s.selectRoomIDsWithMembershipStmt, selectRoomIDsWithMembershipSQL}, + {&s.selectRoomIDsWithAnyMembershipStmt, selectRoomIDsWithAnyMembershipSQL}, + {&s.selectJoinedUsersStmt, selectJoinedUsersSQL}, + {&s.selectStateEventStmt, selectStateEventSQL}, + {&s.selectMembershipCountStmt, selectMembershipCount}, + }.Prepare(db) } // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. @@ -485,3 +482,53 @@ func (s *currentRoomStateStatements) SelectSharedUsers( return result, err } + +func (s *currentRoomStateStatements) SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) { + params := make([]interface{}, len(memberships)+2) + params[0] = roomID + params[1] = excludeUserID + for k, v := range memberships { + params[k+2] = v + } + + query := strings.Replace(selectRoomHeroes, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) + var stmt *sql.Stmt + var err error + if txn != nil { + stmt, err = txn.Prepare(query) + } else { + stmt, err = s.db.Prepare(query) + } + if err != nil { + return []string{}, err + } + defer internal.CloseAndLogIfError(ctx, stmt, "selectRoomHeroes: stmt.close() failed") + + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomHeroes: rows.close() failed") + + var stateKey string + result := make([]string, 0, 5) + for rows.Next() { + if err = rows.Scan(&stateKey); err != nil { + return nil, err + } + result = append(result, stateKey) + } + return result, rows.Err() +} + +func (s *currentRoomStateStatements) SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (count int, err error) { + stmt := sqlutil.TxStmt(txn, s.selectMembershipCountStmt) + err = stmt.QueryRowContext(ctx, roomID, membership).Scan(&count) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, nil + } + return 0, err + } + return count, nil +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 7e54fac17..905a1e1a8 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -18,11 +18,9 @@ import ( "context" "database/sql" "fmt" - "strings" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -64,9 +62,6 @@ const selectMembershipCountSQL = "" + " SELECT * FROM syncapi_memberships WHERE room_id = $1 AND stream_pos <= $2 GROUP BY user_id HAVING(max(stream_pos))" + ") t WHERE t.membership = $3" -const selectHeroesSQL = "" + - "SELECT DISTINCT user_id FROM syncapi_memberships WHERE room_id = $1 AND user_id != $2 AND membership IN ($3) LIMIT 5" - const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" @@ -99,7 +94,6 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembersStmt, selectMembersSQL}, - // {&s.selectHeroesStmt, selectHeroesSQL}, - prepared at runtime due to variadic }.Prepare(db) } @@ -131,39 +125,6 @@ func (s *membershipsStatements) SelectMembershipCount( return } -func (s *membershipsStatements) SelectHeroes( - ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string, -) (heroes []string, err error) { - stmtSQL := strings.Replace(selectHeroesSQL, "($3)", sqlutil.QueryVariadicOffset(len(memberships), 2), 1) - stmt, err := s.db.PrepareContext(ctx, stmtSQL) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, stmt, "SelectHeroes: stmt.close() failed") - params := []interface{}{ - roomID, userID, - } - for _, membership := range memberships { - params = append(params, membership) - } - - stmt = sqlutil.TxStmt(txn, stmt) - var rows *sql.Rows - rows, err = stmt.QueryContext(ctx, params...) - if err != nil { - return - } - defer internal.CloseAndLogIfError(ctx, rows, "SelectHeroes: rows.close() failed") - var hero string - for rows.Next() { - if err = rows.Scan(&hero); err != nil { - return - } - heroes = append(heroes, hero) - } - return heroes, rows.Err() -} - // SelectMembershipForUser returns the membership of the user before and including the given position. If no membership can be found // returns "leave", the topological position and no error. If an error occurs, other than sql.ErrNoRows, returns that and an empty // string as the membership. diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 5ff185a32..166ddd233 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" ) var ctx = context.Background() @@ -664,3 +665,181 @@ func topologyTokenBefore(t *testing.T, db storage.Database, eventID string) *typ return &tok } */ + +func pointer[t any](s t) *t { + return &s +} + +func TestRoomSummary(t *testing.T) { + + alice := test.NewUser(t) + bob := test.NewUser(t) + charlie := test.NewUser(t) + + // Create some dummy users + moreUsers := []*test.User{} + moreUserIDs := []string{} + for i := 0; i < 10; i++ { + u := test.NewUser(t) + moreUsers = append(moreUsers, u) + moreUserIDs = append(moreUserIDs, u.ID) + } + + testCases := []struct { + name string + wantSummary *types.Summary + additionalEvents func(t *testing.T, room *test.Room) + }{ + { + name: "after initial creation", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{}}, + }, + { + name: "invited user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "invited user, but declined", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "joined user after invitation", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(3), InvitedMemberCount: pointer(0), Heroes: []string{charlie.ID, bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, charlie, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined/invited user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID, bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "multiple joined/invited/left user", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(1), Heroes: []string{charlie.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(charlie.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "leaving user after joining", + wantSummary: &types.Summary{JoinedMemberCount: pointer(1), InvitedMemberCount: pointer(0), Heroes: []string{bob.ID}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + }, + { + name: "many users", // heroes ordered by stream id + wantSummary: &types.Summary{JoinedMemberCount: pointer(len(moreUserIDs) + 1), InvitedMemberCount: pointer(0), Heroes: moreUserIDs[:5]}, + additionalEvents: func(t *testing.T, room *test.Room) { + for _, x := range moreUsers { + room.CreateAndInsert(t, x, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(x.ID)) + } + }, + }, + { + name: "canonical alias set", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomCanonicalAlias, map[string]interface{}{ + "alias": "myalias", + }, test.WithStateKey("")) + }, + }, + { + name: "room name set", + wantSummary: &types.Summary{JoinedMemberCount: pointer(2), InvitedMemberCount: pointer(0), Heroes: []string{}}, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomName, map[string]interface{}{ + "name": "my room name", + }, test.WithStateKey("")) + }, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close, closeBase := MustCreateDatabase(t, dbType) + defer close() + defer closeBase() + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + + r := test.NewRoom(t, alice) + + if tc.additionalEvents != nil { + tc.additionalEvents(t, r) + } + + // write the room before creating a transaction + MustWriteEvents(t, db, r.Events()) + + transaction, err := db.NewDatabaseTransaction(ctx) + assert.NoError(t, err) + defer transaction.Rollback() + + summary, err := transaction.GetRoomSummary(ctx, r.ID, alice.ID) + assert.NoError(t, err) + assert.Equal(t, tc.wantSummary, summary) + }) + } + }) +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index a0574b257..c02e4ecc5 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -115,6 +115,9 @@ type CurrentRoomState interface { SelectJoinedUsersInRoom(ctx context.Context, txn *sql.Tx, roomIDs []string) (map[string][]string, error) // SelectSharedUsers returns a subset of otherUserIDs that share a room with userID. SelectSharedUsers(ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string) ([]string, error) + + SelectRoomHeroes(ctx context.Context, txn *sql.Tx, roomID, excludeUserID string, memberships []string) ([]string, error) + SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string) (int, error) } // BackwardsExtremities keeps track of backwards extremities for a room. @@ -185,7 +188,6 @@ type Receipts interface { type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) - SelectHeroes(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (heroes []string, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) SelectMemberships( ctx context.Context, txn *sql.Tx, diff --git a/syncapi/storage/tables/memberships_test.go b/syncapi/storage/tables/memberships_test.go index 0cee7f5a5..df593ae78 100644 --- a/syncapi/storage/tables/memberships_test.go +++ b/syncapi/storage/tables/memberships_test.go @@ -3,8 +3,6 @@ package tables_test import ( "context" "database/sql" - "reflect" - "sort" "testing" "time" @@ -88,43 +86,9 @@ func TestMembershipsTable(t *testing.T) { testUpsert(t, ctx, table, userEvents[0], alice, room) testMembershipCount(t, ctx, table, room) - testHeroes(t, ctx, table, alice, room, users) }) } -func testHeroes(t *testing.T, ctx context.Context, table tables.Memberships, user *test.User, room *test.Room, users []string) { - - // Re-slice and sort the expected users - users = users[1:] - sort.Strings(users) - type testCase struct { - name string - memberships []string - wantHeroes []string - } - - testCases := []testCase{ - {name: "no memberships queried", memberships: []string{}}, - {name: "joined memberships queried should be limited", memberships: []string{gomatrixserverlib.Join}, wantHeroes: users[:5]}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - got, err := table.SelectHeroes(ctx, nil, room.ID, user.ID, tc.memberships) - if err != nil { - t.Fatalf("unable to select heroes: %s", err) - } - if gotLen := len(got); gotLen != len(tc.wantHeroes) { - t.Fatalf("expected %d heroes, got %d", len(tc.wantHeroes), gotLen) - } - - if !reflect.DeepEqual(got, tc.wantHeroes) { - t.Fatalf("expected heroes to be %+v, got %+v", tc.wantHeroes, got) - } - }) - } -} - func testMembershipCount(t *testing.T, ctx context.Context, table tables.Memberships, room *test.Room) { t.Run("membership counts are correct", func(t *testing.T) { // After 10 events, we should have 6 users (5 create related [incl. one member event], 5 member events = 6 users) diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index dd7845574..4664276cf 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "fmt" - "sort" "time" "github.com/matrix-org/dendrite/internal/caching" @@ -14,11 +13,9 @@ import ( "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/syncapi/notifier" "github.com/matrix-org/gomatrixserverlib" "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - - "github.com/matrix-org/dendrite/syncapi/notifier" ) // The max number of per-room goroutines to have running. @@ -339,7 +336,10 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse( case gomatrixserverlib.Join: jr := types.NewJoinResponse() if hasMembershipChange { - p.addRoomSummary(ctx, snapshot, jr, delta.RoomID, device.UserID, latestPosition) + jr.Summary, err = snapshot.GetRoomSummary(ctx, delta.RoomID, device.UserID) + if err != nil { + logrus.WithError(err).Warn("failed to get room summary") + } } jr.Timeline.PrevBatch = &prevBatch jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatSync) @@ -411,45 +411,6 @@ func applyHistoryVisibilityFilter( return events, nil } -func (p *PDUStreamProvider) addRoomSummary(ctx context.Context, snapshot storage.DatabaseTransaction, jr *types.JoinResponse, roomID, userID string, latestPosition types.StreamPosition) { - // Work out how many members are in the room. - joinedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Join, latestPosition) - invitedCount, _ := snapshot.MembershipCount(ctx, roomID, gomatrixserverlib.Invite, latestPosition) - - jr.Summary.JoinedMemberCount = &joinedCount - jr.Summary.InvitedMemberCount = &invitedCount - - fetchStates := []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomName}, - {EventType: gomatrixserverlib.MRoomCanonicalAlias}, - } - // Check if the room has a name or a canonical alias - latestState := &roomserverAPI.QueryLatestEventsAndStateResponse{} - err := p.rsAPI.QueryLatestEventsAndState(ctx, &roomserverAPI.QueryLatestEventsAndStateRequest{StateToFetch: fetchStates, RoomID: roomID}, latestState) - if err != nil { - return - } - // Check if the room has a name or canonical alias, if so, return. - for _, ev := range latestState.StateEvents { - switch ev.Type() { - case gomatrixserverlib.MRoomName: - if gjson.GetBytes(ev.Content(), "name").Str != "" { - return - } - case gomatrixserverlib.MRoomCanonicalAlias: - if gjson.GetBytes(ev.Content(), "alias").Str != "" { - return - } - } - } - heroes, err := snapshot.GetRoomHeroes(ctx, roomID, userID, []string{"join", "invite"}) - if err != nil { - return - } - sort.Strings(heroes) - jr.Summary.Heroes = heroes -} - func (p *PDUStreamProvider) getJoinResponseForCompleteSync( ctx context.Context, snapshot storage.DatabaseTransaction, @@ -493,7 +454,10 @@ func (p *PDUStreamProvider) getJoinResponseForCompleteSync( return } - p.addRoomSummary(ctx, snapshot, jr, roomID, device.UserID, r.From) + jr.Summary, err = snapshot.GetRoomSummary(ctx, roomID, device.UserID) + if err != nil { + logrus.WithError(err).Warn("failed to get room summary") + } // We don't include a device here as we don't need to send down // transaction IDs for complete syncs, but we do it anyway because Sytest demands it for: From 477a44faa67eabba0f5d7f632b12fd6bb2d7ec5b Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Thu, 12 Jan 2023 09:22:53 -0700 Subject: [PATCH 53/67] Always initialize statistics server map --- federationapi/statistics/statistics.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 2ba99112c..0a44375c6 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -35,18 +35,13 @@ func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistic DB: db, FailuresUntilBlacklist: failuresUntilBlacklist, backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), } } // ForServer returns server statistics for the given server name. If it // does not exist, it will create empty statistics and return those. func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics { - // If the map hasn't been initialised yet then do that. - if s.servers == nil { - s.mutex.Lock() - s.servers = make(map[gomatrixserverlib.ServerName]*ServerStatistics) - s.mutex.Unlock() - } // Look up if we have statistics for this server already. s.mutex.RLock() server, found := s.servers[serverName] From eeeb3017d662ad6777c1398b325aa98bc36bae94 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 16 Jan 2023 11:52:30 +0000 Subject: [PATCH 54/67] Switch the default config option values for `recaptcha_sitekey_class` and `recaptcha_form_field` (#2939) Attempting to use the [web auth fallback mechanism](https://spec.matrix.org/v1.5/client-server-api/#fallback) for Google ReCAPTCHA with the default setting for `client_api.recaptcha_sitekey_class` of "g-recaptcha-response" results in no captcha being rendered: ![image](https://user-images.githubusercontent.com/1342360/212482321-14980045-6e20-4d59-adaa-59a01ad88367.png) I cross-checked the captcha code between [dendrite.matrix.org's fallback page](https://dendrite.matrix.org/_matrix/client/r0/auth/m.login.recaptcha/fallback/web?session=asdhjaksd) and [matrix-client.matrix.org's one](https://matrix-client.matrix.org/_matrix/client/r0/auth/m.login.recaptcha/fallback/web?session=asdhjaksd) (which both use the same captcha public key) and noticed a discrepancy in the `class` attribute of the div that renders the captcha. [ReCAPTCHA's docs state](https://developers.google.com/recaptcha/docs/v3#automatically_bind_the_challenge_to_a_button) to use "g-recaptcha" as the class for the submit button. I noticed this when user `@parappanon:parappa.party` reported that they were also seeing no captcha being rendered on their Dendrite instance. Changing `client_api.recaptcha_sitekey_class` to "g-recaptcha" caused their captcha to render properly as well. There may have been a change in the class name from ReCAPTCHA v2 to v3? The [docs for v2](https://developers.google.com/recaptcha/docs/display#auto_render) also request one uses "g-recaptcha" though. Thus I propose changing the default setting to unbreak people's recaptcha auth fallback pages. Should fix dendrite.matrix.org as well. --- setup/config/config_clientapi.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 11628b1b0..1deba6bb5 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -85,10 +85,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js" } if c.RecaptchaFormField == "" { - c.RecaptchaFormField = "g-recaptcha" + c.RecaptchaFormField = "g-recaptcha-response" } if c.RecaptchaSitekeyClass == "" { - c.RecaptchaSitekeyClass = "g-recaptcha-response" + c.RecaptchaSitekeyClass = "g-recaptcha" } checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) From 8582c7520abbfca680da9ba16e40a9a92b9fd21c Mon Sep 17 00:00:00 2001 From: Umar Getagazov Date: Tue, 17 Jan 2023 11:07:42 +0300 Subject: [PATCH 55/67] Omit state field from `/messages` response if empty (#2940) The field type is `[ClientEvent]` in the [spec](https://spec.matrix.org/v1.5/client-server-api/#get_matrixclientv3roomsroomidmessages), but right now `null` can also be returned. Omit the field completely if it's empty. Some clients (rightfully) assume it's either not present at all or it's of the right type (see https://github.com/matrix-org/matrix-react-sdk/pull/9913). ### Pull Request Checklist * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * The PR is a simple struct tag fix * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Umar Getagazov ` Signed-off-by: Umar Getagazov --- syncapi/routing/messages.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 0d740ebfc..cafba17c9 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -57,7 +57,7 @@ type messagesResp struct { StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token End string `json:"end,omitempty"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` - State []gomatrixserverlib.ClientEvent `json:"state"` + State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` } // OnIncomingMessagesRequest implements the /messages endpoint from the From 0d0280cf5ff71ec975b17d0f6dadcae7e46574b5 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 17 Jan 2023 10:08:23 +0100 Subject: [PATCH 56/67] `/sync` performance optimizations (#2927) Since #2849 there is no limit for the current state we fetch to calculate history visibility. In large rooms this can cause us to fetch thousands of membership events we don't really care about. This now only gets the state event types and senders in our timeline, which should significantly reduce the amount of events we fetch from the database. Also removes `MaxTopologicalPosition`, as it is an unnecessary DB call, given we use the result in `topological_position < $1` calls. --- federationapi/consumers/roomserver.go | 2 +- syncapi/routing/memberships.go | 19 +- syncapi/routing/messages.go | 31 +-- syncapi/storage/interface.go | 2 - .../output_room_events_topology_table.go | 19 -- syncapi/storage/shared/storage_sync.go | 17 +- .../output_room_events_topology_table.go | 16 -- syncapi/storage/storage_test.go | 88 ++++++- syncapi/storage/tables/interface.go | 2 - syncapi/streams/stream_pdu.go | 35 ++- syncapi/syncapi_test.go | 246 ++++++++++++++++++ 11 files changed, 372 insertions(+), 105 deletions(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 0c1080afa..52b5744a6 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -195,7 +195,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } // If we added new hosts, inform them about our known presence events for this room - if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { membership, _ := ore.Event.Membership() if membership == gomatrixserverlib.Join { s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 3fcc3235c..9ffdf513f 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -16,16 +16,16 @@ package routing import ( "encoding/json" + "math" "net/http" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) type getMembershipResponse struct { @@ -87,19 +87,18 @@ func GetMemberships( if err != nil { return jsonerror.InternalServerError() } + defer db.Rollback() // nolint: errcheck atToken, err := types.NewTopologyTokenFromString(at) if err != nil { + atToken = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} if queryRes.HasBeenInRoom && !queryRes.IsInRoom { // If you have left the room then this will be the members of the room when you left. atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) - } else { - // If you are joined to the room then this will be the current members of the room. - atToken, err = db.MaxTopologicalPosition(req.Context(), roomID) - } - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") - return jsonerror.InternalServerError() + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") + return jsonerror.InternalServerError() + } } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index cafba17c9..4a01ec357 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -17,6 +17,7 @@ package routing import ( "context" "fmt" + "math" "net/http" "sort" "time" @@ -177,10 +178,11 @@ func OnIncomingMessagesRequest( // If "to" isn't provided, it defaults to either the earliest stream // position (if we're going backward) or to the latest one (if we're // going forward). - to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed") - return jsonerror.InternalServerError() + to = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} + if backwardOrdering { + // go 1 earlier than the first event so we correctly fetch the earliest event + // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. + to = types.TopologyToken{} } wasToProvided = false } @@ -577,24 +579,3 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] return events, nil } - -// setToDefault returns the default value for the "to" query parameter of a -// request to /messages if not provided. It defaults to either the earliest -// topological position (if we're going backward) or to the latest one (if we're -// going forward). -// Returns an error if there was an issue with retrieving the latest position -// from the database -func setToDefault( - ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool, - roomID string, -) (to types.TopologyToken, err error) { - if backwardOrdering { - // go 1 earlier than the first event so we correctly fetch the earliest event - // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. - to = types.TopologyToken{} - } else { - to, err = snapshot.MaxTopologicalPosition(ctx, roomID) - } - - return -} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 4e22f8a6f..a4ba82327 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -84,8 +84,6 @@ type DatabaseTransaction interface { EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) - // MaxTopologicalPosition returns the highest topological position for a given room. - MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 6fab900eb..d0e99f267 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -65,14 +65,6 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" - // Select the max topological position for the room, then sort by stream position and take the highest, - // returning both topological and stream positions. -const selectMaxPositionInTopologySQL = "" + - "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + - " WHERE topological_position=(" + - "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + - ") ORDER BY stream_position DESC LIMIT 1" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" @@ -84,7 +76,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -107,9 +98,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { return nil, err } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { return nil, err } @@ -189,10 +177,3 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } - -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( - ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return -} diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index c6933486c..7b07cac5e 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "math" "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" @@ -269,16 +270,6 @@ func (d *DatabaseTransaction) BackwardExtremitiesForRoom( return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID) } -func (d *DatabaseTransaction) MaxTopologicalPosition( - ctx context.Context, roomID string, -) (types.TopologyToken, error) { - depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, err - } - return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil -} - func (d *DatabaseTransaction) EventPositionInTopology( ctx context.Context, eventID string, ) (types.TopologyToken, error) { @@ -297,11 +288,7 @@ func (d *DatabaseTransaction) StreamToTopologicalPosition( case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward return types.TopologyToken{PDUPosition: streamPos}, nil case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward - topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) - } - return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil + return types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, nil case err != nil: // some other error happened return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err) default: diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 81b264988..879456441 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -61,10 +61,6 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" -const selectMaxPositionInTopologySQL = "" + - "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 ORDER BY stream_position DESC" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" @@ -77,7 +73,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -102,9 +97,6 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { return nil, err } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { return nil, err } @@ -182,11 +174,3 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } - -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( - ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return -} diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 166ddd233..e65367d8b 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "math" "reflect" "testing" @@ -199,10 +200,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { _ = MustWriteEvents(t, db, events) WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { - from, err := snapshot.MaxTopologicalPosition(ctx, r.ID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } + from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} t.Logf("max topo pos = %+v", from) // head towards the beginning of time to := types.TopologyToken{} @@ -219,6 +217,88 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { }) } +func TestStreamToTopologicalPosition(t *testing.T) { + alice := test.NewUser(t) + r := test.NewRoom(t, alice) + + testCases := []struct { + name string + roomID string + streamPos types.StreamPosition + backwardOrdering bool + wantToken types.TopologyToken + }{ + { + name: "forward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "forward ordering not found streamPos returns max position", + roomID: r.ID, + streamPos: 100, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + { + name: "backward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "backward ordering not found streamPos returns maxDepth with param pduPosition", + roomID: r.ID, + streamPos: 100, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 5, PDUPosition: 100}, + }, + { + name: "backward non-existent room returns zero token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 0, PDUPosition: 1}, + }, + { + name: "forward non-existent room returns max token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close, closeBase := MustCreateDatabase(t, dbType) + defer close() + defer closeBase() + + txn, err := db.NewDatabaseTransaction(ctx) + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + MustWriteEvents(t, db, r.Events()) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + token, err := txn.StreamToTopologicalPosition(ctx, tc.roomID, tc.streamPos, tc.backwardOrdering) + if err != nil { + t.Fatal(err) + } + if tc.wantToken != token { + t.Fatalf("expected token %q, got %q", tc.wantToken, token) + } + }) + } + + }) +} + /* // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index c02e4ecc5..8366a67dc 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -91,8 +91,6 @@ type Topology interface { SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error) // SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to. SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) - // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. - SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 4664276cf..44013e37c 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -384,19 +384,32 @@ func applyHistoryVisibilityFilter( roomID, userID string, recentEvents []*gomatrixserverlib.HeaderedEvent, ) ([]*gomatrixserverlib.HeaderedEvent, error) { - // We need to make sure we always include the latest states events, if they are in the timeline. - // We grep at least limit * 2 events, to ensure we really get the needed events. - filter := gomatrixserverlib.DefaultStateFilter() - stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) - if err != nil { - // Not a fatal error, we can continue without the stateEvents, - // they are only needed if there are state events in the timeline. - logrus.WithError(err).Warnf("Failed to get current room state for history visibility") + // We need to make sure we always include the latest state events, if they are in the timeline. + alwaysIncludeIDs := make(map[string]struct{}) + var stateTypes []string + var senders []string + for _, ev := range recentEvents { + if ev.StateKey() != nil { + stateTypes = append(stateTypes, ev.Type()) + senders = append(senders, ev.Sender()) + } } - alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents)) - for _, ev := range stateEvents { - alwaysIncludeIDs[ev.EventID()] = struct{}{} + + // Only get the state again if there are state events in the timeline + if len(stateTypes) > 0 { + filter := gomatrixserverlib.DefaultStateFilter() + filter.Types = &stateTypes + filter.Senders = &senders + stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) + if err != nil { + return nil, fmt.Errorf("failed to get current room state for history visibility calculation: %w", err) + } + + for _, ev := range stateEvents { + alwaysIncludeIDs[ev.EventID()] = struct{}{} + } } + startTime := time.Now() events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") if err != nil { diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 483274481..666a872f8 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -521,6 +521,252 @@ func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatr } } +func TestGetMembership(t *testing.T) { + alice := test.NewUser(t) + + aliceDev := userapi.Device{ + ID: "ALICEID", + UserID: alice.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + bob := test.NewUser(t) + bobDev := userapi.Device{ + ID: "BOBID", + UserID: bob.ID, + AccessToken: "notjoinedtoanyrooms", + } + + testCases := []struct { + name string + roomID string + additionalEvents func(t *testing.T, room *test.Room) + request func(t *testing.T, room *test.Room) *http.Request + wantOK bool + wantMemberCount int + useSleep bool // :/ + }{ + { + name: "/members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "/members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + }, + { + name: "Alice leaves before Bob joins, should not be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "Alice leaves after Bob joins, should be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 2, + }, + { + name: "/joined_members - Alice leaves, shouldn't be able to see members ", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: false, + }, + { + name: "'at' specified, returns memberships before Bob joins", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "at": "t2_5", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "'membership=leave' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "membership": "leave", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=join' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "join", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=leave' & 'membership=join' specified, returns correct memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "leave", + "membership": "join", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "non-existent room ID", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", "!notavalidroom:test"), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: false, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + // Use an actual roomserver for this + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + room := test.NewRoom(t, alice) + t.Cleanup(func() { + t.Logf("running cleanup for %s", tc.name) + }) + // inject additional events + if tc.additionalEvents != nil { + tc.additionalEvents(t, room) + } + if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // wait for the events to come down sync + if tc.useSleep { + time.Sleep(time.Millisecond * 100) + } else { + syncUntil(t, base, aliceDev.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) + return gjson.Get(syncBody, path).Exists() + }) + } + + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, tc.request(t, room)) + if w.Code != 200 && tc.wantOK { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + t.Logf("[%s] Resp: %s", tc.name, w.Body.String()) + + // check we got the expected events + if tc.wantOK { + memberCount := len(gjson.GetBytes(w.Body.Bytes(), "chunk").Array()) + if memberCount != tc.wantMemberCount { + t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount) + } + } + }) + } + }) +} + func TestSendToDevice(t *testing.T) { test.WithAllDatabases(t, testSendToDevice) } From b55a7c238fb4b4db9ff4da0a25f0f83316d20f5e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 17 Jan 2023 19:04:02 +0100 Subject: [PATCH 57/67] Version 0.10.9 (#2942) --- CHANGES.md | 20 ++++++++++++++++++++ internal/version.go | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index f5a82cfe2..fa8230659 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,25 @@ # Changelog +## Dendrite 0.10.9 (2023-01-17) + +### Features + +* Stale device lists are now cleaned up on startup, removing entries for users the server doesn't share a room with anymore +* Dendrite now has its own Helm chart +* Guest access is now handled correctly (disallow joins, kick guests on revocation of guest access, as well as over federation) + +### Fixes + +* Push rules have seen several tweaks and fixes, which should, for example, fix notifications for `m.read_receipts` +* Outgoing presence will now correctly be sent to newly joined hosts +* Fixes the `/_dendrite/admin/resetPassword/{userID}` admin endpoint to use the correct variable +* Federated backfilling for medium/large rooms has been fixed +* `/login` causing wrong device list updates has been resolved +* `/sync` should now return the correct room summary heroes +* The default config options for `recaptcha_sitekey_class` and `recaptcha_form_field` are now set correctly +* `/messages` now omits empty `state` to be more spec compliant (contributed by [handlerug](https://github.com/handlerug)) +* `/sync` has been optimised to only query state events for history visibility if they are really needed + ## Dendrite 0.10.8 (2022-11-29) ### Features diff --git a/internal/version.go b/internal/version.go index 685237b9e..ff31dd784 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 8 + VersionPatch = 9 VersionTag = "" // example: "rc1" ) From 67f5c5bc1e837bbdee14d7d3388984ed8960528a Mon Sep 17 00:00:00 2001 From: genofire Date: Wed, 18 Jan 2023 08:45:34 +0100 Subject: [PATCH 58/67] =?UTF-8?q?fix(helm):=20extract=20image=20tag=20to?= =?UTF-8?q?=20value=20(and=20use=20as=20default=20from=20Chart.=E2=80=A6?= =?UTF-8?q?=20(#2934)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit improve image tag handling on the default helm way. with usage of appVersion from: https://github.com/matrix-org/dendrite/blob/0995dc48224b90432e38fa92345cf5735bca6090/helm/dendrite/Chart.yaml#L4 maybe you like to review @S7evinK ? ### Pull Request Checklist * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Geno ` --- helm/dendrite/Chart.yaml | 4 ++-- helm/dendrite/templates/_helpers.tpl | 4 +++- helm/dendrite/templates/deployment.yaml | 4 ++-- helm/dendrite/templates/jobs.yaml | 3 ++- helm/dendrite/templates/service.yaml | 2 +- helm/dendrite/values.yaml | 6 ++++-- 6 files changed, 14 insertions(+), 9 deletions(-) diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml index 15d1e6d19..6e6641c8d 100644 --- a/helm/dendrite/Chart.yaml +++ b/helm/dendrite/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: dendrite -version: "0.10.8" -appVersion: "0.10.8" +version: "0.10.9" +appVersion: "0.10.9" description: Dendrite Matrix Homeserver type: application keywords: diff --git a/helm/dendrite/templates/_helpers.tpl b/helm/dendrite/templates/_helpers.tpl index 291f351bc..026706588 100644 --- a/helm/dendrite/templates/_helpers.tpl +++ b/helm/dendrite/templates/_helpers.tpl @@ -15,9 +15,11 @@ {{- define "image.name" -}} -image: {{ .name }} +{{- with .Values.image -}} +image: {{ .repository }}:{{ .tag | default (printf "v%s" $.Chart.AppVersion) }} imagePullPolicy: {{ .pullPolicy }} {{- end -}} +{{- end -}} {{/* Expand the name of the chart. diff --git a/helm/dendrite/templates/deployment.yaml b/helm/dendrite/templates/deployment.yaml index 629ffe528..b463c7d0b 100644 --- a/helm/dendrite/templates/deployment.yaml +++ b/helm/dendrite/templates/deployment.yaml @@ -45,8 +45,8 @@ spec: persistentVolumeClaim: claimName: {{ default (print ( include "dendrite.fullname" . ) "-search-pvc") $.Values.persistence.search.existingClaim | quote }} containers: - - name: {{ $.Chart.Name }} - {{- include "image.name" $.Values.image | nindent 8 }} + - name: {{ .Chart.Name }} + {{- include "image.name" . | nindent 8 }} args: - '--config' - '/etc/dendrite/dendrite.yaml' diff --git a/helm/dendrite/templates/jobs.yaml b/helm/dendrite/templates/jobs.yaml index 76915694d..c10f358b0 100644 --- a/helm/dendrite/templates/jobs.yaml +++ b/helm/dendrite/templates/jobs.yaml @@ -8,6 +8,7 @@ metadata: name: {{ $name }} labels: app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: Role @@ -80,7 +81,7 @@ spec: name: signing-key readOnly: true - name: generate-key - {{- include "image.name" $.Values.image | nindent 8 }} + {{- include "image.name" . | nindent 8 }} command: - sh - -c diff --git a/helm/dendrite/templates/service.yaml b/helm/dendrite/templates/service.yaml index 365a43f04..3b571df1f 100644 --- a/helm/dendrite/templates/service.yaml +++ b/helm/dendrite/templates/service.yaml @@ -13,5 +13,5 @@ spec: ports: - name: http protocol: TCP - port: 8008 + port: {{ .Values.service.port }} targetPort: 8008 \ No newline at end of file diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index 2c6e80942..87027a886 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -1,8 +1,10 @@ image: # -- Docker repository/image to use - name: "ghcr.io/matrix-org/dendrite-monolith:v0.10.8" + repository: "ghcr.io/matrix-org/dendrite-monolith" # -- Kubernetes pullPolicy pullPolicy: IfNotPresent + # Overrides the image tag whose default is the chart appVersion. + tag: "" # signing key to use @@ -345,4 +347,4 @@ ingress: service: type: ClusterIP - port: 80 + port: 8008 From 738686ae686004c5efa9fe2096502cdc426c6dd8 Mon Sep 17 00:00:00 2001 From: Neil Date: Thu, 19 Jan 2023 20:02:32 +0000 Subject: [PATCH 59/67] Add `/_dendrite/admin/purgeRoom/{roomID}` (#2662) This adds a new admin endpoint `/_dendrite/admin/purgeRoom/{roomID}`. It completely erases all database entries for a given room ID. The roomserver will start by clearing all data for that room and then will generate an output event to notify downstream components (i.e. the sync API and federation API) to do the same. It does not currently clear media and it is currently not implemented for SQLite since it relies on SQL array operations right now. Co-authored-by: Neil Alexander Co-authored-by: Till Faelligen <2353100+S7evinK@users.noreply.github.com> --- clientapi/admin_test.go | 103 ++++++++++- clientapi/routing/admin.go | 32 ++++ clientapi/routing/routing.go | 6 + federationapi/consumers/roomserver.go | 15 +- federationapi/storage/interface.go | 2 + federationapi/storage/shared/storage.go | 15 ++ internal/sqlutil/sql.go | 21 +++ internal/sqlutil/sqlutil_test.go | 51 +++++- roomserver/api/api.go | 1 + roomserver/api/api_trace.go | 10 ++ roomserver/api/output.go | 8 + roomserver/api/perform.go | 8 + roomserver/internal/perform/perform_admin.go | 37 ++++ roomserver/inthttp/client.go | 12 ++ roomserver/inthttp/server.go | 5 + roomserver/roomserver_test.go | 165 ++++++++++++++++++ roomserver/storage/interface.go | 1 + .../storage/postgres/purge_statements.go | 133 ++++++++++++++ roomserver/storage/postgres/rooms_table.go | 14 ++ roomserver/storage/postgres/storage.go | 5 + roomserver/storage/shared/storage.go | 16 ++ .../storage/sqlite3/purge_statements.go | 153 ++++++++++++++++ roomserver/storage/sqlite3/rooms_table.go | 14 ++ .../storage/sqlite3/state_block_table.go | 3 +- .../storage/sqlite3/state_snapshot_table.go | 33 +++- roomserver/storage/sqlite3/storage.go | 6 + roomserver/storage/tables/interface.go | 7 + syncapi/consumers/roomserver.go | 21 +++ syncapi/storage/interface.go | 2 + .../postgres/backwards_extremities_table.go | 27 +-- syncapi/storage/postgres/invites_table.go | 31 ++-- syncapi/storage/postgres/memberships_table.go | 22 ++- .../postgres/notification_data_table.go | 12 ++ .../postgres/output_room_events_table.go | 12 ++ .../output_room_events_topology_table.go | 42 ++--- syncapi/storage/postgres/peeks_table.go | 39 +++-- syncapi/storage/postgres/receipt_table.go | 27 +-- syncapi/storage/shared/storage_consumer.go | 14 -- syncapi/storage/shared/storage_sync.go | 47 +++++ .../sqlite3/backwards_extremities_table.go | 27 +-- syncapi/storage/sqlite3/invites_table.go | 31 ++-- syncapi/storage/sqlite3/memberships_table.go | 12 ++ .../sqlite3/notification_data_table.go | 12 ++ .../sqlite3/output_room_events_table.go | 12 ++ .../output_room_events_topology_table.go | 42 ++--- syncapi/storage/sqlite3/peeks_table.go | 39 +++-- syncapi/storage/sqlite3/receipt_table.go | 27 +-- syncapi/storage/tables/interface.go | 9 + 48 files changed, 1213 insertions(+), 170 deletions(-) create mode 100644 roomserver/storage/postgres/purge_statements.go create mode 100644 roomserver/storage/sqlite3/purge_statements.go diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 0d973f350..c7ca019ff 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -7,9 +7,12 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -41,7 +44,7 @@ func TestAdminResetPassword(t *testing.T) { userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) keyAPI.SetUserAPI(userAPI) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil) + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) // Create the users in the userapi and login accessTokens := map[*test.User]string{ @@ -112,6 +115,7 @@ func TestAdminResetPassword(t *testing.T) { } for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID) if tc.requestOpt != nil { @@ -132,3 +136,100 @@ func TestAdminResetPassword(t *testing.T) { } }) } + +func TestPurgeRoom(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t) + room := test.NewRoom(t, aliceAdmin, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + room.CreateAndInsert(t, aliceAdmin, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + keyAPI.SetUserAPI(userAPI) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) + + // Create the users in the userapi and login + accessTokens := map[*test.User]string{ + aliceAdmin: "", + } + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + } + + testCases := []struct { + name string + roomID string + wantOK bool + }{ + {name: "Can purge existing room", wantOK: true, roomID: room.ID}, + {name: "Can not purge non-existent room", wantOK: false, roomID: "!doesnotexist:localhost"}, + {name: "rejects invalid room ID", wantOK: false, roomID: "@doesnotexist:localhost"}, + } + + for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/purgeRoom/"+tc.roomID) + + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + + rec := httptest.NewRecorder() + base.DendriteAdminMux.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + + }) +} diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index dbd913376..4b4dedfd1 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -1,6 +1,7 @@ package routing import ( + "context" "encoding/json" "fmt" "net/http" @@ -98,6 +99,37 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } } +func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + roomID, ok := vars["roomID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Expecting room ID."), + } + } + res := &roomserverAPI.PerformAdminPurgeRoomResponse{} + if err := rsAPI.PerformAdminPurgeRoom( + context.Background(), + &roomserverAPI.PerformAdminPurgeRoomRequest{ + RoomID: roomID, + }, + res, + ); err != nil { + return util.ErrorResponse(err) + } + if err := res.Error; err != nil { + return err.JSONResponse() + } + return util.JSONResponse{ + Code: 200, + JSON: res, + } +} + func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { if req.Body == nil { return util.JSONResponse{ diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 09c2cd02f..93f6ea901 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -165,6 +165,12 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", + httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminPurgeRoom(req, cfg, device, rsAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 52b5744a6..82a4db3f7 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/queue" @@ -90,8 +91,10 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms msg := msgs[0] // Guaranteed to exist if onMessage is called receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType)) - // Only handle events we care about - if receivedType != api.OutputTypeNewRoomEvent && receivedType != api.OutputTypeNewInboundPeek { + // Only handle events we care about, avoids unneeded unmarshalling + switch receivedType { + case api.OutputTypeNewRoomEvent, api.OutputTypeNewInboundPeek, api.OutputTypePurgeRoom: + default: return true } @@ -126,6 +129,14 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return false } + case api.OutputTypePurgeRoom: + log.WithField("room_id", output.PurgeRoom.RoomID).Warn("Purging room from federation API") + if err := s.db.PurgeRoom(ctx, output.PurgeRoom.RoomID); err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from federation API") + } else { + logrus.WithField("room_id", output.PurgeRoom.RoomID).Warn("Room purged from federation API") + } + default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 276cd9a50..2b4d905fc 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -71,4 +71,6 @@ type Database interface { GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) // DeleteExpiredEDUs cleans up expired EDUs DeleteExpiredEDUs(ctx context.Context) error + + PurgeRoom(ctx context.Context, roomID string) error } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 1e1ea9e17..6cda55725 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -259,3 +259,18 @@ func (d *Database) GetNotaryKeys( }) return sks, err } + +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge joined hosts: %w", err) + } + if err := d.FederationInboundPeeks.DeleteInboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge inbound peeks: %w", err) + } + if err := d.FederationOutboundPeeks.DeleteOutboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge outbound peeks: %w", err) + } + return nil + }) +} diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 19483b268..81c055edd 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -124,6 +124,11 @@ type QueryProvider interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } +// ExecProvider defines the interface for querys used by RunLimitedVariablesExec. +type ExecProvider interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + // SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement // SQLlite can handle. See https://www.sqlite.org/limits.html for more information. const SQLite3MaxVariables = 999 @@ -153,6 +158,22 @@ func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvide return nil } +// RunLimitedVariablesExec split up a query with more variables than the used database can handle in multiple queries. +func RunLimitedVariablesExec(ctx context.Context, query string, qp ExecProvider, variables []interface{}, limit uint) error { + var start int + for start < len(variables) { + n := minOfInts(len(variables)-start, int(limit)) + nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1) + _, err := qp.ExecContext(ctx, nextQuery, variables[start:start+n]...) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("ExecContext returned an error") + return err + } + start = start + n + } + return nil +} + // StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. type StatementList []struct { Statement **sql.Stmt diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go index 79469cddc..c40757893 100644 --- a/internal/sqlutil/sqlutil_test.go +++ b/internal/sqlutil/sqlutil_test.go @@ -3,10 +3,11 @@ package sqlutil import ( "context" "database/sql" + "errors" "reflect" "testing" - sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/DATA-DOG/go-sqlmock" ) func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { @@ -164,6 +165,54 @@ func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { } } +func TestRunLimitedVariablesExec(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + + // Query and expect two queries to be executed + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + variables := []interface{}{ + 1, 2, 3, 4, + } + + query := "DELETE FROM WHERE id IN ($1)" + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables, 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 3 parameters, still queries two times + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:3], 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 2 parameters, queries only once + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:2], 2); err != nil { + t.Fatal(err) + } + + // Test with invalid query (typo) should return an error + mock.ExpectExec(`DELTE FROM`). + WillReturnResult(sqlmock.NewResult(0, 0)). + WillReturnError(errors.New("typo in query")) + + if err = RunLimitedVariablesExec(context.Background(), "DELTE FROM", db, variables[:2], 2); err == nil { + t.Fatal("expected an error, but got none") + } +} + func assertNoError(t *testing.T, err error, msg string) { t.Helper() if err == nil { diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 420ef278a..a8228ae81 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -151,6 +151,7 @@ type ClientRoomserverAPI interface { PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error + PerformAdminPurgeRoom(ctx context.Context, req *PerformAdminPurgeRoomRequest, res *PerformAdminPurgeRoomResponse) error PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index b23263d17..166b651a2 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -137,6 +137,16 @@ func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser( return err } +func (t *RoomserverInternalAPITrace) PerformAdminPurgeRoom( + ctx context.Context, + req *PerformAdminPurgeRoomRequest, + res *PerformAdminPurgeRoomResponse, +) error { + err := t.Impl.PerformAdminPurgeRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformAdminPurgeRoom req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) PerformAdminDownloadState( ctx context.Context, req *PerformAdminDownloadStateRequest, diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 36d0625c7..0c0f52c45 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -55,6 +55,8 @@ const ( OutputTypeNewInboundPeek OutputType = "new_inbound_peek" // OutputTypeRetirePeek indicates that the kafka event is an OutputRetirePeek OutputTypeRetirePeek OutputType = "retire_peek" + // OutputTypePurgeRoom indicates the event is an OutputPurgeRoom + OutputTypePurgeRoom OutputType = "purge_room" ) // An OutputEvent is an entry in the roomserver output kafka log. @@ -78,6 +80,8 @@ type OutputEvent struct { NewInboundPeek *OutputNewInboundPeek `json:"new_inbound_peek,omitempty"` // The content of event with type OutputTypeRetirePeek RetirePeek *OutputRetirePeek `json:"retire_peek,omitempty"` + // The content of the event with type OutputPurgeRoom + PurgeRoom *OutputPurgeRoom `json:"purge_room,omitempty"` } // Type of the OutputNewRoomEvent. @@ -257,3 +261,7 @@ type OutputRetirePeek struct { UserID string DeviceID string } + +type OutputPurgeRoom struct { + RoomID string +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e789b9568..83cb0460a 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -241,6 +241,14 @@ type PerformAdminEvacuateUserResponse struct { Error *PerformError } +type PerformAdminPurgeRoomRequest struct { + RoomID string `json:"room_id"` +} + +type PerformAdminPurgeRoomResponse struct { + Error *PerformError `json:"error,omitempty"` +} + type PerformAdminDownloadStateRequest struct { RoomID string `json:"room_id"` UserID string `json:"user_id"` diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index d42f4e45d..3256162b4 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) type Admin struct { @@ -242,6 +243,42 @@ func (r *Admin) PerformAdminEvacuateUser( return nil } +func (r *Admin) PerformAdminPurgeRoom( + ctx context.Context, + req *api.PerformAdminPurgeRoomRequest, + res *api.PerformAdminPurgeRoomResponse, +) error { + // Validate we actually got a room ID and nothing else + if _, _, err := gomatrixserverlib.SplitID('!', req.RoomID); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Malformed room ID: %s", err), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Purging room from roomserver") + if err := r.DB.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Warn("Failed to purge room from roomserver") + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: err.Error(), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Room purged from roomserver") + + return r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{ + { + Type: api.OutputTypePurgeRoom, + PurgeRoom: &api.OutputPurgeRoom{ + RoomID: req.RoomID, + }, + }, + }) +} + func (r *Admin) PerformAdminDownloadState( ctx context.Context, req *api.PerformAdminDownloadStateRequest, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 8a2e0a03c..556a137ba 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -40,6 +40,7 @@ const ( RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom" RoomserverPerformAdminEvacuateUserPath = "/roomserver/performAdminEvacuateUser" RoomserverPerformAdminDownloadStatePath = "/roomserver/performAdminDownloadState" + RoomserverPerformAdminPurgeRoomPath = "/roomserver/performAdminPurgeRoom" // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" @@ -285,6 +286,17 @@ func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser( ) } +func (h *httpRoomserverInternalAPI) PerformAdminPurgeRoom( + ctx context.Context, + request *api.PerformAdminPurgeRoomRequest, + response *api.PerformAdminPurgeRoomResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformAdminPurgeRoom", h.roomserverURL+RoomserverPerformAdminPurgeRoomPath, + h.httpClient, ctx, request, response, + ) +} + // QueryLatestEventsAndState implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 4d21909b7..f3a51b0b1 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -65,6 +65,11 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", enableMetrics, r.PerformAdminEvacuateUser), ) + internalAPIMux.Handle( + RoomserverPerformAdminPurgeRoomPath, + httputil.MakeInternalRPCAPI("RoomserverPerformAdminPurgeRoom", enableMetrics, r.PerformAdminPurgeRoom), + ) + internalAPIMux.Handle( RoomserverPerformAdminDownloadStatePath, httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", enableMetrics, r.PerformAdminDownloadState), diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 595ceb526..3ec2560d6 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -14,6 +14,10 @@ import ( userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/roomserver" @@ -223,3 +227,164 @@ func Test_QueryLeftUsers(t *testing.T) { }) } + +func TestPurgeRoom(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + inviteEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, db, close := mustCreateDatabase(t, dbType) + defer close() + + jsCtx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsCtx, &base.Cfg.Global.JetStream) + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // some dummy entries to validate after purging + publishResp := &api.PerformPublishResponse{} + if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil { + t.Fatal(err) + } + if publishResp.Error != nil { + t.Fatal(publishResp.Error) + } + + isPublished, err := db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if !isPublished { + t.Fatalf("room should be published before purging") + } + + aliasResp := &api.SetRoomAliasResponse{} + if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", UserID: alice.ID}, aliasResp); err != nil { + t.Fatal(err) + } + // check the alias is actually there + aliasesResp := &api.GetAliasesForRoomIDResponse{} + if err = rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: room.ID}, aliasesResp); err != nil { + t.Fatal(err) + } + wantAliases := 1 + if gotAliases := len(aliasesResp.Aliases); gotAliases != wantAliases { + t.Fatalf("expected %d aliases, got %d", wantAliases, gotAliases) + } + + // validate the room exists before purging + roomInfo, err := db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo == nil { + t.Fatalf("room does not exist") + } + // remember the roomInfo before purging + existingRoomInfo := roomInfo + + // validate there is an invite for bob + nids, err := db.EventStateKeyNIDs(ctx, []string{bob.ID}) + if err != nil { + t.Fatal(err) + } + bobNID, ok := nids[bob.ID] + if !ok { + t.Fatalf("%s does not exist", bob.ID) + } + + _, inviteEventIDs, _, err := db.GetInvitesForUser(ctx, roomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + wantInviteCount := 1 + if inviteCount := len(inviteEventIDs); inviteCount != wantInviteCount { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + if inviteEventIDs[0] != inviteEvent.EventID() { + t.Fatalf("expected invite event ID %s, got %s", inviteEvent.EventID(), inviteEventIDs[0]) + } + + // purge the room from the database + purgeResp := &api.PerformAdminPurgeRoomResponse{} + if err = rsAPI.PerformAdminPurgeRoom(ctx, &api.PerformAdminPurgeRoomRequest{RoomID: room.ID}, purgeResp); err != nil { + t.Fatal(err) + } + + // wait for all consumers to process the purge event + var sum = 1 + timeout := time.Second * 5 + deadline, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for sum > 0 { + if deadline.Err() != nil { + t.Fatalf("test timed out after %s", timeout) + } + sum = 0 + consumerCh := jsCtx.Consumers(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)) + for x := range consumerCh { + sum += x.NumAckPending + } + time.Sleep(time.Millisecond) + } + + roomInfo, err = db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo != nil { + t.Fatalf("room should not exist after purging: %+v", roomInfo) + } + + // validation below + + // There should be no invite left + _, inviteEventIDs, _, err = db.GetInvitesForUser(ctx, existingRoomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + + if inviteCount := len(inviteEventIDs); inviteCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + + // aliases should be deleted + aliases, err := db.GetAliasesForRoomID(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if aliasCount := len(aliases); aliasCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", 0, aliasCount) + } + + // published room should be deleted + isPublished, err = db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if isPublished { + t.Fatalf("room should not be published after purging") + } + }) +} diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 92bc2e66f..e0b9c56b3 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -173,5 +173,6 @@ type Database interface { GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) + PurgeRoom(ctx context.Context, roomID string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/postgres/purge_statements.go b/roomserver/storage/postgres/purge_statements.go new file mode 100644 index 000000000..efba439bd --- /dev/null +++ b/roomserver/storage/postgres/purge_statements.go @@ -0,0 +1,133 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid = ANY(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids && ANY(" + + " SELECT ARRAY_AGG(event_nid) FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id = ANY(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateBlockEntriesSQL = "" + + "DELETE FROM roomserver_state_block WHERE state_block_nid = ANY(" + + " SELECT DISTINCT UNNEST(state_block_nids) FROM roomserver_state_snapshots WHERE room_nid = $1" + + ")" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateBlockEntriesStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt +} + +func PreparePurgeStatements(db *sql.DB) (*purgeStatements, error) { + s := &purgeStatements{} + + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + {&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateBlockEntriesStmt, + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 994399532..c8346733d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -58,6 +58,9 @@ const insertRoomNIDSQL = "" + const selectRoomNIDSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1 FOR UPDATE" + const selectLatestEventNIDsSQL = "" + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" @@ -85,6 +88,7 @@ const bulkSelectRoomNIDsSQL = "" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -106,6 +110,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 23a5f79eb..872084383 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -189,6 +189,10 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db) + if err != nil { + return err + } d.Database = shared.Database{ DB: db, Cache: cache, @@ -206,6 +210,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room MembershipTable: membership, PublishedTable: published, RedactionsTable: redactions, + Purge: purge, } return nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 725cc5bc7..654b078d2 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -43,6 +43,7 @@ type Database struct { MembershipTable tables.Membership PublishedTable tables.Published RedactionsTable tables.Redactions + Purge tables.Purge GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } @@ -1445,6 +1446,21 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +// PurgeRoom removes all information about a given room from the roomserver. +// For large rooms this operation may take a considerable amount of time. +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err := d.RoomsTable.SelectRoomNIDForUpdate(ctx, txn, roomID) + if err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("room %s does not exist", roomID) + } + return fmt.Errorf("failed to lock the room: %w", err) + } + return d.Purge.PurgeRoom(ctx, txn, roomNID, roomID) + }) +} + func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/roomserver/storage/sqlite3/purge_statements.go b/roomserver/storage/sqlite3/purge_statements.go new file mode 100644 index 000000000..c7b4d27a5 --- /dev/null +++ b/roomserver/storage/sqlite3/purge_statements.go @@ -0,0 +1,153 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid IN (" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids IN(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id IN(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt + stateSnapshot *stateSnapshotStatements +} + +func PreparePurgeStatements(db *sql.DB, stateSnapshot *stateSnapshotStatements) (*purgeStatements, error) { + s := &purgeStatements{stateSnapshot: stateSnapshot} + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + //{&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + if err := s.purgeStateBlocks(ctx, txn, roomNID); err != nil { + return err + } + + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} + +func (s *purgeStatements) purgeStateBlocks( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) error { + // Get all stateBlockNIDs + stateBlockNIDs, err := s.stateSnapshot.selectStateBlockNIDsForRoomNID(ctx, txn, roomNID) + if err != nil { + return err + } + params := make([]interface{}, len(stateBlockNIDs)) + seenNIDs := make(map[types.StateBlockNID]struct{}, len(stateBlockNIDs)) + // dedupe NIDs + for k, v := range stateBlockNIDs { + if _, ok := seenNIDs[v]; ok { + continue + } + params[k] = v + seenNIDs[v] = struct{}{} + } + + query := "DELETE FROM roomserver_state_block WHERE state_block_nid IN($1)" + return sqlutil.RunLimitedVariablesExec(ctx, query, txn, params, sqlutil.SQLite3MaxVariables) +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 25b611b3e..7556b3461 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -74,10 +74,14 @@ const bulkSelectRoomIDsSQL = "" + const bulkSelectRoomNIDsSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -105,6 +109,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, }.Prepare(db) } @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 4e67d4da1..ae8181cfa 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -68,7 +67,7 @@ func CreateStateBlockTable(db *sql.DB) error { return err } -func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func PrepareStateBlockTable(db *sql.DB) (*stateBlockStatements, error) { s := &stateBlockStatements{ db: db, } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 73827522c..930ad14dd 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -62,10 +62,14 @@ const bulkSelectStateBlockNIDsSQL = "" + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" +const selectStateBlockNIDsForRoomNID = "" + + "SELECT state_block_nids FROM roomserver_state_snapshots WHERE room_nid = $1" + type stateSnapshotStatements struct { db *sql.DB insertStateStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt + selectStateBlockNIDsStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -73,7 +77,7 @@ func CreateStateSnapshotTable(db *sql.DB) error { return err } -func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) { s := &stateSnapshotStatements{ db: db, } @@ -81,6 +85,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + {&s.selectStateBlockNIDsStmt, selectStateBlockNIDsForRoomNID}, }.Prepare(db) } @@ -146,3 +151,29 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( ) ([]types.EventNID, error) { return nil, tables.OptimisationNotSupportedError } + +func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.StateBlockNID, error) { + var res []types.StateBlockNID + rows, err := sqlutil.TxStmt(txn, s.selectStateBlockNIDsStmt).QueryContext(ctx, roomNID) + if err != nil { + return res, nil + } + defer internal.CloseAndLogIfError(ctx, rows, "selectStateBlockNIDsForRoomNID: rows.close() failed") + + var stateBlockNIDs []types.StateBlockNID + var stateBlockNIDsJSON string + for rows.Next() { + if err = rows.Scan(&stateBlockNIDsJSON); err != nil { + return nil, err + } + if err = json.Unmarshal([]byte(stateBlockNIDsJSON), &stateBlockNIDs); err != nil { + return nil, err + } + + res = append(res, stateBlockNIDs...) + } + + return res, rows.Err() +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 01c3f879c..392edd289 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -197,6 +197,11 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db, stateSnapshot) + if err != nil { + return err + } + d.Database = shared.Database{ DB: db, Cache: cache, @@ -215,6 +220,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room PublishedTable: published, RedactionsTable: redactions, GetRoomUpdaterFn: d.GetRoomUpdater, + Purge: purge, } return nil } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 80fcf72dd..64145f83d 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -73,6 +73,7 @@ type Events interface { type Rooms interface { InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error) SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) + SelectRoomNIDForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error @@ -173,6 +174,12 @@ type Redactions interface { MarkRedactionValidated(ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool) error } +type Purge interface { + PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, + ) error +} + // StrippedEvent represents a stripped event for returning extracted content values. type StrippedEvent struct { RoomID string diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 1b67f5684..21838039a 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -23,6 +23,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -127,6 +128,12 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms s.onRetirePeek(s.ctx, *output.RetirePeek) case api.OutputTypeRedactedEvent: err = s.onRedactEvent(s.ctx, *output.RedactedEvent) + case api.OutputTypePurgeRoom: + err = s.onPurgeRoom(s.ctx, *output.PurgeRoom) + if err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from sync API") + return true // non-fatal, as otherwise we end up in a loop of trying to purge the room + } default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", @@ -473,6 +480,20 @@ func (s *OutputRoomEventConsumer) onRetirePeek( s.notifier.OnRetirePeek(msg.RoomID, msg.UserID, msg.DeviceID, types.StreamingToken{PDUPosition: sp}) } +func (s *OutputRoomEventConsumer) onPurgeRoom( + ctx context.Context, req api.OutputPurgeRoom, +) error { + logrus.WithField("room_id", req.RoomID).Warn("Purging room from sync API") + + if err := s.db.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Error("Failed to purge room from sync API") + return err + } else { + logrus.WithField("room_id", req.RoomID).Warn("Room purged from sync API") + return nil + } +} + func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.HeaderedEvent) (*gomatrixserverlib.HeaderedEvent, error) { if event.StateKey() == nil { return event, nil diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index a4ba82327..a7a127e3a 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -132,6 +132,8 @@ type Database interface { // PurgeRoomState completely purges room state from the sync API. This is done when // receiving an output event that completely resets the state. PurgeRoomState(ctx context.Context, roomID string) error + // PurgeRoom entirely eliminates a room from the sync API, timeline, state and all. + PurgeRoom(ctx context.Context, roomID string) error // UpsertAccountData keeps track of new or updated account data, by saving the type // of the new/updated data, and the user ID and room ID the data is related to (empty) // room ID means the data isn't specific to any room) diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index 8fc92091f..c20d860a7 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -47,10 +47,14 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -59,16 +63,12 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -106,3 +106,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index aada70d5e..151bffa5d 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -62,11 +62,15 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -181,3 +179,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index ac44b235f..47833893a 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -65,18 +65,22 @@ const selectMembershipCountSQL = "" + const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + const selectMembersSQL = ` -SELECT event_id FROM ( - SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC -) t -WHERE ($3::text IS NULL OR t.membership = $3) - AND ($4::text IS NULL OR t.membership <> $4) + SELECT event_id FROM ( + SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC + ) t + WHERE ($3::text IS NULL OR t.membership = $3) + AND ($4::text IS NULL OR t.membership <> $4) ` type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt selectMembersStmt *sql.Stmt } @@ -90,6 +94,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, {&s.selectMembersStmt, selectMembersSQL}, }.Prepare(db) } @@ -139,6 +144,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go index 2c7b24800..7edfd54a6 100644 --- a/syncapi/storage/postgres/notification_data_table.go +++ b/syncapi/storage/postgres/notification_data_table.go @@ -37,6 +37,7 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, }.Prepare(db) } @@ -44,6 +45,7 @@ type notificationDataStatements struct { upsertRoomUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt } const notificationDataSchema = ` @@ -70,6 +72,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { err = sqlutil.TxStmt(txn, r.upsertRoomUnreadCounts).QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) return @@ -106,3 +111,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 3b69b26f6..0075fc8d3 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -176,6 +176,9 @@ const selectContextAfterEventSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id ASC LIMIT $3" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" type outputRoomEventsStatements struct { @@ -193,6 +196,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt selectSearchStmt *sql.Stmt } @@ -230,6 +234,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, {&s.selectSearchStmt, selectSearchSQL}, }.Prepare(db) } @@ -658,6 +663,13 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { return result, rows.Err() } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit) if err != nil { diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index d0e99f267..2382fca5c 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -18,11 +18,12 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -71,6 +72,9 @@ const selectStreamToTopologicalPositionAscSQL = "" + const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt @@ -78,6 +82,7 @@ type outputRoomEventsTopologyStatements struct { selectPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -86,25 +91,15 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // InsertEventInTopology inserts the given event in the room's topology, based @@ -177,3 +172,10 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } + +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go index e20a4882f..64183073d 100644 --- a/syncapi/storage/postgres/peeks_table.go +++ b/syncapi/storage/postgres/peeks_table.go @@ -65,6 +65,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB insertPeekStmt *sql.Stmt @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { @@ -83,25 +87,15 @@ func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { s := &peekStatements{ db: db, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -184,3 +178,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index 327a7a372..0fcbebfcb 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -62,11 +62,15 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { @@ -86,16 +90,12 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { r := &receiptStatements{ db: db, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { @@ -138,3 +138,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index df2338cf8..aeeebb1d2 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -242,20 +242,6 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e return nil } -func (d *Database) PurgeRoomState( - ctx context.Context, roomID string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. - if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { - return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) - } - return nil - }) -} - func (d *Database) WriteEvent( ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 7b07cac5e..8385b95a5 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -649,6 +649,53 @@ func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) return d.Presence.GetMaxPresenceID(ctx, d.txn) } +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.BackwardExtremities.PurgeBackwardExtremities(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge backward extremities: %w", err) + } + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge current room state: %w", err) + } + if err := d.Invites.PurgeInvites(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge invites: %w", err) + } + if err := d.Memberships.PurgeMemberships(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge memberships: %w", err) + } + if err := d.NotificationData.PurgeNotificationData(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge notification data: %w", err) + } + if err := d.OutputEvents.PurgeEvents(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events: %w", err) + } + if err := d.Topology.PurgeEventsTopology(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events topology: %w", err) + } + if err := d.Peeks.PurgePeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge peeks: %w", err) + } + if err := d.Receipts.PurgeReceipts(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge receipts: %w", err) + } + return nil + }) +} + +func (d *Database) PurgeRoomState( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) + } + return nil + }) +} + func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) { id, err := d.Relations.SelectMaxRelationID(ctx, d.txn) return types.StreamPosition(id), err diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 3a5fd6be3..2d8cf2ed2 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -47,11 +47,15 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { db *sql.DB insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -62,16 +66,12 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -109,3 +109,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index e2dbcd5c8..19450099a 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -57,6 +57,9 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -64,6 +67,7 @@ type inviteEventsStatements struct { selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Inv if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -192,3 +190,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 905a1e1a8..2cc46a10a 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -72,6 +72,9 @@ SELECT event_id FROM AND ($4 IS NULL OR t.membership <> $4) ` +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt @@ -79,6 +82,7 @@ type membershipsStatements struct { //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic selectMembershipForUserStmt *sql.Stmt selectMembersStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -94,6 +98,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembersStmt, selectMembersSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, }.Prepare(db) } @@ -142,6 +147,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index 6242898e1..af2b2c074 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -38,6 +38,7 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, // {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime }.Prepare(db) } @@ -47,6 +48,7 @@ type notificationDataStatements struct { streamIDStatements *StreamIDStatements upsertRoomUnreadCounts *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt //selectUserUnreadCountsForRooms *sql.Stmt } @@ -73,6 +75,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { pos, err = r.streamIDStatements.nextNotificationID(ctx, nil) if err != nil { @@ -124,3 +129,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 1aa4bfff7..db708c083 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -120,6 +120,9 @@ const selectContextAfterEventSQL = "" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -130,6 +133,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt //selectSearchStmt *sql.Stmt - prepared at runtime } @@ -163,6 +167,7 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, //{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime }.Prepare(db) } @@ -666,6 +671,13 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [ return } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { params := make([]interface{}, len(types)) for i := range types { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 879456441..dc698de2d 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,10 +18,11 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -67,6 +68,9 @@ const selectStreamToTopologicalPositionAscSQL = "" + const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { db *sql.DB insertEventInTopologyStmt *sql.Stmt @@ -75,6 +79,7 @@ type outputRoomEventsTopologyStatements struct { selectPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -85,25 +90,15 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // insertEventInTopology inserts the given event in the room's topology, based @@ -174,3 +169,10 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } + +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 4ef51b103..5d5200abc 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -64,6 +64,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) { @@ -84,25 +88,15 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks db: db, streamIDStatements: streamID, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -204,3 +198,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index a4a9b4395..ca3d80fb4 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -58,12 +58,16 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB streamIDStatements *StreamIDStatements upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) { @@ -84,16 +88,12 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re db: db, streamIDStatements: streamID, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } // UpsertReceipt creates new user receipts @@ -153,3 +153,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 8366a67dc..145e197cc 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -39,6 +39,7 @@ type Invites interface { // for the room. SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeInvites(ctx context.Context, txn *sql.Tx, roomID string) error } type Peeks interface { @@ -48,6 +49,7 @@ type Peeks interface { SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error) SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgePeeks(ctx context.Context, txn *sql.Tx, roomID string) error } type Events interface { @@ -75,6 +77,8 @@ type Events interface { SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + + PurgeEvents(ctx context.Context, txn *sql.Tx, roomID string) error ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) } @@ -93,6 +97,7 @@ type Topology interface { SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) + PurgeEventsTopology(ctx context.Context, txn *sql.Tx, roomID string) error } type CurrentRoomState interface { @@ -146,6 +151,7 @@ type BackwardsExtremities interface { SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) + PurgeBackwardExtremities(ctx context.Context, txn *sql.Tx, roomID string) error } // SendToDevice tracks send-to-device messages which are sent to individual @@ -181,12 +187,14 @@ type Receipts interface { UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeReceipts(ctx context.Context, txn *sql.Tx, roomID string) error } type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + PurgeMemberships(ctx context.Context, txn *sql.Tx, roomID string) error SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, @@ -198,6 +206,7 @@ type NotificationData interface { UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) + PurgeNotificationData(ctx context.Context, txn *sql.Tx, roomID string) error } type Ignores interface { From ce2bfc3f2e507a012044906af7f25c9dc52873d7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 20 Jan 2023 12:45:56 +0100 Subject: [PATCH 60/67] Make tests more reliable (#2948) When using `testrig.CreateBase` and then using that base for other `NewInternalAPI` calls, we never actually shutdown the components. `testrig.CreateBase` returns a `close` function, which only removes the database, so still running components have issues connecting to the database, since we ripped it out underneath it - which can result in "Disk I/O" or "pq deadlock detected" issues. --- federationapi/federationapi.go | 5 +---- federationapi/federationapi_test.go | 6 +++--- federationapi/routing/routing.go | 17 ++++++++++++----- setup/base/base.go | 2 ++ test/testrig/base.go | 12 ++++++++++-- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 87eb751f5..ce0ce98e9 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -85,10 +85,7 @@ func AddPublicRoutes( } routing.Setup( - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - cfg, + base, rsAPI, f, keyRing, federation, userAPI, keyAPI, mscCfg, servers, producer, diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 68a06a033..7009230cc 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -273,12 +273,12 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { cfg.Global.ServerName = gomatrixserverlib.ServerName("localhost") cfg.Global.PrivateKey = privKey cfg.Global.JetStream.InMemory = true - base := base.NewBaseDendrite(cfg, "Monolith") + b := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) keyRing := &test.NopJSONVerifier{} // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) - baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true) + federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) + baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 0a3ab7a88..04eb3d067 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -32,6 +32,7 @@ import ( keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -49,8 +50,7 @@ import ( // applied: // nolint: gocyclo func Setup( - fedMux, keyMux, wkMux *mux.Router, - cfg *config.FederationAPI, + base *base.BaseDendrite, rsAPI roomserverAPI.FederationRoomserverAPI, fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, @@ -61,9 +61,16 @@ func Setup( servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, ) { - prometheus.MustRegister( - pduCountTotal, eduCountTotal, - ) + fedMux := base.PublicFederationAPIMux + keyMux := base.PublicKeyAPIMux + wkMux := base.PublicWellKnownAPIMux + cfg := &base.Cfg.FederationAPI + + if base.EnableMetrics { + prometheus.MustRegister( + pduCountTotal, eduCountTotal, + ) + } v2keysmux := keyMux.PathPrefix("/v2").Subrouter() v1fedmux := fedMux.PathPrefix("/v1").Subrouter() diff --git a/setup/base/base.go b/setup/base/base.go index d3adbf53f..ff38209fb 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -264,6 +264,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base // Close implements io.Closer func (b *BaseDendrite) Close() error { + b.ProcessContext.ShutdownDendrite() + b.ProcessContext.WaitForShutdown() return b.tracerCloser.Close() } diff --git a/test/testrig/base.go b/test/testrig/base.go index 7bc26a5c5..52e6ef5f1 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -62,7 +62,12 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f MaxIdleConnections: 2, ConnMaxLifetimeSeconds: 60, } - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close + base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() + close() + } case test.DBTypeSQLite: cfg.Defaults(config.DefaultOpts{ Generate: true, @@ -72,7 +77,10 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() { + base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() // cleanup db files. This risks getting out of sync as we add more database strings :( dbFiles := []config.DataSource{ cfg.FederationAPI.Database.ConnectionString, From a2b486091218e761adc0a00ce19ed4b600e489a2 Mon Sep 17 00:00:00 2001 From: Bernhard Feichtinger <43303168+BieHDC@users.noreply.github.com> Date: Fri, 20 Jan 2023 13:13:36 +0100 Subject: [PATCH 61/67] Fix oversight in cmd/generate-config (#2946) The -dir argument was ignored for media_api->base_path. Signed-off-by: `Bernhard Feichtinger <43303168+BieHDC@users.noreply.github.com>` --- cmd/generate-config/main.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 33b18c471..5f75f5e4d 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -54,6 +54,9 @@ func main() { } else { cfg.Global.DatabaseOptions.ConnectionString = uri } + cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) + cfg.Global.JetStream.StoragePath = config.Path(*dirPath) + cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(*dirPath, "searchindex")) cfg.Logging = []config.LogrusHook{ { Type: "file", From caf310fd7976ed3fe8abbbf8cb72d380c7efd3c2 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 20 Jan 2023 15:18:06 +0100 Subject: [PATCH 62/67] AWSY missing federation tests (#2943) In an attempt to fix the missing AWSY tests and to get to 100% server-server compliance. --- are-we-synapse-yet.list | 10 +- go.mod | 2 +- go.sum | 4 + roomserver/internal/input/input_events.go | 107 +++++++++++----------- sytest-blacklist | 38 +------- sytest-whitelist | 14 ++- 6 files changed, 83 insertions(+), 92 deletions(-) diff --git a/are-we-synapse-yet.list b/are-we-synapse-yet.list index 81c0f8049..585374738 100644 --- a/are-we-synapse-yet.list +++ b/are-we-synapse-yet.list @@ -936,4 +936,12 @@ fst Room state after a rejected message event is the same as before fst Room state after a rejected state event is the same as before fpb Federation publicRoom Name/topic keys are correct fed New federated private chats get full presence information (SYN-115) (10 subtests) -dvk Rejects invalid device keys \ No newline at end of file +dvk Rejects invalid device keys +rmv User can create and send/receive messages in a room with version 10 +rmv local user can join room with version 10 +rmv User can invite local user to room with version 10 +rmv remote user can join room with version 10 +rmv User can invite remote user to room with version 10 +rmv Remote user can backfill in a room with version 10 +rmv Can reject invites over federation for rooms with version 10 +rmv Can receive redactions from regular users over federation in room version 10 \ No newline at end of file diff --git a/go.mod b/go.mod index 2d7174150..a86dd2cb8 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab + github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index b12f65eab..e5cd67bed 100644 --- a/go.sum +++ b/go.sum @@ -350,6 +350,10 @@ github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45 h1:zGrmcm2M4F4f+zk5JXAkw3oHa/zXhOh5XVGBdl7GdPo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 h1:P7me2oCmksST9B4+1I1nA+XrnDQwIqAWmy6ntQrXwc8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 4179fc1ef..67edb3217 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -24,6 +24,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/tidwall/gjson" "github.com/matrix-org/gomatrixserverlib" @@ -40,7 +41,6 @@ import ( "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -166,6 +166,7 @@ func (r *Inputer) processRoomEvent( missingPrev = !input.HasState && len(missingPrevIDs) > 0 } + // If we have missing events (auth or prev), we build a list of servers to ask if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), @@ -200,59 +201,8 @@ func (r *Inputer) processRoomEvent( } } - // First of all, check that the auth events of the event are known. - // If they aren't then we will ask the federation API for them. isRejected := false - authEvents := gomatrixserverlib.NewAuthEvents(nil) - knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.fetchAuthEvents: %w", err) - } - - // Check if the event is allowed by its auth events. If it isn't then - // we consider the event to be "rejected" — it will still be persisted. var rejectionErr error - if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { - isRejected = true - logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) - } - - // Accumulate the auth event NIDs. - authEventIDs := event.AuthEventIDs() - authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) - for _, authEventID := range authEventIDs { - if _, ok := knownEvents[authEventID]; !ok { - // Unknown auth events only really matter if the event actually failed - // auth. If it passed auth then we can assume that everything that was - // known was sufficient, even if extraneous auth events were specified - // but weren't found. - if isRejected { - if event.StateKey() != nil { - return fmt.Errorf( - "missing auth event %s for state event %s (type %q, state key %q)", - authEventID, event.EventID(), event.Type(), *event.StateKey(), - ) - } else { - return fmt.Errorf( - "missing auth event %s for timeline event %s (type %q)", - authEventID, event.EventID(), event.Type(), - ) - } - } - } else { - authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) - } - } - - var softfail bool - if input.Kind == api.KindNew { - // Check that the event passes authentication checks based on the - // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) - if err != nil { - logger.WithError(err).Warn("Error authing soft-failed event") - } - } // At this point we are checking whether we know all of the prev events, and // if we know the state before the prev events. This is necessary before we @@ -314,6 +264,59 @@ func (r *Inputer) processRoomEvent( } } + // Check that the auth events of the event are known. + // If they aren't then we will ask the federation API for them. + authEvents := gomatrixserverlib.NewAuthEvents(nil) + knownEvents := map[string]*types.Event{} + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return fmt.Errorf("r.fetchAuthEvents: %w", err) + } + + // Check if the event is allowed by its auth events. If it isn't then + // we consider the event to be "rejected" — it will still be persisted. + if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + isRejected = true + rejectionErr = err + logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) + } + + // Accumulate the auth event NIDs. + authEventIDs := event.AuthEventIDs() + authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) + for _, authEventID := range authEventIDs { + if _, ok := knownEvents[authEventID]; !ok { + // Unknown auth events only really matter if the event actually failed + // auth. If it passed auth then we can assume that everything that was + // known was sufficient, even if extraneous auth events were specified + // but weren't found. + if isRejected { + if event.StateKey() != nil { + return fmt.Errorf( + "missing auth event %s for state event %s (type %q, state key %q)", + authEventID, event.EventID(), event.Type(), *event.StateKey(), + ) + } else { + return fmt.Errorf( + "missing auth event %s for timeline event %s (type %q)", + authEventID, event.EventID(), event.Type(), + ) + } + } + } else { + authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) + } + } + + var softfail bool + if input.Kind == api.KindNew { + // Check that the event passes authentication checks based on the + // current room state. + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + if err != nil { + logger.WithError(err).Warn("Error authing soft-failed event") + } + } + // Get the state before the event so that we can work out if the event was // allowed at the time, and also to get the history visibility. We won't // bother doing this if the event was already rejected as it just ends up diff --git a/sytest-blacklist b/sytest-blacklist index 99cfbabc8..bb0ee368f 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,54 +1,18 @@ -# Relies on a rejected PL event which will never be accepted into the DAG - -# Caused by - -Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state - -# We don't implement lazy membership loading yet - +# Blacklisted due to https://github.com/matrix-org/matrix-spec/issues/942 The only membership state included in a gapped incremental sync is for senders in the timeline -# Blacklisted out of flakiness after #1479 - -Invited user can reject local invite after originator leaves -Invited user can reject invite for empty room -If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes - -# Blacklisted due to flakiness - -Forgotten room messages cannot be paginated - -# Blacklisted due to flakiness after #1774 - -Local device key changes get to remote servers with correct prev_id - -# we don't support groups - -Remove group category -Remove group role - # Flakey - AS-ghosted users can use rooms themselves AS-ghosted users can use rooms via AS Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server -# More flakey - -Guest users can join guest_access rooms - # This will fail in HTTP API mode, so blacklisted for now - If a device list update goes missing, the server resyncs on the next one # Might be a bug in the test because leaves do appear :-( - Leaves are present in non-gapped incremental syncs -# Below test was passing for the wrong reason, failing correctly since #2858 -New federated private chats get full presence information (SYN-115) - # We don't have any state to calculate m.room.guest_access when accepting invites Guest users can accept invites to private rooms over federation \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 215889a49..1f6ecc29e 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -766,4 +766,16 @@ remote user has tags copied to the new room Local and remote users' homeservers remove a room from their public directory on upgrade Guest users denied access over federation if guest access prohibited Guest users are kicked from guest_access rooms on revocation of guest_access -Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file +Guest users are kicked from guest_access rooms on revocation of guest_access over federation +User can create and send/receive messages in a room with version 10 +local user can join room with version 10 +User can invite local user to room with version 10 +remote user can join room with version 10 +User can invite remote user to room with version 10 +Remote user can backfill in a room with version 10 +Can reject invites over federation for rooms with version 10 +Can receive redactions from regular users over federation in room version 10 +New federated private chats get full presence information (SYN-115) +/state returns M_NOT_FOUND for an outlier +/state_ids returns M_NOT_FOUND for an outlier +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state \ No newline at end of file From 25cb65acdbb6702e84a6bcb6245d6d23d90c2359 Mon Sep 17 00:00:00 2001 From: Catalan Lover <48515417+FSG-Cat@users.noreply.github.com> Date: Fri, 20 Jan 2023 15:41:29 +0100 Subject: [PATCH 63/67] Change Default Room version to 10 (#2933) This PR implements [MSC3904](https://github.com/matrix-org/matrix-spec-proposals/pull/3904). This PR is almost identical to #2781 but this PR is also filed well technically 1 day before the MSC passes FCP but well everyone knows this MSC is expected to have passed FCP on monday so im refiling this change today on saturday as i was doing prep work for monday. I assume that this PR wont be counted as clogging the queue since by the next time i expect to be a work day for this project this PR will be implementing an FCP passed disposition merge MSC. Also as for the lack of tests i belive that this simple change does not need to pass new tests due to that these tests are expected to already have been passed by the successful use of Dendrite with Room version 10 already. ### Pull Request Checklist * [X] I have added tests for PR _or_ I have justified why this PR doesn't need tests. * [X] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) Signed-off-by: Catalan Lover Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com> Co-authored-by: kegsay --- roomserver/version/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roomserver/version/version.go b/roomserver/version/version.go index 729d00a80..c40d8e0f7 100644 --- a/roomserver/version/version.go +++ b/roomserver/version/version.go @@ -23,7 +23,7 @@ import ( // DefaultRoomVersion contains the room version that will, by // default, be used to create new rooms on this server. func DefaultRoomVersion() gomatrixserverlib.RoomVersion { - return gomatrixserverlib.RoomVersionV9 + return gomatrixserverlib.RoomVersionV10 } // RoomVersions returns a map of all known room versions to this From 430932f0f161dd836c98082ff97b57beedec02e6 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 20 Jan 2023 16:20:01 +0100 Subject: [PATCH 64/67] Version 0.11.0 (#2949) --- CHANGES.md | 14 ++++++++++++++ helm/dendrite/Chart.yaml | 4 ++-- helm/dendrite/README.md | 7 ++++--- helm/dendrite/values.yaml | 2 +- internal/version.go | 4 ++-- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index fa8230659..e1f7affb5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,19 @@ # Changelog +## Dendrite 0.11.0 (2023-01-20) + +The last three missing federation API Sytests have been fixed - bringing us to 100% server-server Synapse parity, with client-server parity at 93% 🎉 + +### Features + +* Added `/_dendrite/admin/purgeRoom/{roomID}` to clean up the database +* The default room version was updated to 10 (contributed by [FSG-Cat](https://github.com/FSG-Cat)) + +### Fixes + +* An oversight in the `create-config` binary, which now correctly sets the media path if specified (contributed by [BieHDC](https://github.com/BieHDC)) +* The Helm chart now uses the `$.Chart.AppVersion` as the default image version to pull, with the possibility to override it (contributed by [genofire](https://github.com/genofire)) + ## Dendrite 0.10.9 (2023-01-17) ### Features diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml index 6e6641c8d..174fc5496 100644 --- a/helm/dendrite/Chart.yaml +++ b/helm/dendrite/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: dendrite -version: "0.10.9" -appVersion: "0.10.9" +version: "0.11.0" +appVersion: "0.11.0" description: Dendrite Matrix Homeserver type: application keywords: diff --git a/helm/dendrite/README.md b/helm/dendrite/README.md index cb850d655..6a1658429 100644 --- a/helm/dendrite/README.md +++ b/helm/dendrite/README.md @@ -1,6 +1,6 @@ # dendrite -![Version: 0.10.8](https://img.shields.io/badge/Version-0.10.8-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.10.8](https://img.shields.io/badge/AppVersion-0.10.8-informational?style=flat-square) +![Version: 0.11.0](https://img.shields.io/badge/Version-0.11.0-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.11.0](https://img.shields.io/badge/AppVersion-0.11.0-informational?style=flat-square) Dendrite Matrix Homeserver Status: **NOT PRODUCTION READY** @@ -41,8 +41,9 @@ Create a folder `appservices` and place your configurations in there. The confi | Key | Type | Default | Description | |-----|------|---------|-------------| -| image.name | string | `"ghcr.io/matrix-org/dendrite-monolith:v0.10.8"` | Docker repository/image to use | +| image.repository | string | `"ghcr.io/matrix-org/dendrite-monolith"` | Docker repository/image to use | | image.pullPolicy | string | `"IfNotPresent"` | Kubernetes pullPolicy | +| image.tag | string | `""` | Overrides the image tag whose default is the chart appVersion. | | signing_key.create | bool | `true` | Create a new signing key, if not exists | | signing_key.existingSecret | string | `""` | Use an existing secret | | resources | object | sets some sane default values | Default resource requests/limits. | @@ -144,4 +145,4 @@ Create a folder `appservices` and place your configurations in there. The confi | ingress.annotations | object | `{}` | Extra, custom annotations | | ingress.tls | list | `[]` | | | service.type | string | `"ClusterIP"` | | -| service.port | int | `80` | | +| service.port | int | `8008` | | diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index 87027a886..848241ab6 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -3,7 +3,7 @@ image: repository: "ghcr.io/matrix-org/dendrite-monolith" # -- Kubernetes pullPolicy pullPolicy: IfNotPresent - # Overrides the image tag whose default is the chart appVersion. + # -- Overrides the image tag whose default is the chart appVersion. tag: "" diff --git a/internal/version.go b/internal/version.go index ff31dd784..fbe4a01b0 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 10 - VersionPatch = 9 + VersionMinor = 11 + VersionPatch = 0 VersionTag = "" // example: "rc1" ) From 48fa869fa3578741d1d5775d30f24f6b097ab995 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 23 Jan 2023 13:17:15 +0100 Subject: [PATCH 65/67] Use `t.TempDir` for SQLite databases, so tests don't rip out each others databases (#2950) This should hopefully finally fix issues about `disk I/O error` as seen [here](https://gitlab.alpinelinux.org/alpine/aports/-/jobs/955030/raw) Hopefully this will also fix `SSL accept attempt failed` issues by disabling HTTP keep alives when generating a config for CI. --- cmd/generate-config/main.go | 1 + internal/log.go | 2 ++ internal/log_unix.go | 2 ++ setup/jetstream/helpers.go | 5 +++++ sytest-blacklist | 1 + sytest-whitelist | 7 ++++++- test/db.go | 12 +++++------- test/testrig/base.go | 37 ++++++++++++++----------------------- 8 files changed, 36 insertions(+), 31 deletions(-) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 5f75f5e4d..56a145653 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -70,6 +70,7 @@ func main() { cfg.AppServiceAPI.DisableTLSValidation = true cfg.ClientAPI.RateLimiting.Enabled = false cfg.FederationAPI.DisableTLSValidation = false + cfg.FederationAPI.DisableHTTPKeepalives = true // don't hit matrix.org when running tests!!! cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{} cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) diff --git a/internal/log.go b/internal/log.go index d7e852c81..da6e20418 100644 --- a/internal/log.go +++ b/internal/log.go @@ -24,6 +24,7 @@ import ( "path/filepath" "runtime" "strings" + "sync" "github.com/matrix-org/util" @@ -37,6 +38,7 @@ import ( // this unfortunately results in us adding the same hook multiple times. // This map ensures we only ever add one level hook. var stdLevelLogAdded = make(map[logrus.Level]bool) +var levelLogAddedMu = &sync.Mutex{} type utcFormatter struct { logrus.Formatter diff --git a/internal/log_unix.go b/internal/log_unix.go index b38e7c2e8..8f34c320d 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -85,6 +85,8 @@ func checkSyslogHookParams(params map[string]interface{}) { } func setupStdLogHook(level logrus.Level) { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() if stdLevelLogAdded[level] { return } diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index c1ce9583f..533652160 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -77,6 +77,11 @@ func JetStreamConsumer( // The consumer was deleted so stop. return } else { + // Unfortunately, there's no ErrServerShutdown or similar, so we need to compare the string + if err.Error() == "nats: Server Shutdown" { + logrus.WithContext(ctx).Warn("nats server shutting down") + return + } // Something else went wrong, so we'll panic. sentry.CaptureException(err) logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) diff --git a/sytest-blacklist b/sytest-blacklist index bb0ee368f..49a3cc870 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -7,6 +7,7 @@ AS-ghosted users can use rooms via AS Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server +If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes # This will fail in HTTP API mode, so blacklisted for now If a device list update goes missing, the server resyncs on the next one diff --git a/sytest-whitelist b/sytest-whitelist index 1f6ecc29e..c61e0bc3c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -778,4 +778,9 @@ Can receive redactions from regular users over federation in room version 10 New federated private chats get full presence information (SYN-115) /state returns M_NOT_FOUND for an outlier /state_ids returns M_NOT_FOUND for an outlier -Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state \ No newline at end of file +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state +Invited user can reject invite for empty room +Invited user can reject local invite after originator leaves +Guest users can join guest_access rooms +Forgotten room messages cannot be paginated +Local device key changes get to remote servers with correct prev_id \ No newline at end of file diff --git a/test/db.go b/test/db.go index 17f637e18..54ded6adb 100644 --- a/test/db.go +++ b/test/db.go @@ -22,6 +22,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "testing" "github.com/lib/pq" @@ -103,13 +104,10 @@ func currentUser() string { // TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { - // this will be made in the current working directory which namespaces concurrent package runs correctly - dbname := "dendrite_test.db" + // this will be made in the t.TempDir, which is unique per test + dbname := filepath.Join(t.TempDir(), "dendrite_test.db") return fmt.Sprintf("file:%s", dbname), func() { - err := os.Remove(dbname) - if err != nil { - t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err) - } + t.Cleanup(func() {}) // removes the t.TempDir } } @@ -176,7 +174,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) { for dbName, dbType := range dbs { dbt := dbType t.Run(dbName, func(tt *testing.T) { - //tt.Parallel() + tt.Parallel() testFn(tt, dbt) }) } diff --git a/test/testrig/base.go b/test/testrig/base.go index 52e6ef5f1..9773da223 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -15,18 +15,14 @@ package testrig import ( - "errors" "fmt" - "io/fs" - "os" - "strings" + "path/filepath" "testing" - "github.com/nats-io/nats.go" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/nats-io/nats.go" ) func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) { @@ -77,27 +73,22 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) + + // Use a temp dir provided by go for tests, this will be cleanup by a call to t.CleanUp() + tempDir := t.TempDir() + cfg.FederationAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "federationapi.db")) + cfg.KeyServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "keyserver.db")) + cfg.MSCs.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mscs.db")) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mediaapi.db")) + cfg.RoomServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "roomserver.db")) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "syncapi.db")) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db")) + base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) return base, func() { base.ShutdownDendrite() base.WaitForShutdown() - // cleanup db files. This risks getting out of sync as we add more database strings :( - dbFiles := []config.DataSource{ - cfg.FederationAPI.Database.ConnectionString, - cfg.KeyServer.Database.ConnectionString, - cfg.MSCs.Database.ConnectionString, - cfg.MediaAPI.Database.ConnectionString, - cfg.RoomServer.Database.ConnectionString, - cfg.SyncAPI.Database.ConnectionString, - cfg.UserAPI.AccountDatabase.ConnectionString, - } - for _, fileURI := range dbFiles { - path := strings.TrimPrefix(string(fileURI), "file:") - err := os.Remove(path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - t.Fatalf("failed to cleanup sqlite db '%s': %s", fileURI, err) - } - } + t.Cleanup(func() {}) // removes t.TempDir, where all database files are created } default: t.Fatalf("unknown db type: %v", dbType) From 5b73592f5a4dddf64184fcbe33f4c1835c656480 Mon Sep 17 00:00:00 2001 From: devonh Date: Mon, 23 Jan 2023 17:55:12 +0000 Subject: [PATCH 66/67] Initial Store & Forward Implementation (#2917) This adds store & forward relays into dendrite for p2p. A few things have changed: - new relay api serves new http endpoints for s&f federation - updated outbound federation queueing which will attempt to forward using s&f if appropriate - database entries to track s&f relays for other nodes --- .gitignore | 1 + build/gobind-pinecone/monolith.go | 201 ++++- build/gobind-pinecone/monolith_test.go | 198 +++++ cmd/dendrite-demo-pinecone/ARCHITECTURE.md | 59 ++ cmd/dendrite-demo-pinecone/README.md | 39 + cmd/dendrite-demo-pinecone/main.go | 125 ++- federationapi/api/api.go | 25 +- federationapi/federationapi.go | 5 +- federationapi/internal/api.go | 9 +- .../internal/federationclient_test.go | 202 +++++ federationapi/internal/perform.go | 73 +- federationapi/internal/perform_test.go | 190 ++++ federationapi/inthttp/client.go | 12 + federationapi/queue/destinationqueue.go | 80 +- federationapi/queue/queue.go | 25 +- federationapi/queue/queue_test.go | 436 ++++------ federationapi/routing/profile_test.go | 94 ++ federationapi/routing/query_test.go | 94 ++ federationapi/routing/routing.go | 14 +- federationapi/routing/send.go | 339 +------- federationapi/routing/send_test.go | 605 ++----------- federationapi/statistics/statistics.go | 186 +++- federationapi/statistics/statistics_test.go | 58 +- federationapi/storage/interface.go | 47 +- .../storage/postgres/assumed_offline_table.go | 107 +++ .../storage/postgres/relay_servers_table.go | 137 +++ federationapi/storage/postgres/storage.go | 10 + .../storage/shared/receipt/receipt.go | 42 + federationapi/storage/shared/storage.go | 184 +++- federationapi/storage/shared/storage_edus.go | 29 +- federationapi/storage/shared/storage_pdus.go | 27 +- .../storage/sqlite3/assumed_offline_table.go | 107 +++ .../storage/sqlite3/relay_servers_table.go | 148 ++++ federationapi/storage/sqlite3/storage.go | 13 +- federationapi/storage/storage_test.go | 103 ++- federationapi/storage/tables/interface.go | 27 + .../tables/relay_servers_table_test.go | 224 +++++ go.mod | 20 +- go.sum | 42 +- internal/log.go | 2 + internal/log_unix.go | 4 +- internal/transactionrequest.go | 356 ++++++++ internal/transactionrequest_test.go | 820 ++++++++++++++++++ mediaapi/routing/routing.go | 26 +- relayapi/api/api.go | 56 ++ relayapi/internal/api.go | 53 ++ relayapi/internal/perform.go | 141 +++ relayapi/internal/perform_test.go | 121 +++ relayapi/relayapi.go | 74 ++ relayapi/relayapi_test.go | 154 ++++ relayapi/routing/relaytxn.go | 74 ++ relayapi/routing/relaytxn_test.go | 220 +++++ relayapi/routing/routing.go | 123 +++ relayapi/routing/sendrelay.go | 77 ++ relayapi/routing/sendrelay_test.go | 209 +++++ relayapi/storage/interface.go | 47 + .../postgres/relay_queue_json_table.go | 113 +++ .../storage/postgres/relay_queue_table.go | 156 ++++ relayapi/storage/postgres/storage.go | 64 ++ relayapi/storage/shared/storage.go | 170 ++++ .../storage/sqlite3/relay_queue_json_table.go | 137 +++ relayapi/storage/sqlite3/relay_queue_table.go | 168 ++++ relayapi/storage/sqlite3/storage.go | 64 ++ relayapi/storage/storage.go | 46 + relayapi/storage/tables/interface.go | 66 ++ .../tables/relay_queue_json_table_test.go | 173 ++++ .../storage/tables/relay_queue_table_test.go | 229 +++++ setup/base/base.go | 6 + setup/config/config.go | 5 +- setup/config/config_federationapi.go | 7 + setup/config/config_relayapi.go | 52 ++ setup/config/config_test.go | 29 +- setup/monolith.go | 7 + test/db.go | 1 - test/memory_federation_db.go | 488 +++++++++++ test/memory_relay_db.go | 140 +++ test/testrig/base.go | 4 +- 77 files changed, 7646 insertions(+), 1373 deletions(-) create mode 100644 build/gobind-pinecone/monolith_test.go create mode 100644 cmd/dendrite-demo-pinecone/ARCHITECTURE.md create mode 100644 federationapi/internal/federationclient_test.go create mode 100644 federationapi/internal/perform_test.go create mode 100644 federationapi/routing/profile_test.go create mode 100644 federationapi/routing/query_test.go create mode 100644 federationapi/storage/postgres/assumed_offline_table.go create mode 100644 federationapi/storage/postgres/relay_servers_table.go create mode 100644 federationapi/storage/shared/receipt/receipt.go create mode 100644 federationapi/storage/sqlite3/assumed_offline_table.go create mode 100644 federationapi/storage/sqlite3/relay_servers_table.go create mode 100644 federationapi/storage/tables/relay_servers_table_test.go create mode 100644 internal/transactionrequest.go create mode 100644 internal/transactionrequest_test.go create mode 100644 relayapi/api/api.go create mode 100644 relayapi/internal/api.go create mode 100644 relayapi/internal/perform.go create mode 100644 relayapi/internal/perform_test.go create mode 100644 relayapi/relayapi.go create mode 100644 relayapi/relayapi_test.go create mode 100644 relayapi/routing/relaytxn.go create mode 100644 relayapi/routing/relaytxn_test.go create mode 100644 relayapi/routing/routing.go create mode 100644 relayapi/routing/sendrelay.go create mode 100644 relayapi/routing/sendrelay_test.go create mode 100644 relayapi/storage/interface.go create mode 100644 relayapi/storage/postgres/relay_queue_json_table.go create mode 100644 relayapi/storage/postgres/relay_queue_table.go create mode 100644 relayapi/storage/postgres/storage.go create mode 100644 relayapi/storage/shared/storage.go create mode 100644 relayapi/storage/sqlite3/relay_queue_json_table.go create mode 100644 relayapi/storage/sqlite3/relay_queue_table.go create mode 100644 relayapi/storage/sqlite3/storage.go create mode 100644 relayapi/storage/storage.go create mode 100644 relayapi/storage/tables/interface.go create mode 100644 relayapi/storage/tables/relay_queue_json_table_test.go create mode 100644 relayapi/storage/tables/relay_queue_table_test.go create mode 100644 setup/config/config_relayapi.go create mode 100644 test/memory_federation_db.go create mode 100644 test/memory_relay_db.go diff --git a/.gitignore b/.gitignore index e4f0112c4..fe5e82797 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,7 @@ dendrite.yaml # Database files *.db +*.db-journal # Log files *.log* diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index b8f8111d2..ff61ea6c8 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -41,13 +41,16 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/relayapi" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" userapiAPI "github.com/matrix-org/dendrite/userapi/api" @@ -67,24 +70,27 @@ import ( ) const ( - PeerTypeRemote = pineconeRouter.PeerTypeRemote - PeerTypeMulticast = pineconeRouter.PeerTypeMulticast - PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth - PeerTypeBonjour = pineconeRouter.PeerTypeBonjour + PeerTypeRemote = pineconeRouter.PeerTypeRemote + PeerTypeMulticast = pineconeRouter.PeerTypeMulticast + PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth + PeerTypeBonjour = pineconeRouter.PeerTypeBonjour + relayServerRetryInterval = time.Second * 30 ) type DendriteMonolith struct { - logger logrus.Logger - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - PineconeManager *pineconeConnections.ConnectionManager - StorageDirectory string - CacheDirectory string - listener net.Listener - httpServer *http.Server - processContext *process.ProcessContext - userAPI userapiAPI.UserInternalAPI + logger logrus.Logger + baseDendrite *base.BaseDendrite + PineconeRouter *pineconeRouter.Router + PineconeMulticast *pineconeMulticast.Multicast + PineconeQUIC *pineconeSessions.Sessions + PineconeManager *pineconeConnections.ConnectionManager + StorageDirectory string + CacheDirectory string + listener net.Listener + httpServer *http.Server + userAPI userapiAPI.UserInternalAPI + federationAPI api.FederationInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool } func (m *DendriteMonolith) PublicKey() string { @@ -326,6 +332,7 @@ func (m *DendriteMonolith) Start() { cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(m.StorageDirectory, prefix))) cfg.MediaAPI.BasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(m.StorageDirectory, prefix))) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true @@ -335,9 +342,9 @@ func (m *DendriteMonolith) Start() { panic(err) } - base := base.NewBaseDendrite(cfg, "Monolith") + base := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) + m.baseDendrite = base base.ConfigureAdminEndpoints() - defer base.Close() // nolint: errcheck federation := conn.CreateFederationClient(base, m.PineconeQUIC) @@ -346,11 +353,11 @@ func (m *DendriteMonolith) Start() { rsAPI := roomserver.NewInternalAPI(base) - fsAPI := federationapi.NewInternalAPI( + m.federationAPI = federationapi.NewInternalAPI( base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, m.federationAPI, rsAPI) m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) @@ -358,10 +365,24 @@ func (m *DendriteMonolith) Start() { // The underlying roomserver implementation needs to be able to call the fedsender. // This is different to rsAPI which can be the http client which doesn't need this dependency - rsAPI.SetFederationAPI(fsAPI, keyRing) + rsAPI.SetFederationAPI(m.federationAPI, keyRing) userProvider := users.NewPineconeUserProvider(m.PineconeRouter, m.PineconeQUIC, m.userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, fsAPI, federation) + roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, m.federationAPI, federation) + + js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &base.Cfg.FederationAPI, + UserAPI: m.userAPI, + } + relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer) monolith := setup.Monolith{ Config: base.Cfg, @@ -370,10 +391,11 @@ func (m *DendriteMonolith) Start() { KeyRing: keyRing, AppserviceAPI: asAPI, - FederationAPI: fsAPI, + FederationAPI: m.federationAPI, RoomserverAPI: rsAPI, UserAPI: m.userAPI, KeyAPI: keyAPI, + RelayAPI: relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } @@ -411,8 +433,6 @@ func (m *DendriteMonolith) Start() { Handler: h2c.NewHandler(pMux, h2s), } - m.processContext = base.ProcessContext - go func() { m.logger.Info("Listening on ", cfg.Global.ServerName) @@ -420,7 +440,7 @@ func (m *DendriteMonolith) Start() { case net.ErrClosed, http.ErrServerClosed: m.logger.Info("Stopped listening on ", cfg.Global.ServerName) default: - m.logger.Fatal(err) + m.logger.Error("Stopped listening on ", cfg.Global.ServerName) } }() go func() { @@ -430,33 +450,44 @@ func (m *DendriteMonolith) Start() { case net.ErrClosed, http.ErrServerClosed: m.logger.Info("Stopped listening on ", cfg.Global.ServerName) default: - m.logger.Fatal(err) + m.logger.Error("Stopped listening on ", cfg.Global.ServerName) } }() go func(ch <-chan pineconeEvents.Event) { eLog := logrus.WithField("pinecone", "events") + stopRelayServerSync := make(chan bool) + + relayRetriever := RelayServerRetriever{ + Context: context.Background(), + ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), + FederationAPI: m.federationAPI, + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + RelayAPI: monolith.RelayAPI, + running: *atomic.NewBool(false), + } + relayRetriever.InitializeRelayServers(eLog) for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: + if !relayRetriever.running.Load() { + go relayRetriever.SyncRelayServers(stopRelayServerSync) + } case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: + if relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { + stopRelayServerSync <- true + } case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) + // eLog.Info("Broadcast received from: ", e.PeerID) req := &api.PerformWakeupServersRequest{ ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, } res := &api.PerformWakeupServersResponse{} - if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + if err := m.federationAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } - case pineconeEvents.BandwidthReport: default: } } @@ -464,12 +495,106 @@ func (m *DendriteMonolith) Start() { } func (m *DendriteMonolith) Stop() { - m.processContext.ShutdownDendrite() + m.baseDendrite.Close() + m.baseDendrite.WaitForShutdown() _ = m.listener.Close() m.PineconeMulticast.Stop() _ = m.PineconeQUIC.Close() _ = m.PineconeRouter.Close() - m.processContext.WaitForComponentsToFinish() +} + +type RelayServerRetriever struct { + Context context.Context + ServerName gomatrixserverlib.ServerName + FederationAPI api.FederationInternalAPI + RelayAPI relayServerAPI.RelayInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool + queriedServersMutex sync.Mutex + running atomic.Bool +} + +func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} + response := api.P2PQueryRelayServersResponse{} + err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) + if err != nil { + eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + } + for _, server := range response.RelayServers { + m.relayServersQueried[server] = false + } + + eLog.Infof("Registered relay servers: %v", response.RelayServers) +} + +func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { + defer m.running.Store(false) + + t := time.NewTimer(relayServerRetryInterval) + for { + relayServersToQuery := []gomatrixserverlib.ServerName{} + func() { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + for server, complete := range m.relayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } + } + }() + if len(relayServersToQuery) == 0 { + // All relay servers have been synced. + return + } + m.queryRelayServers(relayServersToQuery) + t.Reset(relayServerRetryInterval) + + select { + case <-stop: + if !t.Stop() { + <-t.C + } + return + case <-t.C: + } + } +} + +func (m *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + + result := map[gomatrixserverlib.ServerName]bool{} + for server, queried := range m.relayServersQueried { + result[server] = queried + } + return result +} + +func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("querying relay servers for any available transactions") + for _, server := range relayServers { + userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) + if err != nil { + return + } + err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) + if err == nil { + func() { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + m.relayServersQueried[server] = true + }() + // TODO : What happens if your relay receives new messages after this point? + // Should you continue to check with them, or should they try and contact you? + // They could send a "new_async_events" message your way maybe? + // Then you could mark them as needing to be queried again. + // What if you miss this message? + // Maybe you should try querying them again after a certain period of time as a backup? + } else { + logrus.Errorf("Failed querying relay server: %s", err.Error()) + } + } } const MaxFrameSize = types.MaxFrameSize diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go new file mode 100644 index 000000000..edcf22bbe --- /dev/null +++ b/build/gobind-pinecone/monolith_test.go @@ -0,0 +1,198 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gobind + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/api" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "gotest.tools/v3/poll" +) + +var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + +type TestNetConn struct { + net.Conn + shouldFail bool +} + +func (t *TestNetConn) Read(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + n := copy(b, TestBuf) + return n, nil + } +} + +func (t *TestNetConn) Write(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + return len(b), nil + } +} + +func (t *TestNetConn) Close() error { + if t.shouldFail { + return fmt.Errorf("Failed") + } else { + return nil + } +} + +func TestConduitStoresPort(t *testing.T) { + conduit := Conduit{port: 7} + assert.Equal(t, 7, conduit.Port()) +} + +func TestConduitRead(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + b := make([]byte, len(TestBuf)) + bytes, err := conduit.Read(b) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) + assert.Equal(t, TestBuf, b) +} + +func TestConduitReadCopy(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + result, err := conduit.ReadCopy() + assert.NoError(t, err) + assert.Equal(t, TestBuf, result) +} + +func TestConduitWrite(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + bytes, err := conduit.Write(TestBuf) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) +} + +func TestConduitClose(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + assert.True(t, conduit.closed.Load()) +} + +func TestConduitReadClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + b := make([]byte, len(TestBuf)) + _, err = conduit.Read(b) + assert.Error(t, err) +} + +func TestConduitReadCopyClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.ReadCopy() + assert.Error(t, err) +} + +func TestConduitWriteClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.Write(TestBuf) + assert.Error(t, err) +} + +func TestConduitReadCopyFails(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{shouldFail: true}} + _, err := conduit.ReadCopy() + assert.Error(t, err) +} + +var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"} + +type FakeFedAPI struct { + api.FederationInternalAPI +} + +func (f *FakeFedAPI) P2PQueryRelayServers(ctx context.Context, req *api.P2PQueryRelayServersRequest, res *api.P2PQueryRelayServersResponse) error { + res.RelayServers = testRelayServers + return nil +} + +type FakeRelayAPI struct { + relayServerAPI.RelayInternalAPI +} + +func (r *FakeRelayAPI) PerformRelayServerSync(ctx context.Context, userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error { + return nil +} + +func TestRelayRetrieverInitialization(t *testing.T) { + retriever := RelayServerRetriever{ + Context: context.Background(), + ServerName: "server", + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + FederationAPI: &FakeFedAPI{}, + RelayAPI: &FakeRelayAPI{}, + } + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) +} + +func TestRelayRetrieverSync(t *testing.T) { + retriever := RelayServerRetriever{ + Context: context.Background(), + ServerName: "server", + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + FederationAPI: &FakeFedAPI{}, + RelayAPI: &FakeRelayAPI{}, + } + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) + + stopRelayServerSync := make(chan bool) + go retriever.SyncRelayServers(stopRelayServerSync) + + check := func(log poll.LogT) poll.Result { + relayServers := retriever.GetQueriedServerStatus() + for _, queried := range relayServers { + if !queried { + return poll.Continue("waiting for all servers to be queried") + } + } + + stopRelayServerSync <- true + return poll.Success() + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestMonolithStarts(t *testing.T) { + monolith := DendriteMonolith{} + monolith.Start() + monolith.PublicKey() + monolith.Stop() +} diff --git a/cmd/dendrite-demo-pinecone/ARCHITECTURE.md b/cmd/dendrite-demo-pinecone/ARCHITECTURE.md new file mode 100644 index 000000000..1b0941053 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/ARCHITECTURE.md @@ -0,0 +1,59 @@ +## Relay Server Architecture + +Relay Servers function similar to the way physical mail drop boxes do. +A node can have many associated relay servers. Matrix events can be sent to them instead of to the destination node, and the destination node will eventually retrieve them from the relay server. +Nodes that want to send events to an offline node need to know what relay servers are associated with their intended destination. +Currently this is manually configured in the dendrite database. In the future this information could be configurable in the app and shared automatically via other means. + +Currently events are sent as complete Matrix Transactions. +Transactions include a list of PDUs, (which contain, among other things, lists of authorization events, previous events, and signatures) a list of EDUs, and other information about the transaction. +There is no additional information sent along with the transaction other than what is typically added to them during Matrix federation today. +In the future this will probably need to change in order to handle more complex room state resolution during p2p usage. + +### Relay Server Architecture + +``` + 0 +--------------------+ + +----------------------------------------+ | P2P Node A | + | Relay Server | | +--------+ | + | | | | Client | | + | +--------------------+ | | +--------+ | + | | Relay Server API | | | | | + | | | | | V | + | .--------. 2 | +-------------+ | | 1 | +------------+ | + | |`--------`| <----- | Forwarder | <------------- | Homeserver | | + | | Database | | +-------------+ | | | +------------+ | + | `----------` | | | +--------------------+ + | ^ | | | + | | 4 | +-------------+ | | + | `------------ | Retriever | <------. +--------------------+ + | | +-------------+ | | | | P2P Node B | + | | | | | | +--------+ | + | +--------------------+ | | | | Client | | + | | | | +--------+ | + +----------------------------------------+ | | | | + | | V | + 3 | | +------------+ | + `------ | Homeserver | | + | +------------+ | + +--------------------+ +``` + +- 0: This relay server is currently only acting on behalf of `P2P Node B`. It will only receive, and later forward events that are destined for `P2P Node B`. +- 1: When `P2P Node A` fails sending directly to `P2P Node B` (after a configurable number of attempts), it checks for any known relay servers associated with `P2P Node B` and sends to all of them. + - If sending to any of the relay servers succeeds, that transaction is considered to be successfully sent. +- 2: The relay server `forwarder` stores the transaction json in its database and marks it as destined for `P2P Node B`. +- 3: When `P2P Node B` comes online, it queries all its relay servers for any missed messages. +- 4: The relay server `retriever` will look in its database for any transactions that are destined for `P2P Node B` and returns them one at a time. + +For now, it is important that we don’t design out a hybrid approach of having both sender-side and recipient-side relay servers. +Both approaches make sense and determining which makes for a better experience depends on the use case. + +#### Sender-Side Relay Servers + +If we are running around truly ad-hoc, and I don't know when or where you will be able to pick up messages, then having a sender designated server makes sense to give things the best chance at making their way to the destination. +But in order to achieve this, you are either relying on p2p presence broadcasts for the relay to know when to try forwarding (which means you are in a pretty small network), or the relay just keeps on periodically attempting to forward to the destination which will lead to a lot of extra traffic on the network. + +#### Recipient-Side Relay Servers + +If we have agreed to some static relay server before going off and doing other things, or if we are talking about more global p2p federation, then having a recipient designated relay server can cut down on redundant traffic since it will sit there idle until the recipient pulls events from it. diff --git a/cmd/dendrite-demo-pinecone/README.md b/cmd/dendrite-demo-pinecone/README.md index d6dd95905..5cacd0924 100644 --- a/cmd/dendrite-demo-pinecone/README.md +++ b/cmd/dendrite-demo-pinecone/README.md @@ -24,3 +24,42 @@ Then point your favourite Matrix client to the homeserver URL`http://localhost: If your peering connection is operational then you should see a `Connected TCP:` line in the log output. If not then try a different peer. Once logged in, you should be able to open the room directory or join a room by its ID. + +## Store & Forward Relays + +To test out the store & forward relay functionality, you need a minimum of 3 instances. +One instance will act as the relay, and the other two instances will be the users trying to communicate. +Then you can send messages between the two nodes and watch as the relay is used if the receiving node is offline. + +### Launching the Nodes + +Relay Server: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir relay/ -listen "[::]:49000" +``` + +Node 1: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir node-1/ -peer "[::]:49000" -port 8007 +``` + +Node 2: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir node-2/ -peer "[::]:49000" -port 8009 +``` + +### Database Setup + +At the moment, the database must be manually configured. +For both `Node 1` and `Node 2` add the following entries to their respective `relay_server` table in the federationapi database: +``` +server_name: {node_1_public_key}, relay_server_name: {relay_public_key} +server_name: {node_2_public_key}, relay_server_name: {relay_public_key} +``` + +After editing the database you will need to relaunch the nodes for the changes to be picked up by dendrite. + +### Testing + +Now you can run two separate instances of element and connect them to `Node 1` and `Node 2`. +You can shutdown one of the nodes and continue sending messages. If you wait long enough, the message will be sent to the relay server. (you can see this in the log output of the relay server) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 3f627b41d..a813c37a2 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -38,16 +38,21 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/relayapi" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" + "go.uber.org/atomic" pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" @@ -66,6 +71,8 @@ var ( instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") ) +const relayServerRetryInterval = time.Second * 30 + // nolint:gocyclo func main() { flag.Parse() @@ -139,6 +146,7 @@ func main() { cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName))) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName))) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) cfg.ClientAPI.RegistrationDisabled = false @@ -224,6 +232,20 @@ func main() { userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation) roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation) + js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &base.Cfg.FederationAPI, + UserAPI: userAPI, + } + relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer) + monolith := setup.Monolith{ Config: base.Cfg, Client: conn.CreateClient(base, pQUIC), @@ -235,6 +257,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, KeyAPI: keyAPI, + RelayAPI: relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } @@ -305,27 +328,38 @@ func main() { go func(ch <-chan pineconeEvents.Event) { eLog := logrus.WithField("pinecone", "events") + relayServerSyncRunning := atomic.NewBool(false) + stopRelayServerSync := make(chan bool) + + m := RelayServerRetriever{ + Context: context.Background(), + ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()), + FederationAPI: fsAPI, + RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + RelayAPI: monolith.RelayAPI, + } + m.InitializeRelayServers(eLog) for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: + if !relayServerSyncRunning.Load() { + go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning) + } case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: + if relayServerSyncRunning.Load() && pRouter.TotalPeerCount() == 0 { + stopRelayServerSync <- true + } case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) + // eLog.Info("Broadcast received from: ", e.PeerID) req := &api.PerformWakeupServersRequest{ ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, } res := &api.PerformWakeupServersResponse{} if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } - case pineconeEvents.BandwidthReport: default: } } @@ -333,3 +367,78 @@ func main() { base.WaitForShutdown() } + +type RelayServerRetriever struct { + Context context.Context + ServerName gomatrixserverlib.ServerName + FederationAPI api.FederationInternalAPI + RelayServersQueried map[gomatrixserverlib.ServerName]bool + RelayAPI relayServerAPI.RelayInternalAPI +} + +func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} + response := api.P2PQueryRelayServersResponse{} + err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) + if err != nil { + eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + } + for _, server := range response.RelayServers { + m.RelayServersQueried[server] = false + } + + eLog.Infof("Registered relay servers: %v", response.RelayServers) +} + +func (m *RelayServerRetriever) syncRelayServers(stop <-chan bool, running atomic.Bool) { + defer running.Store(false) + + t := time.NewTimer(relayServerRetryInterval) + for { + relayServersToQuery := []gomatrixserverlib.ServerName{} + for server, complete := range m.RelayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } + } + if len(relayServersToQuery) == 0 { + // All relay servers have been synced. + return + } + m.queryRelayServers(relayServersToQuery) + t.Reset(relayServerRetryInterval) + + select { + case <-stop: + // We have been asked to stop syncing, drain the timer and return. + if !t.Stop() { + <-t.C + } + return + case <-t.C: + // The timer has expired. Continue to the next loop iteration. + } + } +} + +func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("querying relay servers for any available transactions") + for _, server := range relayServers { + userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) + if err != nil { + return + } + err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) + if err == nil { + m.RelayServersQueried[server] = true + // TODO : What happens if your relay receives new messages after this point? + // Should you continue to check with them, or should they try and contact you? + // They could send a "new_async_events" message your way maybe? + // Then you could mark them as needing to be queried again. + // What if you miss this message? + // Maybe you should try querying them again after a certain period of time as a backup? + } else { + logrus.Errorf("Failed querying relay server: %s", err.Error()) + } + } +} diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 50d0339e4..417b08521 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -18,6 +18,7 @@ type FederationInternalAPI interface { gomatrixserverlib.KeyDatabase ClientFederationAPI RoomserverFederationAPI + P2PFederationAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) @@ -30,7 +31,6 @@ type FederationInternalAPI interface { request *PerformBroadcastEDURequest, response *PerformBroadcastEDUResponse, ) error - PerformWakeupServers( ctx context.Context, request *PerformWakeupServersRequest, @@ -71,6 +71,15 @@ type RoomserverFederationAPI interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationAPI interface { + // Relay Server sync api used in the pinecone demos. + P2PQueryRelayServers( + ctx context.Context, + request *P2PQueryRelayServersRequest, + response *P2PQueryRelayServersResponse, + ) error +} + // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError @@ -82,6 +91,7 @@ type KeyserverFederationAPI interface { // an interface for gmsl.FederationClient - contains functions called by federationapi only. type FederationClient interface { + P2PFederationClient gomatrixserverlib.KeyClient SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) @@ -110,6 +120,11 @@ type FederationClient interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationClient interface { + P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) + P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetRelayTransaction, err error) +} + // FederationClientError is returned from FederationClient methods in the event of a problem. type FederationClientError struct { Err string @@ -233,3 +248,11 @@ type InputPublicKeysRequest struct { type InputPublicKeysResponse struct { } + +type P2PQueryRelayServersRequest struct { + Server gomatrixserverlib.ServerName +} + +type P2PQueryRelayServersResponse struct { + RelayServers []gomatrixserverlib.ServerName +} diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index ce0ce98e9..ed9a545d6 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -113,7 +113,10 @@ func NewInternalAPI( _ = federationDB.RemoveAllServersFromBlacklist() } - stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) + stats := statistics.NewStatistics( + federationDB, + cfg.FederationMaxRetries+1, + cfg.P2PFederationRetriesUntilAssumedOffline+1) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 14056eafc..99773a750 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -109,13 +109,14 @@ func NewFederationInternalAPI( func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) { stats := a.statistics.ForServer(s) - until, blacklisted := stats.BackoffInfo() - if blacklisted { + if stats.Blacklisted() { return stats, &api.FederationClientError{ Blacklisted: true, } } + now := time.Now() + until := stats.BackoffInfo() if until != nil && now.Before(*until) { return stats, &api.FederationClientError{ RetryAfter: time.Until(*until), @@ -163,7 +164,7 @@ func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( RetryAfter: retryAfter, } } - stats.Success() + stats.Success(statistics.SendDirect) return res, nil } @@ -171,7 +172,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted( s gomatrixserverlib.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats := a.statistics.ForServer(s) - if _, blacklisted := stats.BackoffInfo(); blacklisted { + if blacklisted := stats.Blacklisted(); blacklisted { return stats, &api.FederationClientError{ Err: fmt.Sprintf("server %q is blacklisted", s), Blacklisted: true, diff --git a/federationapi/internal/federationclient_test.go b/federationapi/internal/federationclient_test.go new file mode 100644 index 000000000..49137e2d8 --- /dev/null +++ b/federationapi/internal/federationclient_test.go @@ -0,0 +1,202 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 +) + +func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) { + t.queryKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespQueryKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespQueryKeys{}, nil +} + +func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) { + t.claimKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespClaimKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespClaimKeys{}, nil +} + +func TestFederationClientQueryKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysFailure(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{shouldFail: true} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientClaimKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.claimKeysCalled) +} + +func TestFederationClientClaimKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.claimKeysCalled) +} diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index d86d07e03..552942f28 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/consumers" + "github.com/matrix-org/dendrite/federationapi/statistics" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -24,6 +25,10 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( request *api.PerformDirectoryLookupRequest, response *api.PerformDirectoryLookupResponse, ) (err error) { + if !r.shouldAttemptDirectFederation(request.ServerName) { + return fmt.Errorf("relay servers have no meaningful response for directory lookup.") + } + dir, err := r.federation.LookupRoomAlias( ctx, r.cfg.Matrix.ServerName, @@ -36,7 +41,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( } response.RoomID = dir.RoomID response.ServerNames = dir.Servers - r.statistics.ForServer(request.ServerName).Success() + r.statistics.ForServer(request.ServerName).Success(statistics.SendDirect) return nil } @@ -144,6 +149,10 @@ func (r *FederationInternalAPI) performJoinUsingServer( supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for join.") + } + _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) if err != nil { return err @@ -164,7 +173,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.MakeJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" @@ -219,7 +228,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.SendJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // If the remote server returned an event in the "event" key of // the send_join request then we should use that instead. It may @@ -407,6 +416,10 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( serverName gomatrixserverlib.ServerName, supportedVersions []gomatrixserverlib.RoomVersion, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for outbound peek.") + } + // create a unique ID for this peek. // for now we just use the room ID again. In future, if we ever // support concurrent peeks to the same room with different filters @@ -446,7 +459,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.Peek: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Work out if we support the room version that has been supplied in // the peek response. @@ -516,6 +529,10 @@ func (r *FederationInternalAPI) PerformLeave( // Try each server that we were provided until we land on one that // successfully completes the make-leave send-leave dance. for _, serverName := range request.ServerNames { + if !r.shouldAttemptDirectFederation(serverName) { + continue + } + // Try to perform a make_leave using the information supplied in the // request. respMakeLeave, err := r.federation.MakeLeave( @@ -585,7 +602,7 @@ func (r *FederationInternalAPI) PerformLeave( continue } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) return nil } @@ -616,6 +633,12 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } + // TODO (devon): This should be allowed via a relay. Currently only transactions + // can be sent to relays. Would need to extend relays to handle invites. + if !r.shouldAttemptDirectFederation(destination) { + return fmt.Errorf("relay servers have no meaningful response for invite.") + } + logrus.WithFields(logrus.Fields{ "event_id": request.Event.EventID(), "user_id": *request.Event.StateKey(), @@ -682,12 +705,8 @@ func (r *FederationInternalAPI) PerformWakeupServers( func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { for _, srv := range destinations { - // Check the statistics cache for the blacklist status to prevent hitting - // the database unnecessarily. - if r.queues.IsServerBlacklisted(srv) { - _ = r.db.RemoveServerFromBlacklist(srv) - } - r.queues.RetryServer(srv) + wasBlacklisted := r.statistics.ForServer(srv).MarkServerAlive() + r.queues.RetryServer(srv, wasBlacklisted) } } @@ -719,7 +738,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { return fmt.Errorf("auth chain response is missing m.room.create event") } -func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion { +func setDefaultRoomVersionFromJoinEvent( + joinEvent gomatrixserverlib.EventBuilder, +) gomatrixserverlib.RoomVersion { // if auth events are not event references we know it must be v3+ // we have to do these shenanigans to satisfy sytest, specifically for: // "Outbound federation rejects m.room.create events with an unknown room version" @@ -802,3 +823,31 @@ func federatedAuthProvider( return returning, nil } } + +// P2PQueryRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PQueryRelayServers( + ctx context.Context, + request *api.P2PQueryRelayServersRequest, + response *api.P2PQueryRelayServersResponse, +) error { + logrus.Infof("Getting relay servers for: %s", request.Server) + relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server) + if err != nil { + return err + } + + response.RelayServers = relayServers + return nil +} + +func (r *FederationInternalAPI) shouldAttemptDirectFederation( + destination gomatrixserverlib.ServerName, +) bool { + var shouldRelay bool + stats := r.statistics.ForServer(destination) + if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 { + shouldRelay = true + } + + return !shouldRelay +} diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go new file mode 100644 index 000000000..e8e0d00a3 --- /dev/null +++ b/federationapi/internal/perform_test.go @@ -0,0 +1,190 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + api.FederationClient + queryKeysCalled bool + claimKeysCalled bool + shouldFail bool +} + +func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return gomatrixserverlib.RespDirectory{}, nil +} + +func TestPerformWakeupServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.AddServerToBlacklist(server) + testDB.SetServerAssumedOffline(context.Background(), server) + blacklisted, err := testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.True(t, blacklisted) + offline, err := testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.True(t, offline) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{server}, + } + res := api.PerformWakeupServersResponse{} + err = fedAPI.PerformWakeupServers(context.Background(), &req, &res) + assert.NoError(t, err) + + blacklisted, err = testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.False(t, blacklisted) + offline, err = testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.False(t, offline) +} + +func TestQueryRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PQueryRelayServersRequest{ + Server: server, + } + res := api.P2PQueryRelayServersResponse{} + err = fedAPI.P2PQueryRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + assert.Equal(t, len(relayServers), len(res.RelayServers)) +} + +func TestPerformDirectoryLookup(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: "server", + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.NoError(t, err) +} + +func TestPerformDirectoryLookupRelaying(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.SetServerAssumedOffline(context.Background(), server) + testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"}) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: server, + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: server, + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.Error(t, err) +} diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 6eefdc7cd..6130a567d 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -24,6 +24,7 @@ const ( FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" + FederationAPIQueryRelayServers = "/federationapi/queryRelayServers" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -510,3 +511,14 @@ func (h *httpFederationInternalAPI) QueryPublicKeys( h.httpClient, ctx, request, response, ) } + +func (h *httpFederationInternalAPI) P2PQueryRelayServers( + ctx context.Context, + request *api.P2PQueryRelayServersRequest, + response *api.P2PQueryRelayServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryRelayServers", h.federationAPIURL+FederationAPIQueryRelayServers, + h.httpClient, ctx, request, response, + ) +} diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a4a87fe99..51350916d 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -29,7 +29,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -70,7 +70,7 @@ type destinationQueue struct { // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return @@ -84,8 +84,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re oq.pendingMutex.Lock() if len(oq.pendingPDUs) < maxPDUsInMemory { oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: receipt, + pdu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re // sendEDU adds the EDU event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return @@ -115,8 +115,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.pendingMutex.Lock() if len(oq.pendingEDUs) < maxEDUsInMemory { oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: receipt, + edu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -210,10 +210,10 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotPDUs := map[string]struct{}{} gotEDUs := map[string]struct{}{} for _, pdu := range oq.pendingPDUs { - gotPDUs[pdu.receipt.String()] = struct{}{} + gotPDUs[pdu.dbReceipt.String()] = struct{}{} } for _, edu := range oq.pendingEDUs { - gotEDUs[edu.receipt.String()] = struct{}{} + gotEDUs[edu.dbReceipt.String()] = struct{}{} } overflowed := false @@ -371,7 +371,7 @@ func (oq *destinationQueue) backgroundSend() { // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. - terr := oq.nextTransaction(toSendPDUs, toSendEDUs) + terr, sendMethod := oq.nextTransaction(toSendPDUs, toSendEDUs) if terr != nil { // We failed to send the transaction. Mark it as a failure. _, blacklisted := oq.statistics.Failure() @@ -388,18 +388,19 @@ func (oq *destinationQueue) backgroundSend() { return } } else { - oq.handleTransactionSuccess(pduCount, eduCount) + oq.handleTransactionSuccess(pduCount, eduCount, sendMethod) } } } // nextTransaction creates a new transaction from the pending event // queue and sends it. -// Returns an error if the transaction wasn't sent. +// Returns an error if the transaction wasn't sent. And whether the success +// was to a relay server or not. func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) error { +) (err error, sendMethod statistics.SendMethod) { // Create the transaction. t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) @@ -407,7 +408,37 @@ func (oq *destinationQueue) nextTransaction( // Try to send the transaction to the destination server. ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) defer cancel() - _, err := oq.client.SendTransaction(ctx, t) + + relayServers := oq.statistics.KnownRelayServers() + if oq.statistics.AssumedOffline() && len(relayServers) > 0 { + sendMethod = statistics.SendViaRelay + relaySuccess := false + logrus.Infof("Sending to relay servers: %v", relayServers) + // TODO : how to pass through actual userID here?!?!?!?! + userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) + if userErr != nil { + return userErr, sendMethod + } + + // Attempt sending to each known relay server. + for _, relayServer := range relayServers { + _, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer) + if relayErr != nil { + err = relayErr + } else { + // If sending to one of the relay servers succeeds, consider the send successful. + relaySuccess = true + } + } + + // Clear the error if sending to any of the relay servers succeeded. + if relaySuccess { + err = nil + } + } else { + sendMethod = statistics.SendDirect + _, err = oq.client.SendTransaction(ctx, t) + } switch errResponse := err.(type) { case nil: // Clean up the transaction in the database. @@ -427,7 +458,7 @@ func (oq *destinationQueue) nextTransaction( oq.transactionIDMutex.Lock() oq.transactionID = "" oq.transactionIDMutex.Unlock() - return nil + return nil, sendMethod case gomatrix.HTTPError: // Report that we failed to send the transaction and we // will retry again, subject to backoff. @@ -437,13 +468,13 @@ func (oq *destinationQueue) nextTransaction( // to a 400-ish error code := errResponse.Code logrus.Debug("Transaction failed with HTTP", code) - return err + return err, sendMethod default: logrus.WithFields(logrus.Fields{ "destination": oq.destination, logrus.ErrorKey: err, }).Debugf("Failed to send transaction %q", t.TransactionID) - return err + return err, sendMethod } } @@ -453,7 +484,7 @@ func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) createTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { +) (gomatrixserverlib.Transaction, []*receipt.Receipt, []*receipt.Receipt) { // If there's no projected transaction ID then generate one. If // the transaction succeeds then we'll set it back to "" so that // we generate a new one next time. If it fails, we'll preserve @@ -474,8 +505,8 @@ func (oq *destinationQueue) createTransaction( t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.TransactionID = oq.transactionID - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt + var pduReceipts []*receipt.Receipt + var eduReceipts []*receipt.Receipt // Go through PDUs that we retrieved from the database, if any, // and add them into the transaction. @@ -487,7 +518,7 @@ func (oq *destinationQueue) createTransaction( // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) + pduReceipts = append(pduReceipts, pdu.dbReceipt) } // Do the same for pending EDUS in the queue. @@ -497,7 +528,7 @@ func (oq *destinationQueue) createTransaction( continue } t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) + eduReceipts = append(eduReceipts, edu.dbReceipt) } return t, pduReceipts, eduReceipts @@ -530,10 +561,11 @@ func (oq *destinationQueue) blacklistDestination() { // handleTransactionSuccess updates the cached event queues as well as the success and // backoff information for this server. -func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) { +func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int, sendMethod statistics.SendMethod) { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. - oq.statistics.Success() + + oq.statistics.Success(sendMethod) oq.pendingMutex.Lock() defer oq.pendingMutex.Unlock() diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 75b1b36be..5d6b8d44c 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -30,7 +30,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -138,13 +138,13 @@ func NewOutgoingQueues( } type queuedPDU struct { - receipt *shared.Receipt - pdu *gomatrixserverlib.HeaderedEvent + dbReceipt *receipt.Receipt + pdu *gomatrixserverlib.HeaderedEvent } type queuedEDU struct { - receipt *shared.Receipt - edu *gomatrixserverlib.EDU + dbReceipt *receipt.Receipt + edu *gomatrixserverlib.EDU } func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { @@ -374,24 +374,13 @@ func (oqs *OutgoingQueues) SendEDU( return nil } -// IsServerBlacklisted returns whether or not the provided server is currently -// blacklisted. -func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool { - return oqs.statistics.ForServer(srv).Blacklisted() -} - // RetryServer attempts to resend events to the given server if we had given up. -func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { +func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName, wasBlacklisted bool) { if oqs.disabled { return } - serverStatistics := oqs.statistics.ForServer(srv) - forceWakeup := serverStatistics.Blacklisted() - serverStatistics.RemoveBlacklist() - serverStatistics.ClearBackoff() - if queue := oqs.getQueue(srv); queue != nil { - queue.wakeQueueIfEventsPending(forceWakeup) + queue.wakeQueueIfEventsPending(wasBlacklisted) } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index c317edc21..36e2ccbc2 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - "sync" "testing" "time" @@ -26,13 +25,11 @@ import ( "gotest.tools/v3/poll" "github.com/matrix-org/gomatrixserverlib" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" @@ -57,7 +54,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } else { // Fake Database - db := createDatabase() + db := test.NewInMemoryFederationDatabase() b := struct { ProcessContext *process.ProcessContext }{ProcessContext: process.NewProcessContext()} @@ -65,220 +62,6 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } -func createDatabase() storage.Database { - return &fakeDatabase{ - pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), - pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), - associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - } -} - -type fakeDatabase struct { - storage.Database - dbMutex sync.Mutex - pendingPDUServers map[gomatrixserverlib.ServerName]struct{} - pendingEDUServers map[gomatrixserverlib.ServerName]struct{} - blacklistedServers map[gomatrixserverlib.ServerName]struct{} - pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent - pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU - associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} - associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} -} - -var nidMutex sync.Mutex -var nid = int64(0) - -func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal([]byte(js), &event); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingPDUs[&receipt] = &event - return &receipt, nil - } - - var edu gomatrixserverlib.EDU - if err := json.Unmarshal([]byte(js), &edu); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingEDUs[&receipt] = &edu - return &receipt, nil - } - - return nil, errors.New("Failed to determine type of json to store") -} - -func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - pduCount := 0 - pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) - if receipts, ok := d.associatedPDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingPDUs[receipt]; ok { - pdus[receipt] = event - pduCount++ - if pduCount == limit { - break - } - } - } - } - return pdus, nil -} - -func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - eduCount := 0 - edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) - if receipts, ok := d.associatedEDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingEDUs[receipt]; ok { - edus[receipt] = event - eduCount++ - if eduCount == limit { - break - } - } - } - } - return edus, nil -} - -func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingPDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedPDUs[destination]; !ok { - d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedPDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("PDU doesn't exist") - } -} - -func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingEDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedEDUs[destination]; !ok { - d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedEDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("EDU doesn't exist") - } -} - -func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if pdus, ok := d.associatedPDUs[serverName]; ok { - for _, receipt := range receipts { - delete(pdus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if edus, ok := d.associatedEDUs[serverName]; ok { - for _, receipt := range receipts { - delete(edus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingPDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingEDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers[serverName] = struct{}{} - return nil -} - -func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - delete(d.blacklistedServers, serverName) - return nil -} - -func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) - return nil -} - -func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - isBlacklisted := false - if _, ok := d.blacklistedServers[serverName]; ok { - isBlacklisted = true - } - - return isBlacklisted, nil -} - type stubFederationRoomServerAPI struct { rsapi.FederationRoomserverAPI } @@ -290,8 +73,10 @@ func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Cont type stubFederationClient struct { api.FederationClient - shouldTxSucceed bool - txCount atomic.Uint32 + shouldTxSucceed bool + shouldTxRelaySucceed bool + txCount atomic.Uint32 + txRelayCount atomic.Uint32 } func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { @@ -304,6 +89,16 @@ func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixse return gomatrixserverlib.RespSend{}, result } +func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) { + var result error + if !f.shouldTxRelaySucceed { + result = fmt.Errorf("relay transaction failed") + } + + f.txRelayCount.Add(1) + return gomatrixserverlib.EmptyResp{}, result +} + func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { t.Helper() content := `{"type":"m.room.message"}` @@ -319,15 +114,18 @@ func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} } -func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { +func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, shouldTxRelaySucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) fc := &stubFederationClient{ - shouldTxSucceed: shouldTxSucceed, - txCount: *atomic.NewUint32(0), + shouldTxSucceed: shouldTxSucceed, + shouldTxRelaySucceed: shouldTxRelaySucceed, + txCount: *atomic.NewUint32(0), + txRelayCount: *atomic.NewUint32(0), } rs := &stubFederationRoomServerAPI{} - stats := statistics.NewStatistics(db, failuresUntilBlacklist) + + stats := statistics.NewStatistics(db, failuresUntilBlacklist, failuresUntilAssumedOffline) signingInfo := []*gomatrixserverlib.SigningIdentity{ { KeyID: "ed21019:auto", @@ -344,7 +142,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -373,7 +171,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -402,7 +200,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -432,7 +230,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -462,7 +260,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -513,7 +311,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -564,7 +362,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -596,7 +394,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -628,7 +426,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -662,7 +460,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -696,7 +494,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -730,8 +528,8 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -747,7 +545,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -781,8 +579,8 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -801,7 +599,7 @@ func TestSendPDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -845,7 +643,7 @@ func TestSendEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -889,7 +687,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -940,7 +738,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -978,7 +776,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { destination := gomatrixserverlib.ServerName("remotehost") destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. defer close() defer func() { @@ -1023,8 +821,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) assert.NoError(t, dbErrPDU) @@ -1038,3 +836,147 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) }) } + +func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} + +func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go new file mode 100644 index 000000000..763656081 --- /dev/null +++ b/federationapi/routing/profile_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeUserAPI struct { + userAPI.FederationUserAPI +} + +func (u *fakeUserAPI) QueryProfile(ctx context.Context, req *userAPI.QueryProfileRequest, res *userAPI.QueryProfileResponse) error { + return nil +} + +func TestHandleQueryProfile(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin))) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go new file mode 100644 index 000000000..21f35bf0c --- /dev/null +++ b/federationapi/routing/query_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedclient "github.com/matrix-org/dendrite/federationapi/api" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeFedClient struct { + fedclient.FederationClient +} + +func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return +} + +func TestHandleQueryDirectory(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server")) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 04eb3d067..5eb30c6ec 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -41,6 +41,12 @@ import ( "github.com/sirupsen/logrus" ) +const ( + SendRouteName = "Send" + QueryDirectoryRouteName = "QueryDirectory" + QueryProfileRouteName = "QueryProfile" +) + // Setup registers HTTP handlers with the given ServeMux. // The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly // path unescape twice (once from the router, once from MakeFedAPI). We need to have this enabled @@ -68,7 +74,7 @@ func Setup( if base.EnableMetrics { prometheus.MustRegister( - pduCountTotal, eduCountTotal, + internal.PDUCountTotal, internal.EDUCountTotal, ) } @@ -138,7 +144,7 @@ func Setup( cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer, ) }, - )).Methods(http.MethodPut, http.MethodOptions) + )).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName) v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -248,7 +254,7 @@ func Setup( httpReq, federation, cfg, rsAPI, fsAPI, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryDirectoryRouteName) v1fedmux.Handle("/query/profile", MakeFedAPI( "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -257,7 +263,7 @@ func Setup( httpReq, userAPI, cfg, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryProfileRouteName) v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index a146d85bd..67b513c90 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -17,26 +17,20 @@ package routing import ( "context" "encoding/json" - "fmt" "net/http" "sync" "time" - "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" - "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" - syncTypes "github.com/matrix-org/dendrite/syncapi/types" ) const ( @@ -56,26 +50,6 @@ const ( MetricsWorkMissingPrevEvents = "missing_prev_events" ) -var ( - pduCountTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_pdus", - Help: "Number of incoming PDUs from remote servers with labels for success", - }, - []string{"status"}, // 'success' or 'total' - ) - eduCountTotal = prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_edus", - Help: "Number of incoming EDUs from remote servers", - }, - ) -) - var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse // Send implements /_matrix/federation/v1/send/{txnID} @@ -123,18 +97,6 @@ func Send( defer close(ch) defer inFlightTxnsPerOrigin.Delete(index) - t := txnReq{ - rsAPI: rsAPI, - keys: keys, - ourServerName: cfg.Matrix.ServerName, - federation: federation, - servers: servers, - keyAPI: keyAPI, - roomsMu: mu, - producer: producer, - inboundPresenceEnabled: cfg.Matrix.Presence.EnableInbound, - } - var txnEvents struct { PDUs []json.RawMessage `json:"pdus"` EDUs []gomatrixserverlib.EDU `json:"edus"` @@ -155,16 +117,23 @@ func Send( } } - // TODO: Really we should have a function to convert FederationRequest to txnReq - t.PDUs = txnEvents.PDUs - t.EDUs = txnEvents.EDUs - t.Origin = request.Origin() - t.TransactionID = txnID - t.Destination = cfg.Matrix.ServerName + t := internal.NewTxnReq( + rsAPI, + keyAPI, + cfg.Matrix.ServerName, + keys, + mu, + producer, + cfg.Matrix.Presence.EnableInbound, + txnEvents.PDUs, + txnEvents.EDUs, + request.Origin(), + txnID, + cfg.Matrix.ServerName) util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction(httpReq.Context()) + resp, jsonErr := t.ProcessTransaction(httpReq.Context()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -181,283 +150,3 @@ func Send( ch <- res return res } - -type txnReq struct { - gomatrixserverlib.Transaction - rsAPI api.FederationRoomserverAPI - keyAPI keyapi.FederationKeyAPI - ourServerName gomatrixserverlib.ServerName - keys gomatrixserverlib.JSONVerifier - federation txnFederationClient - roomsMu *internal.MutexByRoom - servers federationAPI.ServersInRoomProvider - producer *producers.SyncAPIProducer - inboundPresenceEnabled bool -} - -// A subset of FederationClient functionality that txn requires. Useful for testing. -type txnFederationClient interface { - LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, - ) - LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - t.processEDUs(ctx) - }() - - results := make(map[string]gomatrixserverlib.PDUResult) - roomVersions := make(map[string]gomatrixserverlib.RoomVersion) - getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { - if v, ok := roomVersions[roomID]; ok { - return v - } - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) - return "" - } - roomVersions[roomID] = verRes.RoomVersion - return verRes.RoomVersion - } - - for _, pdu := range t.PDUs { - pduCountTotal.WithLabelValues("total").Inc() - var header struct { - RoomID string `json:"room_id"` - } - if err := json.Unmarshal(pdu, &header); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") - // We don't know the event ID at this point so we can't return the - // failure in the PDU results - continue - } - roomVersion := getRoomVersion(header.RoomID) - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) - if err != nil { - if _, ok := err.(gomatrixserverlib.BadJSONError); ok { - // Room version 6 states that homeservers should strictly enforce canonical JSON - // on PDUs. - // - // This enforces that the entire transaction is rejected if a single bad PDU is - // sent. It is unclear if this is the correct behaviour or not. - // - // See https://github.com/matrix-org/synapse/issues/7543 - return nil, &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("PDU contains bad JSON"), - } - } - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) - continue - } - if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { - continue - } - if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: "Forbidden by server ACLs", - } - continue - } - if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - // pass the event to the roomserver which will do auth checks - // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently - // discarded by the caller of this function - if err = api.SendEvents( - ctx, - t.rsAPI, - api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - event.Headered(roomVersion), - }, - t.Destination, - t.Origin, - api.DoNotSendToOtherServers, - nil, - true, - ); err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - results[event.EventID()] = gomatrixserverlib.PDUResult{} - pduCountTotal.WithLabelValues("success").Inc() - } - - wg.Wait() - return &gomatrixserverlib.RespSend{PDUs: results}, nil -} - -// nolint:gocyclo -func (t *txnReq) processEDUs(ctx context.Context) { - for _, e := range t.EDUs { - eduCountTotal.Inc() - switch e.Type { - case gomatrixserverlib.MTyping: - // https://matrix.org/docs/spec/server_server/latest#typing-notifications - var typingPayload struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` - Typing bool `json:"typing"` - } - if err := json.Unmarshal(e.Content, &typingPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") - } - case gomatrixserverlib.MDirectToDevice: - // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema - var directPayload gomatrixserverlib.ToDeviceMessage - if err := json.Unmarshal(e.Content, &directPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - for userID, byUser := range directPayload.Messages { - for deviceID, message := range byUser { - // TODO: check that the user and the device actually exist here - if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": directPayload.Sender, - "user_id": userID, - "device_id": deviceID, - }).Error("Failed to send send-to-device event to JetStream") - } - } - } - case gomatrixserverlib.MDeviceListUpdate: - if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") - } - case gomatrixserverlib.MReceipt: - // https://matrix.org/docs/spec/server_server/r0.1.4#receipts - payload := map[string]types.FederationReceiptMRead{} - - if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") - continue - } - - for roomID, receipt := range payload { - for userID, mread := range receipt.User { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") - continue - } - if t.Origin != domain { - util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) - continue - } - if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": t.Origin, - "user_id": userID, - "room_id": roomID, - "events": mread.EventIDs, - }).Error("Failed to send receipt event to JetStream") - continue - } - } - } - case types.MSigningKeyUpdate: - if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - logrus.WithError(err).Errorf("Failed to process signing key update") - } - case gomatrixserverlib.MPresence: - if t.inboundPresenceEnabled { - if err := t.processPresence(ctx, e); err != nil { - logrus.WithError(err).Errorf("Failed to process presence update") - } - } - default: - util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") - } - } -} - -// processPresence handles m.receipt events -func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { - payload := types.Presence{} - if err := json.Unmarshal(e.Content, &payload); err != nil { - return err - } - for _, content := range payload.Push { - if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - presence, ok := syncTypes.PresenceFromString(content.Presence) - if !ok { - continue - } - if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { - return err - } - } - return nil -} - -// processReceiptEvent sends receipt events to JetStream -func (t *txnReq) processReceiptEvent(ctx context.Context, - userID, roomID, receiptType string, - timestamp gomatrixserverlib.Timestamp, - eventIDs []string, -) error { - if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { - return nil - } else if serverName == t.ourServerName { - return nil - } else if serverName != t.Origin { - return nil - } - // store every event - for _, eventID := range eventIDs { - if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { - return fmt.Errorf("unable to set receipt event: %w", err) - } - } - - return nil -} diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index b8bfe0221..d7feee0e5 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -1,552 +1,87 @@ -package routing +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test import ( - "context" + "encoding/hex" "encoding/json" - "fmt" + "net/http/httptest" "testing" - "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/api" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") - testDestination = gomatrixserverlib.ServerName("white.orchard") + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") ) -var ( - testRoomVersion = gomatrixserverlib.RoomVersionV1 - testData = []json.RawMessage{ - []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), - // messages - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), - } - testEvents = []*gomatrixserverlib.HeaderedEvent{} - testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) -) +type sendContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} -func init() { - for _, j := range testData { - e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) +func TestHandleSend(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedapi := fedAPI.NewInternalAPI(base, nil, nil, nil, nil, true) + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, nil, nil, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234") + content := sendContent{} + err := req.SetContent(content) if err != nil { - panic("cannot load test data: " + err.Error()) + t.Fatalf("Error: %s", err.Error()) } - h := e.Headered(testRoomVersion) - testEvents = append(testEvents, h) - if e.StateKey() != nil { - testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: e.Type(), - StateKey: *e.StateKey(), - }] = h + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) } - } + vars := map[string]string{"txnID": "1234"} + w := httptest.NewRecorder() + httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + assert.Equal(t, 200, res.StatusCode) + }) } - -type testRoomserverAPI struct { - api.RoomserverInternalAPITrace - inputRoomEvents []api.InputRoomEvent - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse -} - -func (t *testRoomserverAPI) InputRoomEvents( - ctx context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) error { - t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) - for _, ire := range request.InputRoomEvents { - fmt.Println("InputRoomEvents: ", ire.Event.EventID()) - } - return nil -} - -// Query the latest events and state for a room from the room server. -func (t *testRoomserverAPI) QueryLatestEventsAndState( - ctx context.Context, - request *api.QueryLatestEventsAndStateRequest, - response *api.QueryLatestEventsAndStateResponse, -) error { - r := t.queryLatestEventsAndState(request) - response.RoomExists = r.RoomExists - response.RoomVersion = testRoomVersion - response.LatestEvents = r.LatestEvents - response.StateEvents = r.StateEvents - response.Depth = r.Depth - return nil -} - -// Query the state after a list of events in a room from the room server. -func (t *testRoomserverAPI) QueryStateAfterEvents( - ctx context.Context, - request *api.QueryStateAfterEventsRequest, - response *api.QueryStateAfterEventsResponse, -) error { - response.RoomVersion = testRoomVersion - res := t.queryStateAfterEvents(request) - response.PrevEventsExist = res.PrevEventsExist - response.RoomExists = res.RoomExists - response.StateEvents = res.StateEvents - return nil -} - -// Query a list of events by event ID. -func (t *testRoomserverAPI) QueryEventsByID( - ctx context.Context, - request *api.QueryEventsByIDRequest, - response *api.QueryEventsByIDResponse, -) error { - res := t.queryEventsByID(request) - response.Events = res.Events - return nil -} - -// Query if a server is joined to a room -func (t *testRoomserverAPI) QueryServerJoinedToRoom( - ctx context.Context, - request *api.QueryServerJoinedToRoomRequest, - response *api.QueryServerJoinedToRoomResponse, -) error { - response.RoomExists = true - response.IsInRoom = true - return nil -} - -// Asks for the room version for a given room. -func (t *testRoomserverAPI) QueryRoomVersionForRoom( - ctx context.Context, - request *api.QueryRoomVersionForRoomRequest, - response *api.QueryRoomVersionForRoomResponse, -) error { - response.RoomVersion = testRoomVersion - return nil -} - -func (t *testRoomserverAPI) QueryServerBannedFromRoom( - ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, -) error { - res.Banned = false - return nil -} - -type txnFedClient struct { - state map[string]gomatrixserverlib.RespState // event_id to response - stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response - getEvent map[string]gomatrixserverlib.Transaction // event_id to response - getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, -) { - fmt.Println("testFederationClient.LookupState", eventID) - r, ok := c.state[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { - fmt.Println("testFederationClient.LookupStateIDs", eventID) - r, ok := c.stateIDs[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { - fmt.Println("testFederationClient.GetEvent", eventID) - r, ok := c.getEvent[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { - return c.getMissingEvents(missing) -} - -func mustCreateTransaction(rsAPI api.FederationRoomserverAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { - t := &txnReq{ - rsAPI: rsAPI, - keys: &test.NopJSONVerifier{}, - federation: fedClient, - roomsMu: internal.NewMutexByRoom(), - } - t.PDUs = pdus - t.Origin = testOrigin - t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - t.Destination = testDestination - return t -} - -func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) { - res, err := txn.processTransaction(context.Background()) - if err != nil { - t.Errorf("txn.processTransaction returned an error: %v", err) - return - } - if len(res.PDUs) != len(txn.PDUs) { - t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) - return - } -NextPDU: - for eventID, result := range res.PDUs { - if result.Error == "" { - continue - } - for _, eventIDWantError := range pdusWithErrors { - if eventID == eventIDWantError { - break NextPDU - } - } - t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) - } -} - -/* -func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []*gomatrixserverlib.HeaderedEvent) { -NextTuple: - for _, t := range tuples { - for _, o := range omitTuples { - if t == o { - break NextTuple - } - } - h, ok := testStateEvents[t] - if ok { - result = append(result, h) - } - } - return -} -*/ - -func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { - for _, g := range got { - fmt.Println("GOT ", g.Event.EventID()) - } - if len(got) != len(want) { - t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) - return - } - for i := range got { - if got[i].Event.EventID() != want[i].EventID() { - t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) - } - } -} - -// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on -// to the roomserver. It's the most basic test possible. -func TestBasicTransaction(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver -// as it does the auth check. -func TestTransactionFailAuthChecks(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, []string{}) - // expect message to be sent to the roomserver - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, -// we request them from /get_missing_events. It works by setting PrevEventsExist=false in the roomserver query response, -// resulting in a call to /get_missing_events which returns the missing prev event. Both events should be processed in -// topological order and sent to the roomserver. -/* -func TestTransactionFetchMissingPrevEvents(t *testing.T) { - haveEvent := testEvents[len(testEvents)-3] - prevEvent := testEvents[len(testEvents)-2] - inputEvent := testEvents[len(testEvents)-1] - - var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions - rsAPI = &testRoomserverAPI{ - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - res := api.QueryEventsByIDResponse{} - for _, ev := range testEvents { - for _, id := range req.EventIDs { - if ev.EventID() == id { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: true, - StateEvents: testEvents[:5], - } - }, - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - missingPrevEvent := []string{"missing_prev_event"} - if len(req.PrevEventIDs) == 1 { - switch req.PrevEventIDs[0] { - case haveEvent.EventID(): - missingPrevEvent = []string{} - case prevEvent.EventID(): - // we only have this event if we've been send prevEvent - if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { - missingPrevEvent = []string{} - } - } - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: haveEvent.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - haveEvent.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, nil), - } - }, - } - - cli := &txnFedClient{ - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, haveEvent.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{inputEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, inputEvent.EventID()) - } - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - prevEvent.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - inputEvent.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) -} - -// The purpose of this test is to check that when there are missing prev_events and we still haven't been able to fill -// in the hole with /get_missing_events that the state BEFORE the events we want to persist is fetched via /state_ids -// and /event. It works by setting PrevEventsExist=false in the roomserver query response, resulting in -// a call to /get_missing_events which returns 1 out of the 2 events it needs to fill in the gap. Synapse and Dendrite -// both give up after 1x /get_missing_events call, relying on requesting the state AFTER the missing event in order to -// continue. The DAG looks something like: -// FE GME TXN -// A ---> B ---> C ---> D -// TXN=event in the txn, GME=response to /get_missing_events, FE=roomserver's forward extremity. Should result in: -// - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event. -// - state resolution is done to check C is allowed. -// This results in B being sent as an outlier FIRST, then C,D. -func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { - eventA := testEvents[len(testEvents)-5] - // this is also len(testEvents)-4 - eventB := testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }] - eventC := testEvents[len(testEvents)-3] - eventD := testEvents[len(testEvents)-2] - fmt.Println("a:", eventA.EventID()) - fmt.Println("b:", eventB.EventID()) - fmt.Println("c:", eventC.EventID()) - fmt.Println("d:", eventD.EventID()) - var rsAPI *testRoomserverAPI - rsAPI = &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }, - } - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - omitTuples = nil // include event B now - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - var stateEvents []*gomatrixserverlib.HeaderedEvent - if prevEventExists { - stateEvents = fromStateTuples(req.StateToFetch, omitTuples) - } - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: prevEventExists, - RoomExists: true, - StateEvents: stateEvents, - } - }, - - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - - var missingPrevEvent []string - if !prevEventExists { - missingPrevEvent = []string{"test"} - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}, - } - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: eventA.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - eventA.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, omitTuples), - } - }, - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - var res api.QueryEventsByIDResponse - fmt.Println("queryEventsByID ", req.EventIDs) - for _, wantEventID := range req.EventIDs { - for _, ev := range testStateEvents { - // roomserver is missing the power levels event unless it's been sent to us recently as an outlier - if wantEventID == eventB.EventID() { - fmt.Println("Asked for pl event") - for _, inEv := range rsAPI.inputRoomEvents { - fmt.Println("recv ", inEv.Event.EventID()) - if inEv.Event.EventID() == wantEventID { - res.Events = append(res.Events, inEv.Event) - break - } - } - continue - } - if ev.EventID() == wantEventID { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - } - // /state_ids for event B returns every state event but B (it's the state before) - var authEventIDs []string - var stateEventIDs []string - for _, ev := range testStateEvents { - if ev.EventID() == eventB.EventID() { - continue - } - // state res checks what auth events you give it, and this isn't a valid auth event - if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility { - authEventIDs = append(authEventIDs, ev.EventID()) - } - stateEventIDs = append(stateEventIDs, ev.EventID()) - } - cli := &txnFedClient{ - stateIDs: map[string]gomatrixserverlib.RespStateIDs{ - eventB.EventID(): { - StateEventIDs: stateEventIDs, - AuthEventIDs: authEventIDs, - }, - }, - // /event for event B returns it - getEvent: map[string]gomatrixserverlib.Transaction{ - eventB.EventID(): { - PDUs: []json.RawMessage{ - eventB.JSON(), - }, - }, - }, - // /get_missing_events should be done exactly once - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{eventA.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, eventA.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{eventD.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, eventD.EventID()) - } - // just return event C, not event B so /state_ids logic kicks in as there will STILL be missing prev_events - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - eventC.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - eventD.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) -} -*/ diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 0a44375c6..866c09336 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -1,6 +1,7 @@ package statistics import ( + "context" "math" "math/rand" "sync" @@ -28,14 +29,24 @@ type Statistics struct { // just blacklist the host altogether? The backoff is exponential, // so the max time here to attempt is 2**failures seconds. FailuresUntilBlacklist uint32 + + // How many times should we tolerate consecutive failures before we + // mark the destination as offline. At this point we should attempt + // to send messages to the user's async relay servers if we know them. + FailuresUntilAssumedOffline uint32 } -func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics { +func NewStatistics( + db storage.Database, + failuresUntilBlacklist uint32, + failuresUntilAssumedOffline uint32, +) Statistics { return Statistics{ - DB: db, - FailuresUntilBlacklist: failuresUntilBlacklist, - backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), - servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + FailuresUntilAssumedOffline: failuresUntilAssumedOffline, + backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), } } @@ -50,8 +61,9 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS if !found { s.mutex.Lock() server = &ServerStatistics{ - statistics: s, - serverName: serverName, + statistics: s, + serverName: serverName, + knownRelayServers: []gomatrixserverlib.ServerName{}, } s.servers[serverName] = server s.mutex.Unlock() @@ -61,24 +73,49 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS } else { server.blacklisted.Store(blacklisted) } + assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName) + } else { + server.assumedOffline.Store(assumedOffline) + } + + knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName) + } else { + server.relayMutex.Lock() + server.knownRelayServers = knownRelayServers + server.relayMutex.Unlock() + } } return server } +type SendMethod uint8 + +const ( + SendDirect SendMethod = iota + SendViaRelay +) + // ServerStatistics contains information about our interactions with a // remote federated host, e.g. how many times we were successful, how // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - successCounter atomic.Uint32 // how many times have we succeeded? - backoffNotifier func() // notifies destination queue when backoff completes - notifierMutex sync.Mutex + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + assumedOffline atomic.Bool // is the node assumed to be offline + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes + notifierMutex sync.Mutex + knownRelayServers []gomatrixserverlib.ServerName + relayMutex sync.Mutex } const maxJitterMultiplier = 1.4 @@ -113,13 +150,19 @@ func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) { // attempt, which increases the sent counter and resets the idle and // failure counters. If a host was blacklisted at this point then // we will unblacklist it. -func (s *ServerStatistics) Success() { +// `relay` specifies whether the success was to the actual destination +// or one of their relay servers. +func (s *ServerStatistics) Success(method SendMethod) { s.cancel() s.backoffCount.Store(0) - s.successCounter.Inc() - if s.statistics.DB != nil { - if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + // NOTE : Sending to the final destination vs. a relay server has + // slightly different semantics. + if method == SendDirect { + s.successCounter.Inc() + if s.blacklisted.Load() && s.statistics.DB != nil { + if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + } } } } @@ -139,7 +182,18 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { // start a goroutine which will wait out the backoff and // unset the backoffStarted flag when done. if s.backoffStarted.CompareAndSwap(false, true) { - if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist { + backoffCount := s.backoffCount.Inc() + + if backoffCount >= s.statistics.FailuresUntilAssumedOffline { + s.assumedOffline.CompareAndSwap(false, true) + if s.statistics.DB != nil { + if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName) + } + } + } + + if backoffCount >= s.statistics.FailuresUntilBlacklist { s.blacklisted.Store(true) if s.statistics.DB != nil { if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { @@ -157,13 +211,21 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { s.backoffUntil.Store(until) s.statistics.backoffMutex.Lock() - defer s.statistics.backoffMutex.Unlock() s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished) + s.statistics.backoffMutex.Unlock() } return s.backoffUntil.Load().(time.Time), false } +// MarkServerAlive removes the assumed offline and blacklisted statuses from this server. +// Returns whether the server was blacklisted before this point. +func (s *ServerStatistics) MarkServerAlive() bool { + s.removeAssumedOffline() + wasBlacklisted := s.removeBlacklist() + return wasBlacklisted +} + // ClearBackoff stops the backoff timer for this destination if it is running // and removes the timer from the backoffTimers map. func (s *ServerStatistics) ClearBackoff() { @@ -191,13 +253,13 @@ func (s *ServerStatistics) backoffFinished() { } // BackoffInfo returns information about the current or previous backoff. -// Returns the last backoffUntil time and whether the server is currently blacklisted or not. -func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) { +// Returns the last backoffUntil time. +func (s *ServerStatistics) BackoffInfo() *time.Time { until, ok := s.backoffUntil.Load().(time.Time) if ok { - return &until, s.blacklisted.Load() + return &until } - return nil, s.blacklisted.Load() + return nil } // Blacklisted returns true if the server is blacklisted and false @@ -206,10 +268,33 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } -// RemoveBlacklist removes the blacklisted status from the server. -func (s *ServerStatistics) RemoveBlacklist() { +// AssumedOffline returns true if the server is assumed offline and false +// otherwise. +func (s *ServerStatistics) AssumedOffline() bool { + return s.assumedOffline.Load() +} + +// removeBlacklist removes the blacklisted status from the server. +// Returns whether the server was blacklisted. +func (s *ServerStatistics) removeBlacklist() bool { + var wasBlacklisted bool + + if s.Blacklisted() { + wasBlacklisted = true + _ = s.statistics.DB.RemoveServerFromBlacklist(s.serverName) + } s.cancel() s.backoffCount.Store(0) + + return wasBlacklisted +} + +// removeAssumedOffline removes the assumed offline status from the server. +func (s *ServerStatistics) removeAssumedOffline() { + if s.AssumedOffline() { + _ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName) + } + s.assumedOffline.Store(false) } // SuccessCount returns the number of successful requests. This is @@ -217,3 +302,46 @@ func (s *ServerStatistics) RemoveBlacklist() { func (s *ServerStatistics) SuccessCount() uint32 { return s.successCounter.Load() } + +// KnownRelayServers returns the list of relay servers associated with this +// server. +func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName { + s.relayMutex.Lock() + defer s.relayMutex.Unlock() + return s.knownRelayServers +} + +func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) { + seenSet := make(map[gomatrixserverlib.ServerName]bool) + uniqueList := []gomatrixserverlib.ServerName{} + for _, srv := range relayServers { + if seenSet[srv] { + continue + } + seenSet[srv] = true + uniqueList = append(uniqueList, srv) + } + + err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList) + if err != nil { + logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList) + return + } + + for _, newServer := range uniqueList { + alreadyKnown := false + knownRelayServers := s.KnownRelayServers() + for _, srv := range knownRelayServers { + if srv == newServer { + alreadyKnown = true + } + } + if !alreadyKnown { + { + s.relayMutex.Lock() + s.knownRelayServers = append(s.knownRelayServers, newServer) + s.relayMutex.Unlock() + } + } + } +} diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 6aa997f44..183b9aa0c 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -4,17 +4,26 @@ import ( "math" "testing" "time" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 ) func TestBackoff(t *testing.T) { - stats := NewStatistics(nil, 7) + stats := NewStatistics(nil, FailuresUntilBlacklist, FailuresUntilAssumedOffline) server := ServerStatistics{ statistics: &stats, serverName: "test.com", } // Start by checking that counting successes works. - server.Success() + server.Success(SendDirect) if successes := server.SuccessCount(); successes != 1 { t.Fatalf("Expected success count 1, got %d", successes) } @@ -31,9 +40,8 @@ func TestBackoff(t *testing.T) { // side effects since a backoff is already in progress. If it does // then we'll fail. until, blacklisted := server.Failure() - - // Get the duration. - _, blacklist := server.BackoffInfo() + blacklist := server.Blacklisted() + assumedOffline := server.AssumedOffline() duration := time.Until(until) // Unset the backoff, or otherwise our next call will think that @@ -41,16 +49,43 @@ func TestBackoff(t *testing.T) { server.cancel() server.backoffStarted.Store(false) + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assuming the destination was offline but didn't", i) + } + } + + // Check if we should be assumed offline by now. + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assumed offline but didn't", i) + } else { + t.Logf("Backoff %d is assumed offline as expected", i) + } + } else { + if assumedOffline { + t.Fatalf("Backoff %d should not have resulted in assumed offline but did", i) + } else { + t.Logf("Backoff %d is not assumed offline as expected", i) + } + } + // Check if we should be blacklisted by now. if i >= stats.FailuresUntilBlacklist { if !blacklist { t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) } else if blacklist != blacklisted { - t.Fatalf("BackoffInfo and Failure returned different blacklist values") + t.Fatalf("Blacklisted and Failure returned different blacklist values") } else { t.Logf("Backoff %d is blacklisted as expected", i) continue } + } else { + if blacklist { + t.Fatalf("Backoff %d should not have resulted in blacklist but did", i) + } else { + t.Logf("Backoff %d is not blacklisted as expected", i) + } } // Check if the duration is what we expect. @@ -69,3 +104,14 @@ func TestBackoff(t *testing.T) { } } } + +func TestRelayServersListing(t *testing.T) { + stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline) + server := ServerStatistics{statistics: &stats} + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers := server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers = server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) +} diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 2b4d905fc..4f5300af1 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -20,11 +20,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" ) type Database interface { + P2PDatabase gomatrixserverlib.KeyDatabase UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) @@ -34,16 +35,16 @@ type Database interface { // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) - StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) + StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error) - GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) - GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) + GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) + GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error - AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error - CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error + CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error + CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) @@ -54,6 +55,18 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) + // Adds the server to the list of assumed offline servers. + // If the server already exists in the table, nothing happens and returns success. + SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Removes the server from the list of assumed offline servers. + // If the server doesn't exist in the table, nothing happens and returns success. + RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Purges all entries from the assumed offline table. + RemoveAllServersAssumedOffline(ctx context.Context) error + // Gets whether the provided server is present in the table. + // If it is present, returns true. If not, returns false. + IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error) + AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) @@ -74,3 +87,21 @@ type Database interface { PurgeRoom(ctx context.Context, roomID string) error } + +type P2PDatabase interface { + // Stores the given list of servers as relay servers for the provided destination server. + // Providing duplicates will only lead to a single entry and won't lead to an error. + P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Get the list of relay servers associated with the provided destination server. + // If no entry exists in the table, an empty list is returned and does not result in an error. + P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + + // Deletes any entries for the provided destination server that match the provided relayServers list. + // If any of the provided servers don't match an entry, nothing happens and no error is returned. + P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Deletes all entries for the provided destination server. + // If the destination server doesn't exist in the table, nothing happens and no error is returned. + P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error +} diff --git a/federationapi/storage/postgres/assumed_offline_table.go b/federationapi/storage/postgres/assumed_offline_table.go new file mode 100644 index 000000000..5695d2e54 --- /dev/null +++ b/federationapi/storage/postgres/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "TRUNCATE federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/postgres/relay_servers_table.go b/federationapi/storage/postgres/relay_servers_table.go new file mode 100644 index 000000000..f7267978f --- /dev/null +++ b/federationapi/storage/postgres/relay_servers_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + deleteRelayServersStmt *sql.Stmt + deleteAllRelayServersStmt *sql.Stmt +} + +func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteRelayServersStmt, deleteRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers)) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index fe84e932e..b81f128e7 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -62,6 +62,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewPostgresAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewPostgresRelayServersTable(d.db) + if err != nil { + return nil, err + } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err @@ -104,6 +112,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationInboundPeeks: inboundPeeks, FederationOutboundPeeks: outboundPeeks, NotaryServerKeysJSON: notaryJSON, diff --git a/federationapi/storage/shared/receipt/receipt.go b/federationapi/storage/shared/receipt/receipt.go new file mode 100644 index 000000000..b347269c1 --- /dev/null +++ b/federationapi/storage/shared/receipt/receipt.go @@ -0,0 +1,42 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// A Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. +// We don't actually export the NIDs but we need the caller to be able +// to pass them back so that we can clean up if the transaction sends +// successfully. + +package receipt + +import "fmt" + +// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry +// in some database table. +// The internal nid value cannot be modified after a Receipt has been created. +// This guarantees a receipt will always refer to the same table entry that it was created +// to represent. +type Receipt struct { + nid int64 +} + +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + +func (r *Receipt) GetNID() int64 { + return r.nid +} + +func (r *Receipt) String() string { + return fmt.Sprintf("%d", r.nid) +} diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 6cda55725..6769637bc 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal/caching" @@ -37,6 +38,8 @@ type Database struct { FederationQueueJSON tables.FederationQueueJSON FederationJoinedHosts tables.FederationJoinedHosts FederationBlacklist tables.FederationBlacklist + FederationAssumedOffline tables.FederationAssumedOffline + FederationRelayServers tables.FederationRelayServers FederationOutboundPeeks tables.FederationOutboundPeeks FederationInboundPeeks tables.FederationInboundPeeks NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON @@ -44,22 +47,6 @@ type Database struct { ServerSigningKeys tables.FederationServerSigningKeys } -// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. -// We don't actually export the NIDs but we need the caller to be able -// to pass them back so that we can clean up if the transaction sends -// successfully. -type Receipt struct { - nid int64 -} - -func NewReceipt(nid int64) Receipt { - return Receipt{nid: nid} -} - -func (r *Receipt) String() string { - return fmt.Sprintf("%d", r.nid) -} - // UpdateRoom updates the joined hosts for a room and returns what the joined // hosts were before the update, or nil if this was a duplicate message. // This is called when we receive a message from kafka, so we pass in @@ -113,11 +100,18 @@ func (d *Database) GetJoinedHosts( // GetAllJoinedHosts returns the currently joined hosts for // all rooms known to the federation sender. // Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetJoinedHostsForRooms( + ctx context.Context, + roomIDs []string, + excludeSelf, + excludeBlacklisted bool, +) ([]gomatrixserverlib.ServerName, error) { servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err @@ -139,7 +133,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, // metadata entries. func (d *Database) StoreJSON( ctx context.Context, js string, -) (*Receipt, error) { +) (*receipt.Receipt, error) { var nid int64 var err error _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -149,18 +143,21 @@ func (d *Database) StoreJSON( if err != nil { return nil, fmt.Errorf("d.insertQueueJSON: %w", err) } - return &Receipt{ - nid: nid, - }, nil + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil } -func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) }) } -func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) }) @@ -172,51 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error { }) } -func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } -func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveAllServersAssumedOffline( + ctx context.Context, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn) + }) +} + +func (d *Database) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName) +} + +func (d *Database) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName) +} + +func (d *Database) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PRemoveAllRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName) + }) +} + +func (d *Database) AddOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { +func (d *Database) GetOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID, + peekID string, +) (*types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { +func (d *Database) GetOutboundPeeks( + ctx context.Context, + roomID string, +) ([]types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID) } -func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) AddInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { +func (d *Database) GetInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, +) (*types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { +func (d *Database) GetInboundPeeks( + ctx context.Context, + roomID string, +) ([]types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID) } -func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { +func (d *Database) UpdateNotaryKeys( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + serverKeys gomatrixserverlib.ServerKeys, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { validUntil := serverKeys.ValidUntilTS // Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid. @@ -251,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv } func (d *Database) GetNotaryKeys( - ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID, + ctx context.Context, + serverName gomatrixserverlib.ServerName, + optKeyIDs []gomatrixserverlib.KeyID, ) (sks []gomatrixserverlib.ServerKeys, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs) diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index be8355f31..cff1ade6f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -22,6 +22,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{ func (d *Database) AssociateEDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, ) error { @@ -62,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations( var err error for destination := range destinations { err = d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ) } return err @@ -81,10 +82,10 @@ func (d *Database) GetPendingEDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - edus map[*Receipt]*gomatrixserverlib.EDU, + edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error, ) { - edus = make(map[*Receipt]*gomatrixserverlib.EDU) + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) if err != nil { @@ -94,7 +95,8 @@ func (d *Database) GetPendingEDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok { - edus[&Receipt{nid}] = edu + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = edu } else { retrieve = append(retrieve, nid) } @@ -110,7 +112,8 @@ func (d *Database) GetPendingEDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - edus[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = &event d.Cache.StoreFederationQueuedEDU(nid, &event) } @@ -124,7 +127,7 @@ func (d *Database) GetPendingEDUs( func (d *Database) CleanEDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -132,7 +135,7 @@ func (d *Database) CleanEDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index da4cb979d..854e00553 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -30,17 +31,17 @@ import ( func (d *Database) AssociatePDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error for destination := range destinations { err = d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - "", // transaction ID - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table ) } return err @@ -54,7 +55,7 @@ func (d *Database) GetPendingPDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - events map[*Receipt]*gomatrixserverlib.HeaderedEvent, + events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error, ) { // Strictly speaking this doesn't need to be using the writer @@ -62,7 +63,7 @@ func (d *Database) GetPendingPDUs( // a guarantee of transactional isolation, it's actually useful // to know in SQLite mode that nothing else is trying to modify // the database. - events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent) + events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit) if err != nil { @@ -72,7 +73,8 @@ func (d *Database) GetPendingPDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok { - events[&Receipt{nid}] = event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = event } else { retrieve = append(retrieve, nid) } @@ -88,7 +90,8 @@ func (d *Database) GetPendingPDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - events[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = &event d.Cache.StoreFederationQueuedPDU(nid, &event) } @@ -103,7 +106,7 @@ func (d *Database) GetPendingPDUs( func (d *Database) CleanPDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -111,7 +114,7 @@ func (d *Database) CleanPDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/sqlite3/assumed_offline_table.go b/federationapi/storage/sqlite3/assumed_offline_table.go new file mode 100644 index 000000000..ff2afb4da --- /dev/null +++ b/federationapi/storage/sqlite3/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/sqlite3/relay_servers_table.go b/federationapi/storage/sqlite3/relay_servers_table.go new file mode 100644 index 000000000..27c3cca2c --- /dev/null +++ b/federationapi/storage/sqlite3/relay_servers_table.go @@ -0,0 +1,148 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + // deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic + deleteAllRelayServersStmt *sql.Stmt +} + +func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1) + deleteStmt, err := s.db.Prepare(deleteSQL) + if err != nil { + return err + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + params := make([]interface{}, len(relayServers)+1) + params[0] = serverName + for i, v := range relayServers { + params[i+1] = v + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index d13b5defc..1e7e41a2c 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -1,5 +1,4 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -61,6 +60,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewSQLiteRelayServersTable(d.db) + if err != nil { + return nil, err + } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err @@ -103,6 +110,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationOutboundPeeks: outboundPeeks, FederationInboundPeeks: inboundPeeks, NotaryServerKeysJSON: notaryKeys, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 5b57d40d4..1d2a13e81 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -6,14 +6,13 @@ import ( "testing" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/stretchr/testify/assert" - "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { @@ -246,3 +245,99 @@ func TestInboundPeeking(t *testing.T) { assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } + +func TestServersAssumedOffline(t *testing.T) { + server1 := gomatrixserverlib.ServerName("server1") + server2 := gomatrixserverlib.ServerName("server2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + // Set server1 & server2 as assumed offline. + err := db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + err = db.SetServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + + // Ensure both servers are assumed offline. + isOffline, err := db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Set server1 as not assumed offline. + err = db.RemoveServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Re-set server1 as assumed offline. + err = db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure server1 is assumed offline. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + + err = db.RemoveAllServersAssumedOffline(context.Background()) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.False(t, isOffline) + }) +} + +func TestRelayServersStored(t *testing.T) { + server := gomatrixserverlib.ServerName("server") + relayServer1 := gomatrixserverlib.ServerName("relayserver1") + relayServer2 := gomatrixserverlib.ServerName("relayserver2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + + err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + + err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + assert.Equal(t, relayServer2, relayServers[1]) + + err = db.P2PRemoveAllRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 2b36edb46..762504e45 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -49,6 +49,19 @@ type FederationQueueJSON interface { SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) } +type FederationQueueTransactions interface { + InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +type FederationTransactionJSON interface { + InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} + type FederationJoinedHosts interface { InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error @@ -66,6 +79,20 @@ type FederationBlacklist interface { DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } +type FederationAssumedOffline interface { + InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) + DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error +} + +type FederationRelayServers interface { + InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error +} + type FederationOutboundPeeks interface { InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go new file mode 100644 index 000000000..b41211551 --- /dev/null +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -0,0 +1,224 @@ +package tables_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + server1 = "server1" + server2 = "server2" + server3 = "server3" + server4 = "server4" +) + +type RelayServersDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.FederationRelayServers +} + +func mustCreateRelayServersTable( + t *testing.T, + dbType test.DBType, +) (database RelayServersDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.FederationRelayServers + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayServersTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayServersTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayServersDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func Equal(a, b []gomatrixserverlib.ServerName) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func TestShouldInsertRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldInsertRelayServersWithDuplicates(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2} + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + // Insert the same list again, this shouldn't fail and should have no effect. + err = db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldGetRelayServersUnknownDestination(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + + // Query relay servers for a destination that doesn't exist in the table. + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, []gomatrixserverlib.ServerName{}) { + t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers) + } + }) +} + +func TestShouldDeleteCorrectRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + relayServers1 := []gomatrixserverlib.ServerName{server2, server3} + relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4} + + err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error()) + } + + expectedRelayServers := []gomatrixserverlib.ServerName{server3} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldDeleteAllRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteAllRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + + expectedRelayServers1 := []gomatrixserverlib.ServerName{} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers1) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} diff --git a/go.mod b/go.mod index a86dd2cb8..871e94eb3 100644 --- a/go.mod +++ b/go.mod @@ -22,9 +22,9 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 - github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 - github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f + github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb + github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.15 github.com/nats-io/nats-server/v2 v2.9.8 github.com/nats-io/nats.go v1.20.0 @@ -37,17 +37,17 @@ require ( github.com/prometheus/client_golang v1.13.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.1 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.1.0 + golang.org/x/crypto v0.5.0 golang.org/x/image v0.1.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e - golang.org/x/net v0.1.0 - golang.org/x/term v0.1.0 + golang.org/x/net v0.5.0 + golang.org/x/term v0.4.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 @@ -119,12 +119,12 @@ require ( github.com/prometheus/procfs v0.8.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect go.etcd.io/bbolt v1.3.6 // indirect golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 // indirect golang.org/x/mod v0.6.0 // indirect - golang.org/x/sys v0.1.0 // indirect - golang.org/x/text v0.4.0 // indirect + golang.org/x/sys v0.4.0 // indirect + golang.org/x/text v0.6.0 // indirect golang.org/x/time v0.1.0 // indirect golang.org/x/tools v0.2.0 // indirect google.golang.org/protobuf v1.28.1 // indirect diff --git a/go.sum b/go.sum index e5cd67bed..1ca3e8a80 100644 --- a/go.sum +++ b/go.sum @@ -348,16 +348,12 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45 h1:zGrmcm2M4F4f+zk5JXAkw3oHa/zXhOh5XVGBdl7GdPo= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 h1:P7me2oCmksST9B4+1I1nA+XrnDQwIqAWmy6ntQrXwc8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f h1:niRWEVkeeekpjxwnMhKn8PD0PUloDsNXP8W+Ez/co/M= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb h1:2L+ltfNKab56FoBBqAvbBLjoAbxwwoZie+B8d+Mp3JI= +github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -494,12 +490,13 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= @@ -543,8 +540,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -625,8 +622,8 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= +golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -701,12 +698,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.4.0 h1:O7UWfv5+A2qiuulQk30kVinPoMtoIPeVaKLEgLpVkvg= +golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -714,8 +711,9 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/log.go b/internal/log.go index da6e20418..9e8656c5b 100644 --- a/internal/log.go +++ b/internal/log.go @@ -101,6 +101,8 @@ func SetupPprof() { // SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded. func SetupStdLogging() { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() logrus.SetReportCaller(true) logrus.SetFormatter(&utcFormatter{ &logrus.TextFormatter{ diff --git a/internal/log_unix.go b/internal/log_unix.go index 8f34c320d..859427041 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -32,6 +32,8 @@ import ( // If something fails here it means that the logging was improperly configured, // so we just exit with the error func SetupHookLogging(hooks []config.LogrusHook, componentName string) { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() for _, hook := range hooks { // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) @@ -85,8 +87,6 @@ func checkSyslogHookParams(params map[string]interface{}) { } func setupStdLogHook(level logrus.Level) { - levelLogAddedMu.Lock() - defer levelLogAddedMu.Unlock() if stdLevelLogAdded[level] { return } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go new file mode 100644 index 000000000..95673fc14 --- /dev/null +++ b/internal/transactionrequest.go @@ -0,0 +1,356 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/federationapi/types" + keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/roomserver/api" + syncTypes "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" +) + +var ( + PDUCountTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_pdus", + Help: "Number of incoming PDUs from remote servers with labels for success", + }, + []string{"status"}, // 'success' or 'total' + ) + EDUCountTotal = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_edus", + Help: "Number of incoming EDUs from remote servers", + }, + ) +) + +type TxnReq struct { + gomatrixserverlib.Transaction + rsAPI api.FederationRoomserverAPI + keyAPI keyapi.FederationKeyAPI + ourServerName gomatrixserverlib.ServerName + keys gomatrixserverlib.JSONVerifier + roomsMu *MutexByRoom + producer *producers.SyncAPIProducer + inboundPresenceEnabled bool +} + +func NewTxnReq( + rsAPI api.FederationRoomserverAPI, + keyAPI keyapi.FederationKeyAPI, + ourServerName gomatrixserverlib.ServerName, + keys gomatrixserverlib.JSONVerifier, + roomsMu *MutexByRoom, + producer *producers.SyncAPIProducer, + inboundPresenceEnabled bool, + pdus []json.RawMessage, + edus []gomatrixserverlib.EDU, + origin gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, + destination gomatrixserverlib.ServerName, +) TxnReq { + t := TxnReq{ + rsAPI: rsAPI, + keyAPI: keyAPI, + ourServerName: ourServerName, + keys: keys, + roomsMu: roomsMu, + producer: producer, + inboundPresenceEnabled: inboundPresenceEnabled, + } + + t.PDUs = pdus + t.EDUs = edus + t.Origin = origin + t.TransactionID = transactionID + t.Destination = destination + + return t +} + +func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if t.producer != nil { + t.processEDUs(ctx) + } + }() + + results := make(map[string]gomatrixserverlib.PDUResult) + roomVersions := make(map[string]gomatrixserverlib.RoomVersion) + getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { + if v, ok := roomVersions[roomID]; ok { + return v + } + verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} + verRes := api.QueryRoomVersionForRoomResponse{} + if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) + return "" + } + roomVersions[roomID] = verRes.RoomVersion + return verRes.RoomVersion + } + + for _, pdu := range t.PDUs { + PDUCountTotal.WithLabelValues("total").Inc() + var header struct { + RoomID string `json:"room_id"` + } + if err := json.Unmarshal(pdu, &header); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") + // We don't know the event ID at this point so we can't return the + // failure in the PDU results + continue + } + roomVersion := getRoomVersion(header.RoomID) + event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + if err != nil { + if _, ok := err.(gomatrixserverlib.BadJSONError); ok { + // Room version 6 states that homeservers should strictly enforce canonical JSON + // on PDUs. + // + // This enforces that the entire transaction is rejected if a single bad PDU is + // sent. It is unclear if this is the correct behaviour or not. + // + // See https://github.com/matrix-org/synapse/issues/7543 + return nil, &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("PDU contains bad JSON"), + } + } + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + continue + } + if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + continue + } + if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: "Forbidden by server ACLs", + } + continue + } + if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + // pass the event to the roomserver which will do auth checks + // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently + // discarded by the caller of this function + if err = api.SendEvents( + ctx, + t.rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + event.Headered(roomVersion), + }, + t.Destination, + t.Origin, + api.DoNotSendToOtherServers, + nil, + true, + ); err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + results[event.EventID()] = gomatrixserverlib.PDUResult{} + PDUCountTotal.WithLabelValues("success").Inc() + } + + wg.Wait() + return &gomatrixserverlib.RespSend{PDUs: results}, nil +} + +// nolint:gocyclo +func (t *TxnReq) processEDUs(ctx context.Context) { + for _, e := range t.EDUs { + EDUCountTotal.Inc() + switch e.Type { + case gomatrixserverlib.MTyping: + // https://matrix.org/docs/spec/server_server/latest#typing-notifications + var typingPayload struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + Typing bool `json:"typing"` + } + if err := json.Unmarshal(e.Content, &typingPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") + } + case gomatrixserverlib.MDirectToDevice: + // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema + var directPayload gomatrixserverlib.ToDeviceMessage + if err := json.Unmarshal(e.Content, &directPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + for userID, byUser := range directPayload.Messages { + for deviceID, message := range byUser { + // TODO: check that the user and the device actually exist here + if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": directPayload.Sender, + "user_id": userID, + "device_id": deviceID, + }).Error("Failed to send send-to-device event to JetStream") + } + } + } + case gomatrixserverlib.MDeviceListUpdate: + if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") + } + case gomatrixserverlib.MReceipt: + // https://matrix.org/docs/spec/server_server/r0.1.4#receipts + payload := map[string]types.FederationReceiptMRead{} + + if err := json.Unmarshal(e.Content, &payload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") + continue + } + + for roomID, receipt := range payload { + for userID, mread := range receipt.User { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") + continue + } + if t.Origin != domain { + util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + continue + } + if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": t.Origin, + "user_id": userID, + "room_id": roomID, + "events": mread.EventIDs, + }).Error("Failed to send receipt event to JetStream") + continue + } + } + } + case types.MSigningKeyUpdate: + if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + logrus.WithError(err).Errorf("Failed to process signing key update") + } + case gomatrixserverlib.MPresence: + if t.inboundPresenceEnabled { + if err := t.processPresence(ctx, e); err != nil { + logrus.WithError(err).Errorf("Failed to process presence update") + } + } + default: + util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") + } + } +} + +// processPresence handles m.receipt events +func (t *TxnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { + payload := types.Presence{} + if err := json.Unmarshal(e.Content, &payload); err != nil { + return err + } + for _, content := range payload.Push { + if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + presence, ok := syncTypes.PresenceFromString(content.Presence) + if !ok { + continue + } + if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { + return err + } + } + return nil +} + +// processReceiptEvent sends receipt events to JetStream +func (t *TxnReq) processReceiptEvent(ctx context.Context, + userID, roomID, receiptType string, + timestamp gomatrixserverlib.Timestamp, + eventIDs []string, +) error { + if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { + return nil + } else if serverName == t.ourServerName { + return nil + } else if serverName != t.Origin { + return nil + } + // store every event + for _, eventID := range eventIDs { + if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { + return fmt.Errorf("unable to set receipt event: %w", err) + } + } + + return nil +} diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go new file mode 100644 index 000000000..dd1bd3502 --- /dev/null +++ b/internal/transactionrequest_test.go @@ -0,0 +1,820 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/producers" + keyAPI "github.com/matrix-org/dendrite/keyserver/api" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "gotest.tools/v3/poll" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testDestination = gomatrixserverlib.ServerName("white.orchard") +) + +var ( + invalidSignatures = json.RawMessage(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localishhost","sender":"@userid:localhost","signatures":{"localhost":{"ed2559:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiaQiWAQ"}},"type":"m.room.member"}`) + testData = []json.RawMessage{ + []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), + // messages + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), + } + testEvent = []byte(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`) + testRoomVersion = gomatrixserverlib.RoomVersionV1 + testEvents = []*gomatrixserverlib.HeaderedEvent{} + testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) +) + +type FakeRsAPI struct { + rsAPI.RoomserverInternalAPI + shouldFailQuery bool + bannedFromRoom bool + shouldEventsFail bool +} + +func (r *FakeRsAPI) QueryRoomVersionForRoom( + ctx context.Context, + req *rsAPI.QueryRoomVersionForRoomRequest, + res *rsAPI.QueryRoomVersionForRoomResponse, +) error { + if r.shouldFailQuery { + return fmt.Errorf("Failure") + } + res.RoomVersion = gomatrixserverlib.RoomVersionV10 + return nil +} + +func (r *FakeRsAPI) QueryServerBannedFromRoom( + ctx context.Context, + req *rsAPI.QueryServerBannedFromRoomRequest, + res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + if r.bannedFromRoom { + res.Banned = true + } else { + res.Banned = false + } + return nil +} + +func (r *FakeRsAPI) InputRoomEvents( + ctx context.Context, + req *rsAPI.InputRoomEventsRequest, + res *rsAPI.InputRoomEventsResponse, +) error { + if r.shouldEventsFail { + return fmt.Errorf("Failure") + } + return nil +} + +func TestEmptyTransactionRequest(t *testing.T) { + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", nil, nil, nil, false, []json.RawMessage{}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDU(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUs(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, append(testData, testEvent), []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestBadPDU(t *testing.T) { + pdu := json.RawMessage("{\"room_id\":\"asdf\"}") + pdu2 := json.RawMessage("\"roomid\":\"asdf\"") + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{pdu, pdu2, testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUQueryFailure(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldFailQuery: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDUBannedFromRoom(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{bannedFromRoom: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUInvalidSignature(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{invalidSignatures}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUSendFail(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldEventsFail: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func createTransactionWithEDU(ctx *process.ProcessContext, edus []gomatrixserverlib.EDU) (TxnReq, nats.JetStreamContext, *config.Dendrite) { + cfg := &config.Dendrite{} + cfg.Defaults(config.DefaultOpts{ + Generate: true, + Monolithic: true, + }) + cfg.Global.JetStream.InMemory = true + natsInstance := &jetstream.NATSInstance{} + js, _ := natsInstance.Prepare(ctx, &cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &cfg.FederationAPI, + UserAPI: nil, + } + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, producer, true, []json.RawMessage{}, edus, "kaer.morhen", "", "ourserver") + return txn, js, cfg +} + +func TestProcessTransactionRequestEDUTyping(t *testing.T) { + var err error + roomID := "!roomid:kaer.morhen" + userID := "@userid:kaer.morhen" + typing := true + edu := gomatrixserverlib.EDU{Type: "m.typing"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": roomID, + "user_id": userID, + "typing": typing, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.typing"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + room := msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, room) + user := msg.Header.Get(jetstream.UserID) + assert.Equal(t, userID, user) + typ, parseErr := strconv.ParseBool(msg.Header.Get("typing")) + if parseErr != nil { + return true + } + assert.Equal(t, typing, typ) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + cfg.Global.JetStream.Durable("TestTypingConsumer"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUToDevice(t *testing.T) { + var err error + sender := "@userid:kaer.morhen" + messageID := "$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg" + msgType := "m.dendrite.test" + edu := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "sender": sender, + "type": msgType, + "message_id": messageID, + "messages": map[string]interface{}{ + "@alice:example.org": map[string]interface{}{ + "IWHQUZUIAH": map[string]interface{}{ + "algorithm": "m.megolm.v1.aes-sha2", + "room_id": "!Cuyf34gef24t:localhost", + "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ", + "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY...", + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputSendToDeviceEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, sender, output.Sender) + assert.Equal(t, msgType, output.Type) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + cfg.Global.JetStream.Durable("TestToDevice"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUDeviceListUpdate(t *testing.T) { + var err error + deviceID := "QBUAZIFURK" + userID := "@john:example.com" + edu := gomatrixserverlib.EDU{Type: "m.device_list_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "device_display_name": "Mobile", + "device_id": deviceID, + "key": "value", + "keys": map[string]interface{}{ + "algorithms": []string{ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", + }, + "device_id": "JLAFKJWSCS", + "keys": map[string]interface{}{ + "curve25519:JLAFKJWSCS": "3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI", + "ed25519:JLAFKJWSCS": "lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI", + }, + "signatures": map[string]interface{}{ + "@alice:example.com": map[string]interface{}{ + "ed25519:JLAFKJWSCS": "dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA", + }, + }, + "user_id": "@alice:example.com", + }, + "prev_id": []int{ + 5, + }, + "stream_id": 6, + "user_id": userID, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.device_list_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output gomatrixserverlib.DeviceListUpdateEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, userID, output.UserID) + assert.Equal(t, deviceID, output.DeviceID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + cfg.Global.JetStream.Durable("TestDeviceListUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUReceipt(t *testing.T) { + var err error + roomID := "!some_room:example.org" + edu := gomatrixserverlib.EDU{Type: "m.receipt"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:kaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.receipt"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badUser := gomatrixserverlib.EDU{Type: "m.receipt"} + if badUser.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "johnkaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badDomain := gomatrixserverlib.EDU{Type: "m.receipt"} + if badDomain.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:bad.domain": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + edus := []gomatrixserverlib.EDU{badEDU, badUser, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputReceiptEvent + output.RoomID = msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, output.RoomID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + cfg.Global.JetStream.Durable("TestReceipt"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUSigningKeyUpdate(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output keyAPI.CrossSigningKeyUpdate + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + cfg.Global.JetStream.Durable("TestSigningKeyUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUPresence(t *testing.T) { + var err error + userID := "@john:kaer.morhen" + presence := "online" + edu := gomatrixserverlib.EDU{Type: "m.presence"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "push": []map[string]interface{}{{ + "currently_active": true, + "last_active_ago": 5000, + "presence": presence, + "status_msg": "Making cupcakes", + "user_id": userID, + }}, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.presence"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + userIDRes := msg.Header.Get(jetstream.UserID) + presenceRes := msg.Header.Get("presence") + assert.Equal(t, userID, userIDRes) + assert.Equal(t, presence, presenceRes) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + cfg.Global.JetStream.Durable("TestPresence"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUUnhandled(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.unhandled"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, _, _ := createTransactionWithEDU(ctx, []gomatrixserverlib.EDU{edu}) + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func init() { + for _, j := range testData { + e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) + if err != nil { + panic("cannot load test data: " + err.Error()) + } + h := e.Headered(testRoomVersion) + testEvents = append(testEvents, h) + if e.StateKey() != nil { + testStateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = h + } + } +} + +type testRoomserverAPI struct { + rsAPI.RoomserverInternalAPITrace + inputRoomEvents []rsAPI.InputRoomEvent + queryStateAfterEvents func(*rsAPI.QueryStateAfterEventsRequest) rsAPI.QueryStateAfterEventsResponse + queryEventsByID func(req *rsAPI.QueryEventsByIDRequest) rsAPI.QueryEventsByIDResponse + queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse +} + +func (t *testRoomserverAPI) InputRoomEvents( + ctx context.Context, + request *rsAPI.InputRoomEventsRequest, + response *rsAPI.InputRoomEventsResponse, +) error { + t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) + for _, ire := range request.InputRoomEvents { + fmt.Println("InputRoomEvents: ", ire.Event.EventID()) + } + return nil +} + +// Query the latest events and state for a room from the room server. +func (t *testRoomserverAPI) QueryLatestEventsAndState( + ctx context.Context, + request *rsAPI.QueryLatestEventsAndStateRequest, + response *rsAPI.QueryLatestEventsAndStateResponse, +) error { + r := t.queryLatestEventsAndState(request) + response.RoomExists = r.RoomExists + response.RoomVersion = testRoomVersion + response.LatestEvents = r.LatestEvents + response.StateEvents = r.StateEvents + response.Depth = r.Depth + return nil +} + +// Query the state after a list of events in a room from the room server. +func (t *testRoomserverAPI) QueryStateAfterEvents( + ctx context.Context, + request *rsAPI.QueryStateAfterEventsRequest, + response *rsAPI.QueryStateAfterEventsResponse, +) error { + response.RoomVersion = testRoomVersion + res := t.queryStateAfterEvents(request) + response.PrevEventsExist = res.PrevEventsExist + response.RoomExists = res.RoomExists + response.StateEvents = res.StateEvents + return nil +} + +// Query a list of events by event ID. +func (t *testRoomserverAPI) QueryEventsByID( + ctx context.Context, + request *rsAPI.QueryEventsByIDRequest, + response *rsAPI.QueryEventsByIDResponse, +) error { + res := t.queryEventsByID(request) + response.Events = res.Events + return nil +} + +// Query if a server is joined to a room +func (t *testRoomserverAPI) QueryServerJoinedToRoom( + ctx context.Context, + request *rsAPI.QueryServerJoinedToRoomRequest, + response *rsAPI.QueryServerJoinedToRoomResponse, +) error { + response.RoomExists = true + response.IsInRoom = true + return nil +} + +// Asks for the room version for a given room. +func (t *testRoomserverAPI) QueryRoomVersionForRoom( + ctx context.Context, + request *rsAPI.QueryRoomVersionForRoomRequest, + response *rsAPI.QueryRoomVersionForRoomResponse, +) error { + response.RoomVersion = testRoomVersion + return nil +} + +func (t *testRoomserverAPI) QueryServerBannedFromRoom( + ctx context.Context, req *rsAPI.QueryServerBannedFromRoomRequest, res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + res.Banned = false + return nil +} + +func mustCreateTransaction(rsAPI rsAPI.FederationRoomserverAPI, pdus []json.RawMessage) *TxnReq { + t := NewTxnReq( + rsAPI, + nil, + "", + &test.NopJSONVerifier{}, + NewMutexByRoom(), + nil, + false, + pdus, + nil, + testOrigin, + gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())), + testDestination) + t.PDUs = pdus + t.Origin = testOrigin + t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + t.Destination = testDestination + return &t +} + +func mustProcessTransaction(t *testing.T, txn *TxnReq, pdusWithErrors []string) { + res, err := txn.ProcessTransaction(context.Background()) + if err != nil { + t.Errorf("txn.processTransaction returned an error: %v", err) + return + } + if len(res.PDUs) != len(txn.PDUs) { + t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) + return + } +NextPDU: + for eventID, result := range res.PDUs { + if result.Error == "" { + continue + } + for _, eventIDWantError := range pdusWithErrors { + if eventID == eventIDWantError { + break NextPDU + } + } + t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) + } +} + +func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { + for _, g := range got { + fmt.Println("GOT ", g.Event.EventID()) + } + if len(got) != len(want) { + t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) + return + } + for i := range got { + if got[i].Event.EventID() != want[i].EventID() { + t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) + } + } +} + +// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on +// to the roomserver. It's the most basic test possible. +func TestBasicTransaction(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} + +// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver +// as it does the auth check. +func TestTransactionFailAuthChecks(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, []string{}) + // expect message to be sent to the roomserver + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 9dcfa955f..50af2f884 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -108,13 +108,16 @@ func makeDownloadAPI( activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) http.HandlerFunc { - counterVec := promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: name, - Help: "Total number of media_api requests for either thumbnails or full downloads", - }, - []string{"code"}, - ) + var counterVec *prometheus.CounterVec + if cfg.Matrix.Metrics.Enabled { + counterVec = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: name, + Help: "Total number of media_api requests for either thumbnails or full downloads", + }, + []string{"code"}, + ) + } httpHandler := func(w http.ResponseWriter, req *http.Request) { req = util.RequestWithLogging(req) @@ -166,5 +169,12 @@ func makeDownloadAPI( vars["downloadName"], ) } - return promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + + var handlerFunc http.HandlerFunc + if counterVec != nil { + handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + } else { + handlerFunc = http.HandlerFunc(httpHandler) + } + return handlerFunc } diff --git a/relayapi/api/api.go b/relayapi/api/api.go new file mode 100644 index 000000000..9db393225 --- /dev/null +++ b/relayapi/api/api.go @@ -0,0 +1,56 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayInternalAPI is used to query information from the relay server. +type RelayInternalAPI interface { + RelayServerAPI + + // Retrieve from external relay server all transactions stored for us and process them. + PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, + ) error +} + +// RelayServerAPI exposes the store & query transaction functionality of a relay server. +type RelayServerAPI interface { + // Store transactions for forwarding to the destination at a later time. + PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, + ) error + + // Obtain the oldest stored transaction for the specified userID. + QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, + ) (QueryRelayTransactionsResponse, error) +} + +type QueryRelayTransactionsResponse struct { + Transaction gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id"` + EntriesQueued bool `json:"entries_queued"` +} diff --git a/relayapi/internal/api.go b/relayapi/internal/api.go new file mode 100644 index 000000000..3ff8c2add --- /dev/null +++ b/relayapi/internal/api.go @@ -0,0 +1,53 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type RelayInternalAPI struct { + db storage.Database + fedClient fedAPI.FederationClient + rsAPI rsAPI.RoomserverInternalAPI + keyRing *gomatrixserverlib.KeyRing + producer *producers.SyncAPIProducer + presenceEnabledInbound bool + serverName gomatrixserverlib.ServerName +} + +func NewRelayInternalAPI( + db storage.Database, + fedClient fedAPI.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, + presenceEnabledInbound bool, + serverName gomatrixserverlib.ServerName, +) *RelayInternalAPI { + return &RelayInternalAPI{ + db: db, + fedClient: fedClient, + rsAPI: rsAPI, + keyRing: keyRing, + producer: producer, + presenceEnabledInbound: presenceEnabledInbound, + serverName: serverName, + } +} diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go new file mode 100644 index 000000000..594299334 --- /dev/null +++ b/relayapi/internal/perform.go @@ -0,0 +1,141 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// PerformRelayServerSync implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, +) error { + // Providing a default RelayEntry (EntryID = 0) is done to ask the relay if there are any + // transactions available for this node. + prevEntry := gomatrixserverlib.RelayEntry{} + asyncResponse, err := r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Txn) + + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + for asyncResponse.EntriesQueued { + // There are still more entries available for this node from the relay. + logrus.Infof("Retrieving next entry from relay, previous: %v", prevEntry) + asyncResponse, err = r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Txn) + } + + return nil +} + +// PerformStoreTransaction implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, +) error { + logrus.Warnf("Storing transaction for %v", userID) + receipt, err := r.db.StoreTransaction(ctx, transaction) + if err != nil { + logrus.Errorf("db.StoreTransaction: %s", err.Error()) + return err + } + err = r.db.AssociateTransactionWithDestinations( + ctx, + map[gomatrixserverlib.UserID]struct{}{ + userID: {}, + }, + transaction.TransactionID, + receipt) + + return err +} + +// QueryTransactions implements api.RelayInternalAPI +func (r *RelayInternalAPI) QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, +) (api.QueryRelayTransactionsResponse, error) { + logrus.Infof("QueryTransactions for %s", userID.Raw()) + if previousEntry.EntryID > 0 { + logrus.Infof("Cleaning previous entry (%v) from db for %s", + previousEntry.EntryID, + userID.Raw(), + ) + prevReceipt := receipt.NewReceipt(previousEntry.EntryID) + err := r.db.CleanTransactions(ctx, userID, []*receipt.Receipt{&prevReceipt}) + if err != nil { + logrus.Errorf("db.CleanTransactions: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + } + + transaction, receipt, err := r.db.GetTransaction(ctx, userID) + if err != nil { + logrus.Errorf("db.GetTransaction: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + + response := api.QueryRelayTransactionsResponse{} + if transaction != nil && receipt != nil { + logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, userID.Raw()) + response.Transaction = *transaction + response.EntryID = receipt.GetNID() + response.EntriesQueued = true + } else { + logrus.Infof("No more entries in the queue for %s", userID.Raw()) + response.EntryID = 0 + response.EntriesQueued = false + } + + return response, nil +} + +func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) { + logrus.Warn("Processing transaction from relay server") + mu := internal.NewMutexByRoom() + t := internal.NewTxnReq( + r.rsAPI, + nil, + r.serverName, + r.keyRing, + mu, + r.producer, + r.presenceEnabledInbound, + txn.PDUs, + txn.EDUs, + txn.Origin, + txn.TransactionID, + txn.Destination) + + t.ProcessTransaction(context.TODO()) +} diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go new file mode 100644 index 000000000..fb71b7d0e --- /dev/null +++ b/relayapi/internal/perform_test.go @@ -0,0 +1,121 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + fedAPI.FederationClient + shouldFail bool + queryCount uint + queueDepth uint +} + +func (f *testFedClient) P2PGetTransactionFromRelay( + ctx context.Context, + u gomatrixserverlib.UserID, + prev gomatrixserverlib.RelayEntry, + relayServer gomatrixserverlib.ServerName, +) (res gomatrixserverlib.RespGetRelayTransaction, err error) { + f.queryCount++ + if f.shouldFail { + return res, fmt.Errorf("Error") + } + + res = gomatrixserverlib.RespGetRelayTransaction{ + Txn: gomatrixserverlib.Transaction{}, + EntryID: 0, + } + if f.queueDepth > 0 { + res.EntriesQueued = true + } else { + res.EntriesQueued = false + } + f.queueDepth -= 1 + + return +} + +func TestPerformRelayServerSync(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) +} + +func TestPerformRelayServerSyncFedError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{shouldFail: true} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.Error(t, err) +} + +func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{queueDepth: 2} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) + assert.Equal(t, uint(3), fedClient.queryCount) +} diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go new file mode 100644 index 000000000..f9f9d4ff9 --- /dev/null +++ b/relayapi/relayapi.go @@ -0,0 +1,74 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi + +import ( + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. +func AddPublicRoutes( + base *base.BaseDendrite, + keyRing gomatrixserverlib.JSONVerifier, + relayAPI api.RelayInternalAPI, +) { + fedCfg := &base.Cfg.FederationAPI + + relay, ok := relayAPI.(*internal.RelayInternalAPI) + if !ok { + panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " + + "RelayInternalAPI. This is a programming error.") + } + + routing.Setup( + base.PublicFederationAPIMux, + fedCfg, + relay, + keyRing, + ) +} + +func NewRelayInternalAPI( + base *base.BaseDendrite, + fedClient *gomatrixserverlib.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, +) api.RelayInternalAPI { + cfg := &base.Cfg.RelayAPI + + relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) + if err != nil { + logrus.WithError(err).Panic("failed to connect to relay db") + } + + return internal.NewRelayInternalAPI( + relayDB, + fedClient, + rsAPI, + keyRing, + producer, + base.Cfg.Global.Presence.EnableInbound, + base.Cfg.Global.ServerName, + ) +} diff --git a/relayapi/relayapi_test.go b/relayapi/relayapi_test.go new file mode 100644 index 000000000..dfa06811d --- /dev/null +++ b/relayapi/relayapi_test.go @@ -0,0 +1,154 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi_test + +import ( + "crypto/ed25519" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/relayapi" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func TestCreateNewRelayInternalAPI(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + assert.NotNil(t, relayAPI) + }) +} + +func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + if dbType == test.DBTypeSQLite { + base.Cfg.RelayAPI.Database.ConnectionString = "file:" + } else { + base.Cfg.RelayAPI.Database.ConnectionString = "test" + } + defer close() + + assert.Panics(t, func() { + relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + }) + }) +} + +func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + assert.Panics(t, func() { + relayapi.AddPublicRoutes(base, nil, nil) + }) + }) +} + +func createGetRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID) + content := gomatrixserverlib.RelayEntry{EntryID: 0} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +type sendRelayContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} + +func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnID string, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID) + content := sendRelayContent{} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID, "txnID": txnID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +func TestCreateRelayPublicRoutes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + assert.NotNil(t, relayAPI) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + relayapi.AddPublicRoutes(base, keyRing, relayAPI) + + testCases := []struct { + name string + req *http.Request + wantCode int + wantJoinedRooms []string + }{ + { + name: "relay_txn invalid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"), + wantCode: 400, + }, + { + name: "relay_txn valid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), + wantCode: 200, + }, + { + name: "send_relay invalid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"), + wantCode: 400, + }, + { + name: "send_relay valid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), + wantCode: 200, + }, + } + + for _, tc := range testCases { + w := httptest.NewRecorder() + base.PublicFederationAPIMux.ServeHTTP(w, tc.req) + if w.Code != tc.wantCode { + t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) + } + } + }) +} diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go new file mode 100644 index 000000000..1b11b0ecd --- /dev/null +++ b/relayapi/routing/relaytxn.go @@ -0,0 +1,74 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type RelayTransactionResponse struct { + Transaction gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id,omitempty"` + EntriesQueued bool `json:"entries_queued"` +} + +// GetTransactionFromRelay implements /_matrix/federation/v1/relay_txn/{userID} +// This endpoint can be extracted into a separate relay server service. +func GetTransactionFromRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + logrus.Infof("Handling relay_txn for %s", userID.Raw()) + + previousEntry := gomatrixserverlib.RelayEntry{} + if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("invalid json provided"), + } + } + if previousEntry.EntryID < 0 { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("Invalid entry id provided. Must be >= 0."), + } + } + logrus.Infof("Previous entry provided: %v", previousEntry.EntryID) + + response, err := relayAPI.QueryTransactions(httpReq.Context(), userID, previousEntry) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: RelayTransactionResponse{ + Transaction: response.Transaction, + EntryID: response.EntryID, + EntriesQueued: response.EntriesQueued, + }, + } +} diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go new file mode 100644 index 000000000..a47fdb198 --- /dev/null +++ b/relayapi/routing/relaytxn_test.go @@ -0,0 +1,220 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func createQuery( + userID gomatrixserverlib.UserID, + prevEntry gomatrixserverlib.RelayEntry, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/relay_txn/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("GET", userID.Domain(), "relay", path) + request.SetContent(prevEntry) + + return request +} + +func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetInvalidPrevEntryFails(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestGetReturnsSavedTransaction(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetReturnsMultipleSavedTransactions(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + transaction2 := createTransaction() + receipt2, err := db.StoreTransaction(context.Background(), transaction2) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction2.TransactionID, + receipt2) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction2, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go new file mode 100644 index 000000000..6df0cdc5f --- /dev/null +++ b/relayapi/routing/routing.go @@ -0,0 +1,123 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "fmt" + "net/http" + "time" + + "github.com/getsentry/sentry-go" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/httputil" + relayInternal "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// Setup registers HTTP handlers with the given ServeMux. +// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly +// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled +// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz] +// +// Due to Setup being used to call many other functions, a gocyclo nolint is +// applied: +// nolint: gocyclo +func Setup( + fedMux *mux.Router, + cfg *config.FederationAPI, + relayAPI *relayInternal.RelayInternalAPI, + keys gomatrixserverlib.JSONVerifier, +) { + v1fedmux := fedMux.PathPrefix("/v1").Subrouter() + + v1fedmux.Handle("/send_relay/{txnID}/{userID}", MakeRelayAPI( + "send_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return SendTransactionToRelay( + httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]), + *userID, + ) + }, + )).Methods(http.MethodPut, http.MethodOptions) + + v1fedmux.Handle("/relay_txn/{userID}", MakeRelayAPI( + "get_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return GetTransactionFromRelay(httpReq, request, relayAPI, *userID) + }, + )).Methods(http.MethodGet, http.MethodOptions) +} + +// MakeRelayAPI makes an http.Handler that checks matrix relay authentication. +func MakeRelayAPI( + metricsName string, serverName gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, + keyRing gomatrixserverlib.JSONVerifier, + f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + req, time.Now(), serverName, isLocalServerName, keyRing, + ) + if fedReq == nil { + return errResp + } + // add the user to Sentry, if enabled + hub := sentry.GetHubFromContext(req.Context()) + if hub != nil { + hub.Scope().SetTag("origin", string(fedReq.Origin())) + hub.Scope().SetTag("uri", fedReq.RequestURI()) + } + defer func() { + if r := recover(); r != nil { + if hub != nil { + hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) + } + // re-panic to return the 500 + panic(r) + } + }() + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + } + + jsonRes := f(req, fedReq, vars) + // do not log 4xx as errors as they are client fails, not server fails + if hub != nil && jsonRes.Code >= 500 { + hub.Scope().SetExtra("response", jsonRes) + hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) + } + return jsonRes + } + return httputil.MakeExternalAPI(metricsName, h) +} diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go new file mode 100644 index 000000000..a7027f293 --- /dev/null +++ b/relayapi/routing/sendrelay.go @@ -0,0 +1,77 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// SendTransactionToRelay implements PUT /_matrix/federation/v1/relay_txn/{txnID}/{userID} +// This endpoint can be extracted into a separate relay server service. +func SendTransactionToRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + txnID gomatrixserverlib.TransactionID, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + var txnEvents struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` + } + + if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { + logrus.Info("The request body could not be decoded into valid JSON." + err.Error()) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()), + } + } + + // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. + // https://matrix.org/docs/spec/server_server/latest#transactions + if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + } + } + + t := gomatrixserverlib.Transaction{} + t.PDUs = txnEvents.PDUs + t.EDUs = txnEvents.EDUs + t.Origin = fedReq.Origin() + t.TransactionID = txnID + t.Destination = userID.Domain() + + util.GetLogger(httpReq.Context()).Warnf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, fedReq.Origin(), len(t.PDUs), len(t.EDUs)) + + err := relayAPI.PerformStoreTransaction(httpReq.Context(), t, userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("could not store the transaction for forwarding"), + } + } + + return util.JSONResponse{Code: 200} +} diff --git a/relayapi/routing/sendrelay_test.go b/relayapi/routing/sendrelay_test.go new file mode 100644 index 000000000..d9ed75002 --- /dev/null +++ b/relayapi/routing/sendrelay_test.go @@ -0,0 +1,209 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func createTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + return txn +} + +func createFederationRequest( + userID gomatrixserverlib.UserID, + txnID gomatrixserverlib.TransactionID, + origin gomatrixserverlib.ServerName, + destination gomatrixserverlib.ServerName, + content interface{}, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path) + request.SetContent(content) + + return request +} + +func TestForwardEmptyReturnsOk(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.Equal(t, 200, response.Code) +} + +func TestForwardBadJSONReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field bool `json:"pdus"` + } + content := BadData{ + Field: false, + } + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyPDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []json.RawMessage `json:"pdus"` + } + content := BadData{ + Field: []json.RawMessage{}, + } + for i := 0; i < 51; i++ { + content.Field = append(content.Field, []byte{}) + } + assert.Greater(t, len(content.Field), 50) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyEDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []gomatrixserverlib.EDU `json:"edus"` + } + content := BadData{ + Field: []gomatrixserverlib.EDU{}, + } + for i := 0; i < 101; i++ { + content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}) + } + assert.Greater(t, len(content.Field), 100) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestUniqueTransactionStoredInDatabase(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay( + httpReq, &request, relayAPI, txn.TransactionID, *userID) + transaction, _, err := db.GetTransaction(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction") + + transactionCount, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction count") + + assert.Equal(t, 200, response.Code) + assert.Equal(t, int64(1), transactionCount) + assert.Equal(t, txn.TransactionID, transaction.TransactionID) +} diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go new file mode 100644 index 000000000..f5f9a06e5 --- /dev/null +++ b/relayapi/storage/interface.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database interface { + // Adds a new transaction to the queue json table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error) + + // Adds a new transaction_id: server_name mapping with associated json table nid to the queue + // entry table for each provided destination. + AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error + + // Removes every server_name: receipt pair provided from the queue entries table. + // Will then remove every entry for each receipt provided from the queue json table. + // If any of the entries don't exist in either table, nothing will happen for that entry and + // an error will not be generated. + CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error + + // Gets the oldest transaction for the provided server_name. + // If no transactions exist, returns nil and no error. + GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) + + // Gets the number of transactions being stored for the provided server_name. + // If the server doesn't exist in the database then 0 is returned with no error. + GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) +} diff --git a/relayapi/storage/postgres/relay_queue_json_table.go b/relayapi/storage/postgres/relay_queue_json_table.go new file mode 100644 index 000000000..74410fc88 --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_json_table.go @@ -0,0 +1,113 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid BIGSERIAL, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + + " RETURNING json_nid" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid = ANY($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + deleteJSONStmt *sql.Stmt + selectJSONStmt *sql.Stmt +} + +func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + {&s.deleteJSONStmt, deleteQueueJSONSQL}, + {&s.selectJSONStmt, selectQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (int64, error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + var lastid int64 + if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil { + return 0, err + } + return lastid, nil +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) + _, err := stmt.ExecContext(ctx, pq.Int64Array(nids)) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, s.selectJSONStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, err + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/postgres/relay_queue_table.go b/relayapi/storage/postgres/relay_queue_table.go new file mode 100644 index 000000000..e97cf8cc0 --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_table.go @@ -0,0 +1,156 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The destination server that we will send the event to. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + deleteQueueEntriesStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt +} + +func NewPostgresRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.deleteQueueEntriesStmt, deleteQueueEntriesSQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go new file mode 100644 index 000000000..1042beba7 --- /dev/null +++ b/relayapi/storage/postgres/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the relayapi +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + return nil, err + } + queue, err := NewPostgresRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewPostgresRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go new file mode 100644 index 000000000..0993707bf --- /dev/null +++ b/relayapi/storage/shared/storage.go @@ -0,0 +1,170 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + IsLocalServerName func(gomatrixserverlib.ServerName) bool + Cache caching.FederationCache + Writer sqlutil.Writer + RelayQueue tables.RelayQueue + RelayQueueJSON tables.RelayQueueJSON +} + +func (d *Database) StoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, +) (*receipt.Receipt, error) { + var err error + jsonTransaction, err := json.Marshal(transaction) + if err != nil { + return nil, fmt.Errorf("failed to marshal: %w", err) + } + + var nid int64 + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(jsonTransaction)) + return err + }) + if err != nil { + return nil, fmt.Errorf("d.insertQueueJSON: %w", err) + } + + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil +} + +func (d *Database) AssociateTransactionWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.UserID]struct{}, + transactionID gomatrixserverlib.TransactionID, + dbReceipt *receipt.Receipt, +) error { + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var lastErr error + for destination := range destinations { + destination := destination + err := d.RelayQueue.InsertQueueEntry( + ctx, + txn, + transactionID, + destination.Domain(), + dbReceipt.GetNID(), + ) + if err != nil { + lastErr = fmt.Errorf("d.insertQueueEntry: %w", err) + } + } + return lastErr + }) + + return err +} + +func (d *Database) CleanTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + receipts []*receipt.Receipt, +) error { + nids := make([]int64, len(receipts)) + for i, dbReceipt := range receipts { + nids[i] = dbReceipt.GetNID() + } + + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + deleteEntryErr := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids) + // TODO : If there are still queue entries for any of these nids for other destinations + // then we shouldn't delete the json entries. + // But this can't happen with the current api design. + // There will only ever be one server entry for each nid since each call to send_relay + // only accepts a single server name and inside there we create a new json entry. + // So for multiple destinations we would call send_relay multiple times and have multiple + // json entries of the same transaction. + // + // TLDR; this works as expected right now but can easily be optimised in the future. + deleteJSONErr := d.RelayQueueJSON.DeleteQueueJSON(ctx, txn, nids) + + if deleteEntryErr != nil { + return fmt.Errorf("d.deleteQueueEntries: %w", deleteEntryErr) + } + if deleteJSONErr != nil { + return fmt.Errorf("d.deleteQueueJSON: %w", deleteJSONErr) + } + return nil + }) + + return err +} + +func (d *Database) GetTransaction( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) { + entriesRequested := 1 + nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), entriesRequested) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err) + } + if len(nids) == 0 { + return nil, nil, nil + } + firstNID := nids[0] + + txns := map[int64][]byte{} + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids) + return err + }) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err) + } + + transaction := &gomatrixserverlib.Transaction{} + if _, ok := txns[firstNID]; !ok { + return nil, nil, fmt.Errorf("Failed retrieving json blob for transaction: %d", firstNID) + } + + err = json.Unmarshal(txns[firstNID], transaction) + if err != nil { + return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err) + } + + newReceipt := receipt.NewReceipt(firstNID) + return transaction, &newReceipt, nil +} + +func (d *Database) GetTransactionCount( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (int64, error) { + count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain()) + if err != nil { + return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err) + } + return count, nil +} diff --git a/relayapi/storage/sqlite3/relay_queue_json_table.go b/relayapi/storage/sqlite3/relay_queue_json_table.go new file mode 100644 index 000000000..502da3b00 --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_json_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid IN ($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic + //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (lastid int64, err error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("res.LastInsertId: %w", err) + } + return +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(nids)) + for k, v := range nids { + iNIDs[k] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(jsonNIDs)) + for k, v := range jsonNIDs { + iNIDs[k] = v + } + + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, selectStmt) + rows, err := stmt.QueryContext(ctx, iNIDs...) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/sqlite3/relay_queue_table.go b/relayapi/storage/sqlite3/relay_queue_table.go new file mode 100644 index 000000000..49c6b4de5 --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_table.go @@ -0,0 +1,168 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The domain part of the user ID the m.room.member event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt + // deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go new file mode 100644 index 000000000..3ed4ab046 --- /dev/null +++ b/relayapi/storage/sqlite3/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the federation sender +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + return nil, err + } + queue, err := NewSQLiteRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/storage.go b/relayapi/storage/storage.go new file mode 100644 index 000000000..16ecbcfb7 --- /dev/null +++ b/relayapi/storage/storage.go @@ -0,0 +1,46 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !wasm +// +build !wasm + +package storage + +import ( + "fmt" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (Database, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/relayapi/storage/tables/interface.go b/relayapi/storage/tables/interface.go new file mode 100644 index 000000000..9056a5678 --- /dev/null +++ b/relayapi/storage/tables/interface.go @@ -0,0 +1,66 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayQueue table contains a mapping of server name to transaction id and the corresponding nid. +// These are the transactions being stored for the given destination server. +// The nids correspond to entries in the RelayQueueJSON table. +type RelayQueue interface { + // Adds a new transaction_id: server_name mapping with associated json table nid to the table. + // Will ensure only one transaction id is present for each server_name: nid mapping. + // Adding duplicates will silently do nothing. + InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + + // Removes multiple entries from the table corresponding the the list of nids provided. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + + // Get a list of nids associated with the provided server name. + // Returns up to `limit` nids. The entries are returned oldest first. + // Will return an empty list if no matches were found. + SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + + // Get the number of entries in the table associated with the provided server name. + // If there are no matching rows, a count of 0 is returned with err set to nil. + SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +// RelayQueueJSON table contains a map of nid to the raw transaction json. +type RelayQueueJSON interface { + // Adds a new transaction to the table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + + // Removes multiple nids from the table. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + + // Get the transaction json corresponding to the provided nids. + // Will return a partial result containing any matching nid from the table. + // Will return an empty map if no matches were found. + // It is the caller's responsibility to deal with the results appropriately. + // return: map indexed by nid of each matching transaction json. + SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} diff --git a/relayapi/storage/tables/relay_queue_json_table_test.go b/relayapi/storage/tables/relay_queue_json_table_test.go new file mode 100644 index 000000000..efa3363e5 --- /dev/null +++ b/relayapi/storage/tables/relay_queue_json_table_test.go @@ -0,0 +1,173 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func mustCreateTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + + return txn +} + +type RelayQueueJSONDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueueJSON +} + +func mustCreateQueueJSONTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueJSONDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueueJSON + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueJSONTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueJSONDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + _, err = db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + var storedJSON map[int64][]byte + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 1, len(storedJSON)) + + var storedTx gomatrixserverlib.Transaction + json.Unmarshal(storedJSON[1], &storedTx) + + assert.Equal(t, transaction, storedTx) + }) +} + +func TestShouldDeleteTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + storedJSON := map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + storedJSON = map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 0, len(storedJSON)) + }) +} diff --git a/relayapi/storage/tables/relay_queue_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go new file mode 100644 index 000000000..99f9922c0 --- /dev/null +++ b/relayapi/storage/tables/relay_queue_table_test.go @@ -0,0 +1,229 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type RelayQueueDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueue +} + +func mustCreateQueueTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueue + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, nid, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + }) +} + +func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(2) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName = gomatrixserverlib.ServerName("domain") + oldestNID := int64(1) + err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 1) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + + retrievedNids, err = db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, nid, retrievedNids[1]) + assert.Equal(t, 2, len(retrievedNids)) + }) +} + +func TestShouldDeleteQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(0), count) + }) +} + +func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano())) + serverName2 := gomatrixserverlib.ServerName("domain2") + nid2 := int64(2) + transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano())) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + + count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + }) +} diff --git a/setup/base/base.go b/setup/base/base.go index ff38209fb..de8f81517 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -595,6 +595,12 @@ func (b *BaseDendrite) WaitForShutdown() { logrus.Warnf("failed to flush all Sentry events!") } } + if b.Fulltext != nil { + err := b.Fulltext.Close() + if err != nil { + logrus.Warnf("failed to close full text search!") + } + } logrus.Warnf("Dendrite is exiting now") } diff --git a/setup/config/config.go b/setup/config/config.go index 41d2b6674..2b38cd512 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -62,6 +62,7 @@ type Dendrite struct { RoomServer RoomServer `yaml:"room_server"` SyncAPI SyncAPI `yaml:"sync_api"` UserAPI UserAPI `yaml:"user_api"` + RelayAPI RelayAPI `yaml:"relay_api"` MSCs MSCs `yaml:"mscs"` @@ -349,6 +350,7 @@ func (c *Dendrite) Defaults(opts DefaultOpts) { c.SyncAPI.Defaults(opts) c.UserAPI.Defaults(opts) c.AppServiceAPI.Defaults(opts) + c.RelayAPI.Defaults(opts) c.MSCs.Defaults(opts) c.Wiring() } @@ -361,7 +363,7 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { &c.Global, &c.ClientAPI, &c.FederationAPI, &c.KeyServer, &c.MediaAPI, &c.RoomServer, &c.SyncAPI, &c.UserAPI, - &c.AppServiceAPI, &c.MSCs, + &c.AppServiceAPI, &c.RelayAPI, &c.MSCs, } { c.Verify(configErrs, isMonolith) } @@ -377,6 +379,7 @@ func (c *Dendrite) Wiring() { c.SyncAPI.Matrix = &c.Global c.UserAPI.Matrix = &c.Global c.AppServiceAPI.Matrix = &c.Global + c.RelayAPI.Matrix = &c.Global c.MSCs.Matrix = &c.Global c.ClientAPI.Derived = &c.Derived diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 0f853865f..6c198018d 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -18,6 +18,12 @@ type FederationAPI struct { // The default value is 16 if not specified, which is circa 18 hours. FederationMaxRetries uint32 `yaml:"send_max_retries"` + // P2P Feature: How many consecutive failures that we should tolerate when + // sending federation requests to a specific server until we should assume they + // are offline. If we assume they are offline then we will attempt to send + // messages to their relay server if we know of one that is appropriate. + P2PFederationRetriesUntilAssumedOffline uint32 `yaml:"p2p_retries_until_assumed_offline"` + // FederationDisableTLSValidation disables the validation of X.509 TLS certs // on remote federation endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` @@ -43,6 +49,7 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) { c.Database.Defaults(10) } c.FederationMaxRetries = 16 + c.P2PFederationRetriesUntilAssumedOffline = 2 c.DisableTLSValidation = false c.DisableHTTPKeepalives = false if opts.Generate { diff --git a/setup/config/config_relayapi.go b/setup/config/config_relayapi.go new file mode 100644 index 000000000..5a6b093d4 --- /dev/null +++ b/setup/config/config_relayapi.go @@ -0,0 +1,52 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +type RelayAPI struct { + Matrix *Global `yaml:"-"` + + InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` + ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` + + // The database stores information used by the relay queue to + // forward transactions to remote servers. + Database DatabaseOptions `yaml:"database,omitempty"` +} + +func (c *RelayAPI) Defaults(opts DefaultOpts) { + if !opts.Monolithic { + c.InternalAPI.Listen = "http://localhost:7775" + c.InternalAPI.Connect = "http://localhost:7775" + c.ExternalAPI.Listen = "http://[::]:8075" + c.Database.Defaults(10) + } + if opts.Generate { + if !opts.Monolithic { + c.Database.ConnectionString = "file:relayapi.db" + } + } +} + +func (c *RelayAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { + if isMonolith { // polylith required configs below + return + } + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString)) + } + checkURL(configErrs, "relay_api.external_api.listen", string(c.ExternalAPI.Listen)) + checkURL(configErrs, "relay_api.internal_api.listen", string(c.InternalAPI.Listen)) + checkURL(configErrs, "relay_api.internal_api.connect", string(c.InternalAPI.Connect)) +} diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 3408bf46d..ffbf4c3c5 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -20,11 +20,12 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) func TestLoadConfigRelative(t *testing.T) { - _, err := loadConfig("/my/config/dir", []byte(testConfig), + cfg, err := loadConfig("/my/config/dir", []byte(testConfig), mockReadFile{ "/my/config/dir/matrix_key.pem": testKey, "/my/config/dir/tls_cert.pem": testCert, @@ -34,6 +35,15 @@ func TestLoadConfigRelative(t *testing.T) { if err != nil { t.Error("failed to load config:", err) } + + configErrors := &ConfigErrors{} + cfg.Verify(configErrors, false) + if len(*configErrors) > 0 { + for _, err := range *configErrors { + logrus.Errorf("Configuration error: %s", err) + } + t.Error("configuration verification failed") + } } const testConfig = ` @@ -68,6 +78,8 @@ global: display_name: "Server alerts" avatar: "" room_name: "Server Alerts" + jetstream: + addresses: ["test"] app_service_api: internal_api: listen: http://localhost:7777 @@ -84,7 +96,7 @@ client_api: connect: http://localhost:7771 external_api: listen: http://[::]:8071 - registration_disabled: false + registration_disabled: true registration_shared_secret: "" enable_registration_captcha: false recaptcha_public_key: "" @@ -112,6 +124,8 @@ federation_api: connect: http://localhost:7772 external_api: listen: http://[::]:8072 + database: + connection_string: file:federationapi.db key_server: internal_api: listen: http://localhost:7779 @@ -194,6 +208,17 @@ user_api: max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 +relay_api: + internal_api: + listen: http://localhost:7775 + connect: http://localhost:7775 + external_api: + listen: http://[::]:8075 + database: + connection_string: file:relayapi.db +mscs: + database: + connection_string: file:mscs.db tracing: enabled: false jaeger: diff --git a/setup/monolith.go b/setup/monolith.go index 41a897024..5bbe4019e 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -23,6 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal/transactions" keyAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/mediaapi" + "github.com/matrix-org/dendrite/relayapi" + relayAPI "github.com/matrix-org/dendrite/relayapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" @@ -44,6 +46,7 @@ type Monolith struct { RoomserverAPI roomserverAPI.RoomserverInternalAPI UserAPI userapi.UserInternalAPI KeyAPI keyAPI.KeyInternalAPI + RelayAPI relayAPI.RelayInternalAPI // Optional ExtPublicRoomsProvider api.ExtraPublicRoomsProvider @@ -71,4 +74,8 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { syncapi.AddPublicRoutes( base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, ) + + if m.RelayAPI != nil { + relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI) + } } diff --git a/test/db.go b/test/db.go index 54ded6adb..d2f405d49 100644 --- a/test/db.go +++ b/test/db.go @@ -101,7 +101,6 @@ func currentUser() string { // Returns the connection string to use and a close function which must be called when the test finishes. // Calling this function twice will return the same database, which will have data from previous tests // unless close() is called. -// TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { // this will be made in the t.TempDir, which is unique per test diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go new file mode 100644 index 000000000..cc9e1e8fd --- /dev/null +++ b/test/memory_federation_db.go @@ -0,0 +1,488 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "encoding/json" + "errors" + "sync" + "time" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/federationapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var nidMutex sync.Mutex +var nid = int64(0) + +type InMemoryFederationDatabase struct { + dbMutex sync.Mutex + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + assumedOffline map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*receipt.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName +} + +func NewInMemoryFederationDatabase() *InMemoryFederationDatabase { + return &InMemoryFederationDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*receipt.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName), + } +} + +func (d *InMemoryFederationDatabase) StoreJSON( + ctx context.Context, + js string, +) (*receipt.Receipt, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingPDUs[&newReceipt] = &event + return &newReceipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingEDUs[&newReceipt] = &edu + return &newReceipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *InMemoryFederationDatabase) GetPendingPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + pduCount := 0 + pdus = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingPDUs[dbReceipt]; ok { + pdus[dbReceipt] = event + pduCount++ + if pduCount == limit { + break + } + } + } + } + return pdus, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + eduCount := 0 + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingEDUs[dbReceipt]; ok { + edus[dbReceipt] = event + eduCount++ + if eduCount == limit { + break + } + } + } + } + return edus, nil +} + +func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingPDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedPDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, + eduType string, + expireEDUTypes map[string]time.Duration, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingEDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedEDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) CleanPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(pdus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) CleanEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(edus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersFromBlacklist() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +func (d *InMemoryFederationDatabase) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.assumedOffline, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine( + ctx context.Context, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + assumedOffline := false + if _, ok := d.assumedOffline[serverName]; ok { + assumedOffline = true + } + + return assumedOffline, nil +} + +func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + knownRelayServers := []gomatrixserverlib.ServerName{} + if relayServers, ok := d.relayServers[serverName]; ok { + knownRelayServers = relayServers + } + + return knownRelayServers, nil +} + +func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + alreadyKnown := false + for _, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + alreadyKnown = true + } + } + if !alreadyKnown { + d.relayServers[serverName] = append(d.relayServers[serverName], relayServer) + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + +func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) FetcherName() string { + return "" +} + +func (d *InMemoryFederationDatabase) StoreKeys(ctx context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error { + return nil +} + +func (d *InMemoryFederationDatabase) UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { + return nil +} + +func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) DeleteExpiredEDUs(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) PurgeRoom(ctx context.Context, roomID string) error { + return nil +} diff --git a/test/memory_relay_db.go b/test/memory_relay_db.go new file mode 100644 index 000000000..db93919df --- /dev/null +++ b/test/memory_relay_db.go @@ -0,0 +1,140 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + + "github.com/matrix-org/gomatrixserverlib" +) + +type InMemoryRelayDatabase struct { + nid int64 + nidMutex sync.Mutex + transactions map[int64]json.RawMessage + associations map[gomatrixserverlib.ServerName][]int64 +} + +func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { + return &InMemoryRelayDatabase{ + nid: 1, + nidMutex: sync.Mutex{}, + transactions: make(map[int64]json.RawMessage), + associations: make(map[gomatrixserverlib.ServerName][]int64), + } +} + +func (d *InMemoryRelayDatabase) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + if _, ok := d.associations[serverName]; !ok { + d.associations[serverName] = []int64{} + } + d.associations[serverName] = append(d.associations[serverName], nid) + return nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + for _, nid := range jsonNIDs { + for index, associatedNID := range d.associations[serverName] { + if associatedNID == nid { + d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) + } + } + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + results := []int64{} + resultCount := limit + if limit > len(d.associations[serverName]) { + resultCount = len(d.associations[serverName]) + } + if resultCount > 0 { + for i := 0; i < resultCount; i++ { + results = append(results, d.associations[serverName][i]) + } + } + + return results, nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return int64(len(d.associations[serverName])), nil +} + +func (d *InMemoryRelayDatabase) InsertQueueJSON( + ctx context.Context, + txn *sql.Tx, + json string, +) (int64, error) { + d.nidMutex.Lock() + defer d.nidMutex.Unlock() + + nid := d.nid + d.transactions[nid] = []byte(json) + d.nid++ + + return nid, nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueJSON( + ctx context.Context, + txn *sql.Tx, + nids []int64, +) error { + for _, nid := range nids { + delete(d.transactions, nid) + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueJSON( + ctx context.Context, + txn *sql.Tx, + jsonNIDs []int64, +) (map[int64][]byte, error) { + result := make(map[int64][]byte) + for _, nid := range jsonNIDs { + if transaction, ok := d.transactions[nid]; ok { + result[nid] = transaction + } + } + + return result, nil +} diff --git a/test/testrig/base.go b/test/testrig/base.go index 9773da223..dfc0d8aaf 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -67,9 +67,10 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f case test.DBTypeSQLite: cfg.Defaults(config.DefaultOpts{ Generate: true, - Monolithic: false, // because we need a database per component + Monolithic: true, }) cfg.Global.ServerName = "test" + // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) @@ -83,6 +84,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f cfg.RoomServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "roomserver.db")) cfg.SyncAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "syncapi.db")) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db")) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "relayapi.db")) base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) return base, func() { From ace44458b25768099f7b86663f2bb45ddf0d39c9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 26 Jan 2023 08:25:39 +0100 Subject: [PATCH 67/67] Bump commonmarker from 0.23.6 to 0.23.7 in /docs (#2952) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [commonmarker](https://github.com/gjtorikian/commonmarker) from 0.23.6 to 0.23.7.
Release notes

Sourced from commonmarker's releases.

v0.23.7

What's Changed

Full Changelog: https://github.com/gjtorikian/commonmarker/compare/v0.23.6...v0.23.7

v0.23.7.pre1

What's Changed

Full Changelog: https://github.com/gjtorikian/commonmarker/compare/v0.23.6...v0.23.7.pre1

Changelog

Sourced from commonmarker's changelog.

Changelog

v1.0.0.pre6 (2023-01-09)

Full Changelog

Closed issues:

  • Cargo.lock prevents Ruby 3.2.0 from installing commonmarker v1.0.0.pre4 #211

Merged pull requests:

  • always use rb_sys (don't use Ruby's emerging cargo tooling where available) #213 (kivikakk)

v1.0.0.pre5 (2023-01-08)

Full Changelog

Merged pull requests:

v1.0.0.pre4 (2022-12-28)

Full Changelog

Closed issues:

  • Will the cmark-gfm branch continue to be maintained for awhile? #207

Merged pull requests:

v1.0.0.pre3 (2022-11-30)

Full Changelog

Closed issues:

  • Code block incorrectly parsed in commonmarker 1.0.0.pre #202

Merged pull requests:

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=commonmarker&package-manager=bundler&previous-version=0.23.6&new-version=0.23.7)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/Gemfile.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 509a8cbcf..5d79365f8 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -14,7 +14,7 @@ GEM execjs coffee-script-source (1.11.1) colorator (1.1.0) - commonmarker (0.23.6) + commonmarker (0.23.7) concurrent-ruby (1.1.10) dnsruby (1.61.9) simpleidn (~> 0.1)