From eeeb3017d662ad6777c1398b325aa98bc36bae94 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 16 Jan 2023 11:52:30 +0000 Subject: [PATCH 01/14] Switch the default config option values for `recaptcha_sitekey_class` and `recaptcha_form_field` (#2939) Attempting to use the [web auth fallback mechanism](https://spec.matrix.org/v1.5/client-server-api/#fallback) for Google ReCAPTCHA with the default setting for `client_api.recaptcha_sitekey_class` of "g-recaptcha-response" results in no captcha being rendered: ![image](https://user-images.githubusercontent.com/1342360/212482321-14980045-6e20-4d59-adaa-59a01ad88367.png) I cross-checked the captcha code between [dendrite.matrix.org's fallback page](https://dendrite.matrix.org/_matrix/client/r0/auth/m.login.recaptcha/fallback/web?session=asdhjaksd) and [matrix-client.matrix.org's one](https://matrix-client.matrix.org/_matrix/client/r0/auth/m.login.recaptcha/fallback/web?session=asdhjaksd) (which both use the same captcha public key) and noticed a discrepancy in the `class` attribute of the div that renders the captcha. [ReCAPTCHA's docs state](https://developers.google.com/recaptcha/docs/v3#automatically_bind_the_challenge_to_a_button) to use "g-recaptcha" as the class for the submit button. I noticed this when user `@parappanon:parappa.party` reported that they were also seeing no captcha being rendered on their Dendrite instance. Changing `client_api.recaptcha_sitekey_class` to "g-recaptcha" caused their captcha to render properly as well. There may have been a change in the class name from ReCAPTCHA v2 to v3? The [docs for v2](https://developers.google.com/recaptcha/docs/display#auto_render) also request one uses "g-recaptcha" though. Thus I propose changing the default setting to unbreak people's recaptcha auth fallback pages. Should fix dendrite.matrix.org as well. --- setup/config/config_clientapi.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index 11628b1b0..1deba6bb5 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -85,10 +85,10 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { c.RecaptchaApiJsUrl = "https://www.google.com/recaptcha/api.js" } if c.RecaptchaFormField == "" { - c.RecaptchaFormField = "g-recaptcha" + c.RecaptchaFormField = "g-recaptcha-response" } if c.RecaptchaSitekeyClass == "" { - c.RecaptchaSitekeyClass = "g-recaptcha-response" + c.RecaptchaSitekeyClass = "g-recaptcha" } checkNotEmpty(configErrs, "client_api.recaptcha_public_key", c.RecaptchaPublicKey) checkNotEmpty(configErrs, "client_api.recaptcha_private_key", c.RecaptchaPrivateKey) From 8582c7520abbfca680da9ba16e40a9a92b9fd21c Mon Sep 17 00:00:00 2001 From: Umar Getagazov Date: Tue, 17 Jan 2023 11:07:42 +0300 Subject: [PATCH 02/14] Omit state field from `/messages` response if empty (#2940) The field type is `[ClientEvent]` in the [spec](https://spec.matrix.org/v1.5/client-server-api/#get_matrixclientv3roomsroomidmessages), but right now `null` can also be returned. Omit the field completely if it's empty. Some clients (rightfully) assume it's either not present at all or it's of the right type (see https://github.com/matrix-org/matrix-react-sdk/pull/9913). ### Pull Request Checklist * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * The PR is a simple struct tag fix * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Umar Getagazov ` Signed-off-by: Umar Getagazov --- syncapi/routing/messages.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index 0d740ebfc..cafba17c9 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -57,7 +57,7 @@ type messagesResp struct { StartStream string `json:"start_stream,omitempty"` // NOTSPEC: used by Cerulean, so clients can hit /messages then immediately /sync with a latest sync token End string `json:"end,omitempty"` Chunk []gomatrixserverlib.ClientEvent `json:"chunk"` - State []gomatrixserverlib.ClientEvent `json:"state"` + State []gomatrixserverlib.ClientEvent `json:"state,omitempty"` } // OnIncomingMessagesRequest implements the /messages endpoint from the From 0d0280cf5ff71ec975b17d0f6dadcae7e46574b5 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 17 Jan 2023 10:08:23 +0100 Subject: [PATCH 03/14] `/sync` performance optimizations (#2927) Since #2849 there is no limit for the current state we fetch to calculate history visibility. In large rooms this can cause us to fetch thousands of membership events we don't really care about. This now only gets the state event types and senders in our timeline, which should significantly reduce the amount of events we fetch from the database. Also removes `MaxTopologicalPosition`, as it is an unnecessary DB call, given we use the result in `topological_position < $1` calls. --- federationapi/consumers/roomserver.go | 2 +- syncapi/routing/memberships.go | 19 +- syncapi/routing/messages.go | 31 +-- syncapi/storage/interface.go | 2 - .../output_room_events_topology_table.go | 19 -- syncapi/storage/shared/storage_sync.go | 17 +- .../output_room_events_topology_table.go | 16 -- syncapi/storage/storage_test.go | 88 ++++++- syncapi/storage/tables/interface.go | 2 - syncapi/streams/stream_pdu.go | 35 ++- syncapi/syncapi_test.go | 246 ++++++++++++++++++ 11 files changed, 372 insertions(+), 105 deletions(-) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 0c1080afa..52b5744a6 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -195,7 +195,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew } // If we added new hosts, inform them about our known presence events for this room - if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { + if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil { membership, _ := ore.Event.Membership() if membership == gomatrixserverlib.Join { s.sendPresence(ore.Event.RoomID(), addsJoinedHosts) diff --git a/syncapi/routing/memberships.go b/syncapi/routing/memberships.go index 3fcc3235c..9ffdf513f 100644 --- a/syncapi/routing/memberships.go +++ b/syncapi/routing/memberships.go @@ -16,16 +16,16 @@ package routing import ( "encoding/json" + "math" "net/http" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/storage" "github.com/matrix-org/dendrite/syncapi/types" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" ) type getMembershipResponse struct { @@ -87,19 +87,18 @@ func GetMemberships( if err != nil { return jsonerror.InternalServerError() } + defer db.Rollback() // nolint: errcheck atToken, err := types.NewTopologyTokenFromString(at) if err != nil { + atToken = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} if queryRes.HasBeenInRoom && !queryRes.IsInRoom { // If you have left the room then this will be the members of the room when you left. atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID) - } else { - // If you are joined to the room then this will be the current members of the room. - atToken, err = db.MaxTopologicalPosition(req.Context(), roomID) - } - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") - return jsonerror.InternalServerError() + if err != nil { + util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'") + return jsonerror.InternalServerError() + } } } diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index cafba17c9..4a01ec357 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -17,6 +17,7 @@ package routing import ( "context" "fmt" + "math" "net/http" "sort" "time" @@ -177,10 +178,11 @@ func OnIncomingMessagesRequest( // If "to" isn't provided, it defaults to either the earliest stream // position (if we're going backward) or to the latest one (if we're // going forward). - to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed") - return jsonerror.InternalServerError() + to = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} + if backwardOrdering { + // go 1 earlier than the first event so we correctly fetch the earliest event + // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. + to = types.TopologyToken{} } wasToProvided = false } @@ -577,24 +579,3 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] return events, nil } - -// setToDefault returns the default value for the "to" query parameter of a -// request to /messages if not provided. It defaults to either the earliest -// topological position (if we're going backward) or to the latest one (if we're -// going forward). -// Returns an error if there was an issue with retrieving the latest position -// from the database -func setToDefault( - ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool, - roomID string, -) (to types.TopologyToken, err error) { - if backwardOrdering { - // go 1 earlier than the first event so we correctly fetch the earliest event - // this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound. - to = types.TopologyToken{} - } else { - to, err = snapshot.MaxTopologicalPosition(ctx, roomID) - } - - return -} diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 4e22f8a6f..a4ba82327 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -84,8 +84,6 @@ type DatabaseTransaction interface { EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error) - // MaxTopologicalPosition returns the highest topological position for a given room. - MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error) // StreamEventsToEvents converts streamEvent to Event. If device is non-nil and // matches the streamevent.transactionID device then the transaction ID gets // added to the unsigned section of the output event. diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 6fab900eb..d0e99f267 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -65,14 +65,6 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" - // Select the max topological position for the room, then sort by stream position and take the highest, - // returning both topological and stream positions. -const selectMaxPositionInTopologySQL = "" + - "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + - " WHERE topological_position=(" + - "SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" + - ") ORDER BY stream_position DESC LIMIT 1" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" @@ -84,7 +76,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -107,9 +98,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { return nil, err } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { return nil, err } @@ -189,10 +177,3 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } - -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( - ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return -} diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index c6933486c..7b07cac5e 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "math" "github.com/matrix-org/gomatrixserverlib" "github.com/tidwall/gjson" @@ -269,16 +270,6 @@ func (d *DatabaseTransaction) BackwardExtremitiesForRoom( return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID) } -func (d *DatabaseTransaction) MaxTopologicalPosition( - ctx context.Context, roomID string, -) (types.TopologyToken, error) { - depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, err - } - return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil -} - func (d *DatabaseTransaction) EventPositionInTopology( ctx context.Context, eventID string, ) (types.TopologyToken, error) { @@ -297,11 +288,7 @@ func (d *DatabaseTransaction) StreamToTopologicalPosition( case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward return types.TopologyToken{PDUPosition: streamPos}, nil case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward - topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID) - if err != nil { - return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err) - } - return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil + return types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, nil case err != nil: // some other error happened return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err) default: diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 81b264988..879456441 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -61,10 +61,6 @@ const selectPositionInTopologySQL = "" + "SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE event_id = $1" -const selectMaxPositionInTopologySQL = "" + - "SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" + - " WHERE room_id = $1 ORDER BY stream_position DESC" - const selectStreamToTopologicalPositionAscSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;" @@ -77,7 +73,6 @@ type outputRoomEventsTopologyStatements struct { selectEventIDsInRangeASCStmt *sql.Stmt selectEventIDsInRangeDESCStmt *sql.Stmt selectPositionInTopologyStmt *sql.Stmt - selectMaxPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt } @@ -102,9 +97,6 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { return nil, err } - if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil { - return nil, err - } if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { return nil, err } @@ -182,11 +174,3 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } - -func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology( - ctx context.Context, txn *sql.Tx, roomID string, -) (pos types.StreamPosition, spos types.StreamPosition, err error) { - stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt) - err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos) - return -} diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index 166ddd233..e65367d8b 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "math" "reflect" "testing" @@ -199,10 +200,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { _ = MustWriteEvents(t, db, events) WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { - from, err := snapshot.MaxTopologicalPosition(ctx, r.ID) - if err != nil { - t.Fatalf("failed to get MaxTopologicalPosition: %s", err) - } + from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} t.Logf("max topo pos = %+v", from) // head towards the beginning of time to := types.TopologyToken{} @@ -219,6 +217,88 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { }) } +func TestStreamToTopologicalPosition(t *testing.T) { + alice := test.NewUser(t) + r := test.NewRoom(t, alice) + + testCases := []struct { + name string + roomID string + streamPos types.StreamPosition + backwardOrdering bool + wantToken types.TopologyToken + }{ + { + name: "forward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "forward ordering not found streamPos returns max position", + roomID: r.ID, + streamPos: 100, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + { + name: "backward ordering found streamPos returns found position", + roomID: r.ID, + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1}, + }, + { + name: "backward ordering not found streamPos returns maxDepth with param pduPosition", + roomID: r.ID, + streamPos: 100, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 5, PDUPosition: 100}, + }, + { + name: "backward non-existent room returns zero token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: true, + wantToken: types.TopologyToken{Depth: 0, PDUPosition: 1}, + }, + { + name: "forward non-existent room returns max token", + roomID: "!doesnotexist:localhost", + streamPos: 1, + backwardOrdering: false, + wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close, closeBase := MustCreateDatabase(t, dbType) + defer close() + defer closeBase() + + txn, err := db.NewDatabaseTransaction(ctx) + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + MustWriteEvents(t, db, r.Events()) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + token, err := txn.StreamToTopologicalPosition(ctx, tc.roomID, tc.streamPos, tc.backwardOrdering) + if err != nil { + t.Fatal(err) + } + if tc.wantToken != token { + t.Fatalf("expected token %q, got %q", tc.wantToken, token) + } + }) + } + + }) +} + /* // The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth. // For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index c02e4ecc5..8366a67dc 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -91,8 +91,6 @@ type Topology interface { SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error) // SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to. SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) - // SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position. - SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) } diff --git a/syncapi/streams/stream_pdu.go b/syncapi/streams/stream_pdu.go index 4664276cf..44013e37c 100644 --- a/syncapi/streams/stream_pdu.go +++ b/syncapi/streams/stream_pdu.go @@ -384,19 +384,32 @@ func applyHistoryVisibilityFilter( roomID, userID string, recentEvents []*gomatrixserverlib.HeaderedEvent, ) ([]*gomatrixserverlib.HeaderedEvent, error) { - // We need to make sure we always include the latest states events, if they are in the timeline. - // We grep at least limit * 2 events, to ensure we really get the needed events. - filter := gomatrixserverlib.DefaultStateFilter() - stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) - if err != nil { - // Not a fatal error, we can continue without the stateEvents, - // they are only needed if there are state events in the timeline. - logrus.WithError(err).Warnf("Failed to get current room state for history visibility") + // We need to make sure we always include the latest state events, if they are in the timeline. + alwaysIncludeIDs := make(map[string]struct{}) + var stateTypes []string + var senders []string + for _, ev := range recentEvents { + if ev.StateKey() != nil { + stateTypes = append(stateTypes, ev.Type()) + senders = append(senders, ev.Sender()) + } } - alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents)) - for _, ev := range stateEvents { - alwaysIncludeIDs[ev.EventID()] = struct{}{} + + // Only get the state again if there are state events in the timeline + if len(stateTypes) > 0 { + filter := gomatrixserverlib.DefaultStateFilter() + filter.Types = &stateTypes + filter.Senders = &senders + stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil) + if err != nil { + return nil, fmt.Errorf("failed to get current room state for history visibility calculation: %w", err) + } + + for _, ev := range stateEvents { + alwaysIncludeIDs[ev.EventID()] = struct{}{} + } } + startTime := time.Now() events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync") if err != nil { diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 483274481..666a872f8 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -521,6 +521,252 @@ func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatr } } +func TestGetMembership(t *testing.T) { + alice := test.NewUser(t) + + aliceDev := userapi.Device{ + ID: "ALICEID", + UserID: alice.ID, + AccessToken: "ALICE_BEARER_TOKEN", + DisplayName: "Alice", + AccountType: userapi.AccountTypeUser, + } + + bob := test.NewUser(t) + bobDev := userapi.Device{ + ID: "BOBID", + UserID: bob.ID, + AccessToken: "notjoinedtoanyrooms", + } + + testCases := []struct { + name string + roomID string + additionalEvents func(t *testing.T, room *test.Room) + request func(t *testing.T, room *test.Room) *http.Request + wantOK bool + wantMemberCount int + useSleep bool // :/ + }{ + { + name: "/members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "/members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Bob never joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": bobDev.AccessToken, + })) + }, + wantOK: false, + }, + { + name: "/joined_members - Alice joined", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: true, + }, + { + name: "Alice leaves before Bob joins, should not be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "Alice leaves after Bob joins, should be able to see Bob", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 2, + }, + { + name: "/joined_members - Alice leaves, shouldn't be able to see members ", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(alice.ID)) + }, + useSleep: true, + wantOK: false, + }, + { + name: "'at' specified, returns memberships before Bob joins", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "at": "t2_5", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + }, + useSleep: true, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "'membership=leave' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "membership": "leave", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=join' specified, returns no memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "join", + })) + }, + wantOK: true, + wantMemberCount: 0, + }, + { + name: "'not_membership=leave' & 'membership=join' specified, returns correct memberships", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + "not_membership": "leave", + "membership": "join", + })) + }, + additionalEvents: func(t *testing.T, room *test.Room) { + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "join", + }, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "leave", + }, test.WithStateKey(bob.ID)) + }, + wantOK: true, + wantMemberCount: 1, + }, + { + name: "non-existent room ID", + request: func(t *testing.T, room *test.Room) *http.Request { + return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", "!notavalidroom:test"), test.WithQueryParams(map[string]string{ + "access_token": aliceDev.AccessToken, + })) + }, + wantOK: false, + }, + } + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream) + + // Use an actual roomserver for this + rsAPI := roomserver.NewInternalAPI(base) + rsAPI.SetFederationAPI(nil, nil) + + AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{}) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + room := test.NewRoom(t, alice) + t.Cleanup(func() { + t.Logf("running cleanup for %s", tc.name) + }) + // inject additional events + if tc.additionalEvents != nil { + tc.additionalEvents(t, room) + } + if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // wait for the events to come down sync + if tc.useSleep { + time.Sleep(time.Millisecond * 100) + } else { + syncUntil(t, base, aliceDev.AccessToken, false, func(syncBody string) bool { + // wait for the last sent eventID to come down sync + path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID()) + return gjson.Get(syncBody, path).Exists() + }) + } + + w := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(w, tc.request(t, room)) + if w.Code != 200 && tc.wantOK { + t.Logf("%s", w.Body.String()) + t.Fatalf("got HTTP %d want %d", w.Code, 200) + } + t.Logf("[%s] Resp: %s", tc.name, w.Body.String()) + + // check we got the expected events + if tc.wantOK { + memberCount := len(gjson.GetBytes(w.Body.Bytes(), "chunk").Array()) + if memberCount != tc.wantMemberCount { + t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount) + } + } + }) + } + }) +} + func TestSendToDevice(t *testing.T) { test.WithAllDatabases(t, testSendToDevice) } From b55a7c238fb4b4db9ff4da0a25f0f83316d20f5e Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Tue, 17 Jan 2023 19:04:02 +0100 Subject: [PATCH 04/14] Version 0.10.9 (#2942) --- CHANGES.md | 20 ++++++++++++++++++++ internal/version.go | 2 +- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/CHANGES.md b/CHANGES.md index f5a82cfe2..fa8230659 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,25 @@ # Changelog +## Dendrite 0.10.9 (2023-01-17) + +### Features + +* Stale device lists are now cleaned up on startup, removing entries for users the server doesn't share a room with anymore +* Dendrite now has its own Helm chart +* Guest access is now handled correctly (disallow joins, kick guests on revocation of guest access, as well as over federation) + +### Fixes + +* Push rules have seen several tweaks and fixes, which should, for example, fix notifications for `m.read_receipts` +* Outgoing presence will now correctly be sent to newly joined hosts +* Fixes the `/_dendrite/admin/resetPassword/{userID}` admin endpoint to use the correct variable +* Federated backfilling for medium/large rooms has been fixed +* `/login` causing wrong device list updates has been resolved +* `/sync` should now return the correct room summary heroes +* The default config options for `recaptcha_sitekey_class` and `recaptcha_form_field` are now set correctly +* `/messages` now omits empty `state` to be more spec compliant (contributed by [handlerug](https://github.com/handlerug)) +* `/sync` has been optimised to only query state events for history visibility if they are really needed + ## Dendrite 0.10.8 (2022-11-29) ### Features diff --git a/internal/version.go b/internal/version.go index 685237b9e..ff31dd784 100644 --- a/internal/version.go +++ b/internal/version.go @@ -17,7 +17,7 @@ var build string const ( VersionMajor = 0 VersionMinor = 10 - VersionPatch = 8 + VersionPatch = 9 VersionTag = "" // example: "rc1" ) From 67f5c5bc1e837bbdee14d7d3388984ed8960528a Mon Sep 17 00:00:00 2001 From: genofire Date: Wed, 18 Jan 2023 08:45:34 +0100 Subject: [PATCH 05/14] =?UTF-8?q?fix(helm):=20extract=20image=20tag=20to?= =?UTF-8?q?=20value=20(and=20use=20as=20default=20from=20Chart.=E2=80=A6?= =?UTF-8?q?=20(#2934)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit improve image tag handling on the default helm way. with usage of appVersion from: https://github.com/matrix-org/dendrite/blob/0995dc48224b90432e38fa92345cf5735bca6090/helm/dendrite/Chart.yaml#L4 maybe you like to review @S7evinK ? ### Pull Request Checklist * [x] I have added Go unit tests or [Complement integration tests](https://github.com/matrix-org/complement) for this PR _or_ I have justified why this PR doesn't need tests * [x] Pull request includes a [sign off below using a legally identifiable name](https://matrix-org.github.io/dendrite/development/contributing#sign-off) _or_ I have already signed off privately Signed-off-by: `Geno ` --- helm/dendrite/Chart.yaml | 4 ++-- helm/dendrite/templates/_helpers.tpl | 4 +++- helm/dendrite/templates/deployment.yaml | 4 ++-- helm/dendrite/templates/jobs.yaml | 3 ++- helm/dendrite/templates/service.yaml | 2 +- helm/dendrite/values.yaml | 6 ++++-- 6 files changed, 14 insertions(+), 9 deletions(-) diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml index 15d1e6d19..6e6641c8d 100644 --- a/helm/dendrite/Chart.yaml +++ b/helm/dendrite/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: dendrite -version: "0.10.8" -appVersion: "0.10.8" +version: "0.10.9" +appVersion: "0.10.9" description: Dendrite Matrix Homeserver type: application keywords: diff --git a/helm/dendrite/templates/_helpers.tpl b/helm/dendrite/templates/_helpers.tpl index 291f351bc..026706588 100644 --- a/helm/dendrite/templates/_helpers.tpl +++ b/helm/dendrite/templates/_helpers.tpl @@ -15,9 +15,11 @@ {{- define "image.name" -}} -image: {{ .name }} +{{- with .Values.image -}} +image: {{ .repository }}:{{ .tag | default (printf "v%s" $.Chart.AppVersion) }} imagePullPolicy: {{ .pullPolicy }} {{- end -}} +{{- end -}} {{/* Expand the name of the chart. diff --git a/helm/dendrite/templates/deployment.yaml b/helm/dendrite/templates/deployment.yaml index 629ffe528..b463c7d0b 100644 --- a/helm/dendrite/templates/deployment.yaml +++ b/helm/dendrite/templates/deployment.yaml @@ -45,8 +45,8 @@ spec: persistentVolumeClaim: claimName: {{ default (print ( include "dendrite.fullname" . ) "-search-pvc") $.Values.persistence.search.existingClaim | quote }} containers: - - name: {{ $.Chart.Name }} - {{- include "image.name" $.Values.image | nindent 8 }} + - name: {{ .Chart.Name }} + {{- include "image.name" . | nindent 8 }} args: - '--config' - '/etc/dendrite/dendrite.yaml' diff --git a/helm/dendrite/templates/jobs.yaml b/helm/dendrite/templates/jobs.yaml index 76915694d..c10f358b0 100644 --- a/helm/dendrite/templates/jobs.yaml +++ b/helm/dendrite/templates/jobs.yaml @@ -8,6 +8,7 @@ metadata: name: {{ $name }} labels: app.kubernetes.io/component: signingkey-job + {{- include "dendrite.labels" . | nindent 4 }} --- apiVersion: rbac.authorization.k8s.io/v1 kind: Role @@ -80,7 +81,7 @@ spec: name: signing-key readOnly: true - name: generate-key - {{- include "image.name" $.Values.image | nindent 8 }} + {{- include "image.name" . | nindent 8 }} command: - sh - -c diff --git a/helm/dendrite/templates/service.yaml b/helm/dendrite/templates/service.yaml index 365a43f04..3b571df1f 100644 --- a/helm/dendrite/templates/service.yaml +++ b/helm/dendrite/templates/service.yaml @@ -13,5 +13,5 @@ spec: ports: - name: http protocol: TCP - port: 8008 + port: {{ .Values.service.port }} targetPort: 8008 \ No newline at end of file diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index 2c6e80942..87027a886 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -1,8 +1,10 @@ image: # -- Docker repository/image to use - name: "ghcr.io/matrix-org/dendrite-monolith:v0.10.8" + repository: "ghcr.io/matrix-org/dendrite-monolith" # -- Kubernetes pullPolicy pullPolicy: IfNotPresent + # Overrides the image tag whose default is the chart appVersion. + tag: "" # signing key to use @@ -345,4 +347,4 @@ ingress: service: type: ClusterIP - port: 80 + port: 8008 From 738686ae686004c5efa9fe2096502cdc426c6dd8 Mon Sep 17 00:00:00 2001 From: Neil Date: Thu, 19 Jan 2023 20:02:32 +0000 Subject: [PATCH 06/14] Add `/_dendrite/admin/purgeRoom/{roomID}` (#2662) This adds a new admin endpoint `/_dendrite/admin/purgeRoom/{roomID}`. It completely erases all database entries for a given room ID. The roomserver will start by clearing all data for that room and then will generate an output event to notify downstream components (i.e. the sync API and federation API) to do the same. It does not currently clear media and it is currently not implemented for SQLite since it relies on SQL array operations right now. Co-authored-by: Neil Alexander Co-authored-by: Till Faelligen <2353100+S7evinK@users.noreply.github.com> --- clientapi/admin_test.go | 103 ++++++++++- clientapi/routing/admin.go | 32 ++++ clientapi/routing/routing.go | 6 + federationapi/consumers/roomserver.go | 15 +- federationapi/storage/interface.go | 2 + federationapi/storage/shared/storage.go | 15 ++ internal/sqlutil/sql.go | 21 +++ internal/sqlutil/sqlutil_test.go | 51 +++++- roomserver/api/api.go | 1 + roomserver/api/api_trace.go | 10 ++ roomserver/api/output.go | 8 + roomserver/api/perform.go | 8 + roomserver/internal/perform/perform_admin.go | 37 ++++ roomserver/inthttp/client.go | 12 ++ roomserver/inthttp/server.go | 5 + roomserver/roomserver_test.go | 165 ++++++++++++++++++ roomserver/storage/interface.go | 1 + .../storage/postgres/purge_statements.go | 133 ++++++++++++++ roomserver/storage/postgres/rooms_table.go | 14 ++ roomserver/storage/postgres/storage.go | 5 + roomserver/storage/shared/storage.go | 16 ++ .../storage/sqlite3/purge_statements.go | 153 ++++++++++++++++ roomserver/storage/sqlite3/rooms_table.go | 14 ++ .../storage/sqlite3/state_block_table.go | 3 +- .../storage/sqlite3/state_snapshot_table.go | 33 +++- roomserver/storage/sqlite3/storage.go | 6 + roomserver/storage/tables/interface.go | 7 + syncapi/consumers/roomserver.go | 21 +++ syncapi/storage/interface.go | 2 + .../postgres/backwards_extremities_table.go | 27 +-- syncapi/storage/postgres/invites_table.go | 31 ++-- syncapi/storage/postgres/memberships_table.go | 22 ++- .../postgres/notification_data_table.go | 12 ++ .../postgres/output_room_events_table.go | 12 ++ .../output_room_events_topology_table.go | 42 ++--- syncapi/storage/postgres/peeks_table.go | 39 +++-- syncapi/storage/postgres/receipt_table.go | 27 +-- syncapi/storage/shared/storage_consumer.go | 14 -- syncapi/storage/shared/storage_sync.go | 47 +++++ .../sqlite3/backwards_extremities_table.go | 27 +-- syncapi/storage/sqlite3/invites_table.go | 31 ++-- syncapi/storage/sqlite3/memberships_table.go | 12 ++ .../sqlite3/notification_data_table.go | 12 ++ .../sqlite3/output_room_events_table.go | 12 ++ .../output_room_events_topology_table.go | 42 ++--- syncapi/storage/sqlite3/peeks_table.go | 39 +++-- syncapi/storage/sqlite3/receipt_table.go | 27 +-- syncapi/storage/tables/interface.go | 9 + 48 files changed, 1213 insertions(+), 170 deletions(-) create mode 100644 roomserver/storage/postgres/purge_statements.go create mode 100644 roomserver/storage/sqlite3/purge_statements.go diff --git a/clientapi/admin_test.go b/clientapi/admin_test.go index 0d973f350..c7ca019ff 100644 --- a/clientapi/admin_test.go +++ b/clientapi/admin_test.go @@ -7,9 +7,12 @@ import ( "testing" "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/keyserver" "github.com/matrix-org/dendrite/roomserver" + "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/tidwall/gjson" @@ -41,7 +44,7 @@ func TestAdminResetPassword(t *testing.T) { userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) keyAPI.SetUserAPI(userAPI) // We mostly need the userAPI for this test, so nil for other APIs/caches etc. - AddPublicRoutes(base, nil, nil, nil, nil, nil, userAPI, nil, nil, nil) + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) // Create the users in the userapi and login accessTokens := map[*test.User]string{ @@ -112,6 +115,7 @@ func TestAdminResetPassword(t *testing.T) { } for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case t.Run(tc.name, func(t *testing.T) { req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/resetPassword/"+tc.userID) if tc.requestOpt != nil { @@ -132,3 +136,100 @@ func TestAdminResetPassword(t *testing.T) { } }) } + +func TestPurgeRoom(t *testing.T) { + aliceAdmin := test.NewUser(t, test.WithAccountType(uapi.AccountTypeAdmin)) + bob := test.NewUser(t) + room := test.NewRoom(t, aliceAdmin, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + room.CreateAndInsert(t, aliceAdmin, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, baseClose := testrig.CreateBaseDendrite(t, dbType) + defer baseClose() + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + keyAPI.SetUserAPI(userAPI) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // We mostly need the rsAPI for this test, so nil for other APIs/caches etc. + AddPublicRoutes(base, nil, rsAPI, nil, nil, nil, userAPI, nil, nil, nil) + + // Create the users in the userapi and login + accessTokens := map[*test.User]string{ + aliceAdmin: "", + } + for u := range accessTokens { + localpart, serverName, _ := gomatrixserverlib.SplitID('@', u.ID) + userRes := &uapi.PerformAccountCreationResponse{} + password := util.RandomString(8) + if err := userAPI.PerformAccountCreation(ctx, &uapi.PerformAccountCreationRequest{ + AccountType: u.AccountType, + Localpart: localpart, + ServerName: serverName, + Password: password, + }, userRes); err != nil { + t.Errorf("failed to create account: %s", err) + } + + req := test.NewRequest(t, http.MethodPost, "/_matrix/client/v3/login", test.WithJSONBody(t, map[string]interface{}{ + "type": authtypes.LoginTypePassword, + "identifier": map[string]interface{}{ + "type": "m.id.user", + "user": u.ID, + }, + "password": password, + })) + rec := httptest.NewRecorder() + base.PublicClientAPIMux.ServeHTTP(rec, req) + if rec.Code != http.StatusOK { + t.Fatalf("failed to login: %s", rec.Body.String()) + } + accessTokens[u] = gjson.GetBytes(rec.Body.Bytes(), "access_token").String() + } + + testCases := []struct { + name string + roomID string + wantOK bool + }{ + {name: "Can purge existing room", wantOK: true, roomID: room.ID}, + {name: "Can not purge non-existent room", wantOK: false, roomID: "!doesnotexist:localhost"}, + {name: "rejects invalid room ID", wantOK: false, roomID: "@doesnotexist:localhost"}, + } + + for _, tc := range testCases { + tc := tc // ensure we don't accidentally only test the last test case + t.Run(tc.name, func(t *testing.T) { + req := test.NewRequest(t, http.MethodPost, "/_dendrite/admin/purgeRoom/"+tc.roomID) + + req.Header.Set("Authorization", "Bearer "+accessTokens[aliceAdmin]) + + rec := httptest.NewRecorder() + base.DendriteAdminMux.ServeHTTP(rec, req) + t.Logf("%s", rec.Body.String()) + if tc.wantOK && rec.Code != http.StatusOK { + t.Fatalf("expected http status %d, got %d: %s", http.StatusOK, rec.Code, rec.Body.String()) + } + }) + } + + }) +} diff --git a/clientapi/routing/admin.go b/clientapi/routing/admin.go index dbd913376..4b4dedfd1 100644 --- a/clientapi/routing/admin.go +++ b/clientapi/routing/admin.go @@ -1,6 +1,7 @@ package routing import ( + "context" "encoding/json" "fmt" "net/http" @@ -98,6 +99,37 @@ func AdminEvacuateUser(req *http.Request, cfg *config.ClientAPI, device *userapi } } +func AdminPurgeRoom(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, rsAPI roomserverAPI.ClientRoomserverAPI) util.JSONResponse { + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.ErrorResponse(err) + } + roomID, ok := vars["roomID"] + if !ok { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("Expecting room ID."), + } + } + res := &roomserverAPI.PerformAdminPurgeRoomResponse{} + if err := rsAPI.PerformAdminPurgeRoom( + context.Background(), + &roomserverAPI.PerformAdminPurgeRoomRequest{ + RoomID: roomID, + }, + res, + ); err != nil { + return util.ErrorResponse(err) + } + if err := res.Error; err != nil { + return err.JSONResponse() + } + return util.JSONResponse{ + Code: 200, + JSON: res, + } +} + func AdminResetPassword(req *http.Request, cfg *config.ClientAPI, device *userapi.Device, userAPI userapi.ClientUserAPI) util.JSONResponse { if req.Body == nil { return util.JSONResponse{ diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index 09c2cd02f..93f6ea901 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -165,6 +165,12 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/purgeRoom/{roomID}", + httputil.MakeAdminAPI("admin_purge_room", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + return AdminPurgeRoom(req, cfg, device, rsAPI) + }), + ).Methods(http.MethodPost, http.MethodOptions) + dendriteAdminRouter.Handle("/admin/resetPassword/{userID}", httputil.MakeAdminAPI("admin_reset_password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { return AdminResetPassword(req, cfg, device, userAPI) diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 52b5744a6..82a4db3f7 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/federationapi/queue" @@ -90,8 +91,10 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms msg := msgs[0] // Guaranteed to exist if onMessage is called receivedType := api.OutputType(msg.Header.Get(jetstream.RoomEventType)) - // Only handle events we care about - if receivedType != api.OutputTypeNewRoomEvent && receivedType != api.OutputTypeNewInboundPeek { + // Only handle events we care about, avoids unneeded unmarshalling + switch receivedType { + case api.OutputTypeNewRoomEvent, api.OutputTypeNewInboundPeek, api.OutputTypePurgeRoom: + default: return true } @@ -126,6 +129,14 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms return false } + case api.OutputTypePurgeRoom: + log.WithField("room_id", output.PurgeRoom.RoomID).Warn("Purging room from federation API") + if err := s.db.PurgeRoom(ctx, output.PurgeRoom.RoomID); err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from federation API") + } else { + logrus.WithField("room_id", output.PurgeRoom.RoomID).Warn("Room purged from federation API") + } + default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 276cd9a50..2b4d905fc 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -71,4 +71,6 @@ type Database interface { GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) // DeleteExpiredEDUs cleans up expired EDUs DeleteExpiredEDUs(ctx context.Context) error + + PurgeRoom(ctx context.Context, roomID string) error } diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 1e1ea9e17..6cda55725 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -259,3 +259,18 @@ func (d *Database) GetNotaryKeys( }) return sks, err } + +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge joined hosts: %w", err) + } + if err := d.FederationInboundPeeks.DeleteInboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge inbound peeks: %w", err) + } + if err := d.FederationOutboundPeeks.DeleteOutboundPeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge outbound peeks: %w", err) + } + return nil + }) +} diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index 19483b268..81c055edd 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -124,6 +124,11 @@ type QueryProvider interface { QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) } +// ExecProvider defines the interface for querys used by RunLimitedVariablesExec. +type ExecProvider interface { + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + // SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement // SQLlite can handle. See https://www.sqlite.org/limits.html for more information. const SQLite3MaxVariables = 999 @@ -153,6 +158,22 @@ func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvide return nil } +// RunLimitedVariablesExec split up a query with more variables than the used database can handle in multiple queries. +func RunLimitedVariablesExec(ctx context.Context, query string, qp ExecProvider, variables []interface{}, limit uint) error { + var start int + for start < len(variables) { + n := minOfInts(len(variables)-start, int(limit)) + nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1) + _, err := qp.ExecContext(ctx, nextQuery, variables[start:start+n]...) + if err != nil { + util.GetLogger(ctx).WithError(err).Error("ExecContext returned an error") + return err + } + start = start + n + } + return nil +} + // StatementList is a list of SQL statements to prepare and a pointer to where to store the resulting prepared statement. type StatementList []struct { Statement **sql.Stmt diff --git a/internal/sqlutil/sqlutil_test.go b/internal/sqlutil/sqlutil_test.go index 79469cddc..c40757893 100644 --- a/internal/sqlutil/sqlutil_test.go +++ b/internal/sqlutil/sqlutil_test.go @@ -3,10 +3,11 @@ package sqlutil import ( "context" "database/sql" + "errors" "reflect" "testing" - sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/DATA-DOG/go-sqlmock" ) func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) { @@ -164,6 +165,54 @@ func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) { } } +func TestRunLimitedVariablesExec(t *testing.T) { + db, mock, err := sqlmock.New() + assertNoError(t, err, "Failed to make DB") + + // Query and expect two queries to be executed + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + variables := []interface{}{ + 1, 2, 3, 4, + } + + query := "DELETE FROM WHERE id IN ($1)" + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables, 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 3 parameters, still queries two times + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:3], 2); err != nil { + t.Fatal(err) + } + + // Query again, but only 2 parameters, queries only once + mock.ExpectExec(`DELETE FROM WHERE id IN \(\$1\, \$2\)`). + WillReturnResult(sqlmock.NewResult(0, 0)) + + if err = RunLimitedVariablesExec(context.Background(), query, db, variables[:2], 2); err != nil { + t.Fatal(err) + } + + // Test with invalid query (typo) should return an error + mock.ExpectExec(`DELTE FROM`). + WillReturnResult(sqlmock.NewResult(0, 0)). + WillReturnError(errors.New("typo in query")) + + if err = RunLimitedVariablesExec(context.Background(), "DELTE FROM", db, variables[:2], 2); err == nil { + t.Fatal("expected an error, but got none") + } +} + func assertNoError(t *testing.T, err error, msg string) { t.Helper() if err == nil { diff --git a/roomserver/api/api.go b/roomserver/api/api.go index 420ef278a..a8228ae81 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -151,6 +151,7 @@ type ClientRoomserverAPI interface { PerformRoomUpgrade(ctx context.Context, req *PerformRoomUpgradeRequest, resp *PerformRoomUpgradeResponse) error PerformAdminEvacuateRoom(ctx context.Context, req *PerformAdminEvacuateRoomRequest, res *PerformAdminEvacuateRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, req *PerformAdminEvacuateUserRequest, res *PerformAdminEvacuateUserResponse) error + PerformAdminPurgeRoom(ctx context.Context, req *PerformAdminPurgeRoomRequest, res *PerformAdminPurgeRoomResponse) error PerformAdminDownloadState(ctx context.Context, req *PerformAdminDownloadStateRequest, res *PerformAdminDownloadStateResponse) error PerformPeek(ctx context.Context, req *PerformPeekRequest, res *PerformPeekResponse) error PerformUnpeek(ctx context.Context, req *PerformUnpeekRequest, res *PerformUnpeekResponse) error diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index b23263d17..166b651a2 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -137,6 +137,16 @@ func (t *RoomserverInternalAPITrace) PerformAdminEvacuateUser( return err } +func (t *RoomserverInternalAPITrace) PerformAdminPurgeRoom( + ctx context.Context, + req *PerformAdminPurgeRoomRequest, + res *PerformAdminPurgeRoomResponse, +) error { + err := t.Impl.PerformAdminPurgeRoom(ctx, req, res) + util.GetLogger(ctx).WithError(err).Infof("PerformAdminPurgeRoom req=%+v res=%+v", js(req), js(res)) + return err +} + func (t *RoomserverInternalAPITrace) PerformAdminDownloadState( ctx context.Context, req *PerformAdminDownloadStateRequest, diff --git a/roomserver/api/output.go b/roomserver/api/output.go index 36d0625c7..0c0f52c45 100644 --- a/roomserver/api/output.go +++ b/roomserver/api/output.go @@ -55,6 +55,8 @@ const ( OutputTypeNewInboundPeek OutputType = "new_inbound_peek" // OutputTypeRetirePeek indicates that the kafka event is an OutputRetirePeek OutputTypeRetirePeek OutputType = "retire_peek" + // OutputTypePurgeRoom indicates the event is an OutputPurgeRoom + OutputTypePurgeRoom OutputType = "purge_room" ) // An OutputEvent is an entry in the roomserver output kafka log. @@ -78,6 +80,8 @@ type OutputEvent struct { NewInboundPeek *OutputNewInboundPeek `json:"new_inbound_peek,omitempty"` // The content of event with type OutputTypeRetirePeek RetirePeek *OutputRetirePeek `json:"retire_peek,omitempty"` + // The content of the event with type OutputPurgeRoom + PurgeRoom *OutputPurgeRoom `json:"purge_room,omitempty"` } // Type of the OutputNewRoomEvent. @@ -257,3 +261,7 @@ type OutputRetirePeek struct { UserID string DeviceID string } + +type OutputPurgeRoom struct { + RoomID string +} diff --git a/roomserver/api/perform.go b/roomserver/api/perform.go index e789b9568..83cb0460a 100644 --- a/roomserver/api/perform.go +++ b/roomserver/api/perform.go @@ -241,6 +241,14 @@ type PerformAdminEvacuateUserResponse struct { Error *PerformError } +type PerformAdminPurgeRoomRequest struct { + RoomID string `json:"room_id"` +} + +type PerformAdminPurgeRoomResponse struct { + Error *PerformError `json:"error,omitempty"` +} + type PerformAdminDownloadStateRequest struct { RoomID string `json:"room_id"` UserID string `json:"user_id"` diff --git a/roomserver/internal/perform/perform_admin.go b/roomserver/internal/perform/perform_admin.go index d42f4e45d..3256162b4 100644 --- a/roomserver/internal/perform/perform_admin.go +++ b/roomserver/internal/perform/perform_admin.go @@ -28,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" ) type Admin struct { @@ -242,6 +243,42 @@ func (r *Admin) PerformAdminEvacuateUser( return nil } +func (r *Admin) PerformAdminPurgeRoom( + ctx context.Context, + req *api.PerformAdminPurgeRoomRequest, + res *api.PerformAdminPurgeRoomResponse, +) error { + // Validate we actually got a room ID and nothing else + if _, _, err := gomatrixserverlib.SplitID('!', req.RoomID); err != nil { + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: fmt.Sprintf("Malformed room ID: %s", err), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Purging room from roomserver") + if err := r.DB.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Warn("Failed to purge room from roomserver") + res.Error = &api.PerformError{ + Code: api.PerformErrorBadRequest, + Msg: err.Error(), + } + return nil + } + + logrus.WithField("room_id", req.RoomID).Warn("Room purged from roomserver") + + return r.Inputer.OutputProducer.ProduceRoomEvents(req.RoomID, []api.OutputEvent{ + { + Type: api.OutputTypePurgeRoom, + PurgeRoom: &api.OutputPurgeRoom{ + RoomID: req.RoomID, + }, + }, + }) +} + func (r *Admin) PerformAdminDownloadState( ctx context.Context, req *api.PerformAdminDownloadStateRequest, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 8a2e0a03c..556a137ba 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -40,6 +40,7 @@ const ( RoomserverPerformAdminEvacuateRoomPath = "/roomserver/performAdminEvacuateRoom" RoomserverPerformAdminEvacuateUserPath = "/roomserver/performAdminEvacuateUser" RoomserverPerformAdminDownloadStatePath = "/roomserver/performAdminDownloadState" + RoomserverPerformAdminPurgeRoomPath = "/roomserver/performAdminPurgeRoom" // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" @@ -285,6 +286,17 @@ func (h *httpRoomserverInternalAPI) PerformAdminEvacuateUser( ) } +func (h *httpRoomserverInternalAPI) PerformAdminPurgeRoom( + ctx context.Context, + request *api.PerformAdminPurgeRoomRequest, + response *api.PerformAdminPurgeRoomResponse, +) error { + return httputil.CallInternalRPCAPI( + "PerformAdminPurgeRoom", h.roomserverURL+RoomserverPerformAdminPurgeRoomPath, + h.httpClient, ctx, request, response, + ) +} + // QueryLatestEventsAndState implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryLatestEventsAndState( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index 4d21909b7..f3a51b0b1 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -65,6 +65,11 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe httputil.MakeInternalRPCAPI("RoomserverPerformAdminEvacuateUser", enableMetrics, r.PerformAdminEvacuateUser), ) + internalAPIMux.Handle( + RoomserverPerformAdminPurgeRoomPath, + httputil.MakeInternalRPCAPI("RoomserverPerformAdminPurgeRoom", enableMetrics, r.PerformAdminPurgeRoom), + ) + internalAPIMux.Handle( RoomserverPerformAdminDownloadStatePath, httputil.MakeInternalRPCAPI("RoomserverPerformAdminDownloadState", enableMetrics, r.PerformAdminDownloadState), diff --git a/roomserver/roomserver_test.go b/roomserver/roomserver_test.go index 595ceb526..3ec2560d6 100644 --- a/roomserver/roomserver_test.go +++ b/roomserver/roomserver_test.go @@ -14,6 +14,10 @@ import ( userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/dendrite/federationapi" + "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/syncapi" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/dendrite/roomserver" @@ -223,3 +227,164 @@ func Test_QueryLeftUsers(t *testing.T) { }) } + +func TestPurgeRoom(t *testing.T) { + alice := test.NewUser(t) + bob := test.NewUser(t) + room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat)) + + // Invite Bob + inviteEvent := room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{ + "membership": "invite", + }, test.WithStateKey(bob.ID)) + + ctx := context.Background() + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, db, close := mustCreateDatabase(t, dbType) + defer close() + + jsCtx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + defer jetstream.DeleteAllStreams(jsCtx, &base.Cfg.Global.JetStream) + + fedClient := base.CreateFederationClient() + rsAPI := roomserver.NewInternalAPI(base) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fedClient, rsAPI) + userAPI := userapi.NewInternalAPI(base, &base.Cfg.UserAPI, nil, keyAPI, rsAPI, nil) + + // this starts the JetStream consumers + syncapi.AddPublicRoutes(base, userAPI, rsAPI, keyAPI) + federationapi.NewInternalAPI(base, fedClient, rsAPI, base.Caches, nil, true) + rsAPI.SetFederationAPI(nil, nil) + + // Create the room + if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil { + t.Fatalf("failed to send events: %v", err) + } + + // some dummy entries to validate after purging + publishResp := &api.PerformPublishResponse{} + if err := rsAPI.PerformPublish(ctx, &api.PerformPublishRequest{RoomID: room.ID, Visibility: "public"}, publishResp); err != nil { + t.Fatal(err) + } + if publishResp.Error != nil { + t.Fatal(publishResp.Error) + } + + isPublished, err := db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if !isPublished { + t.Fatalf("room should be published before purging") + } + + aliasResp := &api.SetRoomAliasResponse{} + if err = rsAPI.SetRoomAlias(ctx, &api.SetRoomAliasRequest{RoomID: room.ID, Alias: "myalias", UserID: alice.ID}, aliasResp); err != nil { + t.Fatal(err) + } + // check the alias is actually there + aliasesResp := &api.GetAliasesForRoomIDResponse{} + if err = rsAPI.GetAliasesForRoomID(ctx, &api.GetAliasesForRoomIDRequest{RoomID: room.ID}, aliasesResp); err != nil { + t.Fatal(err) + } + wantAliases := 1 + if gotAliases := len(aliasesResp.Aliases); gotAliases != wantAliases { + t.Fatalf("expected %d aliases, got %d", wantAliases, gotAliases) + } + + // validate the room exists before purging + roomInfo, err := db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo == nil { + t.Fatalf("room does not exist") + } + // remember the roomInfo before purging + existingRoomInfo := roomInfo + + // validate there is an invite for bob + nids, err := db.EventStateKeyNIDs(ctx, []string{bob.ID}) + if err != nil { + t.Fatal(err) + } + bobNID, ok := nids[bob.ID] + if !ok { + t.Fatalf("%s does not exist", bob.ID) + } + + _, inviteEventIDs, _, err := db.GetInvitesForUser(ctx, roomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + wantInviteCount := 1 + if inviteCount := len(inviteEventIDs); inviteCount != wantInviteCount { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + if inviteEventIDs[0] != inviteEvent.EventID() { + t.Fatalf("expected invite event ID %s, got %s", inviteEvent.EventID(), inviteEventIDs[0]) + } + + // purge the room from the database + purgeResp := &api.PerformAdminPurgeRoomResponse{} + if err = rsAPI.PerformAdminPurgeRoom(ctx, &api.PerformAdminPurgeRoomRequest{RoomID: room.ID}, purgeResp); err != nil { + t.Fatal(err) + } + + // wait for all consumers to process the purge event + var sum = 1 + timeout := time.Second * 5 + deadline, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + for sum > 0 { + if deadline.Err() != nil { + t.Fatalf("test timed out after %s", timeout) + } + sum = 0 + consumerCh := jsCtx.Consumers(base.Cfg.Global.JetStream.Prefixed(jetstream.OutputRoomEvent)) + for x := range consumerCh { + sum += x.NumAckPending + } + time.Sleep(time.Millisecond) + } + + roomInfo, err = db.RoomInfo(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if roomInfo != nil { + t.Fatalf("room should not exist after purging: %+v", roomInfo) + } + + // validation below + + // There should be no invite left + _, inviteEventIDs, _, err = db.GetInvitesForUser(ctx, existingRoomInfo.RoomNID, bobNID) + if err != nil { + t.Fatal(err) + } + + if inviteCount := len(inviteEventIDs); inviteCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", wantInviteCount, inviteCount) + } + + // aliases should be deleted + aliases, err := db.GetAliasesForRoomID(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if aliasCount := len(aliases); aliasCount > 0 { + t.Fatalf("expected there to be only %d invite events, got %d", 0, aliasCount) + } + + // published room should be deleted + isPublished, err = db.GetPublishedRoom(ctx, room.ID) + if err != nil { + t.Fatal(err) + } + if isPublished { + t.Fatalf("room should not be published after purging") + } + }) +} diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index 92bc2e66f..e0b9c56b3 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -173,5 +173,6 @@ type Database interface { GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) + PurgeRoom(ctx context.Context, roomID string) error UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error } diff --git a/roomserver/storage/postgres/purge_statements.go b/roomserver/storage/postgres/purge_statements.go new file mode 100644 index 000000000..efba439bd --- /dev/null +++ b/roomserver/storage/postgres/purge_statements.go @@ -0,0 +1,133 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid = ANY(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids && ANY(" + + " SELECT ARRAY_AGG(event_nid) FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id = ANY(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateBlockEntriesSQL = "" + + "DELETE FROM roomserver_state_block WHERE state_block_nid = ANY(" + + " SELECT DISTINCT UNNEST(state_block_nids) FROM roomserver_state_snapshots WHERE room_nid = $1" + + ")" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateBlockEntriesStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt +} + +func PreparePurgeStatements(db *sql.DB) (*purgeStatements, error) { + s := &purgeStatements{} + + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + {&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateBlockEntriesStmt, + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} diff --git a/roomserver/storage/postgres/rooms_table.go b/roomserver/storage/postgres/rooms_table.go index 994399532..c8346733d 100644 --- a/roomserver/storage/postgres/rooms_table.go +++ b/roomserver/storage/postgres/rooms_table.go @@ -58,6 +58,9 @@ const insertRoomNIDSQL = "" + const selectRoomNIDSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1 FOR UPDATE" + const selectLatestEventNIDsSQL = "" + "SELECT latest_event_nids, state_snapshot_nid FROM roomserver_rooms WHERE room_nid = $1" @@ -85,6 +88,7 @@ const bulkSelectRoomNIDsSQL = "" + type roomStatements struct { insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -106,6 +110,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { return s, sqlutil.StatementList{ {&s.insertRoomNIDStmt, insertRoomNIDSQL}, {&s.selectRoomNIDStmt, selectRoomNIDSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, {&s.selectLatestEventNIDsStmt, selectLatestEventNIDsSQL}, {&s.selectLatestEventNIDsForUpdateStmt, selectLatestEventNIDsForUpdateSQL}, {&s.updateLatestEventNIDsStmt, updateLatestEventNIDsSQL}, @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index 23a5f79eb..872084383 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -189,6 +189,10 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db) + if err != nil { + return err + } d.Database = shared.Database{ DB: db, Cache: cache, @@ -206,6 +210,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room MembershipTable: membership, PublishedTable: published, RedactionsTable: redactions, + Purge: purge, } return nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 725cc5bc7..654b078d2 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -43,6 +43,7 @@ type Database struct { MembershipTable tables.Membership PublishedTable tables.Published RedactionsTable tables.Redactions + Purge tables.Purge GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error) } @@ -1445,6 +1446,21 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget }) } +// PurgeRoom removes all information about a given room from the roomserver. +// For large rooms this operation may take a considerable amount of time. +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + roomNID, err := d.RoomsTable.SelectRoomNIDForUpdate(ctx, txn, roomID) + if err != nil { + if err == sql.ErrNoRows { + return fmt.Errorf("room %s does not exist", roomID) + } + return fmt.Errorf("failed to lock the room: %w", err) + } + return d.Purge.PurgeRoom(ctx, txn, roomNID, roomID) + }) +} + func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/roomserver/storage/sqlite3/purge_statements.go b/roomserver/storage/sqlite3/purge_statements.go new file mode 100644 index 000000000..c7b4d27a5 --- /dev/null +++ b/roomserver/storage/sqlite3/purge_statements.go @@ -0,0 +1,153 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const purgeEventJSONSQL = "" + + "DELETE FROM roomserver_event_json WHERE event_nid IN (" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeEventsSQL = "" + + "DELETE FROM roomserver_events WHERE room_nid = $1" + +const purgeInvitesSQL = "" + + "DELETE FROM roomserver_invites WHERE room_nid = $1" + +const purgeMembershipsSQL = "" + + "DELETE FROM roomserver_membership WHERE room_nid = $1" + +const purgePreviousEventsSQL = "" + + "DELETE FROM roomserver_previous_events WHERE event_nids IN(" + + " SELECT event_nid FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgePublishedSQL = "" + + "DELETE FROM roomserver_published WHERE room_id = $1" + +const purgeRedactionsSQL = "" + + "DELETE FROM roomserver_redactions WHERE redaction_event_id IN(" + + " SELECT event_id FROM roomserver_events WHERE room_nid = $1" + + ")" + +const purgeRoomAliasesSQL = "" + + "DELETE FROM roomserver_room_aliases WHERE room_id = $1" + +const purgeRoomSQL = "" + + "DELETE FROM roomserver_rooms WHERE room_nid = $1" + +const purgeStateSnapshotEntriesSQL = "" + + "DELETE FROM roomserver_state_snapshots WHERE room_nid = $1" + +type purgeStatements struct { + purgeEventJSONStmt *sql.Stmt + purgeEventsStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt + purgePreviousEventsStmt *sql.Stmt + purgePublishedStmt *sql.Stmt + purgeRedactionStmt *sql.Stmt + purgeRoomAliasesStmt *sql.Stmt + purgeRoomStmt *sql.Stmt + purgeStateSnapshotEntriesStmt *sql.Stmt + stateSnapshot *stateSnapshotStatements +} + +func PreparePurgeStatements(db *sql.DB, stateSnapshot *stateSnapshotStatements) (*purgeStatements, error) { + s := &purgeStatements{stateSnapshot: stateSnapshot} + return s, sqlutil.StatementList{ + {&s.purgeEventJSONStmt, purgeEventJSONSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, + {&s.purgePublishedStmt, purgePublishedSQL}, + {&s.purgePreviousEventsStmt, purgePreviousEventsSQL}, + {&s.purgeRedactionStmt, purgeRedactionsSQL}, + {&s.purgeRoomAliasesStmt, purgeRoomAliasesSQL}, + {&s.purgeRoomStmt, purgeRoomSQL}, + //{&s.purgeStateBlockEntriesStmt, purgeStateBlockEntriesSQL}, + {&s.purgeStateSnapshotEntriesStmt, purgeStateSnapshotEntriesSQL}, + }.Prepare(db) +} + +func (s *purgeStatements) PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, +) error { + + // purge by roomID + purgeByRoomID := []*sql.Stmt{ + s.purgeRoomAliasesStmt, + s.purgePublishedStmt, + } + for _, stmt := range purgeByRoomID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomID) + if err != nil { + return err + } + } + + // purge by roomNID + if err := s.purgeStateBlocks(ctx, txn, roomNID); err != nil { + return err + } + + purgeByRoomNID := []*sql.Stmt{ + s.purgeStateSnapshotEntriesStmt, + s.purgeInvitesStmt, + s.purgeMembershipsStmt, + s.purgePreviousEventsStmt, + s.purgeEventJSONStmt, + s.purgeRedactionStmt, + s.purgeEventsStmt, + s.purgeRoomStmt, + } + for _, stmt := range purgeByRoomNID { + _, err := sqlutil.TxStmt(txn, stmt).ExecContext(ctx, roomNID) + if err != nil { + return err + } + } + return nil +} + +func (s *purgeStatements) purgeStateBlocks( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) error { + // Get all stateBlockNIDs + stateBlockNIDs, err := s.stateSnapshot.selectStateBlockNIDsForRoomNID(ctx, txn, roomNID) + if err != nil { + return err + } + params := make([]interface{}, len(stateBlockNIDs)) + seenNIDs := make(map[types.StateBlockNID]struct{}, len(stateBlockNIDs)) + // dedupe NIDs + for k, v := range stateBlockNIDs { + if _, ok := seenNIDs[v]; ok { + continue + } + params[k] = v + seenNIDs[v] = struct{}{} + } + + query := "DELETE FROM roomserver_state_block WHERE state_block_nid IN($1)" + return sqlutil.RunLimitedVariablesExec(ctx, query, txn, params, sqlutil.SQLite3MaxVariables) +} diff --git a/roomserver/storage/sqlite3/rooms_table.go b/roomserver/storage/sqlite3/rooms_table.go index 25b611b3e..7556b3461 100644 --- a/roomserver/storage/sqlite3/rooms_table.go +++ b/roomserver/storage/sqlite3/rooms_table.go @@ -74,10 +74,14 @@ const bulkSelectRoomIDsSQL = "" + const bulkSelectRoomNIDsSQL = "" + "SELECT room_nid FROM roomserver_rooms WHERE room_id IN ($1)" +const selectRoomNIDForUpdateSQL = "" + + "SELECT room_nid FROM roomserver_rooms WHERE room_id = $1" + type roomStatements struct { db *sql.DB insertRoomNIDStmt *sql.Stmt selectRoomNIDStmt *sql.Stmt + selectRoomNIDForUpdateStmt *sql.Stmt selectLatestEventNIDsStmt *sql.Stmt selectLatestEventNIDsForUpdateStmt *sql.Stmt updateLatestEventNIDsStmt *sql.Stmt @@ -105,6 +109,7 @@ func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) { //{&s.selectRoomVersionForRoomNIDsStmt, selectRoomVersionForRoomNIDsSQL}, {&s.selectRoomInfoStmt, selectRoomInfoSQL}, {&s.selectRoomIDsStmt, selectRoomIDsSQL}, + {&s.selectRoomNIDForUpdateStmt, selectRoomNIDForUpdateSQL}, }.Prepare(db) } @@ -169,6 +174,15 @@ func (s *roomStatements) SelectRoomNID( return types.RoomNID(roomNID), err } +func (s *roomStatements) SelectRoomNIDForUpdate( + ctx context.Context, txn *sql.Tx, roomID string, +) (types.RoomNID, error) { + var roomNID int64 + stmt := sqlutil.TxStmt(txn, s.selectRoomNIDForUpdateStmt) + err := stmt.QueryRowContext(ctx, roomID).Scan(&roomNID) + return types.RoomNID(roomNID), err +} + func (s *roomStatements) SelectLatestEventNIDs( ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, ) ([]types.EventNID, types.StateSnapshotNID, error) { diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 4e67d4da1..ae8181cfa 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -24,7 +24,6 @@ import ( "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" ) @@ -68,7 +67,7 @@ func CreateStateBlockTable(db *sql.DB) error { return err } -func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) { +func PrepareStateBlockTable(db *sql.DB) (*stateBlockStatements, error) { s := &stateBlockStatements{ db: db, } diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 73827522c..930ad14dd 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -62,10 +62,14 @@ const bulkSelectStateBlockNIDsSQL = "" + "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" +const selectStateBlockNIDsForRoomNID = "" + + "SELECT state_block_nids FROM roomserver_state_snapshots WHERE room_nid = $1" + type stateSnapshotStatements struct { db *sql.DB insertStateStmt *sql.Stmt bulkSelectStateBlockNIDsStmt *sql.Stmt + selectStateBlockNIDsStmt *sql.Stmt } func CreateStateSnapshotTable(db *sql.DB) error { @@ -73,7 +77,7 @@ func CreateStateSnapshotTable(db *sql.DB) error { return err } -func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { +func PrepareStateSnapshotTable(db *sql.DB) (*stateSnapshotStatements, error) { s := &stateSnapshotStatements{ db: db, } @@ -81,6 +85,7 @@ func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { return s, sqlutil.StatementList{ {&s.insertStateStmt, insertStateSQL}, {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, + {&s.selectStateBlockNIDsStmt, selectStateBlockNIDsForRoomNID}, }.Prepare(db) } @@ -146,3 +151,29 @@ func (s *stateSnapshotStatements) BulkSelectStateForHistoryVisibility( ) ([]types.EventNID, error) { return nil, tables.OptimisationNotSupportedError } + +func (s *stateSnapshotStatements) selectStateBlockNIDsForRoomNID( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, +) ([]types.StateBlockNID, error) { + var res []types.StateBlockNID + rows, err := sqlutil.TxStmt(txn, s.selectStateBlockNIDsStmt).QueryContext(ctx, roomNID) + if err != nil { + return res, nil + } + defer internal.CloseAndLogIfError(ctx, rows, "selectStateBlockNIDsForRoomNID: rows.close() failed") + + var stateBlockNIDs []types.StateBlockNID + var stateBlockNIDsJSON string + for rows.Next() { + if err = rows.Scan(&stateBlockNIDsJSON); err != nil { + return nil, err + } + if err = json.Unmarshal([]byte(stateBlockNIDsJSON), &stateBlockNIDs); err != nil { + return nil, err + } + + res = append(res, stateBlockNIDs...) + } + + return res, rows.Err() +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 01c3f879c..392edd289 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -197,6 +197,11 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room if err != nil { return err } + purge, err := PreparePurgeStatements(db, stateSnapshot) + if err != nil { + return err + } + d.Database = shared.Database{ DB: db, Cache: cache, @@ -215,6 +220,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room PublishedTable: published, RedactionsTable: redactions, GetRoomUpdaterFn: d.GetRoomUpdater, + Purge: purge, } return nil } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 80fcf72dd..64145f83d 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -73,6 +73,7 @@ type Events interface { type Rooms interface { InsertRoomNID(ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion) (types.RoomNID, error) SelectRoomNID(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) + SelectRoomNIDForUpdate(ctx context.Context, txn *sql.Tx, roomID string) (types.RoomNID, error) SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error) SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error) UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error @@ -173,6 +174,12 @@ type Redactions interface { MarkRedactionValidated(ctx context.Context, txn *sql.Tx, redactionEventID string, validated bool) error } +type Purge interface { + PurgeRoom( + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, roomID string, + ) error +} + // StrippedEvent represents a stripped event for returning extracted content values. type StrippedEvent struct { RoomID string diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index 1b67f5684..21838039a 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -23,6 +23,7 @@ import ( "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/nats-io/nats.go" + "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -127,6 +128,12 @@ func (s *OutputRoomEventConsumer) onMessage(ctx context.Context, msgs []*nats.Ms s.onRetirePeek(s.ctx, *output.RetirePeek) case api.OutputTypeRedactedEvent: err = s.onRedactEvent(s.ctx, *output.RedactedEvent) + case api.OutputTypePurgeRoom: + err = s.onPurgeRoom(s.ctx, *output.PurgeRoom) + if err != nil { + logrus.WithField("room_id", output.PurgeRoom.RoomID).WithError(err).Error("Failed to purge room from sync API") + return true // non-fatal, as otherwise we end up in a loop of trying to purge the room + } default: log.WithField("type", output.Type).Debug( "roomserver output log: ignoring unknown output type", @@ -473,6 +480,20 @@ func (s *OutputRoomEventConsumer) onRetirePeek( s.notifier.OnRetirePeek(msg.RoomID, msg.UserID, msg.DeviceID, types.StreamingToken{PDUPosition: sp}) } +func (s *OutputRoomEventConsumer) onPurgeRoom( + ctx context.Context, req api.OutputPurgeRoom, +) error { + logrus.WithField("room_id", req.RoomID).Warn("Purging room from sync API") + + if err := s.db.PurgeRoom(ctx, req.RoomID); err != nil { + logrus.WithField("room_id", req.RoomID).WithError(err).Error("Failed to purge room from sync API") + return err + } else { + logrus.WithField("room_id", req.RoomID).Warn("Room purged from sync API") + return nil + } +} + func (s *OutputRoomEventConsumer) updateStateEvent(event *gomatrixserverlib.HeaderedEvent) (*gomatrixserverlib.HeaderedEvent, error) { if event.StateKey() == nil { return event, nil diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index a4ba82327..a7a127e3a 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -132,6 +132,8 @@ type Database interface { // PurgeRoomState completely purges room state from the sync API. This is done when // receiving an output event that completely resets the state. PurgeRoomState(ctx context.Context, roomID string) error + // PurgeRoom entirely eliminates a room from the sync API, timeline, state and all. + PurgeRoom(ctx context.Context, roomID string) error // UpsertAccountData keeps track of new or updated account data, by saving the type // of the new/updated data, and the user ID and room ID the data is related to (empty) // room ID means the data isn't specific to any room) diff --git a/syncapi/storage/postgres/backwards_extremities_table.go b/syncapi/storage/postgres/backwards_extremities_table.go index 8fc92091f..c20d860a7 100644 --- a/syncapi/storage/postgres/backwards_extremities_table.go +++ b/syncapi/storage/postgres/backwards_extremities_table.go @@ -47,10 +47,14 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -59,16 +63,12 @@ func NewPostgresBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremiti if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -106,3 +106,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/invites_table.go b/syncapi/storage/postgres/invites_table.go index aada70d5e..151bffa5d 100644 --- a/syncapi/storage/postgres/invites_table.go +++ b/syncapi/storage/postgres/invites_table.go @@ -62,11 +62,15 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { insertInviteEventStmt *sql.Stmt selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewPostgresInvitesTable(db *sql.DB) (tables.Invites, error) { if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -181,3 +179,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index ac44b235f..47833893a 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -65,18 +65,22 @@ const selectMembershipCountSQL = "" + const selectMembershipBeforeSQL = "" + "SELECT membership, topological_pos FROM syncapi_memberships WHERE room_id = $1 and user_id = $2 AND topological_pos <= $3 ORDER BY topological_pos DESC LIMIT 1" +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + const selectMembersSQL = ` -SELECT event_id FROM ( - SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC -) t -WHERE ($3::text IS NULL OR t.membership = $3) - AND ($4::text IS NULL OR t.membership <> $4) + SELECT event_id FROM ( + SELECT DISTINCT ON (room_id, user_id) room_id, user_id, event_id, membership FROM syncapi_memberships WHERE room_id = $1 AND topological_pos <= $2 ORDER BY room_id, user_id, stream_pos DESC + ) t + WHERE ($3::text IS NULL OR t.membership = $3) + AND ($4::text IS NULL OR t.membership <> $4) ` type membershipsStatements struct { upsertMembershipStmt *sql.Stmt selectMembershipCountStmt *sql.Stmt selectMembershipForUserStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt selectMembersStmt *sql.Stmt } @@ -90,6 +94,7 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.upsertMembershipStmt, upsertMembershipSQL}, {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, {&s.selectMembersStmt, selectMembersSQL}, }.Prepare(db) } @@ -139,6 +144,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/postgres/notification_data_table.go b/syncapi/storage/postgres/notification_data_table.go index 2c7b24800..7edfd54a6 100644 --- a/syncapi/storage/postgres/notification_data_table.go +++ b/syncapi/storage/postgres/notification_data_table.go @@ -37,6 +37,7 @@ func NewPostgresNotificationDataTable(db *sql.DB) (tables.NotificationData, erro {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, }.Prepare(db) } @@ -44,6 +45,7 @@ type notificationDataStatements struct { upsertRoomUnreadCounts *sql.Stmt selectUserUnreadCountsForRooms *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt } const notificationDataSchema = ` @@ -70,6 +72,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { err = sqlutil.TxStmt(txn, r.upsertRoomUnreadCounts).QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) return @@ -106,3 +111,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index 3b69b26f6..0075fc8d3 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -176,6 +176,9 @@ const selectContextAfterEventSQL = "" + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + " ORDER BY id ASC LIMIT $3" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE id > $1 AND type = ANY($2) ORDER BY id ASC LIMIT $3" type outputRoomEventsStatements struct { @@ -193,6 +196,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt selectSearchStmt *sql.Stmt } @@ -230,6 +234,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, {&s.selectSearchStmt, selectSearchSQL}, }.Prepare(db) } @@ -658,6 +663,13 @@ func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { return result, rows.Err() } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { rows, err := sqlutil.TxStmt(txn, s.selectSearchStmt).QueryContext(ctx, afterID, pq.StringArray(types), limit) if err != nil { diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index d0e99f267..2382fca5c 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -18,11 +18,12 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -71,6 +72,9 @@ const selectStreamToTopologicalPositionAscSQL = "" + const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { insertEventInTopologyStmt *sql.Stmt selectEventIDsInRangeASCStmt *sql.Stmt @@ -78,6 +82,7 @@ type outputRoomEventsTopologyStatements struct { selectPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -86,25 +91,15 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // InsertEventInTopology inserts the given event in the room's topology, based @@ -177,3 +172,10 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } + +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/peeks_table.go b/syncapi/storage/postgres/peeks_table.go index e20a4882f..64183073d 100644 --- a/syncapi/storage/postgres/peeks_table.go +++ b/syncapi/storage/postgres/peeks_table.go @@ -65,6 +65,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB insertPeekStmt *sql.Stmt @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { @@ -83,25 +87,15 @@ func NewPostgresPeeksTable(db *sql.DB) (tables.Peeks, error) { s := &peekStatements{ db: db, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -184,3 +178,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/postgres/receipt_table.go b/syncapi/storage/postgres/receipt_table.go index 327a7a372..0fcbebfcb 100644 --- a/syncapi/storage/postgres/receipt_table.go +++ b/syncapi/storage/postgres/receipt_table.go @@ -62,11 +62,15 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { @@ -86,16 +90,12 @@ func NewPostgresReceiptsTable(db *sql.DB) (tables.Receipts, error) { r := &receiptStatements{ db: db, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) { @@ -138,3 +138,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/shared/storage_consumer.go b/syncapi/storage/shared/storage_consumer.go index df2338cf8..aeeebb1d2 100644 --- a/syncapi/storage/shared/storage_consumer.go +++ b/syncapi/storage/shared/storage_consumer.go @@ -242,20 +242,6 @@ func (d *Database) handleBackwardExtremities(ctx context.Context, txn *sql.Tx, e return nil } -func (d *Database) PurgeRoomState( - ctx context.Context, roomID string, -) error { - return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. - if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { - return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) - } - return nil - }) -} - func (d *Database) WriteEvent( ctx context.Context, ev *gomatrixserverlib.HeaderedEvent, diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 7b07cac5e..8385b95a5 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -649,6 +649,53 @@ func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) return d.Presence.GetMaxPresenceID(ctx, d.txn) } +func (d *Database) PurgeRoom(ctx context.Context, roomID string) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + if err := d.BackwardExtremities.PurgeBackwardExtremities(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge backward extremities: %w", err) + } + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge current room state: %w", err) + } + if err := d.Invites.PurgeInvites(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge invites: %w", err) + } + if err := d.Memberships.PurgeMemberships(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge memberships: %w", err) + } + if err := d.NotificationData.PurgeNotificationData(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge notification data: %w", err) + } + if err := d.OutputEvents.PurgeEvents(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events: %w", err) + } + if err := d.Topology.PurgeEventsTopology(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge events topology: %w", err) + } + if err := d.Peeks.PurgePeeks(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge peeks: %w", err) + } + if err := d.Receipts.PurgeReceipts(ctx, txn, roomID); err != nil { + return fmt.Errorf("failed to purge receipts: %w", err) + } + return nil + }) +} + +func (d *Database) PurgeRoomState( + ctx context.Context, roomID string, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + // If the event is a create event then we'll delete all of the existing + // data for the room. The only reason that a create event would be replayed + // to us in this way is if we're about to receive the entire room state. + if err := d.CurrentRoomState.DeleteRoomStateForRoom(ctx, txn, roomID); err != nil { + return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateForRoom: %w", err) + } + return nil + }) +} + func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) { id, err := d.Relations.SelectMaxRelationID(ctx, d.txn) return types.StreamPosition(id), err diff --git a/syncapi/storage/sqlite3/backwards_extremities_table.go b/syncapi/storage/sqlite3/backwards_extremities_table.go index 3a5fd6be3..2d8cf2ed2 100644 --- a/syncapi/storage/sqlite3/backwards_extremities_table.go +++ b/syncapi/storage/sqlite3/backwards_extremities_table.go @@ -47,11 +47,15 @@ const selectBackwardExtremitiesForRoomSQL = "" + const deleteBackwardExtremitySQL = "" + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2" +const purgeBackwardExtremitiesSQL = "" + + "DELETE FROM syncapi_backward_extremities WHERE room_id = $1" + type backwardExtremitiesStatements struct { db *sql.DB insertBackwardExtremityStmt *sql.Stmt selectBackwardExtremitiesForRoomStmt *sql.Stmt deleteBackwardExtremityStmt *sql.Stmt + purgeBackwardExtremitiesStmt *sql.Stmt } func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) { @@ -62,16 +66,12 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities if err != nil { return nil, err } - if s.insertBackwardExtremityStmt, err = db.Prepare(insertBackwardExtremitySQL); err != nil { - return nil, err - } - if s.selectBackwardExtremitiesForRoomStmt, err = db.Prepare(selectBackwardExtremitiesForRoomSQL); err != nil { - return nil, err - } - if s.deleteBackwardExtremityStmt, err = db.Prepare(deleteBackwardExtremitySQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertBackwardExtremityStmt, insertBackwardExtremitySQL}, + {&s.selectBackwardExtremitiesForRoomStmt, selectBackwardExtremitiesForRoomSQL}, + {&s.deleteBackwardExtremityStmt, deleteBackwardExtremitySQL}, + {&s.purgeBackwardExtremitiesStmt, purgeBackwardExtremitiesSQL}, + }.Prepare(db) } func (s *backwardExtremitiesStatements) InsertsBackwardExtremity( @@ -109,3 +109,10 @@ func (s *backwardExtremitiesStatements) DeleteBackwardExtremity( _, err = sqlutil.TxStmt(txn, s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID) return err } + +func (s *backwardExtremitiesStatements) PurgeBackwardExtremities( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeBackwardExtremitiesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/invites_table.go b/syncapi/storage/sqlite3/invites_table.go index e2dbcd5c8..19450099a 100644 --- a/syncapi/storage/sqlite3/invites_table.go +++ b/syncapi/storage/sqlite3/invites_table.go @@ -57,6 +57,9 @@ const selectInviteEventsInRangeSQL = "" + const selectMaxInviteIDSQL = "" + "SELECT MAX(id) FROM syncapi_invite_events" +const purgeInvitesSQL = "" + + "DELETE FROM syncapi_invite_events WHERE room_id = $1" + type inviteEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -64,6 +67,7 @@ type inviteEventsStatements struct { selectInviteEventsInRangeStmt *sql.Stmt deleteInviteEventStmt *sql.Stmt selectMaxInviteIDStmt *sql.Stmt + purgeInvitesStmt *sql.Stmt } func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) { @@ -75,19 +79,13 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Inv if err != nil { return nil, err } - if s.insertInviteEventStmt, err = db.Prepare(insertInviteEventSQL); err != nil { - return nil, err - } - if s.selectInviteEventsInRangeStmt, err = db.Prepare(selectInviteEventsInRangeSQL); err != nil { - return nil, err - } - if s.deleteInviteEventStmt, err = db.Prepare(deleteInviteEventSQL); err != nil { - return nil, err - } - if s.selectMaxInviteIDStmt, err = db.Prepare(selectMaxInviteIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertInviteEventStmt, insertInviteEventSQL}, + {&s.selectInviteEventsInRangeStmt, selectInviteEventsInRangeSQL}, + {&s.deleteInviteEventStmt, deleteInviteEventSQL}, + {&s.selectMaxInviteIDStmt, selectMaxInviteIDSQL}, + {&s.purgeInvitesStmt, purgeInvitesSQL}, + }.Prepare(db) } func (s *inviteEventsStatements) InsertInviteEvent( @@ -192,3 +190,10 @@ func (s *inviteEventsStatements) SelectMaxInviteID( } return } + +func (s *inviteEventsStatements) PurgeInvites( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeInvitesStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index 905a1e1a8..2cc46a10a 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -72,6 +72,9 @@ SELECT event_id FROM AND ($4 IS NULL OR t.membership <> $4) ` +const purgeMembershipsSQL = "" + + "DELETE FROM syncapi_memberships WHERE room_id = $1" + type membershipsStatements struct { db *sql.DB upsertMembershipStmt *sql.Stmt @@ -79,6 +82,7 @@ type membershipsStatements struct { //selectHeroesStmt *sql.Stmt - prepared at runtime due to variadic selectMembershipForUserStmt *sql.Stmt selectMembersStmt *sql.Stmt + purgeMembershipsStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -94,6 +98,7 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { {&s.selectMembershipCountStmt, selectMembershipCountSQL}, {&s.selectMembershipForUserStmt, selectMembershipBeforeSQL}, {&s.selectMembersStmt, selectMembersSQL}, + {&s.purgeMembershipsStmt, purgeMembershipsSQL}, }.Prepare(db) } @@ -142,6 +147,13 @@ func (s *membershipsStatements) SelectMembershipForUser( return membership, topologyPos, nil } +func (s *membershipsStatements) PurgeMemberships( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeMembershipsStmt).ExecContext(ctx, roomID) + return err +} + func (s *membershipsStatements) SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, diff --git a/syncapi/storage/sqlite3/notification_data_table.go b/syncapi/storage/sqlite3/notification_data_table.go index 6242898e1..af2b2c074 100644 --- a/syncapi/storage/sqlite3/notification_data_table.go +++ b/syncapi/storage/sqlite3/notification_data_table.go @@ -38,6 +38,7 @@ func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (t return r, sqlutil.StatementList{ {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.selectMaxID, selectMaxNotificationIDSQL}, + {&r.purgeNotificationData, purgeNotificationDataSQL}, // {&r.selectUserUnreadCountsForRooms, selectUserUnreadNotificationsForRooms}, // used at runtime }.Prepare(db) } @@ -47,6 +48,7 @@ type notificationDataStatements struct { streamIDStatements *StreamIDStatements upsertRoomUnreadCounts *sql.Stmt selectMaxID *sql.Stmt + purgeNotificationData *sql.Stmt //selectUserUnreadCountsForRooms *sql.Stmt } @@ -73,6 +75,9 @@ const selectUserUnreadNotificationsForRooms = `SELECT room_id, notification_coun const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` +const purgeNotificationDataSQL = "" + + "DELETE FROM syncapi_notification_data WHERE room_id = $1" + func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { pos, err = r.streamIDStatements.nextNotificationID(ctx, nil) if err != nil { @@ -124,3 +129,10 @@ func (r *notificationDataStatements) SelectMaxID(ctx context.Context, txn *sql.T err := sqlutil.TxStmt(txn, r.selectMaxID).QueryRowContext(ctx).Scan(&id) return id, err } + +func (s *notificationDataStatements) PurgeNotificationData( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeNotificationData).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 1aa4bfff7..db708c083 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -120,6 +120,9 @@ const selectContextAfterEventSQL = "" + const selectSearchSQL = "SELECT id, event_id, headered_event_json FROM syncapi_output_room_events WHERE type IN ($1) AND id > $2 LIMIT $3 ORDER BY id ASC" +const purgeEventsSQL = "" + + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -130,6 +133,7 @@ type outputRoomEventsStatements struct { selectContextEventStmt *sql.Stmt selectContextBeforeEventStmt *sql.Stmt selectContextAfterEventStmt *sql.Stmt + purgeEventsStmt *sql.Stmt //selectSearchStmt *sql.Stmt - prepared at runtime } @@ -163,6 +167,7 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even {&s.selectContextEventStmt, selectContextEventSQL}, {&s.selectContextBeforeEventStmt, selectContextBeforeEventSQL}, {&s.selectContextAfterEventStmt, selectContextAfterEventSQL}, + {&s.purgeEventsStmt, purgeEventsSQL}, //{&s.selectSearchStmt, selectSearchSQL}, - prepared at runtime }.Prepare(db) } @@ -666,6 +671,13 @@ func unmarshalStateIDs(addIDsJSON, delIDsJSON string) (addIDs []string, delIDs [ return } +func (s *outputRoomEventsStatements) PurgeEvents( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsStmt).ExecContext(ctx, roomID) + return err +} + func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, limit, afterID int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) { params := make([]interface{}, len(types)) for i := range types { diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 879456441..dc698de2d 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -18,10 +18,11 @@ import ( "context" "database/sql" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" - "github.com/matrix-org/gomatrixserverlib" ) const outputRoomEventsTopologySchema = ` @@ -67,6 +68,9 @@ const selectStreamToTopologicalPositionAscSQL = "" + const selectStreamToTopologicalPositionDescSQL = "" + "SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position <= $2 ORDER BY topological_position DESC LIMIT 1;" +const purgeEventsTopologySQL = "" + + "DELETE FROM syncapi_output_room_events_topology WHERE room_id = $1" + type outputRoomEventsTopologyStatements struct { db *sql.DB insertEventInTopologyStmt *sql.Stmt @@ -75,6 +79,7 @@ type outputRoomEventsTopologyStatements struct { selectPositionInTopologyStmt *sql.Stmt selectStreamToTopologicalPositionAscStmt *sql.Stmt selectStreamToTopologicalPositionDescStmt *sql.Stmt + purgeEventsTopologyStmt *sql.Stmt } func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { @@ -85,25 +90,15 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) { if err != nil { return nil, err } - if s.insertEventInTopologyStmt, err = db.Prepare(insertEventInTopologySQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeASCStmt, err = db.Prepare(selectEventIDsInRangeASCSQL); err != nil { - return nil, err - } - if s.selectEventIDsInRangeDESCStmt, err = db.Prepare(selectEventIDsInRangeDESCSQL); err != nil { - return nil, err - } - if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil { - return nil, err - } - if s.selectStreamToTopologicalPositionDescStmt, err = db.Prepare(selectStreamToTopologicalPositionDescSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertEventInTopologyStmt, insertEventInTopologySQL}, + {&s.selectEventIDsInRangeASCStmt, selectEventIDsInRangeASCSQL}, + {&s.selectEventIDsInRangeDESCStmt, selectEventIDsInRangeDESCSQL}, + {&s.selectPositionInTopologyStmt, selectPositionInTopologySQL}, + {&s.selectStreamToTopologicalPositionAscStmt, selectStreamToTopologicalPositionAscSQL}, + {&s.selectStreamToTopologicalPositionDescStmt, selectStreamToTopologicalPositionDescSQL}, + {&s.purgeEventsTopologyStmt, purgeEventsTopologySQL}, + }.Prepare(db) } // insertEventInTopology inserts the given event in the room's topology, based @@ -174,3 +169,10 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition( } return } + +func (s *outputRoomEventsTopologyStatements) PurgeEventsTopology( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeEventsTopologyStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 4ef51b103..5d5200abc 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -64,6 +64,9 @@ const selectPeekingDevicesSQL = "" + const selectMaxPeekIDSQL = "" + "SELECT MAX(id) FROM syncapi_peeks" +const purgePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1" + type peekStatements struct { db *sql.DB streamIDStatements *StreamIDStatements @@ -73,6 +76,7 @@ type peekStatements struct { selectPeeksInRangeStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt selectMaxPeekIDStmt *sql.Stmt + purgePeeksStmt *sql.Stmt } func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) { @@ -84,25 +88,15 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks db: db, streamIDStatements: streamID, } - if s.insertPeekStmt, err = db.Prepare(insertPeekSQL); err != nil { - return nil, err - } - if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { - return nil, err - } - if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { - return nil, err - } - if s.selectPeeksInRangeStmt, err = db.Prepare(selectPeeksInRangeSQL); err != nil { - return nil, err - } - if s.selectPeekingDevicesStmt, err = db.Prepare(selectPeekingDevicesSQL); err != nil { - return nil, err - } - if s.selectMaxPeekIDStmt, err = db.Prepare(selectMaxPeekIDSQL); err != nil { - return nil, err - } - return s, nil + return s, sqlutil.StatementList{ + {&s.insertPeekStmt, insertPeekSQL}, + {&s.deletePeekStmt, deletePeekSQL}, + {&s.deletePeeksStmt, deletePeeksSQL}, + {&s.selectPeeksInRangeStmt, selectPeeksInRangeSQL}, + {&s.selectPeekingDevicesStmt, selectPeekingDevicesSQL}, + {&s.selectMaxPeekIDStmt, selectMaxPeekIDSQL}, + {&s.purgePeeksStmt, purgePeeksSQL}, + }.Prepare(db) } func (s *peekStatements) InsertPeek( @@ -204,3 +198,10 @@ func (s *peekStatements) SelectMaxPeekID( } return } + +func (s *peekStatements) PurgePeeks( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgePeeksStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/sqlite3/receipt_table.go b/syncapi/storage/sqlite3/receipt_table.go index a4a9b4395..ca3d80fb4 100644 --- a/syncapi/storage/sqlite3/receipt_table.go +++ b/syncapi/storage/sqlite3/receipt_table.go @@ -58,12 +58,16 @@ const selectRoomReceipts = "" + const selectMaxReceiptIDSQL = "" + "SELECT MAX(id) FROM syncapi_receipts" +const purgeReceiptsSQL = "" + + "DELETE FROM syncapi_receipts WHERE room_id = $1" + type receiptStatements struct { db *sql.DB streamIDStatements *StreamIDStatements upsertReceipt *sql.Stmt selectRoomReceipts *sql.Stmt selectMaxReceiptID *sql.Stmt + purgeReceiptsStmt *sql.Stmt } func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) { @@ -84,16 +88,12 @@ func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Re db: db, streamIDStatements: streamID, } - if r.upsertReceipt, err = db.Prepare(upsertReceipt); err != nil { - return nil, fmt.Errorf("unable to prepare upsertReceipt statement: %w", err) - } - if r.selectRoomReceipts, err = db.Prepare(selectRoomReceipts); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - if r.selectMaxReceiptID, err = db.Prepare(selectMaxReceiptIDSQL); err != nil { - return nil, fmt.Errorf("unable to prepare selectRoomReceipts statement: %w", err) - } - return r, nil + return r, sqlutil.StatementList{ + {&r.upsertReceipt, upsertReceipt}, + {&r.selectRoomReceipts, selectRoomReceipts}, + {&r.selectMaxReceiptID, selectMaxReceiptIDSQL}, + {&r.purgeReceiptsStmt, purgeReceiptsSQL}, + }.Prepare(db) } // UpsertReceipt creates new user receipts @@ -153,3 +153,10 @@ func (s *receiptStatements) SelectMaxReceiptID( } return } + +func (s *receiptStatements) PurgeReceipts( + ctx context.Context, txn *sql.Tx, roomID string, +) error { + _, err := sqlutil.TxStmt(txn, s.purgeReceiptsStmt).ExecContext(ctx, roomID) + return err +} diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 8366a67dc..145e197cc 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -39,6 +39,7 @@ type Invites interface { // for the room. SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (invites map[string]*gomatrixserverlib.HeaderedEvent, retired map[string]*gomatrixserverlib.HeaderedEvent, maxID types.StreamPosition, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeInvites(ctx context.Context, txn *sql.Tx, roomID string) error } type Peeks interface { @@ -48,6 +49,7 @@ type Peeks interface { SelectPeeksInRange(ctxt context.Context, txn *sql.Tx, userID, deviceID string, r types.Range) (peeks []types.Peek, err error) SelectPeekingDevices(ctxt context.Context, txn *sql.Tx) (peekingDevices map[string][]types.PeekingDevice, err error) SelectMaxPeekID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgePeeks(ctx context.Context, txn *sql.Tx, roomID string) error } type Events interface { @@ -75,6 +77,8 @@ type Events interface { SelectContextEvent(ctx context.Context, txn *sql.Tx, roomID, eventID string) (int, gomatrixserverlib.HeaderedEvent, error) SelectContextBeforeEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) ([]*gomatrixserverlib.HeaderedEvent, error) SelectContextAfterEvent(ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter) (int, []*gomatrixserverlib.HeaderedEvent, error) + + PurgeEvents(ctx context.Context, txn *sql.Tx, roomID string) error ReIndex(ctx context.Context, txn *sql.Tx, limit, offset int64, types []string) (map[int64]gomatrixserverlib.HeaderedEvent, error) } @@ -93,6 +97,7 @@ type Topology interface { SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error) + PurgeEventsTopology(ctx context.Context, txn *sql.Tx, roomID string) error } type CurrentRoomState interface { @@ -146,6 +151,7 @@ type BackwardsExtremities interface { SelectBackwardExtremitiesForRoom(ctx context.Context, txn *sql.Tx, roomID string) (bwExtrems map[string][]string, err error) // DeleteBackwardExtremity removes a backwards extremity for a room, if one existed. DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error) + PurgeBackwardExtremities(ctx context.Context, txn *sql.Tx, roomID string) error } // SendToDevice tracks send-to-device messages which are sent to individual @@ -181,12 +187,14 @@ type Receipts interface { UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) + PurgeReceipts(ctx context.Context, txn *sql.Tx, roomID string) error } type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembershipCount(ctx context.Context, txn *sql.Tx, roomID, membership string, pos types.StreamPosition) (count int, err error) SelectMembershipForUser(ctx context.Context, txn *sql.Tx, roomID, userID string, pos int64) (membership string, topologicalPos int, err error) + PurgeMemberships(ctx context.Context, txn *sql.Tx, roomID string) error SelectMemberships( ctx context.Context, txn *sql.Tx, roomID string, pos types.TopologyToken, @@ -198,6 +206,7 @@ type NotificationData interface { UpsertRoomUnreadCounts(ctx context.Context, txn *sql.Tx, userID, roomID string, notificationCount, highlightCount int) (types.StreamPosition, error) SelectUserUnreadCountsForRooms(ctx context.Context, txn *sql.Tx, userID string, roomIDs []string) (map[string]*eventutil.NotificationData, error) SelectMaxID(ctx context.Context, txn *sql.Tx) (int64, error) + PurgeNotificationData(ctx context.Context, txn *sql.Tx, roomID string) error } type Ignores interface { From ce2bfc3f2e507a012044906af7f25c9dc52873d7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 20 Jan 2023 12:45:56 +0100 Subject: [PATCH 07/14] Make tests more reliable (#2948) When using `testrig.CreateBase` and then using that base for other `NewInternalAPI` calls, we never actually shutdown the components. `testrig.CreateBase` returns a `close` function, which only removes the database, so still running components have issues connecting to the database, since we ripped it out underneath it - which can result in "Disk I/O" or "pq deadlock detected" issues. --- federationapi/federationapi.go | 5 +---- federationapi/federationapi_test.go | 6 +++--- federationapi/routing/routing.go | 17 ++++++++++++----- setup/base/base.go | 2 ++ test/testrig/base.go | 12 ++++++++++-- 5 files changed, 28 insertions(+), 14 deletions(-) diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index 87eb751f5..ce0ce98e9 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -85,10 +85,7 @@ func AddPublicRoutes( } routing.Setup( - base.PublicFederationAPIMux, - base.PublicKeyAPIMux, - base.PublicWellKnownAPIMux, - cfg, + base, rsAPI, f, keyRing, federation, userAPI, keyAPI, mscCfg, servers, producer, diff --git a/federationapi/federationapi_test.go b/federationapi/federationapi_test.go index 68a06a033..7009230cc 100644 --- a/federationapi/federationapi_test.go +++ b/federationapi/federationapi_test.go @@ -273,12 +273,12 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) { cfg.Global.ServerName = gomatrixserverlib.ServerName("localhost") cfg.Global.PrivateKey = privKey cfg.Global.JetStream.InMemory = true - base := base.NewBaseDendrite(cfg, "Monolith") + b := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) keyRing := &test.NopJSONVerifier{} // TODO: This is pretty fragile, as if anything calls anything on these nils this test will break. // Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing. - federationapi.AddPublicRoutes(base, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) - baseURL, cancel := test.ListenAndServe(t, base.PublicFederationAPIMux, true) + federationapi.AddPublicRoutes(b, nil, nil, keyRing, nil, &internal.FederationInternalAPI{}, nil, nil) + baseURL, cancel := test.ListenAndServe(t, b.PublicFederationAPIMux, true) defer cancel() serverName := gomatrixserverlib.ServerName(strings.TrimPrefix(baseURL, "https://")) diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 0a3ab7a88..04eb3d067 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -32,6 +32,7 @@ import ( keyserverAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" userapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/gomatrixserverlib" @@ -49,8 +50,7 @@ import ( // applied: // nolint: gocyclo func Setup( - fedMux, keyMux, wkMux *mux.Router, - cfg *config.FederationAPI, + base *base.BaseDendrite, rsAPI roomserverAPI.FederationRoomserverAPI, fsAPI *fedInternal.FederationInternalAPI, keys gomatrixserverlib.JSONVerifier, @@ -61,9 +61,16 @@ func Setup( servers federationAPI.ServersInRoomProvider, producer *producers.SyncAPIProducer, ) { - prometheus.MustRegister( - pduCountTotal, eduCountTotal, - ) + fedMux := base.PublicFederationAPIMux + keyMux := base.PublicKeyAPIMux + wkMux := base.PublicWellKnownAPIMux + cfg := &base.Cfg.FederationAPI + + if base.EnableMetrics { + prometheus.MustRegister( + pduCountTotal, eduCountTotal, + ) + } v2keysmux := keyMux.PathPrefix("/v2").Subrouter() v1fedmux := fedMux.PathPrefix("/v1").Subrouter() diff --git a/setup/base/base.go b/setup/base/base.go index d3adbf53f..ff38209fb 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -264,6 +264,8 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string, options ...Base // Close implements io.Closer func (b *BaseDendrite) Close() error { + b.ProcessContext.ShutdownDendrite() + b.ProcessContext.WaitForShutdown() return b.tracerCloser.Close() } diff --git a/test/testrig/base.go b/test/testrig/base.go index 7bc26a5c5..52e6ef5f1 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -62,7 +62,12 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f MaxIdleConnections: 2, ConnMaxLifetimeSeconds: 60, } - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), close + base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() + close() + } case test.DBTypeSQLite: cfg.Defaults(config.DefaultOpts{ Generate: true, @@ -72,7 +77,10 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) - return base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics), func() { + base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) + return base, func() { + base.ShutdownDendrite() + base.WaitForShutdown() // cleanup db files. This risks getting out of sync as we add more database strings :( dbFiles := []config.DataSource{ cfg.FederationAPI.Database.ConnectionString, From a2b486091218e761adc0a00ce19ed4b600e489a2 Mon Sep 17 00:00:00 2001 From: Bernhard Feichtinger <43303168+BieHDC@users.noreply.github.com> Date: Fri, 20 Jan 2023 13:13:36 +0100 Subject: [PATCH 08/14] Fix oversight in cmd/generate-config (#2946) The -dir argument was ignored for media_api->base_path. Signed-off-by: `Bernhard Feichtinger <43303168+BieHDC@users.noreply.github.com>` --- cmd/generate-config/main.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 33b18c471..5f75f5e4d 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -54,6 +54,9 @@ func main() { } else { cfg.Global.DatabaseOptions.ConnectionString = uri } + cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) + cfg.Global.JetStream.StoragePath = config.Path(*dirPath) + cfg.SyncAPI.Fulltext.IndexPath = config.Path(filepath.Join(*dirPath, "searchindex")) cfg.Logging = []config.LogrusHook{ { Type: "file", From caf310fd7976ed3fe8abbbf8cb72d380c7efd3c2 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 20 Jan 2023 15:18:06 +0100 Subject: [PATCH 09/14] AWSY missing federation tests (#2943) In an attempt to fix the missing AWSY tests and to get to 100% server-server compliance. --- are-we-synapse-yet.list | 10 +- go.mod | 2 +- go.sum | 4 + roomserver/internal/input/input_events.go | 107 +++++++++++----------- sytest-blacklist | 38 +------- sytest-whitelist | 14 ++- 6 files changed, 83 insertions(+), 92 deletions(-) diff --git a/are-we-synapse-yet.list b/are-we-synapse-yet.list index 81c0f8049..585374738 100644 --- a/are-we-synapse-yet.list +++ b/are-we-synapse-yet.list @@ -936,4 +936,12 @@ fst Room state after a rejected message event is the same as before fst Room state after a rejected state event is the same as before fpb Federation publicRoom Name/topic keys are correct fed New federated private chats get full presence information (SYN-115) (10 subtests) -dvk Rejects invalid device keys \ No newline at end of file +dvk Rejects invalid device keys +rmv User can create and send/receive messages in a room with version 10 +rmv local user can join room with version 10 +rmv User can invite local user to room with version 10 +rmv remote user can join room with version 10 +rmv User can invite remote user to room with version 10 +rmv Remote user can backfill in a room with version 10 +rmv Can reject invites over federation for rooms with version 10 +rmv Can receive redactions from regular users over federation in room version 10 \ No newline at end of file diff --git a/go.mod b/go.mod index 2d7174150..a86dd2cb8 100644 --- a/go.mod +++ b/go.mod @@ -22,7 +22,7 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab + github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/mattn/go-sqlite3 v1.14.15 diff --git a/go.sum b/go.sum index b12f65eab..e5cd67bed 100644 --- a/go.sum +++ b/go.sum @@ -350,6 +350,10 @@ github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45 h1:zGrmcm2M4F4f+zk5JXAkw3oHa/zXhOh5XVGBdl7GdPo= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 h1:P7me2oCmksST9B4+1I1nA+XrnDQwIqAWmy6ntQrXwc8= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 4179fc1ef..67edb3217 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -24,6 +24,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/tidwall/gjson" "github.com/matrix-org/gomatrixserverlib" @@ -40,7 +41,6 @@ import ( "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/types" ) @@ -166,6 +166,7 @@ func (r *Inputer) processRoomEvent( missingPrev = !input.HasState && len(missingPrevIDs) > 0 } + // If we have missing events (auth or prev), we build a list of servers to ask if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), @@ -200,59 +201,8 @@ func (r *Inputer) processRoomEvent( } } - // First of all, check that the auth events of the event are known. - // If they aren't then we will ask the federation API for them. isRejected := false - authEvents := gomatrixserverlib.NewAuthEvents(nil) - knownEvents := map[string]*types.Event{} - if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.fetchAuthEvents: %w", err) - } - - // Check if the event is allowed by its auth events. If it isn't then - // we consider the event to be "rejected" — it will still be persisted. var rejectionErr error - if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { - isRejected = true - logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) - } - - // Accumulate the auth event NIDs. - authEventIDs := event.AuthEventIDs() - authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) - for _, authEventID := range authEventIDs { - if _, ok := knownEvents[authEventID]; !ok { - // Unknown auth events only really matter if the event actually failed - // auth. If it passed auth then we can assume that everything that was - // known was sufficient, even if extraneous auth events were specified - // but weren't found. - if isRejected { - if event.StateKey() != nil { - return fmt.Errorf( - "missing auth event %s for state event %s (type %q, state key %q)", - authEventID, event.EventID(), event.Type(), *event.StateKey(), - ) - } else { - return fmt.Errorf( - "missing auth event %s for timeline event %s (type %q)", - authEventID, event.EventID(), event.Type(), - ) - } - } - } else { - authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) - } - } - - var softfail bool - if input.Kind == api.KindNew { - // Check that the event passes authentication checks based on the - // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) - if err != nil { - logger.WithError(err).Warn("Error authing soft-failed event") - } - } // At this point we are checking whether we know all of the prev events, and // if we know the state before the prev events. This is necessary before we @@ -314,6 +264,59 @@ func (r *Inputer) processRoomEvent( } } + // Check that the auth events of the event are known. + // If they aren't then we will ask the federation API for them. + authEvents := gomatrixserverlib.NewAuthEvents(nil) + knownEvents := map[string]*types.Event{} + if err = r.fetchAuthEvents(ctx, logger, roomInfo, virtualHost, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { + return fmt.Errorf("r.fetchAuthEvents: %w", err) + } + + // Check if the event is allowed by its auth events. If it isn't then + // we consider the event to be "rejected" — it will still be persisted. + if err = gomatrixserverlib.Allowed(event, &authEvents); err != nil { + isRejected = true + rejectionErr = err + logger.WithError(rejectionErr).Warnf("Event %s not allowed by auth events", event.EventID()) + } + + // Accumulate the auth event NIDs. + authEventIDs := event.AuthEventIDs() + authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) + for _, authEventID := range authEventIDs { + if _, ok := knownEvents[authEventID]; !ok { + // Unknown auth events only really matter if the event actually failed + // auth. If it passed auth then we can assume that everything that was + // known was sufficient, even if extraneous auth events were specified + // but weren't found. + if isRejected { + if event.StateKey() != nil { + return fmt.Errorf( + "missing auth event %s for state event %s (type %q, state key %q)", + authEventID, event.EventID(), event.Type(), *event.StateKey(), + ) + } else { + return fmt.Errorf( + "missing auth event %s for timeline event %s (type %q)", + authEventID, event.EventID(), event.Type(), + ) + } + } + } else { + authEventNIDs = append(authEventNIDs, knownEvents[authEventID].EventNID) + } + } + + var softfail bool + if input.Kind == api.KindNew { + // Check that the event passes authentication checks based on the + // current room state. + softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + if err != nil { + logger.WithError(err).Warn("Error authing soft-failed event") + } + } + // Get the state before the event so that we can work out if the event was // allowed at the time, and also to get the history visibility. We won't // bother doing this if the event was already rejected as it just ends up diff --git a/sytest-blacklist b/sytest-blacklist index 99cfbabc8..bb0ee368f 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -1,54 +1,18 @@ -# Relies on a rejected PL event which will never be accepted into the DAG - -# Caused by - -Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state - -# We don't implement lazy membership loading yet - +# Blacklisted due to https://github.com/matrix-org/matrix-spec/issues/942 The only membership state included in a gapped incremental sync is for senders in the timeline -# Blacklisted out of flakiness after #1479 - -Invited user can reject local invite after originator leaves -Invited user can reject invite for empty room -If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes - -# Blacklisted due to flakiness - -Forgotten room messages cannot be paginated - -# Blacklisted due to flakiness after #1774 - -Local device key changes get to remote servers with correct prev_id - -# we don't support groups - -Remove group category -Remove group role - # Flakey - AS-ghosted users can use rooms themselves AS-ghosted users can use rooms via AS Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server -# More flakey - -Guest users can join guest_access rooms - # This will fail in HTTP API mode, so blacklisted for now - If a device list update goes missing, the server resyncs on the next one # Might be a bug in the test because leaves do appear :-( - Leaves are present in non-gapped incremental syncs -# Below test was passing for the wrong reason, failing correctly since #2858 -New federated private chats get full presence information (SYN-115) - # We don't have any state to calculate m.room.guest_access when accepting invites Guest users can accept invites to private rooms over federation \ No newline at end of file diff --git a/sytest-whitelist b/sytest-whitelist index 215889a49..1f6ecc29e 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -766,4 +766,16 @@ remote user has tags copied to the new room Local and remote users' homeservers remove a room from their public directory on upgrade Guest users denied access over federation if guest access prohibited Guest users are kicked from guest_access rooms on revocation of guest_access -Guest users are kicked from guest_access rooms on revocation of guest_access over federation \ No newline at end of file +Guest users are kicked from guest_access rooms on revocation of guest_access over federation +User can create and send/receive messages in a room with version 10 +local user can join room with version 10 +User can invite local user to room with version 10 +remote user can join room with version 10 +User can invite remote user to room with version 10 +Remote user can backfill in a room with version 10 +Can reject invites over federation for rooms with version 10 +Can receive redactions from regular users over federation in room version 10 +New federated private chats get full presence information (SYN-115) +/state returns M_NOT_FOUND for an outlier +/state_ids returns M_NOT_FOUND for an outlier +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state \ No newline at end of file From 25cb65acdbb6702e84a6bcb6245d6d23d90c2359 Mon Sep 17 00:00:00 2001 From: Catalan Lover <48515417+FSG-Cat@users.noreply.github.com> Date: Fri, 20 Jan 2023 15:41:29 +0100 Subject: [PATCH 10/14] Change Default Room version to 10 (#2933) This PR implements [MSC3904](https://github.com/matrix-org/matrix-spec-proposals/pull/3904). This PR is almost identical to #2781 but this PR is also filed well technically 1 day before the MSC passes FCP but well everyone knows this MSC is expected to have passed FCP on monday so im refiling this change today on saturday as i was doing prep work for monday. I assume that this PR wont be counted as clogging the queue since by the next time i expect to be a work day for this project this PR will be implementing an FCP passed disposition merge MSC. Also as for the lack of tests i belive that this simple change does not need to pass new tests due to that these tests are expected to already have been passed by the successful use of Dendrite with Room version 10 already. ### Pull Request Checklist * [X] I have added tests for PR _or_ I have justified why this PR doesn't need tests. * [X] Pull request includes a [sign off](https://github.com/matrix-org/dendrite/blob/main/docs/CONTRIBUTING.md#sign-off) Signed-off-by: Catalan Lover Co-authored-by: Till <2353100+S7evinK@users.noreply.github.com> Co-authored-by: kegsay --- roomserver/version/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/roomserver/version/version.go b/roomserver/version/version.go index 729d00a80..c40d8e0f7 100644 --- a/roomserver/version/version.go +++ b/roomserver/version/version.go @@ -23,7 +23,7 @@ import ( // DefaultRoomVersion contains the room version that will, by // default, be used to create new rooms on this server. func DefaultRoomVersion() gomatrixserverlib.RoomVersion { - return gomatrixserverlib.RoomVersionV9 + return gomatrixserverlib.RoomVersionV10 } // RoomVersions returns a map of all known room versions to this From 430932f0f161dd836c98082ff97b57beedec02e6 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Fri, 20 Jan 2023 16:20:01 +0100 Subject: [PATCH 11/14] Version 0.11.0 (#2949) --- CHANGES.md | 14 ++++++++++++++ helm/dendrite/Chart.yaml | 4 ++-- helm/dendrite/README.md | 7 ++++--- helm/dendrite/values.yaml | 2 +- internal/version.go | 4 ++-- 5 files changed, 23 insertions(+), 8 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index fa8230659..e1f7affb5 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,5 +1,19 @@ # Changelog +## Dendrite 0.11.0 (2023-01-20) + +The last three missing federation API Sytests have been fixed - bringing us to 100% server-server Synapse parity, with client-server parity at 93% 🎉 + +### Features + +* Added `/_dendrite/admin/purgeRoom/{roomID}` to clean up the database +* The default room version was updated to 10 (contributed by [FSG-Cat](https://github.com/FSG-Cat)) + +### Fixes + +* An oversight in the `create-config` binary, which now correctly sets the media path if specified (contributed by [BieHDC](https://github.com/BieHDC)) +* The Helm chart now uses the `$.Chart.AppVersion` as the default image version to pull, with the possibility to override it (contributed by [genofire](https://github.com/genofire)) + ## Dendrite 0.10.9 (2023-01-17) ### Features diff --git a/helm/dendrite/Chart.yaml b/helm/dendrite/Chart.yaml index 6e6641c8d..174fc5496 100644 --- a/helm/dendrite/Chart.yaml +++ b/helm/dendrite/Chart.yaml @@ -1,7 +1,7 @@ apiVersion: v2 name: dendrite -version: "0.10.9" -appVersion: "0.10.9" +version: "0.11.0" +appVersion: "0.11.0" description: Dendrite Matrix Homeserver type: application keywords: diff --git a/helm/dendrite/README.md b/helm/dendrite/README.md index cb850d655..6a1658429 100644 --- a/helm/dendrite/README.md +++ b/helm/dendrite/README.md @@ -1,6 +1,6 @@ # dendrite -![Version: 0.10.8](https://img.shields.io/badge/Version-0.10.8-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.10.8](https://img.shields.io/badge/AppVersion-0.10.8-informational?style=flat-square) +![Version: 0.11.0](https://img.shields.io/badge/Version-0.11.0-informational?style=flat-square) ![Type: application](https://img.shields.io/badge/Type-application-informational?style=flat-square) ![AppVersion: 0.11.0](https://img.shields.io/badge/AppVersion-0.11.0-informational?style=flat-square) Dendrite Matrix Homeserver Status: **NOT PRODUCTION READY** @@ -41,8 +41,9 @@ Create a folder `appservices` and place your configurations in there. The confi | Key | Type | Default | Description | |-----|------|---------|-------------| -| image.name | string | `"ghcr.io/matrix-org/dendrite-monolith:v0.10.8"` | Docker repository/image to use | +| image.repository | string | `"ghcr.io/matrix-org/dendrite-monolith"` | Docker repository/image to use | | image.pullPolicy | string | `"IfNotPresent"` | Kubernetes pullPolicy | +| image.tag | string | `""` | Overrides the image tag whose default is the chart appVersion. | | signing_key.create | bool | `true` | Create a new signing key, if not exists | | signing_key.existingSecret | string | `""` | Use an existing secret | | resources | object | sets some sane default values | Default resource requests/limits. | @@ -144,4 +145,4 @@ Create a folder `appservices` and place your configurations in there. The confi | ingress.annotations | object | `{}` | Extra, custom annotations | | ingress.tls | list | `[]` | | | service.type | string | `"ClusterIP"` | | -| service.port | int | `80` | | +| service.port | int | `8008` | | diff --git a/helm/dendrite/values.yaml b/helm/dendrite/values.yaml index 87027a886..848241ab6 100644 --- a/helm/dendrite/values.yaml +++ b/helm/dendrite/values.yaml @@ -3,7 +3,7 @@ image: repository: "ghcr.io/matrix-org/dendrite-monolith" # -- Kubernetes pullPolicy pullPolicy: IfNotPresent - # Overrides the image tag whose default is the chart appVersion. + # -- Overrides the image tag whose default is the chart appVersion. tag: "" diff --git a/internal/version.go b/internal/version.go index ff31dd784..fbe4a01b0 100644 --- a/internal/version.go +++ b/internal/version.go @@ -16,8 +16,8 @@ var build string const ( VersionMajor = 0 - VersionMinor = 10 - VersionPatch = 9 + VersionMinor = 11 + VersionPatch = 0 VersionTag = "" // example: "rc1" ) From 48fa869fa3578741d1d5775d30f24f6b097ab995 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 23 Jan 2023 13:17:15 +0100 Subject: [PATCH 12/14] Use `t.TempDir` for SQLite databases, so tests don't rip out each others databases (#2950) This should hopefully finally fix issues about `disk I/O error` as seen [here](https://gitlab.alpinelinux.org/alpine/aports/-/jobs/955030/raw) Hopefully this will also fix `SSL accept attempt failed` issues by disabling HTTP keep alives when generating a config for CI. --- cmd/generate-config/main.go | 1 + internal/log.go | 2 ++ internal/log_unix.go | 2 ++ setup/jetstream/helpers.go | 5 +++++ sytest-blacklist | 1 + sytest-whitelist | 7 ++++++- test/db.go | 12 +++++------- test/testrig/base.go | 37 ++++++++++++++----------------------- 8 files changed, 36 insertions(+), 31 deletions(-) diff --git a/cmd/generate-config/main.go b/cmd/generate-config/main.go index 5f75f5e4d..56a145653 100644 --- a/cmd/generate-config/main.go +++ b/cmd/generate-config/main.go @@ -70,6 +70,7 @@ func main() { cfg.AppServiceAPI.DisableTLSValidation = true cfg.ClientAPI.RateLimiting.Enabled = false cfg.FederationAPI.DisableTLSValidation = false + cfg.FederationAPI.DisableHTTPKeepalives = true // don't hit matrix.org when running tests!!! cfg.FederationAPI.KeyPerspectives = config.KeyPerspectives{} cfg.MediaAPI.BasePath = config.Path(filepath.Join(*dirPath, "media")) diff --git a/internal/log.go b/internal/log.go index d7e852c81..da6e20418 100644 --- a/internal/log.go +++ b/internal/log.go @@ -24,6 +24,7 @@ import ( "path/filepath" "runtime" "strings" + "sync" "github.com/matrix-org/util" @@ -37,6 +38,7 @@ import ( // this unfortunately results in us adding the same hook multiple times. // This map ensures we only ever add one level hook. var stdLevelLogAdded = make(map[logrus.Level]bool) +var levelLogAddedMu = &sync.Mutex{} type utcFormatter struct { logrus.Formatter diff --git a/internal/log_unix.go b/internal/log_unix.go index b38e7c2e8..8f34c320d 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -85,6 +85,8 @@ func checkSyslogHookParams(params map[string]interface{}) { } func setupStdLogHook(level logrus.Level) { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() if stdLevelLogAdded[level] { return } diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go index c1ce9583f..533652160 100644 --- a/setup/jetstream/helpers.go +++ b/setup/jetstream/helpers.go @@ -77,6 +77,11 @@ func JetStreamConsumer( // The consumer was deleted so stop. return } else { + // Unfortunately, there's no ErrServerShutdown or similar, so we need to compare the string + if err.Error() == "nats: Server Shutdown" { + logrus.WithContext(ctx).Warn("nats server shutting down") + return + } // Something else went wrong, so we'll panic. sentry.CaptureException(err) logrus.WithContext(ctx).WithField("subject", subj).Fatal(err) diff --git a/sytest-blacklist b/sytest-blacklist index bb0ee368f..49a3cc870 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -7,6 +7,7 @@ AS-ghosted users can use rooms via AS Events in rooms with AS-hosted room aliases are sent to AS server Inviting an AS-hosted user asks the AS server Accesing an AS-hosted room alias asks the AS server +If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes # This will fail in HTTP API mode, so blacklisted for now If a device list update goes missing, the server resyncs on the next one diff --git a/sytest-whitelist b/sytest-whitelist index 1f6ecc29e..c61e0bc3c 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -778,4 +778,9 @@ Can receive redactions from regular users over federation in room version 10 New federated private chats get full presence information (SYN-115) /state returns M_NOT_FOUND for an outlier /state_ids returns M_NOT_FOUND for an outlier -Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state \ No newline at end of file +Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state +Invited user can reject invite for empty room +Invited user can reject local invite after originator leaves +Guest users can join guest_access rooms +Forgotten room messages cannot be paginated +Local device key changes get to remote servers with correct prev_id \ No newline at end of file diff --git a/test/db.go b/test/db.go index 17f637e18..54ded6adb 100644 --- a/test/db.go +++ b/test/db.go @@ -22,6 +22,7 @@ import ( "os" "os/exec" "os/user" + "path/filepath" "testing" "github.com/lib/pq" @@ -103,13 +104,10 @@ func currentUser() string { // TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { - // this will be made in the current working directory which namespaces concurrent package runs correctly - dbname := "dendrite_test.db" + // this will be made in the t.TempDir, which is unique per test + dbname := filepath.Join(t.TempDir(), "dendrite_test.db") return fmt.Sprintf("file:%s", dbname), func() { - err := os.Remove(dbname) - if err != nil { - t.Fatalf("failed to cleanup sqlite db '%s': %s", dbname, err) - } + t.Cleanup(func() {}) // removes the t.TempDir } } @@ -176,7 +174,7 @@ func WithAllDatabases(t *testing.T, testFn func(t *testing.T, db DBType)) { for dbName, dbType := range dbs { dbt := dbType t.Run(dbName, func(tt *testing.T) { - //tt.Parallel() + tt.Parallel() testFn(tt, dbt) }) } diff --git a/test/testrig/base.go b/test/testrig/base.go index 52e6ef5f1..9773da223 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -15,18 +15,14 @@ package testrig import ( - "errors" "fmt" - "io/fs" - "os" - "strings" + "path/filepath" "testing" - "github.com/nats-io/nats.go" - "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" + "github.com/nats-io/nats.go" ) func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, func()) { @@ -77,27 +73,22 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) + + // Use a temp dir provided by go for tests, this will be cleanup by a call to t.CleanUp() + tempDir := t.TempDir() + cfg.FederationAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "federationapi.db")) + cfg.KeyServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "keyserver.db")) + cfg.MSCs.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mscs.db")) + cfg.MediaAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "mediaapi.db")) + cfg.RoomServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "roomserver.db")) + cfg.SyncAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "syncapi.db")) + cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db")) + base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) return base, func() { base.ShutdownDendrite() base.WaitForShutdown() - // cleanup db files. This risks getting out of sync as we add more database strings :( - dbFiles := []config.DataSource{ - cfg.FederationAPI.Database.ConnectionString, - cfg.KeyServer.Database.ConnectionString, - cfg.MSCs.Database.ConnectionString, - cfg.MediaAPI.Database.ConnectionString, - cfg.RoomServer.Database.ConnectionString, - cfg.SyncAPI.Database.ConnectionString, - cfg.UserAPI.AccountDatabase.ConnectionString, - } - for _, fileURI := range dbFiles { - path := strings.TrimPrefix(string(fileURI), "file:") - err := os.Remove(path) - if err != nil && !errors.Is(err, fs.ErrNotExist) { - t.Fatalf("failed to cleanup sqlite db '%s': %s", fileURI, err) - } - } + t.Cleanup(func() {}) // removes t.TempDir, where all database files are created } default: t.Fatalf("unknown db type: %v", dbType) From 5b73592f5a4dddf64184fcbe33f4c1835c656480 Mon Sep 17 00:00:00 2001 From: devonh Date: Mon, 23 Jan 2023 17:55:12 +0000 Subject: [PATCH 13/14] Initial Store & Forward Implementation (#2917) This adds store & forward relays into dendrite for p2p. A few things have changed: - new relay api serves new http endpoints for s&f federation - updated outbound federation queueing which will attempt to forward using s&f if appropriate - database entries to track s&f relays for other nodes --- .gitignore | 1 + build/gobind-pinecone/monolith.go | 201 ++++- build/gobind-pinecone/monolith_test.go | 198 +++++ cmd/dendrite-demo-pinecone/ARCHITECTURE.md | 59 ++ cmd/dendrite-demo-pinecone/README.md | 39 + cmd/dendrite-demo-pinecone/main.go | 125 ++- federationapi/api/api.go | 25 +- federationapi/federationapi.go | 5 +- federationapi/internal/api.go | 9 +- .../internal/federationclient_test.go | 202 +++++ federationapi/internal/perform.go | 73 +- federationapi/internal/perform_test.go | 190 ++++ federationapi/inthttp/client.go | 12 + federationapi/queue/destinationqueue.go | 80 +- federationapi/queue/queue.go | 25 +- federationapi/queue/queue_test.go | 436 ++++------ federationapi/routing/profile_test.go | 94 ++ federationapi/routing/query_test.go | 94 ++ federationapi/routing/routing.go | 14 +- federationapi/routing/send.go | 339 +------- federationapi/routing/send_test.go | 605 ++----------- federationapi/statistics/statistics.go | 186 +++- federationapi/statistics/statistics_test.go | 58 +- federationapi/storage/interface.go | 47 +- .../storage/postgres/assumed_offline_table.go | 107 +++ .../storage/postgres/relay_servers_table.go | 137 +++ federationapi/storage/postgres/storage.go | 10 + .../storage/shared/receipt/receipt.go | 42 + federationapi/storage/shared/storage.go | 184 +++- federationapi/storage/shared/storage_edus.go | 29 +- federationapi/storage/shared/storage_pdus.go | 27 +- .../storage/sqlite3/assumed_offline_table.go | 107 +++ .../storage/sqlite3/relay_servers_table.go | 148 ++++ federationapi/storage/sqlite3/storage.go | 13 +- federationapi/storage/storage_test.go | 103 ++- federationapi/storage/tables/interface.go | 27 + .../tables/relay_servers_table_test.go | 224 +++++ go.mod | 20 +- go.sum | 42 +- internal/log.go | 2 + internal/log_unix.go | 4 +- internal/transactionrequest.go | 356 ++++++++ internal/transactionrequest_test.go | 820 ++++++++++++++++++ mediaapi/routing/routing.go | 26 +- relayapi/api/api.go | 56 ++ relayapi/internal/api.go | 53 ++ relayapi/internal/perform.go | 141 +++ relayapi/internal/perform_test.go | 121 +++ relayapi/relayapi.go | 74 ++ relayapi/relayapi_test.go | 154 ++++ relayapi/routing/relaytxn.go | 74 ++ relayapi/routing/relaytxn_test.go | 220 +++++ relayapi/routing/routing.go | 123 +++ relayapi/routing/sendrelay.go | 77 ++ relayapi/routing/sendrelay_test.go | 209 +++++ relayapi/storage/interface.go | 47 + .../postgres/relay_queue_json_table.go | 113 +++ .../storage/postgres/relay_queue_table.go | 156 ++++ relayapi/storage/postgres/storage.go | 64 ++ relayapi/storage/shared/storage.go | 170 ++++ .../storage/sqlite3/relay_queue_json_table.go | 137 +++ relayapi/storage/sqlite3/relay_queue_table.go | 168 ++++ relayapi/storage/sqlite3/storage.go | 64 ++ relayapi/storage/storage.go | 46 + relayapi/storage/tables/interface.go | 66 ++ .../tables/relay_queue_json_table_test.go | 173 ++++ .../storage/tables/relay_queue_table_test.go | 229 +++++ setup/base/base.go | 6 + setup/config/config.go | 5 +- setup/config/config_federationapi.go | 7 + setup/config/config_relayapi.go | 52 ++ setup/config/config_test.go | 29 +- setup/monolith.go | 7 + test/db.go | 1 - test/memory_federation_db.go | 488 +++++++++++ test/memory_relay_db.go | 140 +++ test/testrig/base.go | 4 +- 77 files changed, 7646 insertions(+), 1373 deletions(-) create mode 100644 build/gobind-pinecone/monolith_test.go create mode 100644 cmd/dendrite-demo-pinecone/ARCHITECTURE.md create mode 100644 federationapi/internal/federationclient_test.go create mode 100644 federationapi/internal/perform_test.go create mode 100644 federationapi/routing/profile_test.go create mode 100644 federationapi/routing/query_test.go create mode 100644 federationapi/storage/postgres/assumed_offline_table.go create mode 100644 federationapi/storage/postgres/relay_servers_table.go create mode 100644 federationapi/storage/shared/receipt/receipt.go create mode 100644 federationapi/storage/sqlite3/assumed_offline_table.go create mode 100644 federationapi/storage/sqlite3/relay_servers_table.go create mode 100644 federationapi/storage/tables/relay_servers_table_test.go create mode 100644 internal/transactionrequest.go create mode 100644 internal/transactionrequest_test.go create mode 100644 relayapi/api/api.go create mode 100644 relayapi/internal/api.go create mode 100644 relayapi/internal/perform.go create mode 100644 relayapi/internal/perform_test.go create mode 100644 relayapi/relayapi.go create mode 100644 relayapi/relayapi_test.go create mode 100644 relayapi/routing/relaytxn.go create mode 100644 relayapi/routing/relaytxn_test.go create mode 100644 relayapi/routing/routing.go create mode 100644 relayapi/routing/sendrelay.go create mode 100644 relayapi/routing/sendrelay_test.go create mode 100644 relayapi/storage/interface.go create mode 100644 relayapi/storage/postgres/relay_queue_json_table.go create mode 100644 relayapi/storage/postgres/relay_queue_table.go create mode 100644 relayapi/storage/postgres/storage.go create mode 100644 relayapi/storage/shared/storage.go create mode 100644 relayapi/storage/sqlite3/relay_queue_json_table.go create mode 100644 relayapi/storage/sqlite3/relay_queue_table.go create mode 100644 relayapi/storage/sqlite3/storage.go create mode 100644 relayapi/storage/storage.go create mode 100644 relayapi/storage/tables/interface.go create mode 100644 relayapi/storage/tables/relay_queue_json_table_test.go create mode 100644 relayapi/storage/tables/relay_queue_table_test.go create mode 100644 setup/config/config_relayapi.go create mode 100644 test/memory_federation_db.go create mode 100644 test/memory_relay_db.go diff --git a/.gitignore b/.gitignore index e4f0112c4..fe5e82797 100644 --- a/.gitignore +++ b/.gitignore @@ -56,6 +56,7 @@ dendrite.yaml # Database files *.db +*.db-journal # Log files *.log* diff --git a/build/gobind-pinecone/monolith.go b/build/gobind-pinecone/monolith.go index b8f8111d2..ff61ea6c8 100644 --- a/build/gobind-pinecone/monolith.go +++ b/build/gobind-pinecone/monolith.go @@ -41,13 +41,16 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/relayapi" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" userapiAPI "github.com/matrix-org/dendrite/userapi/api" @@ -67,24 +70,27 @@ import ( ) const ( - PeerTypeRemote = pineconeRouter.PeerTypeRemote - PeerTypeMulticast = pineconeRouter.PeerTypeMulticast - PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth - PeerTypeBonjour = pineconeRouter.PeerTypeBonjour + PeerTypeRemote = pineconeRouter.PeerTypeRemote + PeerTypeMulticast = pineconeRouter.PeerTypeMulticast + PeerTypeBluetooth = pineconeRouter.PeerTypeBluetooth + PeerTypeBonjour = pineconeRouter.PeerTypeBonjour + relayServerRetryInterval = time.Second * 30 ) type DendriteMonolith struct { - logger logrus.Logger - PineconeRouter *pineconeRouter.Router - PineconeMulticast *pineconeMulticast.Multicast - PineconeQUIC *pineconeSessions.Sessions - PineconeManager *pineconeConnections.ConnectionManager - StorageDirectory string - CacheDirectory string - listener net.Listener - httpServer *http.Server - processContext *process.ProcessContext - userAPI userapiAPI.UserInternalAPI + logger logrus.Logger + baseDendrite *base.BaseDendrite + PineconeRouter *pineconeRouter.Router + PineconeMulticast *pineconeMulticast.Multicast + PineconeQUIC *pineconeSessions.Sessions + PineconeManager *pineconeConnections.ConnectionManager + StorageDirectory string + CacheDirectory string + listener net.Listener + httpServer *http.Server + userAPI userapiAPI.UserInternalAPI + federationAPI api.FederationInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool } func (m *DendriteMonolith) PublicKey() string { @@ -326,6 +332,7 @@ func (m *DendriteMonolith) Start() { cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", filepath.Join(m.StorageDirectory, prefix))) cfg.MediaAPI.BasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) cfg.MediaAPI.AbsBasePath = config.Path(filepath.Join(m.CacheDirectory, "media")) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(m.StorageDirectory, prefix))) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.ClientAPI.RegistrationDisabled = false cfg.ClientAPI.OpenRegistrationWithoutVerificationEnabled = true @@ -335,9 +342,9 @@ func (m *DendriteMonolith) Start() { panic(err) } - base := base.NewBaseDendrite(cfg, "Monolith") + base := base.NewBaseDendrite(cfg, "Monolith", base.DisableMetrics) + m.baseDendrite = base base.ConfigureAdminEndpoints() - defer base.Close() // nolint: errcheck federation := conn.CreateFederationClient(base, m.PineconeQUIC) @@ -346,11 +353,11 @@ func (m *DendriteMonolith) Start() { rsAPI := roomserver.NewInternalAPI(base) - fsAPI := federationapi.NewInternalAPI( + m.federationAPI = federationapi.NewInternalAPI( base, federation, rsAPI, base.Caches, keyRing, true, ) - keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI) + keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, m.federationAPI, rsAPI) m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient()) keyAPI.SetUserAPI(m.userAPI) @@ -358,10 +365,24 @@ func (m *DendriteMonolith) Start() { // The underlying roomserver implementation needs to be able to call the fedsender. // This is different to rsAPI which can be the http client which doesn't need this dependency - rsAPI.SetFederationAPI(fsAPI, keyRing) + rsAPI.SetFederationAPI(m.federationAPI, keyRing) userProvider := users.NewPineconeUserProvider(m.PineconeRouter, m.PineconeQUIC, m.userAPI, federation) - roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, fsAPI, federation) + roomProvider := rooms.NewPineconeRoomProvider(m.PineconeRouter, m.PineconeQUIC, m.federationAPI, federation) + + js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &base.Cfg.FederationAPI, + UserAPI: m.userAPI, + } + relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer) monolith := setup.Monolith{ Config: base.Cfg, @@ -370,10 +391,11 @@ func (m *DendriteMonolith) Start() { KeyRing: keyRing, AppserviceAPI: asAPI, - FederationAPI: fsAPI, + FederationAPI: m.federationAPI, RoomserverAPI: rsAPI, UserAPI: m.userAPI, KeyAPI: keyAPI, + RelayAPI: relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } @@ -411,8 +433,6 @@ func (m *DendriteMonolith) Start() { Handler: h2c.NewHandler(pMux, h2s), } - m.processContext = base.ProcessContext - go func() { m.logger.Info("Listening on ", cfg.Global.ServerName) @@ -420,7 +440,7 @@ func (m *DendriteMonolith) Start() { case net.ErrClosed, http.ErrServerClosed: m.logger.Info("Stopped listening on ", cfg.Global.ServerName) default: - m.logger.Fatal(err) + m.logger.Error("Stopped listening on ", cfg.Global.ServerName) } }() go func() { @@ -430,33 +450,44 @@ func (m *DendriteMonolith) Start() { case net.ErrClosed, http.ErrServerClosed: m.logger.Info("Stopped listening on ", cfg.Global.ServerName) default: - m.logger.Fatal(err) + m.logger.Error("Stopped listening on ", cfg.Global.ServerName) } }() go func(ch <-chan pineconeEvents.Event) { eLog := logrus.WithField("pinecone", "events") + stopRelayServerSync := make(chan bool) + + relayRetriever := RelayServerRetriever{ + Context: context.Background(), + ServerName: gomatrixserverlib.ServerName(m.PineconeRouter.PublicKey().String()), + FederationAPI: m.federationAPI, + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + RelayAPI: monolith.RelayAPI, + running: *atomic.NewBool(false), + } + relayRetriever.InitializeRelayServers(eLog) for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: + if !relayRetriever.running.Load() { + go relayRetriever.SyncRelayServers(stopRelayServerSync) + } case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: + if relayRetriever.running.Load() && m.PineconeRouter.TotalPeerCount() == 0 { + stopRelayServerSync <- true + } case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) + // eLog.Info("Broadcast received from: ", e.PeerID) req := &api.PerformWakeupServersRequest{ ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, } res := &api.PerformWakeupServersResponse{} - if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + if err := m.federationAPI.PerformWakeupServers(base.Context(), req, res); err != nil { + eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } - case pineconeEvents.BandwidthReport: default: } } @@ -464,12 +495,106 @@ func (m *DendriteMonolith) Start() { } func (m *DendriteMonolith) Stop() { - m.processContext.ShutdownDendrite() + m.baseDendrite.Close() + m.baseDendrite.WaitForShutdown() _ = m.listener.Close() m.PineconeMulticast.Stop() _ = m.PineconeQUIC.Close() _ = m.PineconeRouter.Close() - m.processContext.WaitForComponentsToFinish() +} + +type RelayServerRetriever struct { + Context context.Context + ServerName gomatrixserverlib.ServerName + FederationAPI api.FederationInternalAPI + RelayAPI relayServerAPI.RelayInternalAPI + relayServersQueried map[gomatrixserverlib.ServerName]bool + queriedServersMutex sync.Mutex + running atomic.Bool +} + +func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} + response := api.P2PQueryRelayServersResponse{} + err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) + if err != nil { + eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + } + for _, server := range response.RelayServers { + m.relayServersQueried[server] = false + } + + eLog.Infof("Registered relay servers: %v", response.RelayServers) +} + +func (m *RelayServerRetriever) SyncRelayServers(stop <-chan bool) { + defer m.running.Store(false) + + t := time.NewTimer(relayServerRetryInterval) + for { + relayServersToQuery := []gomatrixserverlib.ServerName{} + func() { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + for server, complete := range m.relayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } + } + }() + if len(relayServersToQuery) == 0 { + // All relay servers have been synced. + return + } + m.queryRelayServers(relayServersToQuery) + t.Reset(relayServerRetryInterval) + + select { + case <-stop: + if !t.Stop() { + <-t.C + } + return + case <-t.C: + } + } +} + +func (m *RelayServerRetriever) GetQueriedServerStatus() map[gomatrixserverlib.ServerName]bool { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + + result := map[gomatrixserverlib.ServerName]bool{} + for server, queried := range m.relayServersQueried { + result[server] = queried + } + return result +} + +func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("querying relay servers for any available transactions") + for _, server := range relayServers { + userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) + if err != nil { + return + } + err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) + if err == nil { + func() { + m.queriedServersMutex.Lock() + defer m.queriedServersMutex.Unlock() + m.relayServersQueried[server] = true + }() + // TODO : What happens if your relay receives new messages after this point? + // Should you continue to check with them, or should they try and contact you? + // They could send a "new_async_events" message your way maybe? + // Then you could mark them as needing to be queried again. + // What if you miss this message? + // Maybe you should try querying them again after a certain period of time as a backup? + } else { + logrus.Errorf("Failed querying relay server: %s", err.Error()) + } + } } const MaxFrameSize = types.MaxFrameSize diff --git a/build/gobind-pinecone/monolith_test.go b/build/gobind-pinecone/monolith_test.go new file mode 100644 index 000000000..edcf22bbe --- /dev/null +++ b/build/gobind-pinecone/monolith_test.go @@ -0,0 +1,198 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gobind + +import ( + "context" + "fmt" + "net" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/api" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "gotest.tools/v3/poll" +) + +var TestBuf = []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + +type TestNetConn struct { + net.Conn + shouldFail bool +} + +func (t *TestNetConn) Read(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + n := copy(b, TestBuf) + return n, nil + } +} + +func (t *TestNetConn) Write(b []byte) (int, error) { + if t.shouldFail { + return 0, fmt.Errorf("Failed") + } else { + return len(b), nil + } +} + +func (t *TestNetConn) Close() error { + if t.shouldFail { + return fmt.Errorf("Failed") + } else { + return nil + } +} + +func TestConduitStoresPort(t *testing.T) { + conduit := Conduit{port: 7} + assert.Equal(t, 7, conduit.Port()) +} + +func TestConduitRead(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + b := make([]byte, len(TestBuf)) + bytes, err := conduit.Read(b) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) + assert.Equal(t, TestBuf, b) +} + +func TestConduitReadCopy(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + result, err := conduit.ReadCopy() + assert.NoError(t, err) + assert.Equal(t, TestBuf, result) +} + +func TestConduitWrite(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + bytes, err := conduit.Write(TestBuf) + assert.NoError(t, err) + assert.Equal(t, len(TestBuf), bytes) +} + +func TestConduitClose(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + assert.True(t, conduit.closed.Load()) +} + +func TestConduitReadClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + b := make([]byte, len(TestBuf)) + _, err = conduit.Read(b) + assert.Error(t, err) +} + +func TestConduitReadCopyClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.ReadCopy() + assert.Error(t, err) +} + +func TestConduitWriteClosed(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{}} + err := conduit.Close() + assert.NoError(t, err) + _, err = conduit.Write(TestBuf) + assert.Error(t, err) +} + +func TestConduitReadCopyFails(t *testing.T) { + conduit := Conduit{conn: &TestNetConn{shouldFail: true}} + _, err := conduit.ReadCopy() + assert.Error(t, err) +} + +var testRelayServers = []gomatrixserverlib.ServerName{"relay1", "relay2"} + +type FakeFedAPI struct { + api.FederationInternalAPI +} + +func (f *FakeFedAPI) P2PQueryRelayServers(ctx context.Context, req *api.P2PQueryRelayServersRequest, res *api.P2PQueryRelayServersResponse) error { + res.RelayServers = testRelayServers + return nil +} + +type FakeRelayAPI struct { + relayServerAPI.RelayInternalAPI +} + +func (r *FakeRelayAPI) PerformRelayServerSync(ctx context.Context, userID gomatrixserverlib.UserID, relayServer gomatrixserverlib.ServerName) error { + return nil +} + +func TestRelayRetrieverInitialization(t *testing.T) { + retriever := RelayServerRetriever{ + Context: context.Background(), + ServerName: "server", + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + FederationAPI: &FakeFedAPI{}, + RelayAPI: &FakeRelayAPI{}, + } + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) +} + +func TestRelayRetrieverSync(t *testing.T) { + retriever := RelayServerRetriever{ + Context: context.Background(), + ServerName: "server", + relayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + FederationAPI: &FakeFedAPI{}, + RelayAPI: &FakeRelayAPI{}, + } + + retriever.InitializeRelayServers(logrus.WithField("test", "relay")) + relayServers := retriever.GetQueriedServerStatus() + assert.Equal(t, 2, len(relayServers)) + + stopRelayServerSync := make(chan bool) + go retriever.SyncRelayServers(stopRelayServerSync) + + check := func(log poll.LogT) poll.Result { + relayServers := retriever.GetQueriedServerStatus() + for _, queried := range relayServers { + if !queried { + return poll.Continue("waiting for all servers to be queried") + } + } + + stopRelayServerSync <- true + return poll.Success() + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestMonolithStarts(t *testing.T) { + monolith := DendriteMonolith{} + monolith.Start() + monolith.PublicKey() + monolith.Stop() +} diff --git a/cmd/dendrite-demo-pinecone/ARCHITECTURE.md b/cmd/dendrite-demo-pinecone/ARCHITECTURE.md new file mode 100644 index 000000000..1b0941053 --- /dev/null +++ b/cmd/dendrite-demo-pinecone/ARCHITECTURE.md @@ -0,0 +1,59 @@ +## Relay Server Architecture + +Relay Servers function similar to the way physical mail drop boxes do. +A node can have many associated relay servers. Matrix events can be sent to them instead of to the destination node, and the destination node will eventually retrieve them from the relay server. +Nodes that want to send events to an offline node need to know what relay servers are associated with their intended destination. +Currently this is manually configured in the dendrite database. In the future this information could be configurable in the app and shared automatically via other means. + +Currently events are sent as complete Matrix Transactions. +Transactions include a list of PDUs, (which contain, among other things, lists of authorization events, previous events, and signatures) a list of EDUs, and other information about the transaction. +There is no additional information sent along with the transaction other than what is typically added to them during Matrix federation today. +In the future this will probably need to change in order to handle more complex room state resolution during p2p usage. + +### Relay Server Architecture + +``` + 0 +--------------------+ + +----------------------------------------+ | P2P Node A | + | Relay Server | | +--------+ | + | | | | Client | | + | +--------------------+ | | +--------+ | + | | Relay Server API | | | | | + | | | | | V | + | .--------. 2 | +-------------+ | | 1 | +------------+ | + | |`--------`| <----- | Forwarder | <------------- | Homeserver | | + | | Database | | +-------------+ | | | +------------+ | + | `----------` | | | +--------------------+ + | ^ | | | + | | 4 | +-------------+ | | + | `------------ | Retriever | <------. +--------------------+ + | | +-------------+ | | | | P2P Node B | + | | | | | | +--------+ | + | +--------------------+ | | | | Client | | + | | | | +--------+ | + +----------------------------------------+ | | | | + | | V | + 3 | | +------------+ | + `------ | Homeserver | | + | +------------+ | + +--------------------+ +``` + +- 0: This relay server is currently only acting on behalf of `P2P Node B`. It will only receive, and later forward events that are destined for `P2P Node B`. +- 1: When `P2P Node A` fails sending directly to `P2P Node B` (after a configurable number of attempts), it checks for any known relay servers associated with `P2P Node B` and sends to all of them. + - If sending to any of the relay servers succeeds, that transaction is considered to be successfully sent. +- 2: The relay server `forwarder` stores the transaction json in its database and marks it as destined for `P2P Node B`. +- 3: When `P2P Node B` comes online, it queries all its relay servers for any missed messages. +- 4: The relay server `retriever` will look in its database for any transactions that are destined for `P2P Node B` and returns them one at a time. + +For now, it is important that we don’t design out a hybrid approach of having both sender-side and recipient-side relay servers. +Both approaches make sense and determining which makes for a better experience depends on the use case. + +#### Sender-Side Relay Servers + +If we are running around truly ad-hoc, and I don't know when or where you will be able to pick up messages, then having a sender designated server makes sense to give things the best chance at making their way to the destination. +But in order to achieve this, you are either relying on p2p presence broadcasts for the relay to know when to try forwarding (which means you are in a pretty small network), or the relay just keeps on periodically attempting to forward to the destination which will lead to a lot of extra traffic on the network. + +#### Recipient-Side Relay Servers + +If we have agreed to some static relay server before going off and doing other things, or if we are talking about more global p2p federation, then having a recipient designated relay server can cut down on redundant traffic since it will sit there idle until the recipient pulls events from it. diff --git a/cmd/dendrite-demo-pinecone/README.md b/cmd/dendrite-demo-pinecone/README.md index d6dd95905..5cacd0924 100644 --- a/cmd/dendrite-demo-pinecone/README.md +++ b/cmd/dendrite-demo-pinecone/README.md @@ -24,3 +24,42 @@ Then point your favourite Matrix client to the homeserver URL`http://localhost: If your peering connection is operational then you should see a `Connected TCP:` line in the log output. If not then try a different peer. Once logged in, you should be able to open the room directory or join a room by its ID. + +## Store & Forward Relays + +To test out the store & forward relay functionality, you need a minimum of 3 instances. +One instance will act as the relay, and the other two instances will be the users trying to communicate. +Then you can send messages between the two nodes and watch as the relay is used if the receiving node is offline. + +### Launching the Nodes + +Relay Server: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir relay/ -listen "[::]:49000" +``` + +Node 1: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir node-1/ -peer "[::]:49000" -port 8007 +``` + +Node 2: +``` +go run cmd/dendrite-demo-pinecone/main.go -dir node-2/ -peer "[::]:49000" -port 8009 +``` + +### Database Setup + +At the moment, the database must be manually configured. +For both `Node 1` and `Node 2` add the following entries to their respective `relay_server` table in the federationapi database: +``` +server_name: {node_1_public_key}, relay_server_name: {relay_public_key} +server_name: {node_2_public_key}, relay_server_name: {relay_public_key} +``` + +After editing the database you will need to relaunch the nodes for the changes to be picked up by dendrite. + +### Testing + +Now you can run two separate instances of element and connect them to `Node 1` and `Node 2`. +You can shutdown one of the nodes and continue sending messages. If you wait long enough, the message will be sent to the relay server. (you can see this in the log output of the relay server) diff --git a/cmd/dendrite-demo-pinecone/main.go b/cmd/dendrite-demo-pinecone/main.go index 3f627b41d..a813c37a2 100644 --- a/cmd/dendrite-demo-pinecone/main.go +++ b/cmd/dendrite-demo-pinecone/main.go @@ -38,16 +38,21 @@ import ( "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" "github.com/matrix-org/dendrite/federationapi" "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/keyserver" + "github.com/matrix-org/dendrite/relayapi" + relayServerAPI "github.com/matrix-org/dendrite/relayapi/api" "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/setup" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/gomatrixserverlib" + "go.uber.org/atomic" pineconeConnections "github.com/matrix-org/pinecone/connections" pineconeMulticast "github.com/matrix-org/pinecone/multicast" @@ -66,6 +71,8 @@ var ( instanceDir = flag.String("dir", ".", "the directory to store the databases in (if --config not specified)") ) +const relayServerRetryInterval = time.Second * 30 + // nolint:gocyclo func main() { flag.Parse() @@ -139,6 +146,7 @@ func main() { cfg.RoomServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-roomserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.KeyServer.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-keyserver.db", filepath.Join(*instanceDir, *instanceName))) cfg.FederationAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-federationapi.db", filepath.Join(*instanceDir, *instanceName))) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-relayapi.db", filepath.Join(*instanceDir, *instanceName))) cfg.MSCs.MSCs = []string{"msc2836", "msc2946"} cfg.MSCs.Database.ConnectionString = config.DataSource(fmt.Sprintf("file:%s-mscs.db", filepath.Join(*instanceDir, *instanceName))) cfg.ClientAPI.RegistrationDisabled = false @@ -224,6 +232,20 @@ func main() { userProvider := users.NewPineconeUserProvider(pRouter, pQUIC, userAPI, federation) roomProvider := rooms.NewPineconeRoomProvider(pRouter, pQUIC, fsAPI, federation) + js, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: base.Cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: base.Cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &base.Cfg.FederationAPI, + UserAPI: userAPI, + } + relayAPI := relayapi.NewRelayInternalAPI(base, federation, rsAPI, keyRing, producer) + monolith := setup.Monolith{ Config: base.Cfg, Client: conn.CreateClient(base, pQUIC), @@ -235,6 +257,7 @@ func main() { RoomserverAPI: rsAPI, UserAPI: userAPI, KeyAPI: keyAPI, + RelayAPI: relayAPI, ExtPublicRoomsProvider: roomProvider, ExtUserDirectoryProvider: userProvider, } @@ -305,27 +328,38 @@ func main() { go func(ch <-chan pineconeEvents.Event) { eLog := logrus.WithField("pinecone", "events") + relayServerSyncRunning := atomic.NewBool(false) + stopRelayServerSync := make(chan bool) + + m := RelayServerRetriever{ + Context: context.Background(), + ServerName: gomatrixserverlib.ServerName(pRouter.PublicKey().String()), + FederationAPI: fsAPI, + RelayServersQueried: make(map[gomatrixserverlib.ServerName]bool), + RelayAPI: monolith.RelayAPI, + } + m.InitializeRelayServers(eLog) for event := range ch { switch e := event.(type) { case pineconeEvents.PeerAdded: + if !relayServerSyncRunning.Load() { + go m.syncRelayServers(stopRelayServerSync, *relayServerSyncRunning) + } case pineconeEvents.PeerRemoved: - case pineconeEvents.TreeParentUpdate: - case pineconeEvents.SnakeDescUpdate: - case pineconeEvents.TreeRootAnnUpdate: - case pineconeEvents.SnakeEntryAdded: - case pineconeEvents.SnakeEntryRemoved: + if relayServerSyncRunning.Load() && pRouter.TotalPeerCount() == 0 { + stopRelayServerSync <- true + } case pineconeEvents.BroadcastReceived: - eLog.Info("Broadcast received from: ", e.PeerID) + // eLog.Info("Broadcast received from: ", e.PeerID) req := &api.PerformWakeupServersRequest{ ServerNames: []gomatrixserverlib.ServerName{gomatrixserverlib.ServerName(e.PeerID)}, } res := &api.PerformWakeupServersResponse{} if err := fsAPI.PerformWakeupServers(base.Context(), req, res); err != nil { - logrus.WithError(err).Error("Failed to wakeup destination", e.PeerID) + eLog.WithError(err).Error("Failed to wakeup destination", e.PeerID) } - case pineconeEvents.BandwidthReport: default: } } @@ -333,3 +367,78 @@ func main() { base.WaitForShutdown() } + +type RelayServerRetriever struct { + Context context.Context + ServerName gomatrixserverlib.ServerName + FederationAPI api.FederationInternalAPI + RelayServersQueried map[gomatrixserverlib.ServerName]bool + RelayAPI relayServerAPI.RelayInternalAPI +} + +func (m *RelayServerRetriever) InitializeRelayServers(eLog *logrus.Entry) { + request := api.P2PQueryRelayServersRequest{Server: gomatrixserverlib.ServerName(m.ServerName)} + response := api.P2PQueryRelayServersResponse{} + err := m.FederationAPI.P2PQueryRelayServers(m.Context, &request, &response) + if err != nil { + eLog.Warnf("Failed obtaining list of this node's relay servers: %s", err.Error()) + } + for _, server := range response.RelayServers { + m.RelayServersQueried[server] = false + } + + eLog.Infof("Registered relay servers: %v", response.RelayServers) +} + +func (m *RelayServerRetriever) syncRelayServers(stop <-chan bool, running atomic.Bool) { + defer running.Store(false) + + t := time.NewTimer(relayServerRetryInterval) + for { + relayServersToQuery := []gomatrixserverlib.ServerName{} + for server, complete := range m.RelayServersQueried { + if !complete { + relayServersToQuery = append(relayServersToQuery, server) + } + } + if len(relayServersToQuery) == 0 { + // All relay servers have been synced. + return + } + m.queryRelayServers(relayServersToQuery) + t.Reset(relayServerRetryInterval) + + select { + case <-stop: + // We have been asked to stop syncing, drain the timer and return. + if !t.Stop() { + <-t.C + } + return + case <-t.C: + // The timer has expired. Continue to the next loop iteration. + } + } +} + +func (m *RelayServerRetriever) queryRelayServers(relayServers []gomatrixserverlib.ServerName) { + logrus.Info("querying relay servers for any available transactions") + for _, server := range relayServers { + userID, err := gomatrixserverlib.NewUserID("@user:"+string(m.ServerName), false) + if err != nil { + return + } + err = m.RelayAPI.PerformRelayServerSync(context.Background(), *userID, server) + if err == nil { + m.RelayServersQueried[server] = true + // TODO : What happens if your relay receives new messages after this point? + // Should you continue to check with them, or should they try and contact you? + // They could send a "new_async_events" message your way maybe? + // Then you could mark them as needing to be queried again. + // What if you miss this message? + // Maybe you should try querying them again after a certain period of time as a backup? + } else { + logrus.Errorf("Failed querying relay server: %s", err.Error()) + } + } +} diff --git a/federationapi/api/api.go b/federationapi/api/api.go index 50d0339e4..417b08521 100644 --- a/federationapi/api/api.go +++ b/federationapi/api/api.go @@ -18,6 +18,7 @@ type FederationInternalAPI interface { gomatrixserverlib.KeyDatabase ClientFederationAPI RoomserverFederationAPI + P2PFederationAPI QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error) @@ -30,7 +31,6 @@ type FederationInternalAPI interface { request *PerformBroadcastEDURequest, response *PerformBroadcastEDUResponse, ) error - PerformWakeupServers( ctx context.Context, request *PerformWakeupServersRequest, @@ -71,6 +71,15 @@ type RoomserverFederationAPI interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationAPI interface { + // Relay Server sync api used in the pinecone demos. + P2PQueryRelayServers( + ctx context.Context, + request *P2PQueryRelayServersRequest, + response *P2PQueryRelayServersResponse, + ) error +} + // KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver // implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in // this interface are of type FederationClientError @@ -82,6 +91,7 @@ type KeyserverFederationAPI interface { // an interface for gmsl.FederationClient - contains functions called by federationapi only. type FederationClient interface { + P2PFederationClient gomatrixserverlib.KeyClient SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) @@ -110,6 +120,11 @@ type FederationClient interface { LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) } +type P2PFederationClient interface { + P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) + P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetRelayTransaction, err error) +} + // FederationClientError is returned from FederationClient methods in the event of a problem. type FederationClientError struct { Err string @@ -233,3 +248,11 @@ type InputPublicKeysRequest struct { type InputPublicKeysResponse struct { } + +type P2PQueryRelayServersRequest struct { + Server gomatrixserverlib.ServerName +} + +type P2PQueryRelayServersResponse struct { + RelayServers []gomatrixserverlib.ServerName +} diff --git a/federationapi/federationapi.go b/federationapi/federationapi.go index ce0ce98e9..ed9a545d6 100644 --- a/federationapi/federationapi.go +++ b/federationapi/federationapi.go @@ -113,7 +113,10 @@ func NewInternalAPI( _ = federationDB.RemoveAllServersFromBlacklist() } - stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1) + stats := statistics.NewStatistics( + federationDB, + cfg.FederationMaxRetries+1, + cfg.P2PFederationRetriesUntilAssumedOffline+1) js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream) diff --git a/federationapi/internal/api.go b/federationapi/internal/api.go index 14056eafc..99773a750 100644 --- a/federationapi/internal/api.go +++ b/federationapi/internal/api.go @@ -109,13 +109,14 @@ func NewFederationInternalAPI( func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) { stats := a.statistics.ForServer(s) - until, blacklisted := stats.BackoffInfo() - if blacklisted { + if stats.Blacklisted() { return stats, &api.FederationClientError{ Blacklisted: true, } } + now := time.Now() + until := stats.BackoffInfo() if until != nil && now.Before(*until) { return stats, &api.FederationClientError{ RetryAfter: time.Until(*until), @@ -163,7 +164,7 @@ func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted( RetryAfter: retryAfter, } } - stats.Success() + stats.Success(statistics.SendDirect) return res, nil } @@ -171,7 +172,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted( s gomatrixserverlib.ServerName, request func() (interface{}, error), ) (interface{}, error) { stats := a.statistics.ForServer(s) - if _, blacklisted := stats.BackoffInfo(); blacklisted { + if blacklisted := stats.Blacklisted(); blacklisted { return stats, &api.FederationClientError{ Err: fmt.Sprintf("server %q is blacklisted", s), Blacklisted: true, diff --git a/federationapi/internal/federationclient_test.go b/federationapi/internal/federationclient_test.go new file mode 100644 index 000000000..49137e2d8 --- /dev/null +++ b/federationapi/internal/federationclient_test.go @@ -0,0 +1,202 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 +) + +func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) { + t.queryKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespQueryKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespQueryKeys{}, nil +} + +func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) { + t.claimKeysCalled = true + if t.shouldFail { + return gomatrixserverlib.RespClaimKeys{}, fmt.Errorf("Failure") + } + return gomatrixserverlib.RespClaimKeys{}, nil +} + +func TestFederationClientQueryKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.queryKeysCalled) +} + +func TestFederationClientQueryKeysFailure(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{shouldFail: true} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.True(t, fedClient.queryKeysCalled) +} + +func TestFederationClientClaimKeys(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.Nil(t, err) + assert.True(t, fedClient.claimKeysCalled) +} + +func TestFederationClientClaimKeysBlacklisted(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + testDB.AddServerToBlacklist("server") + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "server", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedapi := FederationInternalAPI{ + db: testDB, + cfg: &cfg, + statistics: &stats, + federation: fedClient, + queues: queues, + } + _, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil) + assert.NotNil(t, err) + assert.False(t, fedClient.claimKeysCalled) +} diff --git a/federationapi/internal/perform.go b/federationapi/internal/perform.go index d86d07e03..552942f28 100644 --- a/federationapi/internal/perform.go +++ b/federationapi/internal/perform.go @@ -14,6 +14,7 @@ import ( "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/consumers" + "github.com/matrix-org/dendrite/federationapi/statistics" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/version" ) @@ -24,6 +25,10 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( request *api.PerformDirectoryLookupRequest, response *api.PerformDirectoryLookupResponse, ) (err error) { + if !r.shouldAttemptDirectFederation(request.ServerName) { + return fmt.Errorf("relay servers have no meaningful response for directory lookup.") + } + dir, err := r.federation.LookupRoomAlias( ctx, r.cfg.Matrix.ServerName, @@ -36,7 +41,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup( } response.RoomID = dir.RoomID response.ServerNames = dir.Servers - r.statistics.ForServer(request.ServerName).Success() + r.statistics.ForServer(request.ServerName).Success(statistics.SendDirect) return nil } @@ -144,6 +149,10 @@ func (r *FederationInternalAPI) performJoinUsingServer( supportedVersions []gomatrixserverlib.RoomVersion, unsigned map[string]interface{}, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for join.") + } + _, origin, err := r.cfg.Matrix.SplitLocalID('@', userID) if err != nil { return err @@ -164,7 +173,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.MakeJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Set all the fields to be what they should be, this should be a no-op // but it's possible that the remote server returned us something "odd" @@ -219,7 +228,7 @@ func (r *FederationInternalAPI) performJoinUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.SendJoin: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // If the remote server returned an event in the "event" key of // the send_join request then we should use that instead. It may @@ -407,6 +416,10 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( serverName gomatrixserverlib.ServerName, supportedVersions []gomatrixserverlib.RoomVersion, ) error { + if !r.shouldAttemptDirectFederation(serverName) { + return fmt.Errorf("relay servers have no meaningful response for outbound peek.") + } + // create a unique ID for this peek. // for now we just use the room ID again. In future, if we ever // support concurrent peeks to the same room with different filters @@ -446,7 +459,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer( r.statistics.ForServer(serverName).Failure() return fmt.Errorf("r.federation.Peek: %w", err) } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) // Work out if we support the room version that has been supplied in // the peek response. @@ -516,6 +529,10 @@ func (r *FederationInternalAPI) PerformLeave( // Try each server that we were provided until we land on one that // successfully completes the make-leave send-leave dance. for _, serverName := range request.ServerNames { + if !r.shouldAttemptDirectFederation(serverName) { + continue + } + // Try to perform a make_leave using the information supplied in the // request. respMakeLeave, err := r.federation.MakeLeave( @@ -585,7 +602,7 @@ func (r *FederationInternalAPI) PerformLeave( continue } - r.statistics.ForServer(serverName).Success() + r.statistics.ForServer(serverName).Success(statistics.SendDirect) return nil } @@ -616,6 +633,12 @@ func (r *FederationInternalAPI) PerformInvite( return fmt.Errorf("gomatrixserverlib.SplitID: %w", err) } + // TODO (devon): This should be allowed via a relay. Currently only transactions + // can be sent to relays. Would need to extend relays to handle invites. + if !r.shouldAttemptDirectFederation(destination) { + return fmt.Errorf("relay servers have no meaningful response for invite.") + } + logrus.WithFields(logrus.Fields{ "event_id": request.Event.EventID(), "user_id": *request.Event.StateKey(), @@ -682,12 +705,8 @@ func (r *FederationInternalAPI) PerformWakeupServers( func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) { for _, srv := range destinations { - // Check the statistics cache for the blacklist status to prevent hitting - // the database unnecessarily. - if r.queues.IsServerBlacklisted(srv) { - _ = r.db.RemoveServerFromBlacklist(srv) - } - r.queues.RetryServer(srv) + wasBlacklisted := r.statistics.ForServer(srv).MarkServerAlive() + r.queues.RetryServer(srv, wasBlacklisted) } } @@ -719,7 +738,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error { return fmt.Errorf("auth chain response is missing m.room.create event") } -func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion { +func setDefaultRoomVersionFromJoinEvent( + joinEvent gomatrixserverlib.EventBuilder, +) gomatrixserverlib.RoomVersion { // if auth events are not event references we know it must be v3+ // we have to do these shenanigans to satisfy sytest, specifically for: // "Outbound federation rejects m.room.create events with an unknown room version" @@ -802,3 +823,31 @@ func federatedAuthProvider( return returning, nil } } + +// P2PQueryRelayServers implements api.FederationInternalAPI +func (r *FederationInternalAPI) P2PQueryRelayServers( + ctx context.Context, + request *api.P2PQueryRelayServersRequest, + response *api.P2PQueryRelayServersResponse, +) error { + logrus.Infof("Getting relay servers for: %s", request.Server) + relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server) + if err != nil { + return err + } + + response.RelayServers = relayServers + return nil +} + +func (r *FederationInternalAPI) shouldAttemptDirectFederation( + destination gomatrixserverlib.ServerName, +) bool { + var shouldRelay bool + stats := r.statistics.ForServer(destination) + if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 { + shouldRelay = true + } + + return !shouldRelay +} diff --git a/federationapi/internal/perform_test.go b/federationapi/internal/perform_test.go new file mode 100644 index 000000000..e8e0d00a3 --- /dev/null +++ b/federationapi/internal/perform_test.go @@ -0,0 +1,190 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "testing" + + "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/queue" + "github.com/matrix-org/dendrite/federationapi/statistics" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + api.FederationClient + queryKeysCalled bool + claimKeysCalled bool + shouldFail bool +} + +func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return gomatrixserverlib.RespDirectory{}, nil +} + +func TestPerformWakeupServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.AddServerToBlacklist(server) + testDB.SetServerAssumedOffline(context.Background(), server) + blacklisted, err := testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.True(t, blacklisted) + offline, err := testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.True(t, offline) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformWakeupServersRequest{ + ServerNames: []gomatrixserverlib.ServerName{server}, + } + res := api.PerformWakeupServersResponse{} + err = fedAPI.PerformWakeupServers(context.Background(), &req, &res) + assert.NoError(t, err) + + blacklisted, err = testDB.IsServerBlacklisted(server) + assert.NoError(t, err) + assert.False(t, blacklisted) + offline, err = testDB.IsServerAssumedOffline(context.Background(), server) + assert.NoError(t, err) + assert.False(t, offline) +} + +func TestQueryRelayServers(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"} + err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers) + assert.NoError(t, err) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.P2PQueryRelayServersRequest{ + Server: server, + } + res := api.P2PQueryRelayServersResponse{} + err = fedAPI.P2PQueryRelayServers(context.Background(), &req, &res) + assert.NoError(t, err) + + assert.Equal(t, len(relayServers), len(res.RelayServers)) +} + +func TestPerformDirectoryLookup(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: "relay", + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: "server", + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.NoError(t, err) +} + +func TestPerformDirectoryLookupRelaying(t *testing.T) { + testDB := test.NewInMemoryFederationDatabase() + + server := gomatrixserverlib.ServerName("wakeup") + testDB.SetServerAssumedOffline(context.Background(), server) + testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"}) + + cfg := config.FederationAPI{ + Matrix: &config.Global{ + SigningIdentity: gomatrixserverlib.SigningIdentity{ + ServerName: server, + }, + }, + } + fedClient := &testFedClient{} + stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline) + queues := queue.NewOutgoingQueues( + testDB, process.NewProcessContext(), + false, + cfg.Matrix.ServerName, fedClient, nil, &stats, + nil, + ) + fedAPI := NewFederationInternalAPI( + testDB, &cfg, nil, fedClient, &stats, nil, queues, nil, + ) + + req := api.PerformDirectoryLookupRequest{ + RoomAlias: "room", + ServerName: server, + } + res := api.PerformDirectoryLookupResponse{} + err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res) + assert.Error(t, err) +} diff --git a/federationapi/inthttp/client.go b/federationapi/inthttp/client.go index 6eefdc7cd..6130a567d 100644 --- a/federationapi/inthttp/client.go +++ b/federationapi/inthttp/client.go @@ -24,6 +24,7 @@ const ( FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest" FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU" FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers" + FederationAPIQueryRelayServers = "/federationapi/queryRelayServers" FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices" FederationAPIClaimKeysPath = "/federationapi/client/claimKeys" @@ -510,3 +511,14 @@ func (h *httpFederationInternalAPI) QueryPublicKeys( h.httpClient, ctx, request, response, ) } + +func (h *httpFederationInternalAPI) P2PQueryRelayServers( + ctx context.Context, + request *api.P2PQueryRelayServersRequest, + response *api.P2PQueryRelayServersResponse, +) error { + return httputil.CallInternalRPCAPI( + "QueryRelayServers", h.federationAPIURL+FederationAPIQueryRelayServers, + h.httpClient, ctx, request, response, + ) +} diff --git a/federationapi/queue/destinationqueue.go b/federationapi/queue/destinationqueue.go index a4a87fe99..51350916d 100644 --- a/federationapi/queue/destinationqueue.go +++ b/federationapi/queue/destinationqueue.go @@ -29,7 +29,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -70,7 +70,7 @@ type destinationQueue struct { // Send event adds the event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination) return @@ -84,8 +84,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re oq.pendingMutex.Lock() if len(oq.pendingPDUs) < maxPDUsInMemory { oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ - pdu: event, - receipt: receipt, + pdu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re // sendEDU adds the EDU event to the pending queue for the destination. // If the queue is empty then it starts a background goroutine to // start sending events to that destination. -func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) { +func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, dbReceipt *receipt.Receipt) { if event == nil { logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination) return @@ -115,8 +115,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.pendingMutex.Lock() if len(oq.pendingEDUs) < maxEDUsInMemory { oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ - edu: event, - receipt: receipt, + edu: event, + dbReceipt: dbReceipt, }) } else { oq.overflowed.Store(true) @@ -210,10 +210,10 @@ func (oq *destinationQueue) getPendingFromDatabase() { gotPDUs := map[string]struct{}{} gotEDUs := map[string]struct{}{} for _, pdu := range oq.pendingPDUs { - gotPDUs[pdu.receipt.String()] = struct{}{} + gotPDUs[pdu.dbReceipt.String()] = struct{}{} } for _, edu := range oq.pendingEDUs { - gotEDUs[edu.receipt.String()] = struct{}{} + gotEDUs[edu.dbReceipt.String()] = struct{}{} } overflowed := false @@ -371,7 +371,7 @@ func (oq *destinationQueue) backgroundSend() { // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. - terr := oq.nextTransaction(toSendPDUs, toSendEDUs) + terr, sendMethod := oq.nextTransaction(toSendPDUs, toSendEDUs) if terr != nil { // We failed to send the transaction. Mark it as a failure. _, blacklisted := oq.statistics.Failure() @@ -388,18 +388,19 @@ func (oq *destinationQueue) backgroundSend() { return } } else { - oq.handleTransactionSuccess(pduCount, eduCount) + oq.handleTransactionSuccess(pduCount, eduCount, sendMethod) } } } // nextTransaction creates a new transaction from the pending event // queue and sends it. -// Returns an error if the transaction wasn't sent. +// Returns an error if the transaction wasn't sent. And whether the success +// was to a relay server or not. func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) error { +) (err error, sendMethod statistics.SendMethod) { // Create the transaction. t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus) logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) @@ -407,7 +408,37 @@ func (oq *destinationQueue) nextTransaction( // Try to send the transaction to the destination server. ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5) defer cancel() - _, err := oq.client.SendTransaction(ctx, t) + + relayServers := oq.statistics.KnownRelayServers() + if oq.statistics.AssumedOffline() && len(relayServers) > 0 { + sendMethod = statistics.SendViaRelay + relaySuccess := false + logrus.Infof("Sending to relay servers: %v", relayServers) + // TODO : how to pass through actual userID here?!?!?!?! + userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false) + if userErr != nil { + return userErr, sendMethod + } + + // Attempt sending to each known relay server. + for _, relayServer := range relayServers { + _, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer) + if relayErr != nil { + err = relayErr + } else { + // If sending to one of the relay servers succeeds, consider the send successful. + relaySuccess = true + } + } + + // Clear the error if sending to any of the relay servers succeeded. + if relaySuccess { + err = nil + } + } else { + sendMethod = statistics.SendDirect + _, err = oq.client.SendTransaction(ctx, t) + } switch errResponse := err.(type) { case nil: // Clean up the transaction in the database. @@ -427,7 +458,7 @@ func (oq *destinationQueue) nextTransaction( oq.transactionIDMutex.Lock() oq.transactionID = "" oq.transactionIDMutex.Unlock() - return nil + return nil, sendMethod case gomatrix.HTTPError: // Report that we failed to send the transaction and we // will retry again, subject to backoff. @@ -437,13 +468,13 @@ func (oq *destinationQueue) nextTransaction( // to a 400-ish error code := errResponse.Code logrus.Debug("Transaction failed with HTTP", code) - return err + return err, sendMethod default: logrus.WithFields(logrus.Fields{ "destination": oq.destination, logrus.ErrorKey: err, }).Debugf("Failed to send transaction %q", t.TransactionID) - return err + return err, sendMethod } } @@ -453,7 +484,7 @@ func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) createTransaction( pdus []*queuedPDU, edus []*queuedEDU, -) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) { +) (gomatrixserverlib.Transaction, []*receipt.Receipt, []*receipt.Receipt) { // If there's no projected transaction ID then generate one. If // the transaction succeeds then we'll set it back to "" so that // we generate a new one next time. If it fails, we'll preserve @@ -474,8 +505,8 @@ func (oq *destinationQueue) createTransaction( t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.TransactionID = oq.transactionID - var pduReceipts []*shared.Receipt - var eduReceipts []*shared.Receipt + var pduReceipts []*receipt.Receipt + var eduReceipts []*receipt.Receipt // Go through PDUs that we retrieved from the database, if any, // and add them into the transaction. @@ -487,7 +518,7 @@ func (oq *destinationQueue) createTransaction( // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct t.PDUs = append(t.PDUs, pdu.pdu.JSON()) - pduReceipts = append(pduReceipts, pdu.receipt) + pduReceipts = append(pduReceipts, pdu.dbReceipt) } // Do the same for pending EDUS in the queue. @@ -497,7 +528,7 @@ func (oq *destinationQueue) createTransaction( continue } t.EDUs = append(t.EDUs, *edu.edu) - eduReceipts = append(eduReceipts, edu.receipt) + eduReceipts = append(eduReceipts, edu.dbReceipt) } return t, pduReceipts, eduReceipts @@ -530,10 +561,11 @@ func (oq *destinationQueue) blacklistDestination() { // handleTransactionSuccess updates the cached event queues as well as the success and // backoff information for this server. -func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) { +func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int, sendMethod statistics.SendMethod) { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. - oq.statistics.Success() + + oq.statistics.Success(sendMethod) oq.pendingMutex.Lock() defer oq.pendingMutex.Unlock() diff --git a/federationapi/queue/queue.go b/federationapi/queue/queue.go index 75b1b36be..5d6b8d44c 100644 --- a/federationapi/queue/queue.go +++ b/federationapi/queue/queue.go @@ -30,7 +30,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/process" ) @@ -138,13 +138,13 @@ func NewOutgoingQueues( } type queuedPDU struct { - receipt *shared.Receipt - pdu *gomatrixserverlib.HeaderedEvent + dbReceipt *receipt.Receipt + pdu *gomatrixserverlib.HeaderedEvent } type queuedEDU struct { - receipt *shared.Receipt - edu *gomatrixserverlib.EDU + dbReceipt *receipt.Receipt + edu *gomatrixserverlib.EDU } func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue { @@ -374,24 +374,13 @@ func (oqs *OutgoingQueues) SendEDU( return nil } -// IsServerBlacklisted returns whether or not the provided server is currently -// blacklisted. -func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool { - return oqs.statistics.ForServer(srv).Blacklisted() -} - // RetryServer attempts to resend events to the given server if we had given up. -func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) { +func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName, wasBlacklisted bool) { if oqs.disabled { return } - serverStatistics := oqs.statistics.ForServer(srv) - forceWakeup := serverStatistics.Blacklisted() - serverStatistics.RemoveBlacklist() - serverStatistics.ClearBackoff() - if queue := oqs.getQueue(srv); queue != nil { - queue.wakeQueueIfEventsPending(forceWakeup) + queue.wakeQueueIfEventsPending(wasBlacklisted) } } diff --git a/federationapi/queue/queue_test.go b/federationapi/queue/queue_test.go index c317edc21..36e2ccbc2 100644 --- a/federationapi/queue/queue_test.go +++ b/federationapi/queue/queue_test.go @@ -18,7 +18,6 @@ import ( "context" "encoding/json" "fmt" - "sync" "testing" "time" @@ -26,13 +25,11 @@ import ( "gotest.tools/v3/poll" "github.com/matrix-org/gomatrixserverlib" - "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/statistics" "github.com/matrix-org/dendrite/federationapi/storage" - "github.com/matrix-org/dendrite/federationapi/storage/shared" rsapi "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/process" @@ -57,7 +54,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } else { // Fake Database - db := createDatabase() + db := test.NewInMemoryFederationDatabase() b := struct { ProcessContext *process.ProcessContext }{ProcessContext: process.NewProcessContext()} @@ -65,220 +62,6 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase } } -func createDatabase() storage.Database { - return &fakeDatabase{ - pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), - blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), - pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent), - pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU), - associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}), - } -} - -type fakeDatabase struct { - storage.Database - dbMutex sync.Mutex - pendingPDUServers map[gomatrixserverlib.ServerName]struct{} - pendingEDUServers map[gomatrixserverlib.ServerName]struct{} - blacklistedServers map[gomatrixserverlib.ServerName]struct{} - pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent - pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU - associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} - associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{} -} - -var nidMutex sync.Mutex -var nid = int64(0) - -func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal([]byte(js), &event); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingPDUs[&receipt] = &event - return &receipt, nil - } - - var edu gomatrixserverlib.EDU - if err := json.Unmarshal([]byte(js), &edu); err == nil { - nidMutex.Lock() - defer nidMutex.Unlock() - nid++ - receipt := shared.NewReceipt(nid) - d.pendingEDUs[&receipt] = &edu - return &receipt, nil - } - - return nil, errors.New("Failed to determine type of json to store") -} - -func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - pduCount := 0 - pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent) - if receipts, ok := d.associatedPDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingPDUs[receipt]; ok { - pdus[receipt] = event - pduCount++ - if pduCount == limit { - break - } - } - } - } - return pdus, nil -} - -func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - eduCount := 0 - edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU) - if receipts, ok := d.associatedEDUs[serverName]; ok { - for receipt := range receipts { - if event, ok := d.pendingEDUs[receipt]; ok { - edus[receipt] = event - eduCount++ - if eduCount == limit { - break - } - } - } - } - return edus, nil -} - -func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingPDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedPDUs[destination]; !ok { - d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedPDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("PDU doesn't exist") - } -} - -func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if _, ok := d.pendingEDUs[receipt]; ok { - for destination := range destinations { - if _, ok := d.associatedEDUs[destination]; !ok { - d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{}) - } - d.associatedEDUs[destination][receipt] = struct{}{} - } - - return nil - } else { - return errors.New("EDU doesn't exist") - } -} - -func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if pdus, ok := d.associatedPDUs[serverName]; ok { - for _, receipt := range receipts { - delete(pdus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - if edus, ok := d.associatedEDUs[serverName]; ok { - for _, receipt := range receipts { - delete(edus, receipt) - } - } - - return nil -} - -func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingPDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - servers := []gomatrixserverlib.ServerName{} - for server := range d.pendingEDUServers { - servers = append(servers, server) - } - return servers, nil -} - -func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers[serverName] = struct{}{} - return nil -} - -func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - delete(d.blacklistedServers, serverName) - return nil -} - -func (d *fakeDatabase) RemoveAllServersFromBlacklist() error { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) - return nil -} - -func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { - d.dbMutex.Lock() - defer d.dbMutex.Unlock() - - isBlacklisted := false - if _, ok := d.blacklistedServers[serverName]; ok { - isBlacklisted = true - } - - return isBlacklisted, nil -} - type stubFederationRoomServerAPI struct { rsapi.FederationRoomserverAPI } @@ -290,8 +73,10 @@ func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Cont type stubFederationClient struct { api.FederationClient - shouldTxSucceed bool - txCount atomic.Uint32 + shouldTxSucceed bool + shouldTxRelaySucceed bool + txCount atomic.Uint32 + txRelayCount atomic.Uint32 } func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) { @@ -304,6 +89,16 @@ func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixse return gomatrixserverlib.RespSend{}, result } +func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) { + var result error + if !f.shouldTxRelaySucceed { + result = fmt.Errorf("relay transaction failed") + } + + f.txRelayCount.Add(1) + return gomatrixserverlib.EmptyResp{}, result +} + func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent { t.Helper() content := `{"type":"m.room.message"}` @@ -319,15 +114,18 @@ func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU { return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping} } -func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { +func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, shouldTxRelaySucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) { db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase) fc := &stubFederationClient{ - shouldTxSucceed: shouldTxSucceed, - txCount: *atomic.NewUint32(0), + shouldTxSucceed: shouldTxSucceed, + shouldTxRelaySucceed: shouldTxRelaySucceed, + txCount: *atomic.NewUint32(0), + txRelayCount: *atomic.NewUint32(0), } rs := &stubFederationRoomServerAPI{} - stats := statistics.NewStatistics(db, failuresUntilBlacklist) + + stats := statistics.NewStatistics(db, failuresUntilBlacklist, failuresUntilAssumedOffline) signingInfo := []*gomatrixserverlib.SigningIdentity{ { KeyID: "ed21019:auto", @@ -344,7 +142,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -373,7 +171,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -402,7 +200,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -432,7 +230,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -462,7 +260,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -513,7 +311,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -564,7 +362,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -596,7 +394,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -628,7 +426,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -662,7 +460,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(2) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -696,7 +494,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -730,8 +528,8 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -747,7 +545,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(1) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -781,8 +579,8 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) assert.NoError(t, dbErr) @@ -801,7 +599,7 @@ func TestSendPDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -845,7 +643,7 @@ func TestSendEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -889,7 +687,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) { // test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { // db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true) - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -940,7 +738,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) { t.Parallel() failuresUntilBlacklist := uint32(16) destination := gomatrixserverlib.ServerName("remotehost") - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false) defer close() defer func() { pc.ShutdownDendrite() @@ -978,7 +776,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { destination := gomatrixserverlib.ServerName("remotehost") destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}} test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true) + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, dbType, true) // NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up. defer close() defer func() { @@ -1023,8 +821,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) fc.shouldTxSucceed = true - db.RemoveServerFromBlacklist(destination) - queues.RetryServer(destination) + wasBlacklisted := dest.statistics.MarkServerAlive() + queues.RetryServer(destination, wasBlacklisted) checkRetry := func(log poll.LogT) poll.Result { pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200) assert.NoError(t, dbErrPDU) @@ -1038,3 +836,147 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) { poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond)) }) } + +func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(7) + failuresUntilAssumedOffline := uint32(2) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == failuresUntilAssumedOffline { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 1 { + if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val { + return poll.Success() + } + return poll.Continue("waiting for server to be assumed offline") + } + return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) +} + +func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreatePDU(t) + err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} + +func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) { + t.Parallel() + failuresUntilBlacklist := uint32(16) + failuresUntilAssumedOffline := uint32(1) + destination := gomatrixserverlib.ServerName("remotehost") + db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false) + defer close() + defer func() { + pc.ShutdownDendrite() + <-pc.WaitForShutdown() + }() + + relayServers := []gomatrixserverlib.ServerName{"relayserver"} + queues.statistics.ForServer(destination).AddRelayServers(relayServers) + + ev := mustCreateEDU(t) + err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination}) + assert.NoError(t, err) + + check := func(log poll.LogT) poll.Result { + if fc.txCount.Load() == 1 { + if fc.txRelayCount.Load() == 1 { + data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100) + assert.NoError(t, dbErr) + if len(data) == 0 { + return poll.Success() + } + return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data)) + } + return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load()) + } + return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load()) + } + poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond)) + + assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination) + assert.Equal(t, true, assumedOffline) +} diff --git a/federationapi/routing/profile_test.go b/federationapi/routing/profile_test.go new file mode 100644 index 000000000..763656081 --- /dev/null +++ b/federationapi/routing/profile_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + userAPI "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeUserAPI struct { + userAPI.FederationUserAPI +} + +func (u *fakeUserAPI) QueryProfile(ctx context.Context, req *userAPI.QueryProfileRequest, res *userAPI.QueryProfileResponse) error { + return nil +} + +func TestHandleQueryProfile(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin))) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/query_test.go b/federationapi/routing/query_test.go new file mode 100644 index 000000000..21f35bf0c --- /dev/null +++ b/federationapi/routing/query_test.go @@ -0,0 +1,94 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/hex" + "io" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedclient "github.com/matrix-org/dendrite/federationapi/api" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" +) + +type fakeFedClient struct { + fedclient.FederationClient +} + +func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) { + return +} + +func TestHandleQueryDirectory(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedClient := fakeFedClient{} + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true) + userapi := fakeUserAPI{} + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server")) + type queryContent struct{} + content := queryContent{} + err := req.SetContent(content) + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) + } + // vars := map[string]string{"room_alias": "#room:server"} + w := httptest.NewRecorder() + // httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + data, _ := io.ReadAll(res.Body) + println(string(data)) + assert.Equal(t, 200, res.StatusCode) + }) +} diff --git a/federationapi/routing/routing.go b/federationapi/routing/routing.go index 04eb3d067..5eb30c6ec 100644 --- a/federationapi/routing/routing.go +++ b/federationapi/routing/routing.go @@ -41,6 +41,12 @@ import ( "github.com/sirupsen/logrus" ) +const ( + SendRouteName = "Send" + QueryDirectoryRouteName = "QueryDirectory" + QueryProfileRouteName = "QueryProfile" +) + // Setup registers HTTP handlers with the given ServeMux. // The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly // path unescape twice (once from the router, once from MakeFedAPI). We need to have this enabled @@ -68,7 +74,7 @@ func Setup( if base.EnableMetrics { prometheus.MustRegister( - pduCountTotal, eduCountTotal, + internal.PDUCountTotal, internal.EDUCountTotal, ) } @@ -138,7 +144,7 @@ func Setup( cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer, ) }, - )).Methods(http.MethodPut, http.MethodOptions) + )).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName) v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI( "federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -248,7 +254,7 @@ func Setup( httpReq, federation, cfg, rsAPI, fsAPI, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryDirectoryRouteName) v1fedmux.Handle("/query/profile", MakeFedAPI( "federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, @@ -257,7 +263,7 @@ func Setup( httpReq, userAPI, cfg, ) }, - )).Methods(http.MethodGet) + )).Methods(http.MethodGet).Name(QueryProfileRouteName) v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI( "federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup, diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index a146d85bd..67b513c90 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -17,26 +17,20 @@ package routing import ( "context" "encoding/json" - "fmt" "net/http" "sync" "time" - "github.com/getsentry/sentry-go" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" - "github.com/prometheus/client_golang/prometheus" - "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/clientapi/jsonerror" federationAPI "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/federationapi/producers" - "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/config" - syncTypes "github.com/matrix-org/dendrite/syncapi/types" ) const ( @@ -56,26 +50,6 @@ const ( MetricsWorkMissingPrevEvents = "missing_prev_events" ) -var ( - pduCountTotal = prometheus.NewCounterVec( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_pdus", - Help: "Number of incoming PDUs from remote servers with labels for success", - }, - []string{"status"}, // 'success' or 'total' - ) - eduCountTotal = prometheus.NewCounter( - prometheus.CounterOpts{ - Namespace: "dendrite", - Subsystem: "federationapi", - Name: "recv_edus", - Help: "Number of incoming EDUs from remote servers", - }, - ) -) - var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse // Send implements /_matrix/federation/v1/send/{txnID} @@ -123,18 +97,6 @@ func Send( defer close(ch) defer inFlightTxnsPerOrigin.Delete(index) - t := txnReq{ - rsAPI: rsAPI, - keys: keys, - ourServerName: cfg.Matrix.ServerName, - federation: federation, - servers: servers, - keyAPI: keyAPI, - roomsMu: mu, - producer: producer, - inboundPresenceEnabled: cfg.Matrix.Presence.EnableInbound, - } - var txnEvents struct { PDUs []json.RawMessage `json:"pdus"` EDUs []gomatrixserverlib.EDU `json:"edus"` @@ -155,16 +117,23 @@ func Send( } } - // TODO: Really we should have a function to convert FederationRequest to txnReq - t.PDUs = txnEvents.PDUs - t.EDUs = txnEvents.EDUs - t.Origin = request.Origin() - t.TransactionID = txnID - t.Destination = cfg.Matrix.ServerName + t := internal.NewTxnReq( + rsAPI, + keyAPI, + cfg.Matrix.ServerName, + keys, + mu, + producer, + cfg.Matrix.Presence.EnableInbound, + txnEvents.PDUs, + txnEvents.EDUs, + request.Origin(), + txnID, + cfg.Matrix.ServerName) util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs)) - resp, jsonErr := t.processTransaction(httpReq.Context()) + resp, jsonErr := t.ProcessTransaction(httpReq.Context()) if jsonErr != nil { util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed") return *jsonErr @@ -181,283 +150,3 @@ func Send( ch <- res return res } - -type txnReq struct { - gomatrixserverlib.Transaction - rsAPI api.FederationRoomserverAPI - keyAPI keyapi.FederationKeyAPI - ourServerName gomatrixserverlib.ServerName - keys gomatrixserverlib.JSONVerifier - federation txnFederationClient - roomsMu *internal.MutexByRoom - servers federationAPI.ServersInRoomProvider - producer *producers.SyncAPIProducer - inboundPresenceEnabled bool -} - -// A subset of FederationClient functionality that txn requires. Useful for testing. -type txnFederationClient interface { - LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, - ) - LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) - GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) - LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - t.processEDUs(ctx) - }() - - results := make(map[string]gomatrixserverlib.PDUResult) - roomVersions := make(map[string]gomatrixserverlib.RoomVersion) - getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { - if v, ok := roomVersions[roomID]; ok { - return v - } - verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) - return "" - } - roomVersions[roomID] = verRes.RoomVersion - return verRes.RoomVersion - } - - for _, pdu := range t.PDUs { - pduCountTotal.WithLabelValues("total").Inc() - var header struct { - RoomID string `json:"room_id"` - } - if err := json.Unmarshal(pdu, &header); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") - // We don't know the event ID at this point so we can't return the - // failure in the PDU results - continue - } - roomVersion := getRoomVersion(header.RoomID) - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) - if err != nil { - if _, ok := err.(gomatrixserverlib.BadJSONError); ok { - // Room version 6 states that homeservers should strictly enforce canonical JSON - // on PDUs. - // - // This enforces that the entire transaction is rejected if a single bad PDU is - // sent. It is unclear if this is the correct behaviour or not. - // - // See https://github.com/matrix-org/synapse/issues/7543 - return nil, &util.JSONResponse{ - Code: 400, - JSON: jsonerror.BadJSON("PDU contains bad JSON"), - } - } - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) - continue - } - if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { - continue - } - if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: "Forbidden by server ACLs", - } - continue - } - if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { - util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - // pass the event to the roomserver which will do auth checks - // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently - // discarded by the caller of this function - if err = api.SendEvents( - ctx, - t.rsAPI, - api.KindNew, - []*gomatrixserverlib.HeaderedEvent{ - event.Headered(roomVersion), - }, - t.Destination, - t.Origin, - api.DoNotSendToOtherServers, - nil, - true, - ); err != nil { - util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) - results[event.EventID()] = gomatrixserverlib.PDUResult{ - Error: err.Error(), - } - continue - } - - results[event.EventID()] = gomatrixserverlib.PDUResult{} - pduCountTotal.WithLabelValues("success").Inc() - } - - wg.Wait() - return &gomatrixserverlib.RespSend{PDUs: results}, nil -} - -// nolint:gocyclo -func (t *txnReq) processEDUs(ctx context.Context) { - for _, e := range t.EDUs { - eduCountTotal.Inc() - switch e.Type { - case gomatrixserverlib.MTyping: - // https://matrix.org/docs/spec/server_server/latest#typing-notifications - var typingPayload struct { - RoomID string `json:"room_id"` - UserID string `json:"user_id"` - Typing bool `json:"typing"` - } - if err := json.Unmarshal(e.Content, &typingPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { - util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") - } - case gomatrixserverlib.MDirectToDevice: - // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema - var directPayload gomatrixserverlib.ToDeviceMessage - if err := json.Unmarshal(e.Content, &directPayload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") - continue - } - if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - for userID, byUser := range directPayload.Messages { - for deviceID, message := range byUser { - // TODO: check that the user and the device actually exist here - if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": directPayload.Sender, - "user_id": userID, - "device_id": deviceID, - }).Error("Failed to send send-to-device event to JetStream") - } - } - } - case gomatrixserverlib.MDeviceListUpdate: - if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") - } - case gomatrixserverlib.MReceipt: - // https://matrix.org/docs/spec/server_server/r0.1.4#receipts - payload := map[string]types.FederationReceiptMRead{} - - if err := json.Unmarshal(e.Content, &payload); err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") - continue - } - - for roomID, receipt := range payload { - for userID, mread := range receipt.User { - _, domain, err := gomatrixserverlib.SplitID('@', userID) - if err != nil { - util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") - continue - } - if t.Origin != domain { - util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) - continue - } - if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { - util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ - "sender": t.Origin, - "user_id": userID, - "room_id": roomID, - "events": mread.EventIDs, - }).Error("Failed to send receipt event to JetStream") - continue - } - } - } - case types.MSigningKeyUpdate: - if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { - sentry.CaptureException(err) - logrus.WithError(err).Errorf("Failed to process signing key update") - } - case gomatrixserverlib.MPresence: - if t.inboundPresenceEnabled { - if err := t.processPresence(ctx, e); err != nil { - logrus.WithError(err).Errorf("Failed to process presence update") - } - } - default: - util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") - } - } -} - -// processPresence handles m.receipt events -func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { - payload := types.Presence{} - if err := json.Unmarshal(e.Content, &payload); err != nil { - return err - } - for _, content := range payload.Push { - if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { - continue - } else if serverName == t.ourServerName { - continue - } else if serverName != t.Origin { - continue - } - presence, ok := syncTypes.PresenceFromString(content.Presence) - if !ok { - continue - } - if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { - return err - } - } - return nil -} - -// processReceiptEvent sends receipt events to JetStream -func (t *txnReq) processReceiptEvent(ctx context.Context, - userID, roomID, receiptType string, - timestamp gomatrixserverlib.Timestamp, - eventIDs []string, -) error { - if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { - return nil - } else if serverName == t.ourServerName { - return nil - } else if serverName != t.Origin { - return nil - } - // store every event - for _, eventID := range eventIDs { - if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { - return fmt.Errorf("unable to set receipt event: %w", err) - } - } - - return nil -} diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index b8bfe0221..d7feee0e5 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -1,552 +1,87 @@ -package routing +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test import ( - "context" + "encoding/hex" "encoding/json" - "fmt" + "net/http/httptest" "testing" - "time" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/api" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + fedAPI "github.com/matrix-org/dendrite/federationapi" + fedInternal "github.com/matrix-org/dendrite/federationapi/internal" + "github.com/matrix-org/dendrite/federationapi/routing" + "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" + "golang.org/x/crypto/ed25519" ) const ( - testOrigin = gomatrixserverlib.ServerName("kaer.morhen") - testDestination = gomatrixserverlib.ServerName("white.orchard") + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") ) -var ( - testRoomVersion = gomatrixserverlib.RoomVersionV1 - testData = []json.RawMessage{ - []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), - // messages - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), - []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), - } - testEvents = []*gomatrixserverlib.HeaderedEvent{} - testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) -) +type sendContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} -func init() { - for _, j := range testData { - e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) +func TestHandleSend(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath() + base.PublicFederationAPIMux = fedMux + base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin + base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false + fedapi := fedAPI.NewInternalAPI(base, nil, nil, nil, nil, true) + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + r, ok := fedapi.(*fedInternal.FederationInternalAPI) + if !ok { + panic("This is a programming error.") + } + routing.Setup(base, nil, r, keyRing, nil, nil, nil, &base.Cfg.MSCs, nil, nil) + + handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234") + content := sendContent{} + err := req.SetContent(content) if err != nil { - panic("cannot load test data: " + err.Error()) + t.Fatalf("Error: %s", err.Error()) } - h := e.Headered(testRoomVersion) - testEvents = append(testEvents, h) - if e.StateKey() != nil { - testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: e.Type(), - StateKey: *e.StateKey(), - }] = h + req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk) + httpReq, err := req.HTTPRequest() + if err != nil { + t.Fatalf("Error: %s", err.Error()) } - } + vars := map[string]string{"txnID": "1234"} + w := httptest.NewRecorder() + httpReq = mux.SetURLVars(httpReq, vars) + handler(w, httpReq) + + res := w.Result() + assert.Equal(t, 200, res.StatusCode) + }) } - -type testRoomserverAPI struct { - api.RoomserverInternalAPITrace - inputRoomEvents []api.InputRoomEvent - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse -} - -func (t *testRoomserverAPI) InputRoomEvents( - ctx context.Context, - request *api.InputRoomEventsRequest, - response *api.InputRoomEventsResponse, -) error { - t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) - for _, ire := range request.InputRoomEvents { - fmt.Println("InputRoomEvents: ", ire.Event.EventID()) - } - return nil -} - -// Query the latest events and state for a room from the room server. -func (t *testRoomserverAPI) QueryLatestEventsAndState( - ctx context.Context, - request *api.QueryLatestEventsAndStateRequest, - response *api.QueryLatestEventsAndStateResponse, -) error { - r := t.queryLatestEventsAndState(request) - response.RoomExists = r.RoomExists - response.RoomVersion = testRoomVersion - response.LatestEvents = r.LatestEvents - response.StateEvents = r.StateEvents - response.Depth = r.Depth - return nil -} - -// Query the state after a list of events in a room from the room server. -func (t *testRoomserverAPI) QueryStateAfterEvents( - ctx context.Context, - request *api.QueryStateAfterEventsRequest, - response *api.QueryStateAfterEventsResponse, -) error { - response.RoomVersion = testRoomVersion - res := t.queryStateAfterEvents(request) - response.PrevEventsExist = res.PrevEventsExist - response.RoomExists = res.RoomExists - response.StateEvents = res.StateEvents - return nil -} - -// Query a list of events by event ID. -func (t *testRoomserverAPI) QueryEventsByID( - ctx context.Context, - request *api.QueryEventsByIDRequest, - response *api.QueryEventsByIDResponse, -) error { - res := t.queryEventsByID(request) - response.Events = res.Events - return nil -} - -// Query if a server is joined to a room -func (t *testRoomserverAPI) QueryServerJoinedToRoom( - ctx context.Context, - request *api.QueryServerJoinedToRoomRequest, - response *api.QueryServerJoinedToRoomResponse, -) error { - response.RoomExists = true - response.IsInRoom = true - return nil -} - -// Asks for the room version for a given room. -func (t *testRoomserverAPI) QueryRoomVersionForRoom( - ctx context.Context, - request *api.QueryRoomVersionForRoomRequest, - response *api.QueryRoomVersionForRoomResponse, -) error { - response.RoomVersion = testRoomVersion - return nil -} - -func (t *testRoomserverAPI) QueryServerBannedFromRoom( - ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse, -) error { - res.Banned = false - return nil -} - -type txnFedClient struct { - state map[string]gomatrixserverlib.RespState // event_id to response - stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response - getEvent map[string]gomatrixserverlib.Transaction // event_id to response - getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) -} - -func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( - res gomatrixserverlib.RespState, err error, -) { - fmt.Println("testFederationClient.LookupState", eventID) - r, ok := c.state[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) { - fmt.Println("testFederationClient.LookupStateIDs", eventID) - r, ok := c.stateIDs[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) { - fmt.Println("testFederationClient.GetEvent", eventID) - r, ok := c.getEvent[eventID] - if !ok { - err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID) - return - } - res = r - return -} -func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, - roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) { - return c.getMissingEvents(missing) -} - -func mustCreateTransaction(rsAPI api.FederationRoomserverAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq { - t := &txnReq{ - rsAPI: rsAPI, - keys: &test.NopJSONVerifier{}, - federation: fedClient, - roomsMu: internal.NewMutexByRoom(), - } - t.PDUs = pdus - t.Origin = testOrigin - t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) - t.Destination = testDestination - return t -} - -func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) { - res, err := txn.processTransaction(context.Background()) - if err != nil { - t.Errorf("txn.processTransaction returned an error: %v", err) - return - } - if len(res.PDUs) != len(txn.PDUs) { - t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) - return - } -NextPDU: - for eventID, result := range res.PDUs { - if result.Error == "" { - continue - } - for _, eventIDWantError := range pdusWithErrors { - if eventID == eventIDWantError { - break NextPDU - } - } - t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) - } -} - -/* -func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []*gomatrixserverlib.HeaderedEvent) { -NextTuple: - for _, t := range tuples { - for _, o := range omitTuples { - if t == o { - break NextTuple - } - } - h, ok := testStateEvents[t] - if ok { - result = append(result, h) - } - } - return -} -*/ - -func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { - for _, g := range got { - fmt.Println("GOT ", g.Event.EventID()) - } - if len(got) != len(want) { - t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) - return - } - for i := range got { - if got[i].Event.EventID() != want[i].EventID() { - t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) - } - } -} - -// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on -// to the roomserver. It's the most basic test possible. -func TestBasicTransaction(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver -// as it does the auth check. -func TestTransactionFailAuthChecks(t *testing.T) { - rsAPI := &testRoomserverAPI{} - pdus := []json.RawMessage{ - testData[len(testData)-1], // a message event - } - txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus) - mustProcessTransaction(t, txn, []string{}) - // expect message to be sent to the roomserver - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) -} - -// The purpose of this test is to make sure that when an event is received for which we do not know the prev_events, -// we request them from /get_missing_events. It works by setting PrevEventsExist=false in the roomserver query response, -// resulting in a call to /get_missing_events which returns the missing prev event. Both events should be processed in -// topological order and sent to the roomserver. -/* -func TestTransactionFetchMissingPrevEvents(t *testing.T) { - haveEvent := testEvents[len(testEvents)-3] - prevEvent := testEvents[len(testEvents)-2] - inputEvent := testEvents[len(testEvents)-1] - - var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions - rsAPI = &testRoomserverAPI{ - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - res := api.QueryEventsByIDResponse{} - for _, ev := range testEvents { - for _, id := range req.EventIDs { - if ev.EventID() == id { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: true, - StateEvents: testEvents[:5], - } - }, - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - missingPrevEvent := []string{"missing_prev_event"} - if len(req.PrevEventIDs) == 1 { - switch req.PrevEventIDs[0] { - case haveEvent.EventID(): - missingPrevEvent = []string{} - case prevEvent.EventID(): - // we only have this event if we've been send prevEvent - if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() { - missingPrevEvent = []string{} - } - } - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: haveEvent.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - haveEvent.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, nil), - } - }, - } - - cli := &txnFedClient{ - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, haveEvent.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{inputEvent.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, inputEvent.EventID()) - } - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - prevEvent.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - inputEvent.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent}) -} - -// The purpose of this test is to check that when there are missing prev_events and we still haven't been able to fill -// in the hole with /get_missing_events that the state BEFORE the events we want to persist is fetched via /state_ids -// and /event. It works by setting PrevEventsExist=false in the roomserver query response, resulting in -// a call to /get_missing_events which returns 1 out of the 2 events it needs to fill in the gap. Synapse and Dendrite -// both give up after 1x /get_missing_events call, relying on requesting the state AFTER the missing event in order to -// continue. The DAG looks something like: -// FE GME TXN -// A ---> B ---> C ---> D -// TXN=event in the txn, GME=response to /get_missing_events, FE=roomserver's forward extremity. Should result in: -// - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event. -// - state resolution is done to check C is allowed. -// This results in B being sent as an outlier FIRST, then C,D. -func TestTransactionFetchMissingStateByStateIDs(t *testing.T) { - eventA := testEvents[len(testEvents)-5] - // this is also len(testEvents)-4 - eventB := testStateEvents[gomatrixserverlib.StateKeyTuple{ - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }] - eventC := testEvents[len(testEvents)-3] - eventD := testEvents[len(testEvents)-2] - fmt.Println("a:", eventA.EventID()) - fmt.Println("b:", eventB.EventID()) - fmt.Println("c:", eventC.EventID()) - fmt.Println("d:", eventD.EventID()) - var rsAPI *testRoomserverAPI - rsAPI = &testRoomserverAPI{ - queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - { - EventType: gomatrixserverlib.MRoomPowerLevels, - StateKey: "", - }, - } - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - omitTuples = nil // include event B now - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - var stateEvents []*gomatrixserverlib.HeaderedEvent - if prevEventExists { - stateEvents = fromStateTuples(req.StateToFetch, omitTuples) - } - return api.QueryStateAfterEventsResponse{ - PrevEventsExist: prevEventExists, - RoomExists: true, - StateEvents: stateEvents, - } - }, - - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - askingForEvent := req.PrevEventIDs[0] - haveEventB := false - haveEventC := false - for _, ev := range rsAPI.inputRoomEvents { - switch ev.Event.EventID() { - case eventB.EventID(): - haveEventB = true - case eventC.EventID(): - haveEventC = true - } - } - prevEventExists := false - if askingForEvent == eventC.EventID() { - prevEventExists = haveEventC - } else if askingForEvent == eventB.EventID() { - prevEventExists = haveEventB - } - - var missingPrevEvent []string - if !prevEventExists { - missingPrevEvent = []string{"test"} - } - - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: missingPrevEvent, - } - }, - - queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse { - omitTuples := []gomatrixserverlib.StateKeyTuple{ - {EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""}, - } - return api.QueryLatestEventsAndStateResponse{ - RoomExists: true, - Depth: eventA.Depth(), - LatestEvents: []gomatrixserverlib.EventReference{ - eventA.EventReference(), - }, - StateEvents: fromStateTuples(req.StateToFetch, omitTuples), - } - }, - queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse { - var res api.QueryEventsByIDResponse - fmt.Println("queryEventsByID ", req.EventIDs) - for _, wantEventID := range req.EventIDs { - for _, ev := range testStateEvents { - // roomserver is missing the power levels event unless it's been sent to us recently as an outlier - if wantEventID == eventB.EventID() { - fmt.Println("Asked for pl event") - for _, inEv := range rsAPI.inputRoomEvents { - fmt.Println("recv ", inEv.Event.EventID()) - if inEv.Event.EventID() == wantEventID { - res.Events = append(res.Events, inEv.Event) - break - } - } - continue - } - if ev.EventID() == wantEventID { - res.Events = append(res.Events, ev) - } - } - } - return res - }, - } - // /state_ids for event B returns every state event but B (it's the state before) - var authEventIDs []string - var stateEventIDs []string - for _, ev := range testStateEvents { - if ev.EventID() == eventB.EventID() { - continue - } - // state res checks what auth events you give it, and this isn't a valid auth event - if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility { - authEventIDs = append(authEventIDs, ev.EventID()) - } - stateEventIDs = append(stateEventIDs, ev.EventID()) - } - cli := &txnFedClient{ - stateIDs: map[string]gomatrixserverlib.RespStateIDs{ - eventB.EventID(): { - StateEventIDs: stateEventIDs, - AuthEventIDs: authEventIDs, - }, - }, - // /event for event B returns it - getEvent: map[string]gomatrixserverlib.Transaction{ - eventB.EventID(): { - PDUs: []json.RawMessage{ - eventB.JSON(), - }, - }, - }, - // /get_missing_events should be done exactly once - getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) { - if !reflect.DeepEqual(missing.EarliestEvents, []string{eventA.EventID()}) { - t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, eventA.EventID()) - } - if !reflect.DeepEqual(missing.LatestEvents, []string{eventD.EventID()}) { - t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, eventD.EventID()) - } - // just return event C, not event B so /state_ids logic kicks in as there will STILL be missing prev_events - return gomatrixserverlib.RespMissingEvents{ - Events: []*gomatrixserverlib.Event{ - eventC.Unwrap(), - }, - }, nil - }, - } - - pdus := []json.RawMessage{ - eventD.JSON(), - } - txn := mustCreateTransaction(rsAPI, cli, pdus) - mustProcessTransaction(t, txn, nil) - assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD}) -} -*/ diff --git a/federationapi/statistics/statistics.go b/federationapi/statistics/statistics.go index 0a44375c6..866c09336 100644 --- a/federationapi/statistics/statistics.go +++ b/federationapi/statistics/statistics.go @@ -1,6 +1,7 @@ package statistics import ( + "context" "math" "math/rand" "sync" @@ -28,14 +29,24 @@ type Statistics struct { // just blacklist the host altogether? The backoff is exponential, // so the max time here to attempt is 2**failures seconds. FailuresUntilBlacklist uint32 + + // How many times should we tolerate consecutive failures before we + // mark the destination as offline. At this point we should attempt + // to send messages to the user's async relay servers if we know them. + FailuresUntilAssumedOffline uint32 } -func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics { +func NewStatistics( + db storage.Database, + failuresUntilBlacklist uint32, + failuresUntilAssumedOffline uint32, +) Statistics { return Statistics{ - DB: db, - FailuresUntilBlacklist: failuresUntilBlacklist, - backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), - servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), + DB: db, + FailuresUntilBlacklist: failuresUntilBlacklist, + FailuresUntilAssumedOffline: failuresUntilAssumedOffline, + backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer), + servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics), } } @@ -50,8 +61,9 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS if !found { s.mutex.Lock() server = &ServerStatistics{ - statistics: s, - serverName: serverName, + statistics: s, + serverName: serverName, + knownRelayServers: []gomatrixserverlib.ServerName{}, } s.servers[serverName] = server s.mutex.Unlock() @@ -61,24 +73,49 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS } else { server.blacklisted.Store(blacklisted) } + assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName) + } else { + server.assumedOffline.Store(assumedOffline) + } + + knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName) + if err != nil { + logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName) + } else { + server.relayMutex.Lock() + server.knownRelayServers = knownRelayServers + server.relayMutex.Unlock() + } } return server } +type SendMethod uint8 + +const ( + SendDirect SendMethod = iota + SendViaRelay +) + // ServerStatistics contains information about our interactions with a // remote federated host, e.g. how many times we were successful, how // many times we failed etc. It also manages the backoff time and black- // listing a remote host if it remains uncooperative. type ServerStatistics struct { - statistics *Statistics // - serverName gomatrixserverlib.ServerName // - blacklisted atomic.Bool // is the node blacklisted - backoffStarted atomic.Bool // is the backoff started - backoffUntil atomic.Value // time.Time until this backoff interval ends - backoffCount atomic.Uint32 // number of times BackoffDuration has been called - successCounter atomic.Uint32 // how many times have we succeeded? - backoffNotifier func() // notifies destination queue when backoff completes - notifierMutex sync.Mutex + statistics *Statistics // + serverName gomatrixserverlib.ServerName // + blacklisted atomic.Bool // is the node blacklisted + assumedOffline atomic.Bool // is the node assumed to be offline + backoffStarted atomic.Bool // is the backoff started + backoffUntil atomic.Value // time.Time until this backoff interval ends + backoffCount atomic.Uint32 // number of times BackoffDuration has been called + successCounter atomic.Uint32 // how many times have we succeeded? + backoffNotifier func() // notifies destination queue when backoff completes + notifierMutex sync.Mutex + knownRelayServers []gomatrixserverlib.ServerName + relayMutex sync.Mutex } const maxJitterMultiplier = 1.4 @@ -113,13 +150,19 @@ func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) { // attempt, which increases the sent counter and resets the idle and // failure counters. If a host was blacklisted at this point then // we will unblacklist it. -func (s *ServerStatistics) Success() { +// `relay` specifies whether the success was to the actual destination +// or one of their relay servers. +func (s *ServerStatistics) Success(method SendMethod) { s.cancel() s.backoffCount.Store(0) - s.successCounter.Inc() - if s.statistics.DB != nil { - if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { - logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + // NOTE : Sending to the final destination vs. a relay server has + // slightly different semantics. + if method == SendDirect { + s.successCounter.Inc() + if s.blacklisted.Load() && s.statistics.DB != nil { + if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName) + } } } } @@ -139,7 +182,18 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { // start a goroutine which will wait out the backoff and // unset the backoffStarted flag when done. if s.backoffStarted.CompareAndSwap(false, true) { - if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist { + backoffCount := s.backoffCount.Inc() + + if backoffCount >= s.statistics.FailuresUntilAssumedOffline { + s.assumedOffline.CompareAndSwap(false, true) + if s.statistics.DB != nil { + if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil { + logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName) + } + } + } + + if backoffCount >= s.statistics.FailuresUntilBlacklist { s.blacklisted.Store(true) if s.statistics.DB != nil { if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil { @@ -157,13 +211,21 @@ func (s *ServerStatistics) Failure() (time.Time, bool) { s.backoffUntil.Store(until) s.statistics.backoffMutex.Lock() - defer s.statistics.backoffMutex.Unlock() s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished) + s.statistics.backoffMutex.Unlock() } return s.backoffUntil.Load().(time.Time), false } +// MarkServerAlive removes the assumed offline and blacklisted statuses from this server. +// Returns whether the server was blacklisted before this point. +func (s *ServerStatistics) MarkServerAlive() bool { + s.removeAssumedOffline() + wasBlacklisted := s.removeBlacklist() + return wasBlacklisted +} + // ClearBackoff stops the backoff timer for this destination if it is running // and removes the timer from the backoffTimers map. func (s *ServerStatistics) ClearBackoff() { @@ -191,13 +253,13 @@ func (s *ServerStatistics) backoffFinished() { } // BackoffInfo returns information about the current or previous backoff. -// Returns the last backoffUntil time and whether the server is currently blacklisted or not. -func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) { +// Returns the last backoffUntil time. +func (s *ServerStatistics) BackoffInfo() *time.Time { until, ok := s.backoffUntil.Load().(time.Time) if ok { - return &until, s.blacklisted.Load() + return &until } - return nil, s.blacklisted.Load() + return nil } // Blacklisted returns true if the server is blacklisted and false @@ -206,10 +268,33 @@ func (s *ServerStatistics) Blacklisted() bool { return s.blacklisted.Load() } -// RemoveBlacklist removes the blacklisted status from the server. -func (s *ServerStatistics) RemoveBlacklist() { +// AssumedOffline returns true if the server is assumed offline and false +// otherwise. +func (s *ServerStatistics) AssumedOffline() bool { + return s.assumedOffline.Load() +} + +// removeBlacklist removes the blacklisted status from the server. +// Returns whether the server was blacklisted. +func (s *ServerStatistics) removeBlacklist() bool { + var wasBlacklisted bool + + if s.Blacklisted() { + wasBlacklisted = true + _ = s.statistics.DB.RemoveServerFromBlacklist(s.serverName) + } s.cancel() s.backoffCount.Store(0) + + return wasBlacklisted +} + +// removeAssumedOffline removes the assumed offline status from the server. +func (s *ServerStatistics) removeAssumedOffline() { + if s.AssumedOffline() { + _ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName) + } + s.assumedOffline.Store(false) } // SuccessCount returns the number of successful requests. This is @@ -217,3 +302,46 @@ func (s *ServerStatistics) RemoveBlacklist() { func (s *ServerStatistics) SuccessCount() uint32 { return s.successCounter.Load() } + +// KnownRelayServers returns the list of relay servers associated with this +// server. +func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName { + s.relayMutex.Lock() + defer s.relayMutex.Unlock() + return s.knownRelayServers +} + +func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) { + seenSet := make(map[gomatrixserverlib.ServerName]bool) + uniqueList := []gomatrixserverlib.ServerName{} + for _, srv := range relayServers { + if seenSet[srv] { + continue + } + seenSet[srv] = true + uniqueList = append(uniqueList, srv) + } + + err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList) + if err != nil { + logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList) + return + } + + for _, newServer := range uniqueList { + alreadyKnown := false + knownRelayServers := s.KnownRelayServers() + for _, srv := range knownRelayServers { + if srv == newServer { + alreadyKnown = true + } + } + if !alreadyKnown { + { + s.relayMutex.Lock() + s.knownRelayServers = append(s.knownRelayServers, newServer) + s.relayMutex.Unlock() + } + } + } +} diff --git a/federationapi/statistics/statistics_test.go b/federationapi/statistics/statistics_test.go index 6aa997f44..183b9aa0c 100644 --- a/federationapi/statistics/statistics_test.go +++ b/federationapi/statistics/statistics_test.go @@ -4,17 +4,26 @@ import ( "math" "testing" "time" + + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + FailuresUntilAssumedOffline = 3 + FailuresUntilBlacklist = 8 ) func TestBackoff(t *testing.T) { - stats := NewStatistics(nil, 7) + stats := NewStatistics(nil, FailuresUntilBlacklist, FailuresUntilAssumedOffline) server := ServerStatistics{ statistics: &stats, serverName: "test.com", } // Start by checking that counting successes works. - server.Success() + server.Success(SendDirect) if successes := server.SuccessCount(); successes != 1 { t.Fatalf("Expected success count 1, got %d", successes) } @@ -31,9 +40,8 @@ func TestBackoff(t *testing.T) { // side effects since a backoff is already in progress. If it does // then we'll fail. until, blacklisted := server.Failure() - - // Get the duration. - _, blacklist := server.BackoffInfo() + blacklist := server.Blacklisted() + assumedOffline := server.AssumedOffline() duration := time.Until(until) // Unset the backoff, or otherwise our next call will think that @@ -41,16 +49,43 @@ func TestBackoff(t *testing.T) { server.cancel() server.backoffStarted.Store(false) + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assuming the destination was offline but didn't", i) + } + } + + // Check if we should be assumed offline by now. + if i >= stats.FailuresUntilAssumedOffline { + if !assumedOffline { + t.Fatalf("Backoff %d should have resulted in assumed offline but didn't", i) + } else { + t.Logf("Backoff %d is assumed offline as expected", i) + } + } else { + if assumedOffline { + t.Fatalf("Backoff %d should not have resulted in assumed offline but did", i) + } else { + t.Logf("Backoff %d is not assumed offline as expected", i) + } + } + // Check if we should be blacklisted by now. if i >= stats.FailuresUntilBlacklist { if !blacklist { t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i) } else if blacklist != blacklisted { - t.Fatalf("BackoffInfo and Failure returned different blacklist values") + t.Fatalf("Blacklisted and Failure returned different blacklist values") } else { t.Logf("Backoff %d is blacklisted as expected", i) continue } + } else { + if blacklist { + t.Fatalf("Backoff %d should not have resulted in blacklist but did", i) + } else { + t.Logf("Backoff %d is not blacklisted as expected", i) + } } // Check if the duration is what we expect. @@ -69,3 +104,14 @@ func TestBackoff(t *testing.T) { } } } + +func TestRelayServersListing(t *testing.T) { + stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline) + server := ServerStatistics{statistics: &stats} + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers := server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) + server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"}) + relayServers = server.KnownRelayServers() + assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers) +} diff --git a/federationapi/storage/interface.go b/federationapi/storage/interface.go index 2b4d905fc..4f5300af1 100644 --- a/federationapi/storage/interface.go +++ b/federationapi/storage/interface.go @@ -20,11 +20,12 @@ import ( "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/dendrite/federationapi/storage/shared" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/types" ) type Database interface { + P2PDatabase gomatrixserverlib.KeyDatabase UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) @@ -34,16 +35,16 @@ type Database interface { // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) - StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) + StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error) - GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) - GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) + GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) + GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) - AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error - AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error + AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error + AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error - CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error - CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error + CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error + CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) @@ -54,6 +55,18 @@ type Database interface { RemoveAllServersFromBlacklist() error IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) + // Adds the server to the list of assumed offline servers. + // If the server already exists in the table, nothing happens and returns success. + SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Removes the server from the list of assumed offline servers. + // If the server doesn't exist in the table, nothing happens and returns success. + RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error + // Purges all entries from the assumed offline table. + RemoveAllServersAssumedOffline(ctx context.Context) error + // Gets whether the provided server is present in the table. + // If it is present, returns true. If not, returns false. + IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error) + AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) @@ -74,3 +87,21 @@ type Database interface { PurgeRoom(ctx context.Context, roomID string) error } + +type P2PDatabase interface { + // Stores the given list of servers as relay servers for the provided destination server. + // Providing duplicates will only lead to a single entry and won't lead to an error. + P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Get the list of relay servers associated with the provided destination server. + // If no entry exists in the table, an empty list is returned and does not result in an error. + P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + + // Deletes any entries for the provided destination server that match the provided relayServers list. + // If any of the provided servers don't match an entry, nothing happens and no error is returned. + P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + + // Deletes all entries for the provided destination server. + // If the destination server doesn't exist in the table, nothing happens and no error is returned. + P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error +} diff --git a/federationapi/storage/postgres/assumed_offline_table.go b/federationapi/storage/postgres/assumed_offline_table.go new file mode 100644 index 000000000..5695d2e54 --- /dev/null +++ b/federationapi/storage/postgres/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "TRUNCATE federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/postgres/relay_servers_table.go b/federationapi/storage/postgres/relay_servers_table.go new file mode 100644 index 000000000..f7267978f --- /dev/null +++ b/federationapi/storage/postgres/relay_servers_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + deleteRelayServersStmt *sql.Stmt + deleteAllRelayServersStmt *sql.Stmt +} + +func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteRelayServersStmt, deleteRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers)) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/postgres/storage.go b/federationapi/storage/postgres/storage.go index fe84e932e..b81f128e7 100644 --- a/federationapi/storage/postgres/storage.go +++ b/federationapi/storage/postgres/storage.go @@ -62,6 +62,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewPostgresAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewPostgresRelayServersTable(d.db) + if err != nil { + return nil, err + } inboundPeeks, err := NewPostgresInboundPeeksTable(d.db) if err != nil { return nil, err @@ -104,6 +112,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationInboundPeeks: inboundPeeks, FederationOutboundPeeks: outboundPeeks, NotaryServerKeysJSON: notaryJSON, diff --git a/federationapi/storage/shared/receipt/receipt.go b/federationapi/storage/shared/receipt/receipt.go new file mode 100644 index 000000000..b347269c1 --- /dev/null +++ b/federationapi/storage/shared/receipt/receipt.go @@ -0,0 +1,42 @@ +// Copyright 2023 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// A Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. +// We don't actually export the NIDs but we need the caller to be able +// to pass them back so that we can clean up if the transaction sends +// successfully. + +package receipt + +import "fmt" + +// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry +// in some database table. +// The internal nid value cannot be modified after a Receipt has been created. +// This guarantees a receipt will always refer to the same table entry that it was created +// to represent. +type Receipt struct { + nid int64 +} + +func NewReceipt(nid int64) Receipt { + return Receipt{nid: nid} +} + +func (r *Receipt) GetNID() int64 { + return r.nid +} + +func (r *Receipt) String() string { + return fmt.Sprintf("%d", r.nid) +} diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index 6cda55725..6769637bc 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -20,6 +20,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/dendrite/federationapi/storage/tables" "github.com/matrix-org/dendrite/federationapi/types" "github.com/matrix-org/dendrite/internal/caching" @@ -37,6 +38,8 @@ type Database struct { FederationQueueJSON tables.FederationQueueJSON FederationJoinedHosts tables.FederationJoinedHosts FederationBlacklist tables.FederationBlacklist + FederationAssumedOffline tables.FederationAssumedOffline + FederationRelayServers tables.FederationRelayServers FederationOutboundPeeks tables.FederationOutboundPeeks FederationInboundPeeks tables.FederationInboundPeeks NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON @@ -44,22 +47,6 @@ type Database struct { ServerSigningKeys tables.FederationServerSigningKeys } -// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs. -// We don't actually export the NIDs but we need the caller to be able -// to pass them back so that we can clean up if the transaction sends -// successfully. -type Receipt struct { - nid int64 -} - -func NewReceipt(nid int64) Receipt { - return Receipt{nid: nid} -} - -func (r *Receipt) String() string { - return fmt.Sprintf("%d", r.nid) -} - // UpdateRoom updates the joined hosts for a room and returns what the joined // hosts were before the update, or nil if this was a duplicate message. // This is called when we receive a message from kafka, so we pass in @@ -113,11 +100,18 @@ func (d *Database) GetJoinedHosts( // GetAllJoinedHosts returns the currently joined hosts for // all rooms known to the federation sender. // Returns an error if something goes wrong. -func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx) } -func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { +func (d *Database) GetJoinedHostsForRooms( + ctx context.Context, + roomIDs []string, + excludeSelf, + excludeBlacklisted bool, +) ([]gomatrixserverlib.ServerName, error) { servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted) if err != nil { return nil, err @@ -139,7 +133,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, // metadata entries. func (d *Database) StoreJSON( ctx context.Context, js string, -) (*Receipt, error) { +) (*receipt.Receipt, error) { var nid int64 var err error _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { @@ -149,18 +143,21 @@ func (d *Database) StoreJSON( if err != nil { return nil, fmt.Errorf("d.insertQueueJSON: %w", err) } - return &Receipt{ - nid: nid, - }, nil + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil } -func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName) }) } -func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error { +func (d *Database) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName) }) @@ -172,51 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error { }) } -func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) { +func (d *Database) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName) } -func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName) + }) +} + +func (d *Database) RemoveAllServersAssumedOffline( + ctx context.Context, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn) + }) +} + +func (d *Database) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName) +} + +func (d *Database) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName) +} + +func (d *Database) P2PRemoveRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers) + }) +} + +func (d *Database) P2PRemoveAllRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName) + }) +} + +func (d *Database) AddOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { +func (d *Database) GetOutboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID, + peekID string, +) (*types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { +func (d *Database) GetOutboundPeeks( + ctx context.Context, + roomID string, +) ([]types.OutboundPeek, error) { return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID) } -func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) AddInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { +func (d *Database) RenewInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, + renewalInterval int64, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval) }) } -func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { +func (d *Database) GetInboundPeek( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + roomID string, + peekID string, +) (*types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID) } -func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { +func (d *Database) GetInboundPeeks( + ctx context.Context, + roomID string, +) ([]types.InboundPeek, error) { return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID) } -func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { +func (d *Database) UpdateNotaryKeys( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + serverKeys gomatrixserverlib.ServerKeys, +) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { validUntil := serverKeys.ValidUntilTS // Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid. @@ -251,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv } func (d *Database) GetNotaryKeys( - ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID, + ctx context.Context, + serverName gomatrixserverlib.ServerName, + optKeyIDs []gomatrixserverlib.KeyID, ) (sks []gomatrixserverlib.ServerKeys, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs) diff --git a/federationapi/storage/shared/storage_edus.go b/federationapi/storage/shared/storage_edus.go index be8355f31..cff1ade6f 100644 --- a/federationapi/storage/shared/storage_edus.go +++ b/federationapi/storage/shared/storage_edus.go @@ -22,6 +22,7 @@ import ( "fmt" "time" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{ func (d *Database) AssociateEDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration, ) error { @@ -62,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations( var err error for destination := range destinations { err = d.FederationQueueEDUs.InsertQueueEDU( - ctx, // context - txn, // SQL transaction - eduType, // EDU type for coalescing - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table - expiresAt, // The timestamp this EDU will expire + ctx, // context + txn, // SQL transaction + eduType, // EDU type for coalescing + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table + expiresAt, // The timestamp this EDU will expire ) } return err @@ -81,10 +82,10 @@ func (d *Database) GetPendingEDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - edus map[*Receipt]*gomatrixserverlib.EDU, + edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error, ) { - edus = make(map[*Receipt]*gomatrixserverlib.EDU) + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit) if err != nil { @@ -94,7 +95,8 @@ func (d *Database) GetPendingEDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok { - edus[&Receipt{nid}] = edu + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = edu } else { retrieve = append(retrieve, nid) } @@ -110,7 +112,8 @@ func (d *Database) GetPendingEDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - edus[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + edus[&newReceipt] = &event d.Cache.StoreFederationQueuedEDU(nid, &event) } @@ -124,7 +127,7 @@ func (d *Database) GetPendingEDUs( func (d *Database) CleanEDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -132,7 +135,7 @@ func (d *Database) CleanEDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/shared/storage_pdus.go b/federationapi/storage/shared/storage_pdus.go index da4cb979d..854e00553 100644 --- a/federationapi/storage/shared/storage_pdus.go +++ b/federationapi/storage/shared/storage_pdus.go @@ -21,6 +21,7 @@ import ( "errors" "fmt" + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" "github.com/matrix-org/gomatrixserverlib" ) @@ -30,17 +31,17 @@ import ( func (d *Database) AssociatePDUWithDestinations( ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, - receipt *Receipt, + dbReceipt *receipt.Receipt, ) error { return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { var err error for destination := range destinations { err = d.FederationQueuePDUs.InsertQueuePDU( - ctx, // context - txn, // SQL transaction - "", // transaction ID - destination, // destination server name - receipt.nid, // NID from the federationapi_queue_json table + ctx, // context + txn, // SQL transaction + "", // transaction ID + destination, // destination server name + dbReceipt.GetNID(), // NID from the federationapi_queue_json table ) } return err @@ -54,7 +55,7 @@ func (d *Database) GetPendingPDUs( serverName gomatrixserverlib.ServerName, limit int, ) ( - events map[*Receipt]*gomatrixserverlib.HeaderedEvent, + events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error, ) { // Strictly speaking this doesn't need to be using the writer @@ -62,7 +63,7 @@ func (d *Database) GetPendingPDUs( // a guarantee of transactional isolation, it's actually useful // to know in SQLite mode that nothing else is trying to modify // the database. - events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent) + events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit) if err != nil { @@ -72,7 +73,8 @@ func (d *Database) GetPendingPDUs( retrieve := make([]int64, 0, len(nids)) for _, nid := range nids { if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok { - events[&Receipt{nid}] = event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = event } else { retrieve = append(retrieve, nid) } @@ -88,7 +90,8 @@ func (d *Database) GetPendingPDUs( if err := json.Unmarshal(blob, &event); err != nil { return fmt.Errorf("json.Unmarshal: %w", err) } - events[&Receipt{nid}] = &event + newReceipt := receipt.NewReceipt(nid) + events[&newReceipt] = &event d.Cache.StoreFederationQueuedPDU(nid, &event) } @@ -103,7 +106,7 @@ func (d *Database) GetPendingPDUs( func (d *Database) CleanPDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, - receipts []*Receipt, + receipts []*receipt.Receipt, ) error { if len(receipts) == 0 { return errors.New("expected receipt") @@ -111,7 +114,7 @@ func (d *Database) CleanPDUs( nids := make([]int64, len(receipts)) for i := range receipts { - nids[i] = receipts[i].nid + nids[i] = receipts[i].GetNID() } return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { diff --git a/federationapi/storage/sqlite3/assumed_offline_table.go b/federationapi/storage/sqlite3/assumed_offline_table.go new file mode 100644 index 000000000..ff2afb4da --- /dev/null +++ b/federationapi/storage/sqlite3/assumed_offline_table.go @@ -0,0 +1,107 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const assumedOfflineSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_assumed_offline( + -- The assumed offline server name + server_name TEXT PRIMARY KEY NOT NULL +); +` + +const insertAssumedOfflineSQL = "" + + "INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" + + " ON CONFLICT DO NOTHING" + +const selectAssumedOfflineSQL = "" + + "SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline WHERE server_name = $1" + +const deleteAllAssumedOfflineSQL = "" + + "DELETE FROM federationsender_assumed_offline" + +type assumedOfflineStatements struct { + db *sql.DB + insertAssumedOfflineStmt *sql.Stmt + selectAssumedOfflineStmt *sql.Stmt + deleteAssumedOfflineStmt *sql.Stmt + deleteAllAssumedOfflineStmt *sql.Stmt +} + +func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) { + s = &assumedOfflineStatements{ + db: db, + } + _, err = db.Exec(assumedOfflineSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL}, + {&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL}, + {&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL}, + {&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL}, + }.Prepare(db) +} + +func (s *assumedOfflineStatements) InsertAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) SelectAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (bool, error) { + stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt) + res, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return false, err + } + defer res.Close() // nolint:errcheck + // The query will return the server name if the server is assume offline, and + // will return no rows if not. By calling Next, we find out if a row was + // returned or not - we don't care about the value itself. + return res.Next(), nil +} + +func (s *assumedOfflineStatements) DeleteAssumedOffline( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx, serverName) + return err +} + +func (s *assumedOfflineStatements) DeleteAllAssumedOffline( + ctx context.Context, txn *sql.Tx, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt) + _, err := stmt.ExecContext(ctx) + return err +} diff --git a/federationapi/storage/sqlite3/relay_servers_table.go b/federationapi/storage/sqlite3/relay_servers_table.go new file mode 100644 index 000000000..27c3cca2c --- /dev/null +++ b/federationapi/storage/sqlite3/relay_servers_table.go @@ -0,0 +1,148 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayServersSchema = ` +CREATE TABLE IF NOT EXISTS federationsender_relay_servers ( + -- The destination server name + server_name TEXT NOT NULL, + -- The relay server name for a given destination + relay_server_name TEXT NOT NULL, + UNIQUE (server_name, relay_server_name) +); + +CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx + ON federationsender_relay_servers (server_name); +` + +const insertRelayServersSQL = "" + + "INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" + + " ON CONFLICT DO NOTHING" + +const selectRelayServersSQL = "" + + "SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1" + +const deleteRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)" + +const deleteAllRelayServersSQL = "" + + "DELETE FROM federationsender_relay_servers WHERE server_name = $1" + +type relayServersStatements struct { + db *sql.DB + insertRelayServersStmt *sql.Stmt + selectRelayServersStmt *sql.Stmt + // deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic + deleteAllRelayServersStmt *sql.Stmt +} + +func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) { + s = &relayServersStatements{ + db: db, + } + _, err = db.Exec(relayServersSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertRelayServersStmt, insertRelayServersSQL}, + {&s.selectRelayServersStmt, selectRelayServersSQL}, + {&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL}, + }.Prepare(db) +} + +func (s *relayServersStatements) InsertRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + for _, relayServer := range relayServers { + stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil { + return err + } + } + return nil +} + +func (s *relayServersStatements) SelectRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt) + rows, err := stmt.QueryContext(ctx, serverName) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var relayServer string + if err = rows.Scan(&relayServer); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(relayServer)) + } + return result, nil +} + +func (s *relayServersStatements) DeleteRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1) + deleteStmt, err := s.db.Prepare(deleteSQL) + if err != nil { + return err + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + params := make([]interface{}, len(relayServers)+1) + params[0] = serverName + for i, v := range relayServers { + params[i+1] = v + } + + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayServersStatements) DeleteAllRelayServers( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt) + if _, err := stmt.ExecContext(ctx, serverName); err != nil { + return err + } + return nil +} diff --git a/federationapi/storage/sqlite3/storage.go b/federationapi/storage/sqlite3/storage.go index d13b5defc..1e7e41a2c 100644 --- a/federationapi/storage/sqlite3/storage.go +++ b/federationapi/storage/sqlite3/storage.go @@ -1,5 +1,4 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// Copyright 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -61,6 +60,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, if err != nil { return nil, err } + assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db) + if err != nil { + return nil, err + } + relayServers, err := NewSQLiteRelayServersTable(d.db) + if err != nil { + return nil, err + } outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db) if err != nil { return nil, err @@ -103,6 +110,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, FederationQueueEDUs: queueEDUs, FederationQueueJSON: queueJSON, FederationBlacklist: blacklist, + FederationAssumedOffline: assumedOffline, + FederationRelayServers: relayServers, FederationOutboundPeeks: outboundPeeks, FederationInboundPeeks: inboundPeeks, NotaryServerKeysJSON: notaryKeys, diff --git a/federationapi/storage/storage_test.go b/federationapi/storage/storage_test.go index 5b57d40d4..1d2a13e81 100644 --- a/federationapi/storage/storage_test.go +++ b/federationapi/storage/storage_test.go @@ -6,14 +6,13 @@ import ( "testing" "time" - "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" - "github.com/stretchr/testify/assert" - "github.com/matrix-org/dendrite/federationapi/storage" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/test" "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/stretchr/testify/assert" ) func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) { @@ -246,3 +245,99 @@ func TestInboundPeeking(t *testing.T) { assert.ElementsMatch(t, gotPeekIDs, peekIDs) }) } + +func TestServersAssumedOffline(t *testing.T) { + server1 := gomatrixserverlib.ServerName("server1") + server2 := gomatrixserverlib.ServerName("server2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + // Set server1 & server2 as assumed offline. + err := db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + err = db.SetServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + + // Ensure both servers are assumed offline. + isOffline, err := db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Set server1 as not assumed offline. + err = db.RemoveServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.True(t, isOffline) + + // Re-set server1 as assumed offline. + err = db.SetServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + + // Ensure server1 is assumed offline. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.True(t, isOffline) + + err = db.RemoveAllServersAssumedOffline(context.Background()) + assert.Nil(t, err) + + // Ensure both servers have correct state. + isOffline, err = db.IsServerAssumedOffline(context.Background(), server1) + assert.Nil(t, err) + assert.False(t, isOffline) + isOffline, err = db.IsServerAssumedOffline(context.Background(), server2) + assert.Nil(t, err) + assert.False(t, isOffline) + }) +} + +func TestRelayServersStored(t *testing.T) { + server := gomatrixserverlib.ServerName("server") + relayServer1 := gomatrixserverlib.ServerName("relayserver1") + relayServer2 := gomatrixserverlib.ServerName("relayserver2") + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, closeDB := mustCreateFederationDatabase(t, dbType) + defer closeDB() + + err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + + err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + + err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2}) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Equal(t, relayServer1, relayServers[0]) + assert.Equal(t, relayServer2, relayServers[1]) + + err = db.P2PRemoveAllRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + + relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server) + assert.Nil(t, err) + assert.Zero(t, len(relayServers)) + }) +} diff --git a/federationapi/storage/tables/interface.go b/federationapi/storage/tables/interface.go index 2b36edb46..762504e45 100644 --- a/federationapi/storage/tables/interface.go +++ b/federationapi/storage/tables/interface.go @@ -49,6 +49,19 @@ type FederationQueueJSON interface { SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) } +type FederationQueueTransactions interface { + InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +type FederationTransactionJSON interface { + InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} + type FederationJoinedHosts interface { InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error @@ -66,6 +79,20 @@ type FederationBlacklist interface { DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error } +type FederationAssumedOffline interface { + InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error) + DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error + DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error +} + +type FederationRelayServers interface { + InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error) + DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error + DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error +} + type FederationOutboundPeeks interface { InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error) diff --git a/federationapi/storage/tables/relay_servers_table_test.go b/federationapi/storage/tables/relay_servers_table_test.go new file mode 100644 index 000000000..b41211551 --- /dev/null +++ b/federationapi/storage/tables/relay_servers_table_test.go @@ -0,0 +1,224 @@ +package tables_test + +import ( + "context" + "database/sql" + "testing" + + "github.com/matrix-org/dendrite/federationapi/storage/postgres" + "github.com/matrix-org/dendrite/federationapi/storage/sqlite3" + "github.com/matrix-org/dendrite/federationapi/storage/tables" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + server1 = "server1" + server2 = "server2" + server3 = "server3" + server4 = "server4" +) + +type RelayServersDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.FederationRelayServers +} + +func mustCreateRelayServersTable( + t *testing.T, + dbType test.DBType, +) (database RelayServersDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.FederationRelayServers + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayServersTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayServersTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayServersDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func Equal(a, b []gomatrixserverlib.ServerName) bool { + if len(a) != len(b) { + return false + } + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} + +func TestShouldInsertRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldInsertRelayServersWithDuplicates(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2} + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + // Insert the same list again, this shouldn't fail and should have no effect. + err = db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldGetRelayServersUnknownDestination(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + + // Query relay servers for a destination that doesn't exist in the table. + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + + if !Equal(relayServers, []gomatrixserverlib.ServerName{}) { + t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers) + } + }) +} + +func TestShouldDeleteCorrectRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + relayServers1 := []gomatrixserverlib.ServerName{server2, server3} + relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4} + + err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4}) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error()) + } + + expectedRelayServers := []gomatrixserverlib.ServerName{server3} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} + +func TestShouldDeleteAllRelayServers(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateRelayServersTable(t, dbType) + defer close() + expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3} + + err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + err = db.Table.DeleteAllRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error()) + } + + expectedRelayServers1 := []gomatrixserverlib.ServerName{} + relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers1) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers) + } + relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2) + if err != nil { + t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error()) + } + if !Equal(relayServers, expectedRelayServers) { + t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers) + } + }) +} diff --git a/go.mod b/go.mod index a86dd2cb8..871e94eb3 100644 --- a/go.mod +++ b/go.mod @@ -22,9 +22,9 @@ require ( github.com/matrix-org/dugong v0.0.0-20210921133753-66e6b1c67e2e github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 - github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 - github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 - github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 + github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f + github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb + github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.15 github.com/nats-io/nats-server/v2 v2.9.8 github.com/nats-io/nats.go v1.20.0 @@ -37,17 +37,17 @@ require ( github.com/prometheus/client_golang v1.13.0 github.com/sirupsen/logrus v1.9.0 github.com/stretchr/testify v1.8.1 - github.com/tidwall/gjson v1.14.3 + github.com/tidwall/gjson v1.14.4 github.com/tidwall/sjson v1.2.5 github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-lib v2.4.1+incompatible github.com/yggdrasil-network/yggdrasil-go v0.4.6 go.uber.org/atomic v1.10.0 - golang.org/x/crypto v0.1.0 + golang.org/x/crypto v0.5.0 golang.org/x/image v0.1.0 golang.org/x/mobile v0.0.0-20221020085226-b36e6246172e - golang.org/x/net v0.1.0 - golang.org/x/term v0.1.0 + golang.org/x/net v0.5.0 + golang.org/x/term v0.4.0 gopkg.in/h2non/bimg.v1 v1.1.9 gopkg.in/yaml.v2 v2.4.0 gotest.tools/v3 v3.4.0 @@ -119,12 +119,12 @@ require ( github.com/prometheus/procfs v0.8.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20220927061507-ef77025ab5aa // indirect github.com/tidwall/match v1.1.1 // indirect - github.com/tidwall/pretty v1.2.0 // indirect + github.com/tidwall/pretty v1.2.1 // indirect go.etcd.io/bbolt v1.3.6 // indirect golang.org/x/exp v0.0.0-20221031165847-c99f073a8326 // indirect golang.org/x/mod v0.6.0 // indirect - golang.org/x/sys v0.1.0 // indirect - golang.org/x/text v0.4.0 // indirect + golang.org/x/sys v0.4.0 // indirect + golang.org/x/text v0.6.0 // indirect golang.org/x/time v0.1.0 // indirect golang.org/x/tools v0.2.0 // indirect google.golang.org/protobuf v1.28.1 // indirect diff --git a/go.sum b/go.sum index e5cd67bed..1ca3e8a80 100644 --- a/go.sum +++ b/go.sum @@ -348,16 +348,12 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91 h1:s7fexw github.com/matrix-org/go-sqlite3-js v0.0.0-20220419092513-28aa791a1c91/go.mod h1:e+cg2q7C7yE5QnAXgzo512tgFh1RbQLC0+jozuegKgo= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530 h1:kHKxCOLcHH8r4Fzarl4+Y3K5hjothkVW5z7T1dUM11U= github.com/matrix-org/gomatrix v0.0.0-20220926102614-ceba4d9f7530/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab h1:ChaQdT2mpxMm3GRXNOZzLDQ/wOnlKZ8o60LmZGOjdj8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230105074811-965b10ae73ab/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45 h1:zGrmcm2M4F4f+zk5JXAkw3oHa/zXhOh5XVGBdl7GdPo= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119135050-7da03ab58f45/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8 h1:P7me2oCmksST9B4+1I1nA+XrnDQwIqAWmy6ntQrXwc8= -github.com/matrix-org/gomatrixserverlib v0.0.0-20230119151039-d8748f6d5dc8/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847 h1:auIBCi7gfZuvztD0aPr1G/J5Ya5vWr79M/+TJqwD/JM= -github.com/matrix-org/pinecone v0.0.0-20221118192051-fef26631b847/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= -github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f h1:niRWEVkeeekpjxwnMhKn8PD0PUloDsNXP8W+Ez/co/M= +github.com/matrix-org/gomatrixserverlib v0.0.0-20230119205614-cb888d80b00f/go.mod h1:Mtifyr8q8htcBeugvlDnkBcNUy5LO8OzUoplAf1+mb4= +github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb h1:2L+ltfNKab56FoBBqAvbBLjoAbxwwoZie+B8d+Mp3JI= +github.com/matrix-org/pinecone v0.11.1-0.20230111184901-61850f0e63cb/go.mod h1:F3GHppRuHCTDeoOmmgjZMeJdbql91+RSGGsATWfC7oc= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 h1:6z4KxomXSIGWqhHcfzExgkH3Z3UkIXry4ibJS4Aqz2Y= +github.com/matrix-org/util v0.0.0-20221111132719-399730281e66/go.mod h1:iBI1foelCqA09JJgPV0FYz4qA5dUXYOxMi57FxKBdd4= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.16 h1:bq3VjFmv/sOjHtdEhmkEV4x1AJtvUvOJ2PFAZ5+peKQ= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= @@ -494,12 +490,13 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.14.3 h1:9jvXn7olKEHU1S9vwoMGliaT8jq1vJ7IH/n9zD9Dnlw= -github.com/tidwall/gjson v1.14.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE= @@ -543,8 +540,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210314154223-e6e6c4f2bb5b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU= -golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -625,8 +622,8 @@ golang.org/x/net v0.0.0-20210525063256-abc453219eb5/go.mod h1:9nx3DQGgdP8bBQD5qx golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0= -golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= +golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw= +golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -701,12 +698,12 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20221010170243-090e33056c14/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= -golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.1.0 h1:g6Z6vPFA9dYBAF7DWcH6sCcOntplXsDKcliusYijMlw= -golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.4.0 h1:O7UWfv5+A2qiuulQk30kVinPoMtoIPeVaKLEgLpVkvg= +golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -714,8 +711,9 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= diff --git a/internal/log.go b/internal/log.go index da6e20418..9e8656c5b 100644 --- a/internal/log.go +++ b/internal/log.go @@ -101,6 +101,8 @@ func SetupPprof() { // SetupStdLogging configures the logging format to standard output. Typically, it is called when the config is not yet loaded. func SetupStdLogging() { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() logrus.SetReportCaller(true) logrus.SetFormatter(&utcFormatter{ &logrus.TextFormatter{ diff --git a/internal/log_unix.go b/internal/log_unix.go index 8f34c320d..859427041 100644 --- a/internal/log_unix.go +++ b/internal/log_unix.go @@ -32,6 +32,8 @@ import ( // If something fails here it means that the logging was improperly configured, // so we just exit with the error func SetupHookLogging(hooks []config.LogrusHook, componentName string) { + levelLogAddedMu.Lock() + defer levelLogAddedMu.Unlock() for _, hook := range hooks { // Check we received a proper logging level level, err := logrus.ParseLevel(hook.Level) @@ -85,8 +87,6 @@ func checkSyslogHookParams(params map[string]interface{}) { } func setupStdLogHook(level logrus.Level) { - levelLogAddedMu.Lock() - defer levelLogAddedMu.Unlock() if stdLevelLogAdded[level] { return } diff --git a/internal/transactionrequest.go b/internal/transactionrequest.go new file mode 100644 index 000000000..95673fc14 --- /dev/null +++ b/internal/transactionrequest.go @@ -0,0 +1,356 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "sync" + + "github.com/getsentry/sentry-go" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/federationapi/types" + keyapi "github.com/matrix-org/dendrite/keyserver/api" + "github.com/matrix-org/dendrite/roomserver/api" + syncTypes "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" +) + +var ( + PDUCountTotal = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_pdus", + Help: "Number of incoming PDUs from remote servers with labels for success", + }, + []string{"status"}, // 'success' or 'total' + ) + EDUCountTotal = prometheus.NewCounter( + prometheus.CounterOpts{ + Namespace: "dendrite", + Subsystem: "federationapi", + Name: "recv_edus", + Help: "Number of incoming EDUs from remote servers", + }, + ) +) + +type TxnReq struct { + gomatrixserverlib.Transaction + rsAPI api.FederationRoomserverAPI + keyAPI keyapi.FederationKeyAPI + ourServerName gomatrixserverlib.ServerName + keys gomatrixserverlib.JSONVerifier + roomsMu *MutexByRoom + producer *producers.SyncAPIProducer + inboundPresenceEnabled bool +} + +func NewTxnReq( + rsAPI api.FederationRoomserverAPI, + keyAPI keyapi.FederationKeyAPI, + ourServerName gomatrixserverlib.ServerName, + keys gomatrixserverlib.JSONVerifier, + roomsMu *MutexByRoom, + producer *producers.SyncAPIProducer, + inboundPresenceEnabled bool, + pdus []json.RawMessage, + edus []gomatrixserverlib.EDU, + origin gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, + destination gomatrixserverlib.ServerName, +) TxnReq { + t := TxnReq{ + rsAPI: rsAPI, + keyAPI: keyAPI, + ourServerName: ourServerName, + keys: keys, + roomsMu: roomsMu, + producer: producer, + inboundPresenceEnabled: inboundPresenceEnabled, + } + + t.PDUs = pdus + t.EDUs = edus + t.Origin = origin + t.TransactionID = transactionID + t.Destination = destination + + return t +} + +func (t *TxnReq) ProcessTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if t.producer != nil { + t.processEDUs(ctx) + } + }() + + results := make(map[string]gomatrixserverlib.PDUResult) + roomVersions := make(map[string]gomatrixserverlib.RoomVersion) + getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { + if v, ok := roomVersions[roomID]; ok { + return v + } + verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} + verRes := api.QueryRoomVersionForRoomResponse{} + if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID) + return "" + } + roomVersions[roomID] = verRes.RoomVersion + return verRes.RoomVersion + } + + for _, pdu := range t.PDUs { + PDUCountTotal.WithLabelValues("total").Inc() + var header struct { + RoomID string `json:"room_id"` + } + if err := json.Unmarshal(pdu, &header); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event") + // We don't know the event ID at this point so we can't return the + // failure in the PDU results + continue + } + roomVersion := getRoomVersion(header.RoomID) + event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) + if err != nil { + if _, ok := err.(gomatrixserverlib.BadJSONError); ok { + // Room version 6 states that homeservers should strictly enforce canonical JSON + // on PDUs. + // + // This enforces that the entire transaction is rejected if a single bad PDU is + // sent. It is unclear if this is the correct behaviour or not. + // + // See https://github.com/matrix-org/synapse/issues/7543 + return nil, &util.JSONResponse{ + Code: 400, + JSON: jsonerror.BadJSON("PDU contains bad JSON"), + } + } + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu)) + continue + } + if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") { + continue + } + if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) { + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: "Forbidden by server ACLs", + } + continue + } + if err = event.VerifyEventSignatures(ctx, t.keys); err != nil { + util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID()) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + // pass the event to the roomserver which will do auth checks + // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently + // discarded by the caller of this function + if err = api.SendEvents( + ctx, + t.rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + event.Headered(roomVersion), + }, + t.Destination, + t.Origin, + api.DoNotSendToOtherServers, + nil, + true, + ); err != nil { + util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err) + results[event.EventID()] = gomatrixserverlib.PDUResult{ + Error: err.Error(), + } + continue + } + + results[event.EventID()] = gomatrixserverlib.PDUResult{} + PDUCountTotal.WithLabelValues("success").Inc() + } + + wg.Wait() + return &gomatrixserverlib.RespSend{PDUs: results}, nil +} + +// nolint:gocyclo +func (t *TxnReq) processEDUs(ctx context.Context) { + for _, e := range t.EDUs { + EDUCountTotal.Inc() + switch e.Type { + case gomatrixserverlib.MTyping: + // https://matrix.org/docs/spec/server_server/latest#typing-notifications + var typingPayload struct { + RoomID string `json:"room_id"` + UserID string `json:"user_id"` + Typing bool `json:"typing"` + } + if err := json.Unmarshal(e.Content, &typingPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil { + util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream") + } + case gomatrixserverlib.MDirectToDevice: + // https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema + var directPayload gomatrixserverlib.ToDeviceMessage + if err := json.Unmarshal(e.Content, &directPayload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events") + continue + } + if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + for userID, byUser := range directPayload.Messages { + for deviceID, message := range byUser { + // TODO: check that the user and the device actually exist here + if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": directPayload.Sender, + "user_id": userID, + "device_id": deviceID, + }).Error("Failed to send send-to-device event to JetStream") + } + } + } + case gomatrixserverlib.MDeviceListUpdate: + if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate") + } + case gomatrixserverlib.MReceipt: + // https://matrix.org/docs/spec/server_server/r0.1.4#receipts + payload := map[string]types.FederationReceiptMRead{} + + if err := json.Unmarshal(e.Content, &payload); err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event") + continue + } + + for roomID, receipt := range payload { + for userID, mread := range receipt.User { + _, domain, err := gomatrixserverlib.SplitID('@', userID) + if err != nil { + util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender") + continue + } + if t.Origin != domain { + util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin) + continue + } + if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil { + util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{ + "sender": t.Origin, + "user_id": userID, + "room_id": roomID, + "events": mread.EventIDs, + }).Error("Failed to send receipt event to JetStream") + continue + } + } + } + case types.MSigningKeyUpdate: + if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil { + sentry.CaptureException(err) + logrus.WithError(err).Errorf("Failed to process signing key update") + } + case gomatrixserverlib.MPresence: + if t.inboundPresenceEnabled { + if err := t.processPresence(ctx, e); err != nil { + logrus.WithError(err).Errorf("Failed to process presence update") + } + } + default: + util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU") + } + } +} + +// processPresence handles m.receipt events +func (t *TxnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error { + payload := types.Presence{} + if err := json.Unmarshal(e.Content, &payload); err != nil { + return err + } + for _, content := range payload.Push { + if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil { + continue + } else if serverName == t.ourServerName { + continue + } else if serverName != t.Origin { + continue + } + presence, ok := syncTypes.PresenceFromString(content.Presence) + if !ok { + continue + } + if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil { + return err + } + } + return nil +} + +// processReceiptEvent sends receipt events to JetStream +func (t *TxnReq) processReceiptEvent(ctx context.Context, + userID, roomID, receiptType string, + timestamp gomatrixserverlib.Timestamp, + eventIDs []string, +) error { + if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil { + return nil + } else if serverName == t.ourServerName { + return nil + } else if serverName != t.Origin { + return nil + } + // store every event + for _, eventID := range eventIDs { + if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil { + return fmt.Errorf("unable to set receipt event: %w", err) + } + } + + return nil +} diff --git a/internal/transactionrequest_test.go b/internal/transactionrequest_test.go new file mode 100644 index 000000000..dd1bd3502 --- /dev/null +++ b/internal/transactionrequest_test.go @@ -0,0 +1,820 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "testing" + "time" + + "github.com/matrix-org/dendrite/federationapi/producers" + keyAPI "github.com/matrix-org/dendrite/keyserver/api" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/setup/process" + "github.com/matrix-org/dendrite/syncapi/types" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/nats-io/nats.go" + "github.com/stretchr/testify/assert" + "go.uber.org/atomic" + "gotest.tools/v3/poll" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") + testDestination = gomatrixserverlib.ServerName("white.orchard") +) + +var ( + invalidSignatures = json.RawMessage(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localishhost","sender":"@userid:localhost","signatures":{"localhost":{"ed2559:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiaQiWAQ"}},"type":"m.room.member"}`) + testData = []json.RawMessage{ + []byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`), + // messages + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`), + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`), + } + testEvent = []byte(`{"auth_events":["$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg","$BcEcbZnlFLB5rxSNSZNBn6fO3jU/TKAJ79wfKyCQLiU"],"content":{"body":"Test Message"},"depth":3917,"hashes":{"sha256":"cNAWtlHIegrji0mMA6x1rhpYCccY8W1NsWZqSpJFhjs"},"origin":"localhost","origin_server_ts":0,"prev_events":["$4GDB0bVjkWwS3G4noUZCq5oLWzpBYpwzdMcf7gj24CI"],"room_id":"!roomid:localhost","sender":"@userid:localhost","signatures":{"localhost":{"ed25519:auto":"NKym6Kcy3u9mGUr21Hjfe3h7DfDilDhN5PqztT0QZ4NTZ+8Y7owseLolQVXp+TvNjecvzdDywsXXVvGiuQiWAQ"}},"type":"m.room.message"}`) + testRoomVersion = gomatrixserverlib.RoomVersionV1 + testEvents = []*gomatrixserverlib.HeaderedEvent{} + testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent) +) + +type FakeRsAPI struct { + rsAPI.RoomserverInternalAPI + shouldFailQuery bool + bannedFromRoom bool + shouldEventsFail bool +} + +func (r *FakeRsAPI) QueryRoomVersionForRoom( + ctx context.Context, + req *rsAPI.QueryRoomVersionForRoomRequest, + res *rsAPI.QueryRoomVersionForRoomResponse, +) error { + if r.shouldFailQuery { + return fmt.Errorf("Failure") + } + res.RoomVersion = gomatrixserverlib.RoomVersionV10 + return nil +} + +func (r *FakeRsAPI) QueryServerBannedFromRoom( + ctx context.Context, + req *rsAPI.QueryServerBannedFromRoomRequest, + res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + if r.bannedFromRoom { + res.Banned = true + } else { + res.Banned = false + } + return nil +} + +func (r *FakeRsAPI) InputRoomEvents( + ctx context.Context, + req *rsAPI.InputRoomEventsRequest, + res *rsAPI.InputRoomEventsResponse, +) error { + if r.shouldEventsFail { + return fmt.Errorf("Failure") + } + return nil +} + +func TestEmptyTransactionRequest(t *testing.T) { + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", nil, nil, nil, false, []json.RawMessage{}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDU(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUs(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, append(testData, testEvent), []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestBadPDU(t *testing.T) { + pdu := json.RawMessage("{\"room_id\":\"asdf\"}") + pdu2 := json.RawMessage("\"roomid\":\"asdf\"") + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{pdu, pdu2, testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.Empty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUQueryFailure(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldFailQuery: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func TestProcessTransactionRequestPDUBannedFromRoom(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{bannedFromRoom: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUInvalidSignature(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{invalidSignatures}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func TestProcessTransactionRequestPDUSendFail(t *testing.T) { + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{shouldEventsFail: true}, nil, "ourserver", keyRing, nil, nil, false, []json.RawMessage{testEvent}, []gomatrixserverlib.EDU{}, "", "", "") + txnRes, jsonRes := txn.ProcessTransaction(context.Background()) + + assert.Nil(t, jsonRes) + assert.Equal(t, 1, len(txnRes.PDUs)) + for _, result := range txnRes.PDUs { + assert.NotEmpty(t, result.Error) + } +} + +func createTransactionWithEDU(ctx *process.ProcessContext, edus []gomatrixserverlib.EDU) (TxnReq, nats.JetStreamContext, *config.Dendrite) { + cfg := &config.Dendrite{} + cfg.Defaults(config.DefaultOpts{ + Generate: true, + Monolithic: true, + }) + cfg.Global.JetStream.InMemory = true + natsInstance := &jetstream.NATSInstance{} + js, _ := natsInstance.Prepare(ctx, &cfg.Global.JetStream) + producer := &producers.SyncAPIProducer{ + JetStream: js, + TopicReceiptEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + TopicSendToDeviceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + TopicTypingEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + TopicPresenceEvent: cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + TopicDeviceListUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + TopicSigningKeyUpdate: cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + Config: &cfg.FederationAPI, + UserAPI: nil, + } + keyRing := &test.NopJSONVerifier{} + txn := NewTxnReq(&FakeRsAPI{}, nil, "ourserver", keyRing, nil, producer, true, []json.RawMessage{}, edus, "kaer.morhen", "", "ourserver") + return txn, js, cfg +} + +func TestProcessTransactionRequestEDUTyping(t *testing.T) { + var err error + roomID := "!roomid:kaer.morhen" + userID := "@userid:kaer.morhen" + typing := true + edu := gomatrixserverlib.EDU{Type: "m.typing"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": roomID, + "user_id": userID, + "typing": typing, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.typing"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + room := msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, room) + user := msg.Header.Get(jetstream.UserID) + assert.Equal(t, userID, user) + typ, parseErr := strconv.ParseBool(msg.Header.Get("typing")) + if parseErr != nil { + return true + } + assert.Equal(t, typing, typ) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputTypingEvent), + cfg.Global.JetStream.Durable("TestTypingConsumer"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUToDevice(t *testing.T) { + var err error + sender := "@userid:kaer.morhen" + messageID := "$x4MKEPRSF6OGlo0qpnsP3BfSmYX5HhVlykOsQH3ECyg" + msgType := "m.dendrite.test" + edu := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "sender": sender, + "type": msgType, + "message_id": messageID, + "messages": map[string]interface{}{ + "@alice:example.org": map[string]interface{}{ + "IWHQUZUIAH": map[string]interface{}{ + "algorithm": "m.megolm.v1.aes-sha2", + "room_id": "!Cuyf34gef24t:localhost", + "session_id": "X3lUlvLELLYxeTx4yOVu6UDpasGEVO0Jbu+QFnm0cKQ", + "session_key": "AgAAAADxKHa9uFxcXzwYoNueL5Xqi69IkD4sni8LlfJL7qNBEY...", + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.direct_to_device"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputSendToDeviceEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, sender, output.Sender) + assert.Equal(t, msgType, output.Type) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputSendToDeviceEvent), + cfg.Global.JetStream.Durable("TestToDevice"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUDeviceListUpdate(t *testing.T) { + var err error + deviceID := "QBUAZIFURK" + userID := "@john:example.com" + edu := gomatrixserverlib.EDU{Type: "m.device_list_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "device_display_name": "Mobile", + "device_id": deviceID, + "key": "value", + "keys": map[string]interface{}{ + "algorithms": []string{ + "m.olm.v1.curve25519-aes-sha2", + "m.megolm.v1.aes-sha2", + }, + "device_id": "JLAFKJWSCS", + "keys": map[string]interface{}{ + "curve25519:JLAFKJWSCS": "3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI", + "ed25519:JLAFKJWSCS": "lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI", + }, + "signatures": map[string]interface{}{ + "@alice:example.com": map[string]interface{}{ + "ed25519:JLAFKJWSCS": "dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA", + }, + }, + "user_id": "@alice:example.com", + }, + "prev_id": []int{ + 5, + }, + "stream_id": 6, + "user_id": userID, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.device_list_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output gomatrixserverlib.DeviceListUpdateEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + assert.Equal(t, userID, output.UserID) + assert.Equal(t, deviceID, output.DeviceID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputDeviceListUpdate), + cfg.Global.JetStream.Durable("TestDeviceListUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUReceipt(t *testing.T) { + var err error + roomID := "!some_room:example.org" + edu := gomatrixserverlib.EDU{Type: "m.receipt"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:kaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.receipt"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + badUser := gomatrixserverlib.EDU{Type: "m.receipt"} + if badUser.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "johnkaer.morhen": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badDomain := gomatrixserverlib.EDU{Type: "m.receipt"} + if badDomain.Content, err = json.Marshal(map[string]interface{}{ + roomID: map[string]interface{}{ + "m.read": map[string]interface{}{ + "@john:bad.domain": map[string]interface{}{ + "data": map[string]interface{}{ + "ts": 1533358089009, + }, + "event_ids": []string{ + "$read_this_event:matrix.org", + }, + }, + }, + }, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + edus := []gomatrixserverlib.EDU{badEDU, badUser, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output types.OutputReceiptEvent + output.RoomID = msg.Header.Get(jetstream.RoomID) + assert.Equal(t, roomID, output.RoomID) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputReceiptEvent), + cfg.Global.JetStream.Durable("TestReceipt"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUSigningKeyUpdate(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.signing_key_update"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + var output keyAPI.CrossSigningKeyUpdate + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + println(err.Error()) + return true + } + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.InputSigningKeyUpdate), + cfg.Global.JetStream.Durable("TestSigningKeyUpdate"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUPresence(t *testing.T) { + var err error + userID := "@john:kaer.morhen" + presence := "online" + edu := gomatrixserverlib.EDU{Type: "m.presence"} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "push": []map[string]interface{}{{ + "currently_active": true, + "last_active_ago": 5000, + "presence": presence, + "status_msg": "Making cupcakes", + "user_id": userID, + }}, + }); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + badEDU := gomatrixserverlib.EDU{Type: "m.presence"} + badEDU.Content = gomatrixserverlib.RawJSON("badjson") + edus := []gomatrixserverlib.EDU{badEDU, edu} + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, js, cfg := createTransactionWithEDU(ctx, edus) + received := atomic.NewBool(false) + onMessage := func(ctx context.Context, msgs []*nats.Msg) bool { + msg := msgs[0] // Guaranteed to exist if onMessage is called + + userIDRes := msg.Header.Get(jetstream.UserID) + presenceRes := msg.Header.Get("presence") + assert.Equal(t, userID, userIDRes) + assert.Equal(t, presence, presenceRes) + + received.Store(true) + return true + } + err = jetstream.JetStreamConsumer( + ctx.Context(), js, cfg.Global.JetStream.Prefixed(jetstream.OutputPresenceEvent), + cfg.Global.JetStream.Durable("TestPresence"), 1, + onMessage, nats.DeliverAll(), nats.ManualAck(), + ) + assert.Nil(t, err) + + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) + + check := func(log poll.LogT) poll.Result { + if received.Load() { + return poll.Success() + } + return poll.Continue("waiting for events to be processed") + } + poll.WaitOn(t, check, poll.WithTimeout(2*time.Second), poll.WithDelay(10*time.Millisecond)) +} + +func TestProcessTransactionRequestEDUUnhandled(t *testing.T) { + var err error + edu := gomatrixserverlib.EDU{Type: "m.unhandled"} + if edu.Content, err = json.Marshal(map[string]interface{}{}); err != nil { + t.Errorf("failed to marshal EDU JSON") + } + + ctx := process.NewProcessContext() + defer ctx.ShutdownDendrite() + txn, _, _ := createTransactionWithEDU(ctx, []gomatrixserverlib.EDU{edu}) + txnRes, jsonRes := txn.ProcessTransaction(ctx.Context()) + + assert.Nil(t, jsonRes) + assert.Zero(t, len(txnRes.PDUs)) +} + +func init() { + for _, j := range testData { + e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion) + if err != nil { + panic("cannot load test data: " + err.Error()) + } + h := e.Headered(testRoomVersion) + testEvents = append(testEvents, h) + if e.StateKey() != nil { + testStateEvents[gomatrixserverlib.StateKeyTuple{ + EventType: e.Type(), + StateKey: *e.StateKey(), + }] = h + } + } +} + +type testRoomserverAPI struct { + rsAPI.RoomserverInternalAPITrace + inputRoomEvents []rsAPI.InputRoomEvent + queryStateAfterEvents func(*rsAPI.QueryStateAfterEventsRequest) rsAPI.QueryStateAfterEventsResponse + queryEventsByID func(req *rsAPI.QueryEventsByIDRequest) rsAPI.QueryEventsByIDResponse + queryLatestEventsAndState func(*rsAPI.QueryLatestEventsAndStateRequest) rsAPI.QueryLatestEventsAndStateResponse +} + +func (t *testRoomserverAPI) InputRoomEvents( + ctx context.Context, + request *rsAPI.InputRoomEventsRequest, + response *rsAPI.InputRoomEventsResponse, +) error { + t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...) + for _, ire := range request.InputRoomEvents { + fmt.Println("InputRoomEvents: ", ire.Event.EventID()) + } + return nil +} + +// Query the latest events and state for a room from the room server. +func (t *testRoomserverAPI) QueryLatestEventsAndState( + ctx context.Context, + request *rsAPI.QueryLatestEventsAndStateRequest, + response *rsAPI.QueryLatestEventsAndStateResponse, +) error { + r := t.queryLatestEventsAndState(request) + response.RoomExists = r.RoomExists + response.RoomVersion = testRoomVersion + response.LatestEvents = r.LatestEvents + response.StateEvents = r.StateEvents + response.Depth = r.Depth + return nil +} + +// Query the state after a list of events in a room from the room server. +func (t *testRoomserverAPI) QueryStateAfterEvents( + ctx context.Context, + request *rsAPI.QueryStateAfterEventsRequest, + response *rsAPI.QueryStateAfterEventsResponse, +) error { + response.RoomVersion = testRoomVersion + res := t.queryStateAfterEvents(request) + response.PrevEventsExist = res.PrevEventsExist + response.RoomExists = res.RoomExists + response.StateEvents = res.StateEvents + return nil +} + +// Query a list of events by event ID. +func (t *testRoomserverAPI) QueryEventsByID( + ctx context.Context, + request *rsAPI.QueryEventsByIDRequest, + response *rsAPI.QueryEventsByIDResponse, +) error { + res := t.queryEventsByID(request) + response.Events = res.Events + return nil +} + +// Query if a server is joined to a room +func (t *testRoomserverAPI) QueryServerJoinedToRoom( + ctx context.Context, + request *rsAPI.QueryServerJoinedToRoomRequest, + response *rsAPI.QueryServerJoinedToRoomResponse, +) error { + response.RoomExists = true + response.IsInRoom = true + return nil +} + +// Asks for the room version for a given room. +func (t *testRoomserverAPI) QueryRoomVersionForRoom( + ctx context.Context, + request *rsAPI.QueryRoomVersionForRoomRequest, + response *rsAPI.QueryRoomVersionForRoomResponse, +) error { + response.RoomVersion = testRoomVersion + return nil +} + +func (t *testRoomserverAPI) QueryServerBannedFromRoom( + ctx context.Context, req *rsAPI.QueryServerBannedFromRoomRequest, res *rsAPI.QueryServerBannedFromRoomResponse, +) error { + res.Banned = false + return nil +} + +func mustCreateTransaction(rsAPI rsAPI.FederationRoomserverAPI, pdus []json.RawMessage) *TxnReq { + t := NewTxnReq( + rsAPI, + nil, + "", + &test.NopJSONVerifier{}, + NewMutexByRoom(), + nil, + false, + pdus, + nil, + testOrigin, + gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())), + testDestination) + t.PDUs = pdus + t.Origin = testOrigin + t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + t.Destination = testDestination + return &t +} + +func mustProcessTransaction(t *testing.T, txn *TxnReq, pdusWithErrors []string) { + res, err := txn.ProcessTransaction(context.Background()) + if err != nil { + t.Errorf("txn.processTransaction returned an error: %v", err) + return + } + if len(res.PDUs) != len(txn.PDUs) { + t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs)) + return + } +NextPDU: + for eventID, result := range res.PDUs { + if result.Error == "" { + continue + } + for _, eventIDWantError := range pdusWithErrors { + if eventID == eventIDWantError { + break NextPDU + } + } + t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error) + } +} + +func assertInputRoomEvents(t *testing.T, got []rsAPI.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) { + for _, g := range got { + fmt.Println("GOT ", g.Event.EventID()) + } + if len(got) != len(want) { + t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want)) + return + } + for i := range got { + if got[i].Event.EventID() != want[i].EventID() { + t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID()) + } + } +} + +// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on +// to the roomserver. It's the most basic test possible. +func TestBasicTransaction(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, nil) + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} + +// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver +// as it does the auth check. +func TestTransactionFailAuthChecks(t *testing.T) { + rsAPI := &testRoomserverAPI{} + pdus := []json.RawMessage{ + testData[len(testData)-1], // a message event + } + txn := mustCreateTransaction(rsAPI, pdus) + mustProcessTransaction(t, txn, []string{}) + // expect message to be sent to the roomserver + assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]}) +} diff --git a/mediaapi/routing/routing.go b/mediaapi/routing/routing.go index 9dcfa955f..50af2f884 100644 --- a/mediaapi/routing/routing.go +++ b/mediaapi/routing/routing.go @@ -108,13 +108,16 @@ func makeDownloadAPI( activeRemoteRequests *types.ActiveRemoteRequests, activeThumbnailGeneration *types.ActiveThumbnailGeneration, ) http.HandlerFunc { - counterVec := promauto.NewCounterVec( - prometheus.CounterOpts{ - Name: name, - Help: "Total number of media_api requests for either thumbnails or full downloads", - }, - []string{"code"}, - ) + var counterVec *prometheus.CounterVec + if cfg.Matrix.Metrics.Enabled { + counterVec = promauto.NewCounterVec( + prometheus.CounterOpts{ + Name: name, + Help: "Total number of media_api requests for either thumbnails or full downloads", + }, + []string{"code"}, + ) + } httpHandler := func(w http.ResponseWriter, req *http.Request) { req = util.RequestWithLogging(req) @@ -166,5 +169,12 @@ func makeDownloadAPI( vars["downloadName"], ) } - return promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + + var handlerFunc http.HandlerFunc + if counterVec != nil { + handlerFunc = promhttp.InstrumentHandlerCounter(counterVec, http.HandlerFunc(httpHandler)) + } else { + handlerFunc = http.HandlerFunc(httpHandler) + } + return handlerFunc } diff --git a/relayapi/api/api.go b/relayapi/api/api.go new file mode 100644 index 000000000..9db393225 --- /dev/null +++ b/relayapi/api/api.go @@ -0,0 +1,56 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayInternalAPI is used to query information from the relay server. +type RelayInternalAPI interface { + RelayServerAPI + + // Retrieve from external relay server all transactions stored for us and process them. + PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, + ) error +} + +// RelayServerAPI exposes the store & query transaction functionality of a relay server. +type RelayServerAPI interface { + // Store transactions for forwarding to the destination at a later time. + PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, + ) error + + // Obtain the oldest stored transaction for the specified userID. + QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, + ) (QueryRelayTransactionsResponse, error) +} + +type QueryRelayTransactionsResponse struct { + Transaction gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id"` + EntriesQueued bool `json:"entries_queued"` +} diff --git a/relayapi/internal/api.go b/relayapi/internal/api.go new file mode 100644 index 000000000..3ff8c2add --- /dev/null +++ b/relayapi/internal/api.go @@ -0,0 +1,53 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/gomatrixserverlib" +) + +type RelayInternalAPI struct { + db storage.Database + fedClient fedAPI.FederationClient + rsAPI rsAPI.RoomserverInternalAPI + keyRing *gomatrixserverlib.KeyRing + producer *producers.SyncAPIProducer + presenceEnabledInbound bool + serverName gomatrixserverlib.ServerName +} + +func NewRelayInternalAPI( + db storage.Database, + fedClient fedAPI.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, + presenceEnabledInbound bool, + serverName gomatrixserverlib.ServerName, +) *RelayInternalAPI { + return &RelayInternalAPI{ + db: db, + fedClient: fedClient, + rsAPI: rsAPI, + keyRing: keyRing, + producer: producer, + presenceEnabledInbound: presenceEnabledInbound, + serverName: serverName, + } +} diff --git a/relayapi/internal/perform.go b/relayapi/internal/perform.go new file mode 100644 index 000000000..594299334 --- /dev/null +++ b/relayapi/internal/perform.go @@ -0,0 +1,141 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// PerformRelayServerSync implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformRelayServerSync( + ctx context.Context, + userID gomatrixserverlib.UserID, + relayServer gomatrixserverlib.ServerName, +) error { + // Providing a default RelayEntry (EntryID = 0) is done to ask the relay if there are any + // transactions available for this node. + prevEntry := gomatrixserverlib.RelayEntry{} + asyncResponse, err := r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Txn) + + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + for asyncResponse.EntriesQueued { + // There are still more entries available for this node from the relay. + logrus.Infof("Retrieving next entry from relay, previous: %v", prevEntry) + asyncResponse, err = r.fedClient.P2PGetTransactionFromRelay(ctx, userID, prevEntry, relayServer) + prevEntry = gomatrixserverlib.RelayEntry{EntryID: asyncResponse.EntryID} + if err != nil { + logrus.Errorf("P2PGetTransactionFromRelay: %s", err.Error()) + return err + } + r.processTransaction(&asyncResponse.Txn) + } + + return nil +} + +// PerformStoreTransaction implements api.RelayInternalAPI +func (r *RelayInternalAPI) PerformStoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, + userID gomatrixserverlib.UserID, +) error { + logrus.Warnf("Storing transaction for %v", userID) + receipt, err := r.db.StoreTransaction(ctx, transaction) + if err != nil { + logrus.Errorf("db.StoreTransaction: %s", err.Error()) + return err + } + err = r.db.AssociateTransactionWithDestinations( + ctx, + map[gomatrixserverlib.UserID]struct{}{ + userID: {}, + }, + transaction.TransactionID, + receipt) + + return err +} + +// QueryTransactions implements api.RelayInternalAPI +func (r *RelayInternalAPI) QueryTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + previousEntry gomatrixserverlib.RelayEntry, +) (api.QueryRelayTransactionsResponse, error) { + logrus.Infof("QueryTransactions for %s", userID.Raw()) + if previousEntry.EntryID > 0 { + logrus.Infof("Cleaning previous entry (%v) from db for %s", + previousEntry.EntryID, + userID.Raw(), + ) + prevReceipt := receipt.NewReceipt(previousEntry.EntryID) + err := r.db.CleanTransactions(ctx, userID, []*receipt.Receipt{&prevReceipt}) + if err != nil { + logrus.Errorf("db.CleanTransactions: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + } + + transaction, receipt, err := r.db.GetTransaction(ctx, userID) + if err != nil { + logrus.Errorf("db.GetTransaction: %s", err.Error()) + return api.QueryRelayTransactionsResponse{}, err + } + + response := api.QueryRelayTransactionsResponse{} + if transaction != nil && receipt != nil { + logrus.Infof("Obtained transaction (%v) for %s", transaction.TransactionID, userID.Raw()) + response.Transaction = *transaction + response.EntryID = receipt.GetNID() + response.EntriesQueued = true + } else { + logrus.Infof("No more entries in the queue for %s", userID.Raw()) + response.EntryID = 0 + response.EntriesQueued = false + } + + return response, nil +} + +func (r *RelayInternalAPI) processTransaction(txn *gomatrixserverlib.Transaction) { + logrus.Warn("Processing transaction from relay server") + mu := internal.NewMutexByRoom() + t := internal.NewTxnReq( + r.rsAPI, + nil, + r.serverName, + r.keyRing, + mu, + r.producer, + r.presenceEnabledInbound, + txn.PDUs, + txn.EDUs, + txn.Origin, + txn.TransactionID, + txn.Destination) + + t.ProcessTransaction(context.TODO()) +} diff --git a/relayapi/internal/perform_test.go b/relayapi/internal/perform_test.go new file mode 100644 index 000000000..fb71b7d0e --- /dev/null +++ b/relayapi/internal/perform_test.go @@ -0,0 +1,121 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "fmt" + "testing" + + fedAPI "github.com/matrix-org/dendrite/federationapi/api" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type testFedClient struct { + fedAPI.FederationClient + shouldFail bool + queryCount uint + queueDepth uint +} + +func (f *testFedClient) P2PGetTransactionFromRelay( + ctx context.Context, + u gomatrixserverlib.UserID, + prev gomatrixserverlib.RelayEntry, + relayServer gomatrixserverlib.ServerName, +) (res gomatrixserverlib.RespGetRelayTransaction, err error) { + f.queryCount++ + if f.shouldFail { + return res, fmt.Errorf("Error") + } + + res = gomatrixserverlib.RespGetRelayTransaction{ + Txn: gomatrixserverlib.Transaction{}, + EntryID: 0, + } + if f.queueDepth > 0 { + res.EntriesQueued = true + } else { + res.EntriesQueued = false + } + f.queueDepth -= 1 + + return +} + +func TestPerformRelayServerSync(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) +} + +func TestPerformRelayServerSyncFedError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{shouldFail: true} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.Error(t, err) +} + +func TestPerformRelayServerSyncRunsUntilQueueEmpty(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.Nil(t, err, "Invalid userID") + + fedClient := &testFedClient{queueDepth: 2} + relayAPI := NewRelayInternalAPI( + &db, fedClient, nil, nil, nil, false, "", + ) + + err = relayAPI.PerformRelayServerSync(context.Background(), *userID, gomatrixserverlib.ServerName("relay")) + assert.NoError(t, err) + assert.Equal(t, uint(3), fedClient.queryCount) +} diff --git a/relayapi/relayapi.go b/relayapi/relayapi.go new file mode 100644 index 000000000..f9f9d4ff9 --- /dev/null +++ b/relayapi/relayapi.go @@ -0,0 +1,74 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi + +import ( + "github.com/matrix-org/dendrite/federationapi/producers" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage" + rsAPI "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" +) + +// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component. +func AddPublicRoutes( + base *base.BaseDendrite, + keyRing gomatrixserverlib.JSONVerifier, + relayAPI api.RelayInternalAPI, +) { + fedCfg := &base.Cfg.FederationAPI + + relay, ok := relayAPI.(*internal.RelayInternalAPI) + if !ok { + panic("relayapi.AddPublicRoutes called with a RelayInternalAPI impl which was not " + + "RelayInternalAPI. This is a programming error.") + } + + routing.Setup( + base.PublicFederationAPIMux, + fedCfg, + relay, + keyRing, + ) +} + +func NewRelayInternalAPI( + base *base.BaseDendrite, + fedClient *gomatrixserverlib.FederationClient, + rsAPI rsAPI.RoomserverInternalAPI, + keyRing *gomatrixserverlib.KeyRing, + producer *producers.SyncAPIProducer, +) api.RelayInternalAPI { + cfg := &base.Cfg.RelayAPI + + relayDB, err := storage.NewDatabase(base, &cfg.Database, base.Caches, base.Cfg.Global.IsLocalServerName) + if err != nil { + logrus.WithError(err).Panic("failed to connect to relay db") + } + + return internal.NewRelayInternalAPI( + relayDB, + fedClient, + rsAPI, + keyRing, + producer, + base.Cfg.Global.Presence.EnableInbound, + base.Cfg.Global.ServerName, + ) +} diff --git a/relayapi/relayapi_test.go b/relayapi/relayapi_test.go new file mode 100644 index 000000000..dfa06811d --- /dev/null +++ b/relayapi/relayapi_test.go @@ -0,0 +1,154 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package relayapi_test + +import ( + "crypto/ed25519" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing" + "github.com/matrix-org/dendrite/relayapi" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/dendrite/test/testrig" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func TestCreateNewRelayInternalAPI(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + assert.NotNil(t, relayAPI) + }) +} + +func TestCreateRelayInternalInvalidDatabasePanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + if dbType == test.DBTypeSQLite { + base.Cfg.RelayAPI.Database.ConnectionString = "file:" + } else { + base.Cfg.RelayAPI.Database.ConnectionString = "test" + } + defer close() + + assert.Panics(t, func() { + relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + }) + }) +} + +func TestCreateInvalidRelayPublicRoutesPanics(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + assert.Panics(t, func() { + relayapi.AddPublicRoutes(base, nil, nil) + }) + }) +} + +func createGetRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("GET", origin, serverName, "/_matrix/federation/v1/relay_txn/"+userID) + content := gomatrixserverlib.RelayEntry{EntryID: 0} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +type sendRelayContent struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` +} + +func createSendRelayTxnHTTPRequest(serverName gomatrixserverlib.ServerName, txnID string, userID string) *http.Request { + _, sk, _ := ed25519.GenerateKey(nil) + keyID := signing.KeyID + pk := sk.Public().(ed25519.PublicKey) + origin := gomatrixserverlib.ServerName(hex.EncodeToString(pk)) + req := gomatrixserverlib.NewFederationRequest("PUT", origin, serverName, "/_matrix/federation/v1/send_relay/"+txnID+"/"+userID) + content := sendRelayContent{} + req.SetContent(content) + req.Sign(origin, gomatrixserverlib.KeyID(keyID), sk) + httpreq, _ := req.HTTPRequest() + vars := map[string]string{"userID": userID, "txnID": txnID} + httpreq = mux.SetURLVars(httpreq, vars) + return httpreq +} + +func TestCreateRelayPublicRoutes(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + base, close := testrig.CreateBaseDendrite(t, dbType) + defer close() + + relayAPI := relayapi.NewRelayInternalAPI(base, nil, nil, nil, nil) + assert.NotNil(t, relayAPI) + + serverKeyAPI := &signing.YggdrasilKeys{} + keyRing := serverKeyAPI.KeyRing() + relayapi.AddPublicRoutes(base, keyRing, relayAPI) + + testCases := []struct { + name string + req *http.Request + wantCode int + wantJoinedRooms []string + }{ + { + name: "relay_txn invalid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "user:local"), + wantCode: 400, + }, + { + name: "relay_txn valid user id", + req: createGetRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "@user:local"), + wantCode: 200, + }, + { + name: "send_relay invalid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "user:local"), + wantCode: 400, + }, + { + name: "send_relay valid user id", + req: createSendRelayTxnHTTPRequest(base.Cfg.Global.ServerName, "123", "@user:local"), + wantCode: 200, + }, + } + + for _, tc := range testCases { + w := httptest.NewRecorder() + base.PublicFederationAPIMux.ServeHTTP(w, tc.req) + if w.Code != tc.wantCode { + t.Fatalf("%s: got HTTP %d want %d", tc.name, w.Code, tc.wantCode) + } + } + }) +} diff --git a/relayapi/routing/relaytxn.go b/relayapi/routing/relaytxn.go new file mode 100644 index 000000000..1b11b0ecd --- /dev/null +++ b/relayapi/routing/relaytxn.go @@ -0,0 +1,74 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +type RelayTransactionResponse struct { + Transaction gomatrixserverlib.Transaction `json:"transaction"` + EntryID int64 `json:"entry_id,omitempty"` + EntriesQueued bool `json:"entries_queued"` +} + +// GetTransactionFromRelay implements /_matrix/federation/v1/relay_txn/{userID} +// This endpoint can be extracted into a separate relay server service. +func GetTransactionFromRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + logrus.Infof("Handling relay_txn for %s", userID.Raw()) + + previousEntry := gomatrixserverlib.RelayEntry{} + if err := json.Unmarshal(fedReq.Content(), &previousEntry); err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("invalid json provided"), + } + } + if previousEntry.EntryID < 0 { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("Invalid entry id provided. Must be >= 0."), + } + } + logrus.Infof("Previous entry provided: %v", previousEntry.EntryID) + + response, err := relayAPI.QueryTransactions(httpReq.Context(), userID, previousEntry) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + } + } + + return util.JSONResponse{ + Code: http.StatusOK, + JSON: RelayTransactionResponse{ + Transaction: response.Transaction, + EntryID: response.EntryID, + EntriesQueued: response.EntriesQueued, + }, + } +} diff --git a/relayapi/routing/relaytxn_test.go b/relayapi/routing/relaytxn_test.go new file mode 100644 index 000000000..a47fdb198 --- /dev/null +++ b/relayapi/routing/relaytxn_test.go @@ -0,0 +1,220 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +func createQuery( + userID gomatrixserverlib.UserID, + prevEntry gomatrixserverlib.RelayEntry, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/relay_txn/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("GET", userID.Domain(), "relay", path) + request.SetContent(prevEntry) + + return request +} + +func TestGetEmptyDatabaseReturnsNothing(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.Equal(t, false, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetInvalidPrevEntryFails(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + + _, err = db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: -1}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusInternalServerError, response.Code) +} + +func TestGetReturnsSavedTransaction(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} + +func TestGetReturnsMultipleSavedTransactions(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + transaction := createTransaction() + receipt, err := db.StoreTransaction(context.Background(), transaction) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction.TransactionID, + receipt) + assert.NoError(t, err, "Failed to associate transaction with user") + + transaction2 := createTransaction() + receipt2, err := db.StoreTransaction(context.Background(), transaction2) + assert.NoError(t, err, "Failed to store transaction") + + err = db.AssociateTransactionWithDestinations( + context.Background(), + map[gomatrixserverlib.UserID]struct{}{ + *userID: {}, + }, + transaction2.TransactionID, + receipt2) + assert.NoError(t, err, "Failed to associate transaction with user") + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + request := createQuery(*userID, gomatrixserverlib.RelayEntry{}) + response := routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse := response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction, jsonResponse.Transaction) + + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.True(t, jsonResponse.EntriesQueued) + assert.Equal(t, transaction2, jsonResponse.Transaction) + + // And once more to clear the queue + request = createQuery(*userID, gomatrixserverlib.RelayEntry{EntryID: jsonResponse.EntryID}) + response = routing.GetTransactionFromRelay(httpReq, &request, relayAPI, *userID) + assert.Equal(t, http.StatusOK, response.Code) + + jsonResponse = response.JSON.(routing.RelayTransactionResponse) + assert.False(t, jsonResponse.EntriesQueued) + assert.Equal(t, gomatrixserverlib.Transaction{}, jsonResponse.Transaction) + + count, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err) + assert.Zero(t, count) +} diff --git a/relayapi/routing/routing.go b/relayapi/routing/routing.go new file mode 100644 index 000000000..6df0cdc5f --- /dev/null +++ b/relayapi/routing/routing.go @@ -0,0 +1,123 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "fmt" + "net/http" + "time" + + "github.com/getsentry/sentry-go" + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/internal/httputil" + relayInternal "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" +) + +// Setup registers HTTP handlers with the given ServeMux. +// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly +// path unescape twice (once from the router, once from MakeRelayAPI). We need to have this enabled +// so we can decode paths like foo/bar%2Fbaz as [foo, bar/baz] - by default it will decode to [foo, bar, baz] +// +// Due to Setup being used to call many other functions, a gocyclo nolint is +// applied: +// nolint: gocyclo +func Setup( + fedMux *mux.Router, + cfg *config.FederationAPI, + relayAPI *relayInternal.RelayInternalAPI, + keys gomatrixserverlib.JSONVerifier, +) { + v1fedmux := fedMux.PathPrefix("/v1").Subrouter() + + v1fedmux.Handle("/send_relay/{txnID}/{userID}", MakeRelayAPI( + "send_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return SendTransactionToRelay( + httpReq, request, relayAPI, gomatrixserverlib.TransactionID(vars["txnID"]), + *userID, + ) + }, + )).Methods(http.MethodPut, http.MethodOptions) + + v1fedmux.Handle("/relay_txn/{userID}", MakeRelayAPI( + "get_relay_transaction", "", cfg.Matrix.IsLocalServerName, keys, + func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest, vars map[string]string) util.JSONResponse { + userID, err := gomatrixserverlib.NewUserID(vars["userID"], false) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidUsername("Username was invalid"), + } + } + return GetTransactionFromRelay(httpReq, request, relayAPI, *userID) + }, + )).Methods(http.MethodGet, http.MethodOptions) +} + +// MakeRelayAPI makes an http.Handler that checks matrix relay authentication. +func MakeRelayAPI( + metricsName string, serverName gomatrixserverlib.ServerName, + isLocalServerName func(gomatrixserverlib.ServerName) bool, + keyRing gomatrixserverlib.JSONVerifier, + f func(*http.Request, *gomatrixserverlib.FederationRequest, map[string]string) util.JSONResponse, +) http.Handler { + h := func(req *http.Request) util.JSONResponse { + fedReq, errResp := gomatrixserverlib.VerifyHTTPRequest( + req, time.Now(), serverName, isLocalServerName, keyRing, + ) + if fedReq == nil { + return errResp + } + // add the user to Sentry, if enabled + hub := sentry.GetHubFromContext(req.Context()) + if hub != nil { + hub.Scope().SetTag("origin", string(fedReq.Origin())) + hub.Scope().SetTag("uri", fedReq.RequestURI()) + } + defer func() { + if r := recover(); r != nil { + if hub != nil { + hub.CaptureException(fmt.Errorf("%s panicked", req.URL.Path)) + } + // re-panic to return the 500 + panic(r) + } + }() + vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) + if err != nil { + return util.MatrixErrorResponse(400, "M_UNRECOGNISED", "badly encoded query params") + } + + jsonRes := f(req, fedReq, vars) + // do not log 4xx as errors as they are client fails, not server fails + if hub != nil && jsonRes.Code >= 500 { + hub.Scope().SetExtra("response", jsonRes) + hub.CaptureException(fmt.Errorf("%s returned HTTP %d", req.URL.Path, jsonRes.Code)) + } + return jsonRes + } + return httputil.MakeExternalAPI(metricsName, h) +} diff --git a/relayapi/routing/sendrelay.go b/relayapi/routing/sendrelay.go new file mode 100644 index 000000000..a7027f293 --- /dev/null +++ b/relayapi/routing/sendrelay.go @@ -0,0 +1,77 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "encoding/json" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/relayapi/api" + "github.com/matrix-org/gomatrixserverlib" + "github.com/matrix-org/util" + "github.com/sirupsen/logrus" +) + +// SendTransactionToRelay implements PUT /_matrix/federation/v1/relay_txn/{txnID}/{userID} +// This endpoint can be extracted into a separate relay server service. +func SendTransactionToRelay( + httpReq *http.Request, + fedReq *gomatrixserverlib.FederationRequest, + relayAPI api.RelayInternalAPI, + txnID gomatrixserverlib.TransactionID, + userID gomatrixserverlib.UserID, +) util.JSONResponse { + var txnEvents struct { + PDUs []json.RawMessage `json:"pdus"` + EDUs []gomatrixserverlib.EDU `json:"edus"` + } + + if err := json.Unmarshal(fedReq.Content(), &txnEvents); err != nil { + logrus.Info("The request body could not be decoded into valid JSON." + err.Error()) + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.NotJSON("The request body could not be decoded into valid JSON." + err.Error()), + } + } + + // Transactions are limited in size; they can have at most 50 PDUs and 100 EDUs. + // https://matrix.org/docs/spec/server_server/latest#transactions + if len(txnEvents.PDUs) > 50 || len(txnEvents.EDUs) > 100 { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("max 50 pdus / 100 edus"), + } + } + + t := gomatrixserverlib.Transaction{} + t.PDUs = txnEvents.PDUs + t.EDUs = txnEvents.EDUs + t.Origin = fedReq.Origin() + t.TransactionID = txnID + t.Destination = userID.Domain() + + util.GetLogger(httpReq.Context()).Warnf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, fedReq.Origin(), len(t.PDUs), len(t.EDUs)) + + err := relayAPI.PerformStoreTransaction(httpReq.Context(), t, userID) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: jsonerror.BadJSON("could not store the transaction for forwarding"), + } + } + + return util.JSONResponse{Code: 200} +} diff --git a/relayapi/routing/sendrelay_test.go b/relayapi/routing/sendrelay_test.go new file mode 100644 index 000000000..d9ed75002 --- /dev/null +++ b/relayapi/routing/sendrelay_test.go @@ -0,0 +1,209 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing_test + +import ( + "context" + "encoding/json" + "net/http" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/internal" + "github.com/matrix-org/dendrite/relayapi/routing" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func createTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + return txn +} + +func createFederationRequest( + userID gomatrixserverlib.UserID, + txnID gomatrixserverlib.TransactionID, + origin gomatrixserverlib.ServerName, + destination gomatrixserverlib.ServerName, + content interface{}, +) gomatrixserverlib.FederationRequest { + var federationPathPrefixV1 = "/_matrix/federation/v1" + path := federationPathPrefixV1 + "/send_relay/" + string(txnID) + "/" + userID.Raw() + request := gomatrixserverlib.NewFederationRequest("PUT", origin, destination, path) + request.SetContent(content) + + return request +} + +func TestForwardEmptyReturnsOk(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.Equal(t, 200, response.Code) +} + +func TestForwardBadJSONReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field bool `json:"pdus"` + } + content := BadData{ + Field: false, + } + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyPDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []json.RawMessage `json:"pdus"` + } + content := BadData{ + Field: []json.RawMessage{}, + } + for i := 0; i < 51; i++ { + content.Field = append(content.Field, []byte{}) + } + assert.Greater(t, len(content.Field), 50) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestForwardTooManyEDUsReturnsError(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + type BadData struct { + Field []gomatrixserverlib.EDU `json:"edus"` + } + content := BadData{ + Field: []gomatrixserverlib.EDU{}, + } + for i := 0; i < 101; i++ { + content.Field = append(content.Field, gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}) + } + assert.Greater(t, len(content.Field), 100) + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, content) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay(httpReq, &request, relayAPI, "1", *userID) + + assert.NotEqual(t, 200, response.Code) +} + +func TestUniqueTransactionStoredInDatabase(t *testing.T) { + testDB := test.NewInMemoryRelayDatabase() + db := shared.Database{ + Writer: sqlutil.NewDummyWriter(), + RelayQueue: testDB, + RelayQueueJSON: testDB, + } + httpReq := &http.Request{} + userID, err := gomatrixserverlib.NewUserID("@local:domain", false) + assert.NoError(t, err, "Invalid userID") + + txn := createTransaction() + request := createFederationRequest(*userID, txn.TransactionID, txn.Origin, txn.Destination, txn) + + relayAPI := internal.NewRelayInternalAPI( + &db, nil, nil, nil, nil, false, "", + ) + + response := routing.SendTransactionToRelay( + httpReq, &request, relayAPI, txn.TransactionID, *userID) + transaction, _, err := db.GetTransaction(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction") + + transactionCount, err := db.GetTransactionCount(context.Background(), *userID) + assert.NoError(t, err, "Failed retrieving transaction count") + + assert.Equal(t, 200, response.Code) + assert.Equal(t, int64(1), transactionCount) + assert.Equal(t, txn.TransactionID, transaction.TransactionID) +} diff --git a/relayapi/storage/interface.go b/relayapi/storage/interface.go new file mode 100644 index 000000000..f5f9a06e5 --- /dev/null +++ b/relayapi/storage/interface.go @@ -0,0 +1,47 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package storage + +import ( + "context" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database interface { + // Adds a new transaction to the queue json table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + StoreTransaction(ctx context.Context, txn gomatrixserverlib.Transaction) (*receipt.Receipt, error) + + // Adds a new transaction_id: server_name mapping with associated json table nid to the queue + // entry table for each provided destination. + AssociateTransactionWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.UserID]struct{}, transactionID gomatrixserverlib.TransactionID, dbReceipt *receipt.Receipt) error + + // Removes every server_name: receipt pair provided from the queue entries table. + // Will then remove every entry for each receipt provided from the queue json table. + // If any of the entries don't exist in either table, nothing will happen for that entry and + // an error will not be generated. + CleanTransactions(ctx context.Context, userID gomatrixserverlib.UserID, receipts []*receipt.Receipt) error + + // Gets the oldest transaction for the provided server_name. + // If no transactions exist, returns nil and no error. + GetTransaction(ctx context.Context, userID gomatrixserverlib.UserID) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) + + // Gets the number of transactions being stored for the provided server_name. + // If the server doesn't exist in the database then 0 is returned with no error. + GetTransactionCount(ctx context.Context, userID gomatrixserverlib.UserID) (int64, error) +} diff --git a/relayapi/storage/postgres/relay_queue_json_table.go b/relayapi/storage/postgres/relay_queue_json_table.go new file mode 100644 index 000000000..74410fc88 --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_json_table.go @@ -0,0 +1,113 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid BIGSERIAL, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + + " RETURNING json_nid" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid = ANY($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid = ANY($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + deleteJSONStmt *sql.Stmt + selectJSONStmt *sql.Stmt +} + +func NewPostgresRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + {&s.deleteJSONStmt, deleteQueueJSONSQL}, + {&s.selectJSONStmt, selectQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (int64, error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + var lastid int64 + if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil { + return 0, err + } + return lastid, nil +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) + _, err := stmt.ExecContext(ctx, pq.Int64Array(nids)) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, s.selectJSONStmt) + rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, err + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/postgres/relay_queue_table.go b/relayapi/storage/postgres/relay_queue_table.go new file mode 100644 index 000000000..e97cf8cc0 --- /dev/null +++ b/relayapi/storage/postgres/relay_queue_table.go @@ -0,0 +1,156 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The destination server that we will send the event to. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid = ANY($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + deleteQueueEntriesStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt +} + +func NewPostgresRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = s.db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.deleteQueueEntriesStmt, deleteQueueEntriesSQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + stmt := sqlutil.TxStmt(txn, s.deleteQueueEntriesStmt) + _, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs)) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/postgres/storage.go b/relayapi/storage/postgres/storage.go new file mode 100644 index 000000000..1042beba7 --- /dev/null +++ b/relayapi/storage/postgres/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the relayapi +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil { + return nil, err + } + queue, err := NewPostgresRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewPostgresRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/shared/storage.go b/relayapi/storage/shared/storage.go new file mode 100644 index 000000000..0993707bf --- /dev/null +++ b/relayapi/storage/shared/storage.go @@ -0,0 +1,170 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package shared + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/gomatrixserverlib" +) + +type Database struct { + DB *sql.DB + IsLocalServerName func(gomatrixserverlib.ServerName) bool + Cache caching.FederationCache + Writer sqlutil.Writer + RelayQueue tables.RelayQueue + RelayQueueJSON tables.RelayQueueJSON +} + +func (d *Database) StoreTransaction( + ctx context.Context, + transaction gomatrixserverlib.Transaction, +) (*receipt.Receipt, error) { + var err error + jsonTransaction, err := json.Marshal(transaction) + if err != nil { + return nil, fmt.Errorf("failed to marshal: %w", err) + } + + var nid int64 + _ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + nid, err = d.RelayQueueJSON.InsertQueueJSON(ctx, txn, string(jsonTransaction)) + return err + }) + if err != nil { + return nil, fmt.Errorf("d.insertQueueJSON: %w", err) + } + + newReceipt := receipt.NewReceipt(nid) + return &newReceipt, nil +} + +func (d *Database) AssociateTransactionWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.UserID]struct{}, + transactionID gomatrixserverlib.TransactionID, + dbReceipt *receipt.Receipt, +) error { + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + var lastErr error + for destination := range destinations { + destination := destination + err := d.RelayQueue.InsertQueueEntry( + ctx, + txn, + transactionID, + destination.Domain(), + dbReceipt.GetNID(), + ) + if err != nil { + lastErr = fmt.Errorf("d.insertQueueEntry: %w", err) + } + } + return lastErr + }) + + return err +} + +func (d *Database) CleanTransactions( + ctx context.Context, + userID gomatrixserverlib.UserID, + receipts []*receipt.Receipt, +) error { + nids := make([]int64, len(receipts)) + for i, dbReceipt := range receipts { + nids[i] = dbReceipt.GetNID() + } + + err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + deleteEntryErr := d.RelayQueue.DeleteQueueEntries(ctx, txn, userID.Domain(), nids) + // TODO : If there are still queue entries for any of these nids for other destinations + // then we shouldn't delete the json entries. + // But this can't happen with the current api design. + // There will only ever be one server entry for each nid since each call to send_relay + // only accepts a single server name and inside there we create a new json entry. + // So for multiple destinations we would call send_relay multiple times and have multiple + // json entries of the same transaction. + // + // TLDR; this works as expected right now but can easily be optimised in the future. + deleteJSONErr := d.RelayQueueJSON.DeleteQueueJSON(ctx, txn, nids) + + if deleteEntryErr != nil { + return fmt.Errorf("d.deleteQueueEntries: %w", deleteEntryErr) + } + if deleteJSONErr != nil { + return fmt.Errorf("d.deleteQueueJSON: %w", deleteJSONErr) + } + return nil + }) + + return err +} + +func (d *Database) GetTransaction( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (*gomatrixserverlib.Transaction, *receipt.Receipt, error) { + entriesRequested := 1 + nids, err := d.RelayQueue.SelectQueueEntries(ctx, nil, userID.Domain(), entriesRequested) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueEntries: %w", err) + } + if len(nids) == 0 { + return nil, nil, nil + } + firstNID := nids[0] + + txns := map[int64][]byte{} + err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + txns, err = d.RelayQueueJSON.SelectQueueJSON(ctx, txn, nids) + return err + }) + if err != nil { + return nil, nil, fmt.Errorf("d.SelectQueueJSON: %w", err) + } + + transaction := &gomatrixserverlib.Transaction{} + if _, ok := txns[firstNID]; !ok { + return nil, nil, fmt.Errorf("Failed retrieving json blob for transaction: %d", firstNID) + } + + err = json.Unmarshal(txns[firstNID], transaction) + if err != nil { + return nil, nil, fmt.Errorf("Unmarshal transaction: %w", err) + } + + newReceipt := receipt.NewReceipt(firstNID) + return transaction, &newReceipt, nil +} + +func (d *Database) GetTransactionCount( + ctx context.Context, + userID gomatrixserverlib.UserID, +) (int64, error) { + count, err := d.RelayQueue.SelectQueueEntryCount(ctx, nil, userID.Domain()) + if err != nil { + return 0, fmt.Errorf("d.SelectQueueEntryCount: %w", err) + } + return count, nil +} diff --git a/relayapi/storage/sqlite3/relay_queue_json_table.go b/relayapi/storage/sqlite3/relay_queue_json_table.go new file mode 100644 index 000000000..502da3b00 --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_json_table.go @@ -0,0 +1,137 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" +) + +const relayQueueJSONSchema = ` +-- The relayapi_queue_json table contains event contents that +-- we are storing for future forwarding. +CREATE TABLE IF NOT EXISTS relayapi_queue_json ( + -- The JSON NID. This allows cross-referencing to find the JSON blob. + json_nid INTEGER PRIMARY KEY AUTOINCREMENT, + -- The JSON body. Text so that we preserve UTF-8. + json_body TEXT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_json_json_nid_idx + ON relayapi_queue_json (json_nid); +` + +const insertQueueJSONSQL = "" + + "INSERT INTO relayapi_queue_json (json_body)" + + " VALUES ($1)" + +const deleteQueueJSONSQL = "" + + "DELETE FROM relayapi_queue_json WHERE json_nid IN ($1)" + +const selectQueueJSONSQL = "" + + "SELECT json_nid, json_body FROM relayapi_queue_json" + + " WHERE json_nid IN ($1)" + +type relayQueueJSONStatements struct { + db *sql.DB + insertJSONStmt *sql.Stmt + //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic + //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueJSONTable(db *sql.DB) (s *relayQueueJSONStatements, err error) { + s = &relayQueueJSONStatements{ + db: db, + } + _, err = db.Exec(relayQueueJSONSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertJSONStmt, insertQueueJSONSQL}, + }.Prepare(db) +} + +func (s *relayQueueJSONStatements) InsertQueueJSON( + ctx context.Context, txn *sql.Tx, json string, +) (lastid int64, err error) { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("res.LastInsertId: %w", err) + } + return +} + +func (s *relayQueueJSONStatements) DeleteQueueJSON( + ctx context.Context, txn *sql.Tx, nids []int64, +) error { + deleteSQL := strings.Replace(deleteQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(nids)) + for k, v := range nids { + iNIDs[k] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err +} + +func (s *relayQueueJSONStatements) SelectQueueJSON( + ctx context.Context, txn *sql.Tx, jsonNIDs []int64, +) (map[int64][]byte, error) { + selectSQL := strings.Replace(selectQueueJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + selectStmt, err := txn.Prepare(selectSQL) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(jsonNIDs)) + for k, v := range jsonNIDs { + iNIDs[k] = v + } + + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, selectStmt) + rows, err := stmt.QueryContext(ctx, iNIDs...) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "selectQueueJSON: rows.close() failed") + for rows.Next() { + var nid int64 + var blob []byte + if err = rows.Scan(&nid, &blob); err != nil { + return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) + } + blobs[nid] = blob + } + return blobs, err +} diff --git a/relayapi/storage/sqlite3/relay_queue_table.go b/relayapi/storage/sqlite3/relay_queue_table.go new file mode 100644 index 000000000..49c6b4de5 --- /dev/null +++ b/relayapi/storage/sqlite3/relay_queue_table.go @@ -0,0 +1,168 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/gomatrixserverlib" +) + +const relayQueueSchema = ` +CREATE TABLE IF NOT EXISTS relayapi_queue ( + -- The transaction ID that was generated before persisting the event. + transaction_id TEXT NOT NULL, + -- The domain part of the user ID the m.room.member event is for. + server_name TEXT NOT NULL, + -- The JSON NID from the relayapi_queue_json table. + json_nid BIGINT NOT NULL +); + +CREATE UNIQUE INDEX IF NOT EXISTS relayapi_queue_queue_json_nid_idx + ON relayapi_queue (json_nid, server_name); +CREATE INDEX IF NOT EXISTS relayapi_queue_json_nid_idx + ON relayapi_queue (json_nid); +CREATE INDEX IF NOT EXISTS relayapi_queue_server_name_idx + ON relayapi_queue (server_name); +` + +const insertQueueEntrySQL = "" + + "INSERT INTO relayapi_queue (transaction_id, server_name, json_nid)" + + " VALUES ($1, $2, $3)" + +const deleteQueueEntriesSQL = "" + + "DELETE FROM relayapi_queue WHERE server_name = $1 AND json_nid IN ($2)" + +const selectQueueEntriesSQL = "" + + "SELECT json_nid FROM relayapi_queue" + + " WHERE server_name = $1" + + " ORDER BY json_nid" + + " LIMIT $2" + +const selectQueueEntryCountSQL = "" + + "SELECT COUNT(*) FROM relayapi_queue" + + " WHERE server_name = $1" + +type relayQueueStatements struct { + db *sql.DB + insertQueueEntryStmt *sql.Stmt + selectQueueEntriesStmt *sql.Stmt + selectQueueEntryCountStmt *sql.Stmt + // deleteQueueEntriesStmt *sql.Stmt - prepared at runtime due to variadic +} + +func NewSQLiteRelayQueueTable( + db *sql.DB, +) (s *relayQueueStatements, err error) { + s = &relayQueueStatements{ + db: db, + } + _, err = db.Exec(relayQueueSchema) + if err != nil { + return + } + + return s, sqlutil.StatementList{ + {&s.insertQueueEntryStmt, insertQueueEntrySQL}, + {&s.selectQueueEntriesStmt, selectQueueEntriesSQL}, + {&s.selectQueueEntryCountStmt, selectQueueEntryCountSQL}, + }.Prepare(db) +} + +func (s *relayQueueStatements) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEntryStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err +} + +func (s *relayQueueStatements) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + deleteSQL := strings.Replace(deleteQueueEntriesSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueEntries s.db.Prepare: %w", err) + } + + params := make([]interface{}, len(jsonNIDs)+1) + params[0] = serverName + for k, v := range jsonNIDs { + params[k+1] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, params...) + return err +} + +func (s *relayQueueStatements) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueEntriesStmt) + rows, err := stmt.QueryContext(ctx, serverName, limit) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []int64 + for rows.Next() { + var nid int64 + if err = rows.Scan(&nid); err != nil { + return nil, err + } + result = append(result, nid) + } + + return result, rows.Err() +} + +func (s *relayQueueStatements) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueEntryCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} diff --git a/relayapi/storage/sqlite3/storage.go b/relayapi/storage/sqlite3/storage.go new file mode 100644 index 000000000..3ed4ab046 --- /dev/null +++ b/relayapi/storage/sqlite3/storage.go @@ -0,0 +1,64 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "database/sql" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/shared" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// Database stores information needed by the federation sender +type Database struct { + shared.Database + db *sql.DB + writer sqlutil.Writer +} + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (*Database, error) { + var d Database + var err error + if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil { + return nil, err + } + queue, err := NewSQLiteRelayQueueTable(d.db) + if err != nil { + return nil, err + } + queueJSON, err := NewSQLiteRelayQueueJSONTable(d.db) + if err != nil { + return nil, err + } + d.Database = shared.Database{ + DB: d.db, + IsLocalServerName: isLocalServerName, + Cache: cache, + Writer: d.writer, + RelayQueue: queue, + RelayQueueJSON: queueJSON, + } + return &d, nil +} diff --git a/relayapi/storage/storage.go b/relayapi/storage/storage.go new file mode 100644 index 000000000..16ecbcfb7 --- /dev/null +++ b/relayapi/storage/storage.go @@ -0,0 +1,46 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !wasm +// +build !wasm + +package storage + +import ( + "fmt" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/setup/base" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +// NewDatabase opens a new database +func NewDatabase( + base *base.BaseDendrite, + dbProperties *config.DatabaseOptions, + cache caching.FederationCache, + isLocalServerName func(gomatrixserverlib.ServerName) bool, +) (Database, error) { + switch { + case dbProperties.ConnectionString.IsSQLite(): + return sqlite3.NewDatabase(base, dbProperties, cache, isLocalServerName) + case dbProperties.ConnectionString.IsPostgres(): + return postgres.NewDatabase(base, dbProperties, cache, isLocalServerName) + default: + return nil, fmt.Errorf("unexpected database type") + } +} diff --git a/relayapi/storage/tables/interface.go b/relayapi/storage/tables/interface.go new file mode 100644 index 000000000..9056a5678 --- /dev/null +++ b/relayapi/storage/tables/interface.go @@ -0,0 +1,66 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables + +import ( + "context" + "database/sql" + + "github.com/matrix-org/gomatrixserverlib" +) + +// RelayQueue table contains a mapping of server name to transaction id and the corresponding nid. +// These are the transactions being stored for the given destination server. +// The nids correspond to entries in the RelayQueueJSON table. +type RelayQueue interface { + // Adds a new transaction_id: server_name mapping with associated json table nid to the table. + // Will ensure only one transaction id is present for each server_name: nid mapping. + // Adding duplicates will silently do nothing. + InsertQueueEntry(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error + + // Removes multiple entries from the table corresponding the the list of nids provided. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error + + // Get a list of nids associated with the provided server name. + // Returns up to `limit` nids. The entries are returned oldest first. + // Will return an empty list if no matches were found. + SelectQueueEntries(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error) + + // Get the number of entries in the table associated with the provided server name. + // If there are no matching rows, a count of 0 is returned with err set to nil. + SelectQueueEntryCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error) +} + +// RelayQueueJSON table contains a map of nid to the raw transaction json. +type RelayQueueJSON interface { + // Adds a new transaction to the table. + // Adding a duplicate transaction will result in a new row being added and a new unique nid. + // return: unique nid representing this entry. + InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error) + + // Removes multiple nids from the table. + // If any of the provided nids don't match a row in the table, that deletion is considered + // successful. + DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error + + // Get the transaction json corresponding to the provided nids. + // Will return a partial result containing any matching nid from the table. + // Will return an empty map if no matches were found. + // It is the caller's responsibility to deal with the results appropriately. + // return: map indexed by nid of each matching transaction json. + SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error) +} diff --git a/relayapi/storage/tables/relay_queue_json_table_test.go b/relayapi/storage/tables/relay_queue_json_table_test.go new file mode 100644 index 000000000..efa3363e5 --- /dev/null +++ b/relayapi/storage/tables/relay_queue_json_table_test.go @@ -0,0 +1,173 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +const ( + testOrigin = gomatrixserverlib.ServerName("kaer.morhen") +) + +func mustCreateTransaction() gomatrixserverlib.Transaction { + txn := gomatrixserverlib.Transaction{} + txn.PDUs = []json.RawMessage{ + []byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`), + } + txn.Origin = testOrigin + + return txn +} + +type RelayQueueJSONDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueueJSON +} + +func mustCreateQueueJSONTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueJSONDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueueJSON + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueJSONTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueJSONTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueJSONDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + _, err = db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + var storedJSON map[int64][]byte + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 1, len(storedJSON)) + + var storedTx gomatrixserverlib.Transaction + json.Unmarshal(storedJSON[1], &storedTx) + + assert.Equal(t, transaction, storedTx) + }) +} + +func TestShouldDeleteTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueJSONTable(t, dbType) + defer close() + + transaction := mustCreateTransaction() + tx, err := json.Marshal(transaction) + if err != nil { + t.Fatalf("Invalid transaction: %s", err.Error()) + } + + nid, err := db.Table.InsertQueueJSON(ctx, nil, string(tx)) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + storedJSON := map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + storedJSON = map[int64][]byte{} + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + storedJSON, err = db.Table.SelectQueueJSON(ctx, txn, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, 0, len(storedJSON)) + }) +} diff --git a/relayapi/storage/tables/relay_queue_table_test.go b/relayapi/storage/tables/relay_queue_table_test.go new file mode 100644 index 000000000..99f9922c0 --- /dev/null +++ b/relayapi/storage/tables/relay_queue_table_test.go @@ -0,0 +1,229 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tables_test + +import ( + "context" + "database/sql" + "fmt" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/relayapi/storage/postgres" + "github.com/matrix-org/dendrite/relayapi/storage/sqlite3" + "github.com/matrix-org/dendrite/relayapi/storage/tables" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/dendrite/test" + "github.com/matrix-org/gomatrixserverlib" + "github.com/stretchr/testify/assert" +) + +type RelayQueueDatabase struct { + DB *sql.DB + Writer sqlutil.Writer + Table tables.RelayQueue +} + +func mustCreateQueueTable( + t *testing.T, + dbType test.DBType, +) (database RelayQueueDatabase, close func()) { + t.Helper() + connStr, close := test.PrepareDBConnectionString(t, dbType) + db, err := sqlutil.Open(&config.DatabaseOptions{ + ConnectionString: config.DataSource(connStr), + }, sqlutil.NewExclusiveWriter()) + assert.NoError(t, err) + var tab tables.RelayQueue + switch dbType { + case test.DBTypePostgres: + tab, err = postgres.NewPostgresRelayQueueTable(db) + assert.NoError(t, err) + case test.DBTypeSQLite: + tab, err = sqlite3.NewSQLiteRelayQueueTable(db) + assert.NoError(t, err) + } + assert.NoError(t, err) + + database = RelayQueueDatabase{ + DB: db, + Writer: sqlutil.NewDummyWriter(), + Table: tab, + } + return database, close +} + +func TestShoudInsertQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + }) +} + +func TestShouldRetrieveInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, nid, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + }) +} + +func TestShouldRetrieveOldestInsertedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(2) + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName = gomatrixserverlib.ServerName("domain") + oldestNID := int64(1) + err = db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, oldestNID) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + retrievedNids, err := db.Table.SelectQueueEntries(ctx, nil, serverName, 1) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, 1, len(retrievedNids)) + + retrievedNids, err = db.Table.SelectQueueEntries(ctx, nil, serverName, 10) + if err != nil { + t.Fatalf("Failed retrieving transaction: %s", err.Error()) + } + + assert.Equal(t, oldestNID, retrievedNids[0]) + assert.Equal(t, nid, retrievedNids[1]) + assert.Equal(t, 2, len(retrievedNids)) + }) +} + +func TestShouldDeleteQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(0), count) + }) +} + +func TestShouldDeleteOnlySpecifiedQueueTransaction(t *testing.T) { + ctx := context.Background() + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := mustCreateQueueTable(t, dbType) + defer close() + + transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano())) + serverName := gomatrixserverlib.ServerName("domain") + nid := int64(1) + transactionID2 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d2", time.Now().UnixNano())) + serverName2 := gomatrixserverlib.ServerName("domain2") + nid2 := int64(2) + transactionID3 := gomatrixserverlib.TransactionID(fmt.Sprintf("%d3", time.Now().UnixNano())) + + err := db.Table.InsertQueueEntry(ctx, nil, transactionID, serverName, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID2, serverName2, nid) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + err = db.Table.InsertQueueEntry(ctx, nil, transactionID3, serverName, nid2) + if err != nil { + t.Fatalf("Failed inserting transaction: %s", err.Error()) + } + + _ = db.Writer.Do(db.DB, nil, func(txn *sql.Tx) error { + err = db.Table.DeleteQueueEntries(ctx, txn, serverName, []int64{nid}) + return err + }) + if err != nil { + t.Fatalf("Failed deleting transaction: %s", err.Error()) + } + + count, err := db.Table.SelectQueueEntryCount(ctx, nil, serverName) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + + count, err = db.Table.SelectQueueEntryCount(ctx, nil, serverName2) + if err != nil { + t.Fatalf("Failed retrieving transaction count: %s", err.Error()) + } + assert.Equal(t, int64(1), count) + }) +} diff --git a/setup/base/base.go b/setup/base/base.go index ff38209fb..de8f81517 100644 --- a/setup/base/base.go +++ b/setup/base/base.go @@ -595,6 +595,12 @@ func (b *BaseDendrite) WaitForShutdown() { logrus.Warnf("failed to flush all Sentry events!") } } + if b.Fulltext != nil { + err := b.Fulltext.Close() + if err != nil { + logrus.Warnf("failed to close full text search!") + } + } logrus.Warnf("Dendrite is exiting now") } diff --git a/setup/config/config.go b/setup/config/config.go index 41d2b6674..2b38cd512 100644 --- a/setup/config/config.go +++ b/setup/config/config.go @@ -62,6 +62,7 @@ type Dendrite struct { RoomServer RoomServer `yaml:"room_server"` SyncAPI SyncAPI `yaml:"sync_api"` UserAPI UserAPI `yaml:"user_api"` + RelayAPI RelayAPI `yaml:"relay_api"` MSCs MSCs `yaml:"mscs"` @@ -349,6 +350,7 @@ func (c *Dendrite) Defaults(opts DefaultOpts) { c.SyncAPI.Defaults(opts) c.UserAPI.Defaults(opts) c.AppServiceAPI.Defaults(opts) + c.RelayAPI.Defaults(opts) c.MSCs.Defaults(opts) c.Wiring() } @@ -361,7 +363,7 @@ func (c *Dendrite) Verify(configErrs *ConfigErrors, isMonolith bool) { &c.Global, &c.ClientAPI, &c.FederationAPI, &c.KeyServer, &c.MediaAPI, &c.RoomServer, &c.SyncAPI, &c.UserAPI, - &c.AppServiceAPI, &c.MSCs, + &c.AppServiceAPI, &c.RelayAPI, &c.MSCs, } { c.Verify(configErrs, isMonolith) } @@ -377,6 +379,7 @@ func (c *Dendrite) Wiring() { c.SyncAPI.Matrix = &c.Global c.UserAPI.Matrix = &c.Global c.AppServiceAPI.Matrix = &c.Global + c.RelayAPI.Matrix = &c.Global c.MSCs.Matrix = &c.Global c.ClientAPI.Derived = &c.Derived diff --git a/setup/config/config_federationapi.go b/setup/config/config_federationapi.go index 0f853865f..6c198018d 100644 --- a/setup/config/config_federationapi.go +++ b/setup/config/config_federationapi.go @@ -18,6 +18,12 @@ type FederationAPI struct { // The default value is 16 if not specified, which is circa 18 hours. FederationMaxRetries uint32 `yaml:"send_max_retries"` + // P2P Feature: How many consecutive failures that we should tolerate when + // sending federation requests to a specific server until we should assume they + // are offline. If we assume they are offline then we will attempt to send + // messages to their relay server if we know of one that is appropriate. + P2PFederationRetriesUntilAssumedOffline uint32 `yaml:"p2p_retries_until_assumed_offline"` + // FederationDisableTLSValidation disables the validation of X.509 TLS certs // on remote federation endpoints. This is not recommended in production! DisableTLSValidation bool `yaml:"disable_tls_validation"` @@ -43,6 +49,7 @@ func (c *FederationAPI) Defaults(opts DefaultOpts) { c.Database.Defaults(10) } c.FederationMaxRetries = 16 + c.P2PFederationRetriesUntilAssumedOffline = 2 c.DisableTLSValidation = false c.DisableHTTPKeepalives = false if opts.Generate { diff --git a/setup/config/config_relayapi.go b/setup/config/config_relayapi.go new file mode 100644 index 000000000..5a6b093d4 --- /dev/null +++ b/setup/config/config_relayapi.go @@ -0,0 +1,52 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +type RelayAPI struct { + Matrix *Global `yaml:"-"` + + InternalAPI InternalAPIOptions `yaml:"internal_api,omitempty"` + ExternalAPI ExternalAPIOptions `yaml:"external_api,omitempty"` + + // The database stores information used by the relay queue to + // forward transactions to remote servers. + Database DatabaseOptions `yaml:"database,omitempty"` +} + +func (c *RelayAPI) Defaults(opts DefaultOpts) { + if !opts.Monolithic { + c.InternalAPI.Listen = "http://localhost:7775" + c.InternalAPI.Connect = "http://localhost:7775" + c.ExternalAPI.Listen = "http://[::]:8075" + c.Database.Defaults(10) + } + if opts.Generate { + if !opts.Monolithic { + c.Database.ConnectionString = "file:relayapi.db" + } + } +} + +func (c *RelayAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { + if isMonolith { // polylith required configs below + return + } + if c.Matrix.DatabaseOptions.ConnectionString == "" { + checkNotEmpty(configErrs, "relay_api.database.connection_string", string(c.Database.ConnectionString)) + } + checkURL(configErrs, "relay_api.external_api.listen", string(c.ExternalAPI.Listen)) + checkURL(configErrs, "relay_api.internal_api.listen", string(c.InternalAPI.Listen)) + checkURL(configErrs, "relay_api.internal_api.connect", string(c.InternalAPI.Connect)) +} diff --git a/setup/config/config_test.go b/setup/config/config_test.go index 3408bf46d..ffbf4c3c5 100644 --- a/setup/config/config_test.go +++ b/setup/config/config_test.go @@ -20,11 +20,12 @@ import ( "testing" "github.com/matrix-org/gomatrixserverlib" + "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" ) func TestLoadConfigRelative(t *testing.T) { - _, err := loadConfig("/my/config/dir", []byte(testConfig), + cfg, err := loadConfig("/my/config/dir", []byte(testConfig), mockReadFile{ "/my/config/dir/matrix_key.pem": testKey, "/my/config/dir/tls_cert.pem": testCert, @@ -34,6 +35,15 @@ func TestLoadConfigRelative(t *testing.T) { if err != nil { t.Error("failed to load config:", err) } + + configErrors := &ConfigErrors{} + cfg.Verify(configErrors, false) + if len(*configErrors) > 0 { + for _, err := range *configErrors { + logrus.Errorf("Configuration error: %s", err) + } + t.Error("configuration verification failed") + } } const testConfig = ` @@ -68,6 +78,8 @@ global: display_name: "Server alerts" avatar: "" room_name: "Server Alerts" + jetstream: + addresses: ["test"] app_service_api: internal_api: listen: http://localhost:7777 @@ -84,7 +96,7 @@ client_api: connect: http://localhost:7771 external_api: listen: http://[::]:8071 - registration_disabled: false + registration_disabled: true registration_shared_secret: "" enable_registration_captcha: false recaptcha_public_key: "" @@ -112,6 +124,8 @@ federation_api: connect: http://localhost:7772 external_api: listen: http://[::]:8072 + database: + connection_string: file:federationapi.db key_server: internal_api: listen: http://localhost:7779 @@ -194,6 +208,17 @@ user_api: max_open_conns: 100 max_idle_conns: 2 conn_max_lifetime: -1 +relay_api: + internal_api: + listen: http://localhost:7775 + connect: http://localhost:7775 + external_api: + listen: http://[::]:8075 + database: + connection_string: file:relayapi.db +mscs: + database: + connection_string: file:mscs.db tracing: enabled: false jaeger: diff --git a/setup/monolith.go b/setup/monolith.go index 41a897024..5bbe4019e 100644 --- a/setup/monolith.go +++ b/setup/monolith.go @@ -23,6 +23,8 @@ import ( "github.com/matrix-org/dendrite/internal/transactions" keyAPI "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/mediaapi" + "github.com/matrix-org/dendrite/relayapi" + relayAPI "github.com/matrix-org/dendrite/relayapi/api" roomserverAPI "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/setup/base" "github.com/matrix-org/dendrite/setup/config" @@ -44,6 +46,7 @@ type Monolith struct { RoomserverAPI roomserverAPI.RoomserverInternalAPI UserAPI userapi.UserInternalAPI KeyAPI keyAPI.KeyInternalAPI + RelayAPI relayAPI.RelayInternalAPI // Optional ExtPublicRoomsProvider api.ExtraPublicRoomsProvider @@ -71,4 +74,8 @@ func (m *Monolith) AddAllPublicRoutes(base *base.BaseDendrite) { syncapi.AddPublicRoutes( base, m.UserAPI, m.RoomserverAPI, m.KeyAPI, ) + + if m.RelayAPI != nil { + relayapi.AddPublicRoutes(base, m.KeyRing, m.RelayAPI) + } } diff --git a/test/db.go b/test/db.go index 54ded6adb..d2f405d49 100644 --- a/test/db.go +++ b/test/db.go @@ -101,7 +101,6 @@ func currentUser() string { // Returns the connection string to use and a close function which must be called when the test finishes. // Calling this function twice will return the same database, which will have data from previous tests // unless close() is called. -// TODO: namespace for concurrent package tests func PrepareDBConnectionString(t *testing.T, dbType DBType) (connStr string, close func()) { if dbType == DBTypeSQLite { // this will be made in the t.TempDir, which is unique per test diff --git a/test/memory_federation_db.go b/test/memory_federation_db.go new file mode 100644 index 000000000..cc9e1e8fd --- /dev/null +++ b/test/memory_federation_db.go @@ -0,0 +1,488 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "encoding/json" + "errors" + "sync" + "time" + + "github.com/matrix-org/dendrite/federationapi/storage/shared/receipt" + "github.com/matrix-org/dendrite/federationapi/types" + "github.com/matrix-org/gomatrixserverlib" +) + +var nidMutex sync.Mutex +var nid = int64(0) + +type InMemoryFederationDatabase struct { + dbMutex sync.Mutex + pendingPDUServers map[gomatrixserverlib.ServerName]struct{} + pendingEDUServers map[gomatrixserverlib.ServerName]struct{} + blacklistedServers map[gomatrixserverlib.ServerName]struct{} + assumedOffline map[gomatrixserverlib.ServerName]struct{} + pendingPDUs map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent + pendingEDUs map[*receipt.Receipt]*gomatrixserverlib.EDU + associatedPDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + associatedEDUs map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{} + relayServers map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName +} + +func NewInMemoryFederationDatabase() *InMemoryFederationDatabase { + return &InMemoryFederationDatabase{ + pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}), + blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}), + assumedOffline: make(map[gomatrixserverlib.ServerName]struct{}), + pendingPDUs: make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent), + pendingEDUs: make(map[*receipt.Receipt]*gomatrixserverlib.EDU), + associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*receipt.Receipt]struct{}), + relayServers: make(map[gomatrixserverlib.ServerName][]gomatrixserverlib.ServerName), + } +} + +func (d *InMemoryFederationDatabase) StoreJSON( + ctx context.Context, + js string, +) (*receipt.Receipt, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal([]byte(js), &event); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingPDUs[&newReceipt] = &event + return &newReceipt, nil + } + + var edu gomatrixserverlib.EDU + if err := json.Unmarshal([]byte(js), &edu); err == nil { + nidMutex.Lock() + defer nidMutex.Unlock() + nid++ + newReceipt := receipt.NewReceipt(nid) + d.pendingEDUs[&newReceipt] = &edu + return &newReceipt, nil + } + + return nil, errors.New("Failed to determine type of json to store") +} + +func (d *InMemoryFederationDatabase) GetPendingPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + pduCount := 0 + pdus = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent) + if receipts, ok := d.associatedPDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingPDUs[dbReceipt]; ok { + pdus[dbReceipt] = event + pduCount++ + if pduCount == limit { + break + } + } + } + } + return pdus, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + limit int, +) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + eduCount := 0 + edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU) + if receipts, ok := d.associatedEDUs[serverName]; ok { + for dbReceipt := range receipts { + if event, ok := d.pendingEDUs[dbReceipt]; ok { + edus[dbReceipt] = event + eduCount++ + if eduCount == limit { + break + } + } + } + } + return edus, nil +} + +func (d *InMemoryFederationDatabase) AssociatePDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingPDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedPDUs[destination]; !ok { + d.associatedPDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedPDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("PDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) AssociateEDUWithDestinations( + ctx context.Context, + destinations map[gomatrixserverlib.ServerName]struct{}, + dbReceipt *receipt.Receipt, + eduType string, + expireEDUTypes map[string]time.Duration, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if _, ok := d.pendingEDUs[dbReceipt]; ok { + for destination := range destinations { + if _, ok := d.associatedEDUs[destination]; !ok { + d.associatedEDUs[destination] = make(map[*receipt.Receipt]struct{}) + } + d.associatedEDUs[destination][dbReceipt] = struct{}{} + } + + return nil + } else { + return errors.New("EDU doesn't exist") + } +} + +func (d *InMemoryFederationDatabase) CleanPDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if pdus, ok := d.associatedPDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(pdus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) CleanEDUs( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + receipts []*receipt.Receipt, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if edus, ok := d.associatedEDUs[serverName]; ok { + for _, dbReceipt := range receipts { + delete(edus, dbReceipt) + } + } + + return nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if pdus, ok := d.associatedPDUs[serverName]; ok { + count = int64(len(pdus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + var count int64 + if edus, ok := d.associatedEDUs[serverName]; ok { + count = int64(len(edus)) + } + return count, nil +} + +func (d *InMemoryFederationDatabase) GetPendingPDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingPDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) GetPendingEDUServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + servers := []gomatrixserverlib.ServerName{} + for server := range d.pendingEDUServers { + servers = append(servers, server) + } + return servers, nil +} + +func (d *InMemoryFederationDatabase) AddServerToBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerFromBlacklist( + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.blacklistedServers, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersFromBlacklist() error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerBlacklisted( + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + isBlacklisted := false + if _, ok := d.blacklistedServers[serverName]; ok { + isBlacklisted = true + } + + return isBlacklisted, nil +} + +func (d *InMemoryFederationDatabase) SetServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline[serverName] = struct{}{} + return nil +} + +func (d *InMemoryFederationDatabase) RemoveServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + delete(d.assumedOffline, serverName) + return nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffine( + ctx context.Context, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + d.assumedOffline = make(map[gomatrixserverlib.ServerName]struct{}) + return nil +} + +func (d *InMemoryFederationDatabase) IsServerAssumedOffline( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (bool, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + assumedOffline := false + if _, ok := d.assumedOffline[serverName]; ok { + assumedOffline = true + } + + return assumedOffline, nil +} + +func (d *InMemoryFederationDatabase) P2PGetRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) ([]gomatrixserverlib.ServerName, error) { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + knownRelayServers := []gomatrixserverlib.ServerName{} + if relayServers, ok := d.relayServers[serverName]; ok { + knownRelayServers = relayServers + } + + return knownRelayServers, nil +} + +func (d *InMemoryFederationDatabase) P2PAddRelayServersForServer( + ctx context.Context, + serverName gomatrixserverlib.ServerName, + relayServers []gomatrixserverlib.ServerName, +) error { + d.dbMutex.Lock() + defer d.dbMutex.Unlock() + + if knownRelayServers, ok := d.relayServers[serverName]; ok { + for _, relayServer := range relayServers { + alreadyKnown := false + for _, knownRelayServer := range knownRelayServers { + if relayServer == knownRelayServer { + alreadyKnown = true + } + } + if !alreadyKnown { + d.relayServers[serverName] = append(d.relayServers[serverName], relayServer) + } + } + } else { + d.relayServers[serverName] = relayServers + } + + return nil +} + +func (d *InMemoryFederationDatabase) FetchKeys(ctx context.Context, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) FetcherName() string { + return "" +} + +func (d *InMemoryFederationDatabase) StoreKeys(ctx context.Context, results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult) error { + return nil +} + +func (d *InMemoryFederationDatabase) UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) RemoveAllServersAssumedOffline(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error { + return nil +} + +func (d *InMemoryFederationDatabase) P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error { + return nil +} + +func (d *InMemoryFederationDatabase) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error { + return nil +} + +func (d *InMemoryFederationDatabase) GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) { + return nil, nil +} + +func (d *InMemoryFederationDatabase) DeleteExpiredEDUs(ctx context.Context) error { + return nil +} + +func (d *InMemoryFederationDatabase) PurgeRoom(ctx context.Context, roomID string) error { + return nil +} diff --git a/test/memory_relay_db.go b/test/memory_relay_db.go new file mode 100644 index 000000000..db93919df --- /dev/null +++ b/test/memory_relay_db.go @@ -0,0 +1,140 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package test + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + + "github.com/matrix-org/gomatrixserverlib" +) + +type InMemoryRelayDatabase struct { + nid int64 + nidMutex sync.Mutex + transactions map[int64]json.RawMessage + associations map[gomatrixserverlib.ServerName][]int64 +} + +func NewInMemoryRelayDatabase() *InMemoryRelayDatabase { + return &InMemoryRelayDatabase{ + nid: 1, + nidMutex: sync.Mutex{}, + transactions: make(map[int64]json.RawMessage), + associations: make(map[gomatrixserverlib.ServerName][]int64), + } +} + +func (d *InMemoryRelayDatabase) InsertQueueEntry( + ctx context.Context, + txn *sql.Tx, + transactionID gomatrixserverlib.TransactionID, + serverName gomatrixserverlib.ServerName, + nid int64, +) error { + if _, ok := d.associations[serverName]; !ok { + d.associations[serverName] = []int64{} + } + d.associations[serverName] = append(d.associations[serverName], nid) + return nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueEntries( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + jsonNIDs []int64, +) error { + for _, nid := range jsonNIDs { + for index, associatedNID := range d.associations[serverName] { + if associatedNID == nid { + d.associations[serverName] = append(d.associations[serverName][:index], d.associations[serverName][index+1:]...) + } + } + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntries( + ctx context.Context, + txn *sql.Tx, serverName gomatrixserverlib.ServerName, + limit int, +) ([]int64, error) { + results := []int64{} + resultCount := limit + if limit > len(d.associations[serverName]) { + resultCount = len(d.associations[serverName]) + } + if resultCount > 0 { + for i := 0; i < resultCount; i++ { + results = append(results, d.associations[serverName][i]) + } + } + + return results, nil +} + +func (d *InMemoryRelayDatabase) SelectQueueEntryCount( + ctx context.Context, + txn *sql.Tx, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return int64(len(d.associations[serverName])), nil +} + +func (d *InMemoryRelayDatabase) InsertQueueJSON( + ctx context.Context, + txn *sql.Tx, + json string, +) (int64, error) { + d.nidMutex.Lock() + defer d.nidMutex.Unlock() + + nid := d.nid + d.transactions[nid] = []byte(json) + d.nid++ + + return nid, nil +} + +func (d *InMemoryRelayDatabase) DeleteQueueJSON( + ctx context.Context, + txn *sql.Tx, + nids []int64, +) error { + for _, nid := range nids { + delete(d.transactions, nid) + } + + return nil +} + +func (d *InMemoryRelayDatabase) SelectQueueJSON( + ctx context.Context, + txn *sql.Tx, + jsonNIDs []int64, +) (map[int64][]byte, error) { + result := make(map[int64][]byte) + for _, nid := range jsonNIDs { + if transaction, ok := d.transactions[nid]; ok { + result[nid] = transaction + } + } + + return result, nil +} diff --git a/test/testrig/base.go b/test/testrig/base.go index 9773da223..dfc0d8aaf 100644 --- a/test/testrig/base.go +++ b/test/testrig/base.go @@ -67,9 +67,10 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f case test.DBTypeSQLite: cfg.Defaults(config.DefaultOpts{ Generate: true, - Monolithic: false, // because we need a database per component + Monolithic: true, }) cfg.Global.ServerName = "test" + // use a distinct prefix else concurrent postgres/sqlite runs will clash since NATS will use // the file system event with InMemory=true :( cfg.Global.JetStream.TopicPrefix = fmt.Sprintf("Test_%d_", dbType) @@ -83,6 +84,7 @@ func CreateBaseDendrite(t *testing.T, dbType test.DBType) (*base.BaseDendrite, f cfg.RoomServer.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "roomserver.db")) cfg.SyncAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "syncapi.db")) cfg.UserAPI.AccountDatabase.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "userapi.db")) + cfg.RelayAPI.Database.ConnectionString = config.DataSource(filepath.Join("file://", tempDir, "relayapi.db")) base := base.NewBaseDendrite(&cfg, "Test", base.DisableMetrics) return base, func() { From ace44458b25768099f7b86663f2bb45ddf0d39c9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Thu, 26 Jan 2023 08:25:39 +0100 Subject: [PATCH 14/14] Bump commonmarker from 0.23.6 to 0.23.7 in /docs (#2952) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [commonmarker](https://github.com/gjtorikian/commonmarker) from 0.23.6 to 0.23.7.
Release notes

Sourced from commonmarker's releases.

v0.23.7

What's Changed

Full Changelog: https://github.com/gjtorikian/commonmarker/compare/v0.23.6...v0.23.7

v0.23.7.pre1

What's Changed

Full Changelog: https://github.com/gjtorikian/commonmarker/compare/v0.23.6...v0.23.7.pre1

Changelog

Sourced from commonmarker's changelog.

Changelog

v1.0.0.pre6 (2023-01-09)

Full Changelog

Closed issues:

  • Cargo.lock prevents Ruby 3.2.0 from installing commonmarker v1.0.0.pre4 #211

Merged pull requests:

  • always use rb_sys (don't use Ruby's emerging cargo tooling where available) #213 (kivikakk)

v1.0.0.pre5 (2023-01-08)

Full Changelog

Merged pull requests:

v1.0.0.pre4 (2022-12-28)

Full Changelog

Closed issues:

  • Will the cmark-gfm branch continue to be maintained for awhile? #207

Merged pull requests:

v1.0.0.pre3 (2022-11-30)

Full Changelog

Closed issues:

  • Code block incorrectly parsed in commonmarker 1.0.0.pre #202

Merged pull requests:

... (truncated)

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=commonmarker&package-manager=bundler&previous-version=0.23.6&new-version=0.23.7)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself) - `@dependabot use these labels` will set the current labels as the default for future PRs for this repo and language - `@dependabot use these reviewers` will set the current reviewers as the default for future PRs for this repo and language - `@dependabot use these assignees` will set the current assignees as the default for future PRs for this repo and language - `@dependabot use this milestone` will set the current milestone as the default for future PRs for this repo and language You can disable automated security fix PRs for this repo from the [Security Alerts page](https://github.com/matrix-org/dendrite/network/alerts).
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/Gemfile.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Gemfile.lock b/docs/Gemfile.lock index 509a8cbcf..5d79365f8 100644 --- a/docs/Gemfile.lock +++ b/docs/Gemfile.lock @@ -14,7 +14,7 @@ GEM execjs coffee-script-source (1.11.1) colorator (1.1.0) - commonmarker (0.23.6) + commonmarker (0.23.7) concurrent-ruby (1.1.10) dnsruby (1.61.9) simpleidn (~> 0.1)