From 0df982a2e50021183fa478d99b2e463d512ff230 Mon Sep 17 00:00:00 2001 From: Till Faelligen <2353100+S7evinK@users.noreply.github.com> Date: Thu, 13 Jul 2023 14:17:48 +0200 Subject: [PATCH 1/3] Update NATS again [skip ci] --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index c954678ea..08ebb623e 100644 --- a/go.mod +++ b/go.mod @@ -26,7 +26,7 @@ require ( github.com/matrix-org/pinecone v0.11.1-0.20230210171230-8c3b24f2649a github.com/matrix-org/util v0.0.0-20221111132719-399730281e66 github.com/mattn/go-sqlite3 v1.14.17 - github.com/nats-io/nats-server/v2 v2.9.15 + github.com/nats-io/nats-server/v2 v2.9.19 github.com/nats-io/nats.go v1.27.0 github.com/neilalexander/utp v0.1.1-0.20210727203401-54ae7b1cd5f9 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 diff --git a/go.sum b/go.sum index 4ea627260..3c1c327cf 100644 --- a/go.sum +++ b/go.sum @@ -243,8 +243,8 @@ github.com/mschoch/smat v0.2.0 h1:8imxQsjDm8yFEAVBe7azKmKSgzSkZXDuKkSq9374khM= github.com/mschoch/smat v0.2.0/go.mod h1:kc9mz7DoBKqDyiRL7VZN8KvXQMWeTaVnttLRXOlotKw= github.com/nats-io/jwt/v2 v2.4.1 h1:Y35W1dgbbz2SQUYDPCaclXcuqleVmpbRa7646Jf2EX4= github.com/nats-io/jwt/v2 v2.4.1/go.mod h1:24BeQtRwxRV8ruvC4CojXlx/WQ/VjuwlYiH+vu/+ibI= -github.com/nats-io/nats-server/v2 v2.9.15 h1:MuwEJheIwpvFgqvbs20W8Ish2azcygjf4Z0liVu2I4c= -github.com/nats-io/nats-server/v2 v2.9.15/go.mod h1:QlCTy115fqpx4KSOPFIxSV7DdI6OxtZsGOL1JLdeRlE= +github.com/nats-io/nats-server/v2 v2.9.19 h1:OF9jSKZGo425C/FcVVIvNgpd36CUe7aVTTXEZRJk6kA= +github.com/nats-io/nats-server/v2 v2.9.19/go.mod h1:aTb/xtLCGKhfTFLxP591CMWfkdgBmcUUSkiSOe5A3gw= github.com/nats-io/nats.go v1.27.0 h1:3o9fsPhmoKm+yK7rekH2GtWoE+D9jFbw8N3/ayI1C00= github.com/nats-io/nats.go v1.27.0/go.mod h1:XpbWUlOElGwTYbMR7imivs7jJj9GtK7ypv321Wp6pjc= github.com/nats-io/nkeys v0.4.4 h1:xvBJ8d69TznjcQl9t6//Q5xXuVhyYiSos6RPtvQNTwA= From f12982472c71b8daf3de682c2807989ee695d2cf Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 13 Jul 2023 14:18:37 +0200 Subject: [PATCH 2/3] Tweaks around `/messages` (#3149) Try to mitigate some issues with `/messages` --- syncapi/routing/messages.go | 105 ++++++++++-------- syncapi/routing/routing.go | 5 + syncapi/storage/interface.go | 7 +- .../output_room_events_topology_table.go | 24 ++-- syncapi/storage/shared/storage_sync.go | 8 +- .../output_room_events_topology_table.go | 23 +++- syncapi/storage/storage_test.go | 38 ++++++- syncapi/storage/tables/interface.go | 6 +- syncapi/storage/tables/topology_test.go | 50 +++++---- syncapi/syncapi.go | 3 + syncapi/syncapi_test.go | 1 + 11 files changed, 182 insertions(+), 88 deletions(-) diff --git a/syncapi/routing/messages.go b/syncapi/routing/messages.go index c38716185..23a095449 100644 --- a/syncapi/routing/messages.go +++ b/syncapi/routing/messages.go @@ -53,6 +53,7 @@ type messagesReq struct { wasToProvided bool backwardOrdering bool filter *synctypes.RoomEventFilter + didBackfill bool } type messagesResp struct { @@ -251,18 +252,19 @@ func OnIncomingMessagesRequest( } // If start and end are equal, we either reached the beginning or something else - // is wrong. To avoid endless loops from clients, set end to 0 an empty string - if start == end { + // is wrong. If we have nothing to return set end to 0. + if start == end || len(clientEvents) == 0 { end = types.TopologyToken{} } util.GetLogger(req.Context()).WithFields(logrus.Fields{ - "from": from.String(), - "to": to.String(), - "limit": filter.Limit, - "backwards": backwardOrdering, - "return_start": start.String(), - "return_end": end.String(), + "request_from": from.String(), + "request_to": to.String(), + "limit": filter.Limit, + "backwards": backwardOrdering, + "response_start": start.String(), + "response_end": end.String(), + "backfilled": mReq.didBackfill, }).Info("Responding") res := messagesResp{ @@ -284,11 +286,6 @@ func OnIncomingMessagesRequest( })...) } - // If we didn't return any events, set the end to an empty string, so it will be omitted - // in the response JSON. - if len(res.Chunk) == 0 { - res.End = "" - } if fromStream != nil { res.StartStream = fromStream.String() } @@ -328,11 +325,12 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv ) { emptyToken := types.TopologyToken{} // Retrieve the events from the local database. - streamEvents, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) + streamEvents, _, end, err := r.snapshot.GetEventsInTopologicalRange(r.ctx, r.from, r.to, r.roomID, r.filter, r.backwardOrdering) if err != nil { err = fmt.Errorf("GetEventsInRange: %w", err) - return []synctypes.ClientEvent{}, emptyToken, emptyToken, err + return []synctypes.ClientEvent{}, *r.from, emptyToken, err } + end.Decrement() var events []*rstypes.HeaderedEvent util.GetLogger(r.ctx).WithFields(logrus.Fields{ @@ -346,32 +344,54 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv // on the ordering), or we've reached a backward extremity. if len(streamEvents) == 0 { if events, err = r.handleEmptyEventsSlice(); err != nil { - return []synctypes.ClientEvent{}, emptyToken, emptyToken, err + return []synctypes.ClientEvent{}, *r.from, emptyToken, err } } else { if events, err = r.handleNonEmptyEventsSlice(streamEvents); err != nil { - return []synctypes.ClientEvent{}, emptyToken, emptyToken, err + return []synctypes.ClientEvent{}, *r.from, emptyToken, err } } // If we didn't get any event, we don't need to proceed any further. if len(events) == 0 { - return []synctypes.ClientEvent{}, *r.from, *r.to, nil + return []synctypes.ClientEvent{}, *r.from, emptyToken, nil } - // Get the position of the first and the last event in the room's topology. - // This position is currently determined by the event's depth, so we could - // also use it instead of retrieving from the database. However, if we ever - // change the way topological positions are defined (as depth isn't the most - // reliable way to define it), it would be easier and less troublesome to - // only have to change it in one place, i.e. the database. - start, end, err = r.getStartEnd(events) + // Apply room history visibility filter + startTime := time.Now() + filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages") if err != nil { - return []synctypes.ClientEvent{}, *r.from, *r.to, err + return []synctypes.ClientEvent{}, *r.from, *r.to, nil + } + logrus.WithFields(logrus.Fields{ + "duration": time.Since(startTime), + "room_id": r.roomID, + "events_before": len(events), + "events_after": len(filteredEvents), + }).Debug("applied history visibility (messages)") + + // No events left after applying history visibility + if len(filteredEvents) == 0 { + return []synctypes.ClientEvent{}, *r.from, emptyToken, nil + } + + // If we backfilled in the process of getting events, we need + // to re-fetch the start/end positions + if r.didBackfill { + _, end, err = r.getStartEnd(filteredEvents) + if err != nil { + return []synctypes.ClientEvent{}, *r.from, *r.to, err + } } // Sort the events to ensure we send them in the right order. if r.backwardOrdering { + if events[len(events)-1].Type() == spec.MRoomCreate { + // NOTSPEC: We've hit the beginning of the room so there's really nowhere + // else to go. This seems to fix Element iOS from looping on /messages endlessly. + end = types.TopologyToken{} + } + // This reverses the array from old->new to new->old reversed := func(in []*rstypes.HeaderedEvent) []*rstypes.HeaderedEvent { out := make([]*rstypes.HeaderedEvent, len(in)) @@ -380,24 +400,14 @@ func (r *messagesReq) retrieveEvents(ctx context.Context, rsAPI api.SyncRoomserv } return out } - events = reversed(events) - } - if len(events) == 0 { - return []synctypes.ClientEvent{}, *r.from, *r.to, nil + filteredEvents = reversed(filteredEvents) } - // Apply room history visibility filter - startTime := time.Now() - filteredEvents, err := internal.ApplyHistoryVisibilityFilter(r.ctx, r.snapshot, r.rsAPI, events, nil, r.device.UserID, "messages") - logrus.WithFields(logrus.Fields{ - "duration": time.Since(startTime), - "room_id": r.roomID, - "events_before": len(events), - "events_after": len(filteredEvents), - }).Debug("applied history visibility (messages)") + start = *r.from + return synctypes.ToClientEvents(gomatrixserverlib.ToPDUs(filteredEvents), synctypes.FormatAll, func(roomID spec.RoomID, senderID spec.SenderID) (*spec.UserID, error) { return rsAPI.QueryUserIDForSender(ctx, roomID, senderID) - }), start, end, err + }), start, end, nil } func (r *messagesReq) getStartEnd(events []*rstypes.HeaderedEvent) (start, end types.TopologyToken, err error) { @@ -450,6 +460,7 @@ func (r *messagesReq) handleEmptyEventsSlice() ( if err != nil { return } + r.didBackfill = true } else { // If not, it means the slice was empty because we reached the room's // creation, so return an empty slice. @@ -499,7 +510,7 @@ func (r *messagesReq) handleNonEmptyEventsSlice(streamEvents []types.StreamEvent if err != nil { return } - + r.didBackfill = true // Append the PDUs to the list to send back to the client. events = append(events, pdus...) } @@ -561,15 +572,17 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] if res.HistoryVisibility == "" { res.HistoryVisibility = gomatrixserverlib.HistoryVisibilityShared } - for i := range res.Events { + events := res.Events + for i := range events { + events[i].Visibility = res.HistoryVisibility _, err = r.db.WriteEvent( context.Background(), - res.Events[i], + events[i], []*rstypes.HeaderedEvent{}, []string{}, []string{}, nil, true, - res.HistoryVisibility, + events[i].Visibility, ) if err != nil { return nil, err @@ -577,14 +590,10 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][] } // we may have got more than the requested limit so resize now - events := res.Events if len(events) > limit { // last `limit` events events = events[len(events)-limit:] } - for _, ev := range events { - ev.Visibility = res.HistoryVisibility - } return events, nil } diff --git a/syncapi/routing/routing.go b/syncapi/routing/routing.go index 8542c0b73..a837e1696 100644 --- a/syncapi/routing/routing.go +++ b/syncapi/routing/routing.go @@ -43,6 +43,7 @@ func Setup( cfg *config.SyncAPI, lazyLoadCache caching.LazyLoadCache, fts fulltext.Indexer, + rateLimits *httputil.RateLimits, ) { v1unstablemux := csMux.PathPrefix("/{apiversion:(?:v1|unstable)}/").Subrouter() v3mux := csMux.PathPrefix("/{apiversion:(?:r0|v3)}/").Subrouter() @@ -53,6 +54,10 @@ func Setup( }, httputil.WithAllowGuests())).Methods(http.MethodGet, http.MethodOptions) v3mux.Handle("/rooms/{roomID}/messages", httputil.MakeAuthAPI("room_messages", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { + // not specced, but ensure we're rate limiting requests to this endpoint + if r := rateLimits.Limit(req, device); r != nil { + return *r + } vars, err := httputil.URLDecodeMapValues(mux.Vars(req)) if err != nil { return util.ErrorResponse(err) diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 243b2592a..dca5d1a14 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -81,8 +81,11 @@ type DatabaseTransaction interface { // If no data is retrieved, returns an empty map // If there was an issue with the retrieval, returns an error GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *synctypes.EventFilter) (map[string][]string, types.StreamPosition, error) - // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. If backwardsOrdering is true, the most recent event must be first, else last. - GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *synctypes.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error) + // GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit. + // If backwardsOrdering is true, the most recent event must be first, else last. + // Returns the filtered StreamEvents on success. Returns **unfiltered** StreamEvents and ErrNoEventsForFilter if + // the provided filter removed all events, this can be used to still calculate the start/end position. (e.g for `/messages`) + GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, filter *synctypes.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, start, end types.TopologyToken, err error) // EventPositionInTopology returns the depth and stream position of the given event. EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error) // BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events. diff --git a/syncapi/storage/postgres/output_room_events_topology_table.go b/syncapi/storage/postgres/output_room_events_topology_table.go index 7140a92fc..b281f3300 100644 --- a/syncapi/storage/postgres/output_room_events_topology_table.go +++ b/syncapi/storage/postgres/output_room_events_topology_table.go @@ -48,14 +48,14 @@ const insertEventInTopologySQL = "" + " RETURNING topological_position" const selectEventIDsInRangeASCSQL = "" + - "SELECT event_id FROM syncapi_output_room_events_topology" + + "SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE room_id = $1 AND (" + "(topological_position > $2 AND topological_position < $3) OR" + "(topological_position = $4 AND stream_position >= $5)" + ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" const selectEventIDsInRangeDESCSQL = "" + - "SELECT event_id FROM syncapi_output_room_events_topology" + + "SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE room_id = $1 AND (" + "(topological_position > $2 AND topological_position < $3) OR" + "(topological_position = $4 AND stream_position <= $5)" + @@ -113,12 +113,13 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( } // SelectEventIDsInRange selects the IDs of events which positions are within a -// given range in a given room's topological order. +// given range in a given room's topological order. Returns the start/end topological tokens for +// the returned eventIDs. // Returns an empty slice if no events match the given range. func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool, -) (eventIDs []string, err error) { +) (eventIDs []string, start, end types.TopologyToken, err error) { // Decide on the selection's order according to whether chronological order // is requested or not. var stmt *sql.Stmt @@ -132,7 +133,7 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit) if err == sql.ErrNoRows { // If no event matched the request, return an empty slice. - return []string{}, nil + return []string{}, start, end, nil } else if err != nil { return } @@ -140,14 +141,23 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( // Return the IDs. var eventID string + var token types.TopologyToken + var tokens []types.TopologyToken for rows.Next() { - if err = rows.Scan(&eventID); err != nil { + if err = rows.Scan(&eventID, &token.Depth, &token.PDUPosition); err != nil { return } eventIDs = append(eventIDs, eventID) + tokens = append(tokens, token) } - return eventIDs, rows.Err() + // The values are already ordered by SQL, so we can use them as is. + if len(tokens) > 0 { + start = tokens[0] + end = tokens[len(tokens)-1] + } + + return eventIDs, start, end, rows.Err() } // SelectPositionInTopology returns the position of a given event in the diff --git a/syncapi/storage/shared/storage_sync.go b/syncapi/storage/shared/storage_sync.go index 8e79b71df..cd17fdc69 100644 --- a/syncapi/storage/shared/storage_sync.go +++ b/syncapi/storage/shared/storage_sync.go @@ -237,7 +237,7 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange( roomID string, filter *synctypes.RoomEventFilter, backwardOrdering bool, -) (events []types.StreamEvent, err error) { +) (events []types.StreamEvent, start, end types.TopologyToken, err error) { var minDepth, maxDepth, maxStreamPosForMaxDepth types.StreamPosition if backwardOrdering { // Backward ordering means the 'from' token has a higher depth than the 'to' token @@ -255,7 +255,7 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange( // Select the event IDs from the defined range. var eIDs []string - eIDs, err = d.Topology.SelectEventIDsInRange( + eIDs, start, end, err = d.Topology.SelectEventIDsInRange( ctx, d.txn, roomID, minDepth, maxDepth, maxStreamPosForMaxDepth, filter.Limit, !backwardOrdering, ) if err != nil { @@ -264,6 +264,10 @@ func (d *DatabaseTransaction) GetEventsInTopologicalRange( // Retrieve the events' contents using their IDs. events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eIDs, filter, true) + if err != nil { + return + } + return } diff --git a/syncapi/storage/sqlite3/output_room_events_topology_table.go b/syncapi/storage/sqlite3/output_room_events_topology_table.go index 68b75f5b1..614e1df9e 100644 --- a/syncapi/storage/sqlite3/output_room_events_topology_table.go +++ b/syncapi/storage/sqlite3/output_room_events_topology_table.go @@ -44,14 +44,14 @@ const insertEventInTopologySQL = "" + " ON CONFLICT DO NOTHING" const selectEventIDsInRangeASCSQL = "" + - "SELECT event_id FROM syncapi_output_room_events_topology" + + "SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE room_id = $1 AND (" + "(topological_position > $2 AND topological_position < $3) OR" + "(topological_position = $4 AND stream_position >= $5)" + ") ORDER BY topological_position ASC, stream_position ASC LIMIT $6" const selectEventIDsInRangeDESCSQL = "" + - "SELECT event_id FROM syncapi_output_room_events_topology" + + "SELECT event_id, topological_position, stream_position FROM syncapi_output_room_events_topology" + " WHERE room_id = $1 AND (" + "(topological_position > $2 AND topological_position < $3) OR" + "(topological_position = $4 AND stream_position <= $5)" + @@ -111,11 +111,15 @@ func (s *outputRoomEventsTopologyStatements) InsertEventInTopology( return types.StreamPosition(event.Depth()), err } +// SelectEventIDsInRange selects the IDs of events which positions are within a +// given range in a given room's topological order. Returns the start/end topological tokens for +// the returned eventIDs. +// Returns an empty slice if no events match the given range. func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool, -) (eventIDs []string, err error) { +) (eventIDs []string, start, end types.TopologyToken, err error) { // Decide on the selection's order according to whether chronological order // is requested or not. var stmt *sql.Stmt @@ -129,18 +133,27 @@ func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange( rows, err := stmt.QueryContext(ctx, roomID, minDepth, maxDepth, maxDepth, maxStreamPos, limit) if err == sql.ErrNoRows { // If no event matched the request, return an empty slice. - return []string{}, nil + return []string{}, start, end, nil } else if err != nil { return } // Return the IDs. var eventID string + var token types.TopologyToken + var tokens []types.TopologyToken for rows.Next() { - if err = rows.Scan(&eventID); err != nil { + if err = rows.Scan(&eventID, &token.Depth, &token.PDUPosition); err != nil { return } eventIDs = append(eventIDs, eventID) + tokens = append(tokens, token) + } + + // The values are already ordered by SQL, so we can use them as is. + if len(tokens) > 0 { + start = tokens[0] + end = tokens[len(tokens)-1] } return diff --git a/syncapi/storage/storage_test.go b/syncapi/storage/storage_test.go index f57b0d618..ce7ca3fc7 100644 --- a/syncapi/storage/storage_test.go +++ b/syncapi/storage/storage_test.go @@ -213,12 +213,48 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) { // backpaginate 5 messages starting at the latest position. filter := &synctypes.RoomEventFilter{Limit: 5} - paginatedEvents, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) + paginatedEvents, start, end, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) if err != nil { t.Fatalf("GetEventsInTopologicalRange returned an error: %s", err) } gots := snapshot.StreamEventsToEvents(context.Background(), nil, paginatedEvents, nil) test.AssertEventsEqual(t, gots, test.Reversed(events[len(events)-5:])) + assert.Equal(t, types.TopologyToken{Depth: 15, PDUPosition: 15}, start) + assert.Equal(t, types.TopologyToken{Depth: 11, PDUPosition: 11}, end) + }) + }) +} + +// The purpose of this test is to ensure that backfilling returns no start/end if a given filter removes +// all events. +func TestGetEventsInRangeWithTopologyTokenNoEventsForFilter(t *testing.T) { + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + db, close := MustCreateDatabase(t, dbType) + defer close() + alice := test.NewUser(t) + r := test.NewRoom(t, alice) + for i := 0; i < 10; i++ { + r.CreateAndInsert(t, alice, "m.room.message", map[string]interface{}{"body": fmt.Sprintf("hi %d", i)}) + } + events := r.Events() + _ = MustWriteEvents(t, db, events) + + WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) { + from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64} + t.Logf("max topo pos = %+v", from) + // head towards the beginning of time + to := types.TopologyToken{} + + // backpaginate 20 messages starting at the latest position. + notTypes := []string{spec.MRoomRedaction} + senders := []string{alice.ID} + filter := &synctypes.RoomEventFilter{Limit: 20, NotTypes: ¬Types, Senders: &senders} + paginatedEvents, start, end, err := snapshot.GetEventsInTopologicalRange(ctx, &from, &to, r.ID, filter, true) + assert.NoError(t, err) + assert.Equal(t, 0, len(paginatedEvents)) + // Even if we didn't get anything back due to the filter, we should still have start/end + assert.Equal(t, types.TopologyToken{Depth: 15, PDUPosition: 15}, start) + assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 1}, end) }) }) } diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 854292bd2..f5c66c42d 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -89,11 +89,11 @@ type Topology interface { // InsertEventInTopology inserts the given event in the room's topology, based on the event's depth. // `pos` is the stream position of this event in the events table, and is used to order events which have the same depth. InsertEventInTopology(ctx context.Context, txn *sql.Tx, event *rstypes.HeaderedEvent, pos types.StreamPosition) (topoPos types.StreamPosition, err error) - // SelectEventIDsInRange selects the IDs of events whose depths are within a given range in a given room's topological order. - // Events with `minDepth` are *exclusive*, as is the event which has exactly `minDepth`,`maxStreamPos`. + // SelectEventIDsInRange selects the IDs and the topological position of events whose depths are within a given range in a given room's topological order. + // Events with `minDepth` are *exclusive*, as is the event which has exactly `minDepth`,`maxStreamPos`. Returns the eventIDs and start/end topological tokens. // `maxStreamPos` is only used when events have the same depth as `maxDepth`, which results in events less than `maxStreamPos` being returned. // Returns an empty slice if no events match the given range. - SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error) + SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, start, end types.TopologyToken, err error) // SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to. SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error) // SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room. diff --git a/syncapi/storage/tables/topology_test.go b/syncapi/storage/tables/topology_test.go index f4f75bdf3..7691cc5f8 100644 --- a/syncapi/storage/tables/topology_test.go +++ b/syncapi/storage/tables/topology_test.go @@ -13,6 +13,7 @@ import ( "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/test" + "github.com/stretchr/testify/assert" ) func newTopologyTable(t *testing.T, dbType test.DBType) (tables.Topology, *sql.DB, func()) { @@ -60,28 +61,37 @@ func TestTopologyTable(t *testing.T) { highestPos = topoPos + 1 } // check ordering works without limit - eventIDs, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true) - if err != nil { - return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) - } + eventIDs, start, end, err := tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, true) + assert.NoError(t, err, "failed to SelectEventIDsInRange") test.AssertEventIDsEqual(t, eventIDs, events[:]) - eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false) - if err != nil { - return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) - } - test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:])) - // check ordering works with limit - eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true) - if err != nil { - return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) - } - test.AssertEventIDsEqual(t, eventIDs, events[:3]) - eventIDs, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false) - if err != nil { - return fmt.Errorf("failed to SelectEventIDsInRange: %s", err) - } - test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:])) + assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 0}, start) + assert.Equal(t, types.TopologyToken{Depth: 5, PDUPosition: 4}, end) + eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 100, false) + assert.NoError(t, err, "failed to SelectEventIDsInRange") + test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[:])) + assert.Equal(t, types.TopologyToken{Depth: 5, PDUPosition: 4}, start) + assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 0}, end) + + // check ordering works with limit + eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, true) + assert.NoError(t, err, "failed to SelectEventIDsInRange") + test.AssertEventIDsEqual(t, eventIDs, events[:3]) + assert.Equal(t, types.TopologyToken{Depth: 1, PDUPosition: 0}, start) + assert.Equal(t, types.TopologyToken{Depth: 3, PDUPosition: 2}, end) + + eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, room.ID, 0, highestPos, highestPos, 3, false) + assert.NoError(t, err, "failed to SelectEventIDsInRange") + test.AssertEventIDsEqual(t, eventIDs, test.Reversed(events[len(events)-3:])) + assert.Equal(t, types.TopologyToken{Depth: 5, PDUPosition: 4}, start) + assert.Equal(t, types.TopologyToken{Depth: 3, PDUPosition: 2}, end) + + // Check that we return no values for invalid rooms + eventIDs, start, end, err = tab.SelectEventIDsInRange(ctx, txn, "!doesnotexist:localhost", 0, highestPos, highestPos, 10, false) + assert.NoError(t, err, "failed to SelectEventIDsInRange") + assert.Equal(t, 0, len(eventIDs)) + assert.Equal(t, types.TopologyToken{}, start) + assert.Equal(t, types.TopologyToken{}, end) return nil }) if err != nil { diff --git a/syncapi/syncapi.go b/syncapi/syncapi.go index 64a4af757..af6bddc7a 100644 --- a/syncapi/syncapi.go +++ b/syncapi/syncapi.go @@ -144,8 +144,11 @@ func AddPublicRoutes( logrus.WithError(err).Panicf("failed to start receipts consumer") } + rateLimits := httputil.NewRateLimits(&dendriteCfg.ClientAPI.RateLimiting) + routing.Setup( routers.Client, requestPool, syncDB, userAPI, rsAPI, &dendriteCfg.SyncAPI, caches, fts, + rateLimits, ) } diff --git a/syncapi/syncapi_test.go b/syncapi/syncapi_test.go index 19815b79b..996b21e90 100644 --- a/syncapi/syncapi_test.go +++ b/syncapi/syncapi_test.go @@ -433,6 +433,7 @@ func testHistoryVisibility(t *testing.T, dbType test.DBType) { } cfg, processCtx, close := testrig.CreateConfig(t, dbType) + cfg.ClientAPI.RateLimiting = config.RateLimiting{Enabled: false} routers := httputil.NewRouters() cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) caches := caching.NewRistrettoCache(128*1024*1024, time.Hour, caching.DisableMetrics) From 5267cc0f54db37b8a71a4caa7148e1dff7ae27c1 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Thu, 13 Jul 2023 14:19:08 +0200 Subject: [PATCH 3/3] Optimise getting local members and membership counts (#3150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The previous version was getting **ALL** membership events (as `ClientEvents`, so going through `NewEventFromTrustedJSONWithID`) for a given room. Now we are querying only locally joined users as `ClientEvents`, which should **significantly** reduce allocations. Take for example a large room with 2k membership events, but only 1 local user - avoiding 1999 `NewEventFromTrustedJSONWithID` calls just to calculate the `roomSize` which we can also query by other means. This is also getting called for every `OutputRoomEvent` in the userAPI. Benchmark with 1 local user and 100 remote users. ``` pkg: github.com/matrix-org/dendrite/userapi/consumers cpu: 12th Gen Intel(R) Core(TM) i5-12500H │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ LocalRoomMembers-16 375.9µ ± 7% 327.6µ ± 6% -12.85% (p=0.000 n=10) │ old.txt │ new.txt │ │ B/op │ B/op vs base │ LocalRoomMembers-16 79.426Ki ± 0% 8.507Ki ± 0% -89.29% (p=0.000 n=10) │ old.txt │ new.txt │ │ allocs/op │ allocs/op vs base │ LocalRoomMembers-16 1015.0 ± 0% 277.0 ± 0% -72.71% (p=0.000 n=10) ``` --- roomserver/api/api.go | 1 + roomserver/internal/query/query.go | 14 +++++ userapi/consumers/roomserver.go | 32 +++++------ userapi/consumers/roomserver_test.go | 81 ++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 19 deletions(-) diff --git a/roomserver/api/api.go b/roomserver/api/api.go index ab56529c5..c29406a1a 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -227,6 +227,7 @@ type UserRoomserverAPI interface { QueryMembershipsForRoom(ctx context.Context, req *QueryMembershipsForRoomRequest, res *QueryMembershipsForRoomResponse) error PerformAdminEvacuateUser(ctx context.Context, userID string) (affected []string, err error) PerformJoin(ctx context.Context, req *PerformJoinRequest) (roomID string, joinedVia spec.ServerName, err error) + JoinedUserCount(ctx context.Context, roomID string) (int, error) } type FederationRoomserverAPI interface { diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 626d3c13e..39e3bd0ec 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -974,6 +974,20 @@ func (r *Queryer) LocallyJoinedUsers(ctx context.Context, roomVersion gomatrixse return joinedUsers, nil } +func (r *Queryer) JoinedUserCount(ctx context.Context, roomID string) (int, error) { + info, err := r.DB.RoomInfo(ctx, roomID) + if err != nil { + return 0, err + } + if info == nil { + return 0, nil + } + + // TODO: this can be further optimised by just using a SELECT COUNT query + nids, err := r.DB.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false) + return len(nids), err +} + // nolint:gocyclo func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, roomID spec.RoomID, senderID spec.SenderID) (string, error) { // Look up if we know anything about the room. If it doesn't exist diff --git a/userapi/consumers/roomserver.go b/userapi/consumers/roomserver.go index 9a9a407ce..1f866ef4d 100644 --- a/userapi/consumers/roomserver.go +++ b/userapi/consumers/roomserver.go @@ -405,18 +405,25 @@ func newLocalMembership(event *synctypes.ClientEvent) (*localMembership, error) // localRoomMembers fetches the current local members of a room, and // the total number of members. func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID string) ([]*localMembership, int, error) { + // Get only locally joined users to avoid unmarshalling and caching + // membership events we only use to calculate the room size. req := &rsapi.QueryMembershipsForRoomRequest{ RoomID: roomID, JoinedOnly: true, + LocalOnly: true, } var res rsapi.QueryMembershipsForRoomResponse - - // XXX: This could potentially race if the state for the event is not known yet - // e.g. the event came over federation but we do not have the full state persisted. if err := s.rsAPI.QueryMembershipsForRoom(ctx, req, &res); err != nil { return nil, 0, err } + // Since we only queried locally joined users above, + // we also need to ask the roomserver about the joined user count. + totalCount, err := s.rsAPI.JoinedUserCount(ctx, roomID) + if err != nil { + return nil, 0, err + } + var members []*localMembership for _, event := range res.JoinEvents { // Filter out invalid join events @@ -426,31 +433,18 @@ func (s *OutputRoomEventConsumer) localRoomMembers(ctx context.Context, roomID s if *event.StateKey == "" { continue } - _, serverName, err := gomatrixserverlib.SplitID('@', *event.StateKey) - if err != nil { - log.WithError(err).Error("failed to get servername from statekey") - continue - } - // Only get memberships for our server - if serverName != s.serverName { - continue - } + // We're going to trust the Query from above to really just return + // local users member, err := newLocalMembership(&event) if err != nil { log.WithError(err).Errorf("Parsing MemberContent") continue } - if member.Membership != spec.Join { - continue - } - if member.Domain != s.cfg.Matrix.ServerName { - continue - } members = append(members, member) } - return members, len(res.JoinEvents), nil + return members, totalCount, nil } // roomName returns the name in the event (if type==m.room.name), or diff --git a/userapi/consumers/roomserver_test.go b/userapi/consumers/roomserver_test.go index 4dc81e74a..49dd5b238 100644 --- a/userapi/consumers/roomserver_test.go +++ b/userapi/consumers/roomserver_test.go @@ -2,16 +2,22 @@ package consumers import ( "context" + "crypto/ed25519" "reflect" "sync" "testing" "time" + "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/matrix-org/dendrite/test/testrig" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib/spec" "github.com/stretchr/testify/assert" + "golang.org/x/crypto/bcrypt" "github.com/matrix-org/dendrite/internal/pushrules" rsapi "github.com/matrix-org/dendrite/roomserver/api" @@ -139,6 +145,42 @@ func Test_evaluatePushRules(t *testing.T) { }) } +func TestLocalRoomMembers(t *testing.T) { + alice := test.NewUser(t) + _, sk, err := ed25519.GenerateKey(nil) + assert.NoError(t, err) + bob := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk)) + charlie := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk)) + + room := test.NewRoom(t, alice) + room.CreateAndInsert(t, bob, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(bob.ID)) + room.CreateAndInsert(t, charlie, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(charlie.ID)) + + test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { + cfg, processCtx, close := testrig.CreateConfig(t, dbType) + defer close() + + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + db, err := storage.NewUserDatabase(processCtx.Context(), cm, &cfg.UserAPI.AccountDatabase, cfg.Global.ServerName, bcrypt.MinCost, 1000, 1000, "") + assert.NoError(t, err) + + err = rsapi.SendEvents(processCtx.Context(), rsAPI, rsapi.KindNew, room.Events(), "", "test", "test", nil, false) + assert.NoError(t, err) + + consumer := OutputRoomEventConsumer{db: db, rsAPI: rsAPI, serverName: "test", cfg: &cfg.UserAPI} + members, count, err := consumer.localRoomMembers(processCtx.Context(), room.ID) + assert.NoError(t, err) + assert.Equal(t, 3, count) + expectedLocalMember := &localMembership{UserID: alice.ID, Localpart: alice.Localpart, Domain: "test", MemberContent: gomatrixserverlib.MemberContent{Membership: spec.Join}} + assert.Equal(t, expectedLocalMember, members[0]) + }) + +} + func TestMessageStats(t *testing.T) { type args struct { eventType string @@ -257,3 +299,42 @@ func TestMessageStats(t *testing.T) { } }) } + +func BenchmarkLocalRoomMembers(b *testing.B) { + t := &testing.T{} + + cfg, processCtx, close := testrig.CreateConfig(t, test.DBTypePostgres) + defer close() + cm := sqlutil.NewConnectionManager(processCtx, cfg.Global.DatabaseOptions) + natsInstance := &jetstream.NATSInstance{} + caches := caching.NewRistrettoCache(8*1024*1024, time.Hour, caching.DisableMetrics) + rsAPI := roomserver.NewInternalAPI(processCtx, cfg, cm, natsInstance, caches, caching.DisableMetrics) + rsAPI.SetFederationAPI(nil, nil) + db, err := storage.NewUserDatabase(processCtx.Context(), cm, &cfg.UserAPI.AccountDatabase, cfg.Global.ServerName, bcrypt.MinCost, 1000, 1000, "") + assert.NoError(b, err) + + consumer := OutputRoomEventConsumer{db: db, rsAPI: rsAPI, serverName: "test", cfg: &cfg.UserAPI} + _, sk, err := ed25519.GenerateKey(nil) + assert.NoError(b, err) + + alice := test.NewUser(t) + room := test.NewRoom(t, alice) + + for i := 0; i < 100; i++ { + user := test.NewUser(t, test.WithSigningServer("notlocalhost", "ed25519:abc", sk)) + room.CreateAndInsert(t, user, spec.MRoomMember, map[string]string{"membership": spec.Join}, test.WithStateKey(user.ID)) + } + + err = rsapi.SendEvents(processCtx.Context(), rsAPI, rsapi.KindNew, room.Events(), "", "test", "test", nil, false) + assert.NoError(b, err) + + expectedLocalMember := &localMembership{UserID: alice.ID, Localpart: alice.Localpart, Domain: "test", MemberContent: gomatrixserverlib.MemberContent{Membership: spec.Join}} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + members, count, err := consumer.localRoomMembers(processCtx.Context(), room.ID) + assert.NoError(b, err) + assert.Equal(b, 101, count) + assert.Equal(b, expectedLocalMember, members[0]) + } +}