From 65e85e3da1ae159db800911a25263a3c2468a766 Mon Sep 17 00:00:00 2001 From: Kegan Dougal Date: Wed, 7 Apr 2021 17:46:53 +0100 Subject: [PATCH] Optimise memory usage when calling /g_m_e --- federationapi/routing/send.go | 36 +++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index b48d6c0b8..b7cf0c2d0 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -620,7 +620,9 @@ func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserv return gomatrixserverlib.Allowed(e, &authUsingState) } -func (t *txnReq) processEventWithMissingState(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) error { +func (t *txnReq) processEventWithMissingState( + ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion, +) error { // Do this with a fresh context, so that we keep working even if the // original request times out. With any luck, by the time the remote // side retries, we'll have fetched the missing state. @@ -803,6 +805,14 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix return respState, false, nil } +func (t *txnReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { + if cached, exists := t.haveEvents[ev.EventID()]; exists { + return cached + } + t.haveEvents[ev.EventID()] = ev + return ev +} + func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *gomatrixserverlib.RespState { var res api.QueryStateAfterEventsResponse err := t.rsAPI.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ @@ -810,15 +820,21 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event PrevEventIDs: []string{eventID}, }, &res) if err != nil || !res.PrevEventsExist { - util.GetLogger(ctx).WithError(err).Warnf("failed to query state after %s locally", eventID) + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally", eventID) return nil } + stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents)) for i, ev := range res.StateEvents { - t.haveEvents[ev.EventID()] = res.StateEvents[i] + // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this + // processEvent request, which is better for memory. + stateEvents[i] = t.cacheAndReturn(ev) } + // we should never access res.StateEvents again so we delete it here to make GC faster + res.StateEvents = nil + var authEvents []*gomatrixserverlib.Event missingAuthEvents := map[string]bool{} - for _, ev := range res.StateEvents { + for _, ev := range stateEvents { for _, ae := range ev.AuthEventIDs() { if aev, ok := t.haveEvents[ae]; ok { authEvents = append(authEvents, aev.Unwrap()) @@ -843,14 +859,13 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event return nil } for i := range queryRes.Events { - evID := queryRes.Events[i].EventID() - t.haveEvents[evID] = queryRes.Events[i] - authEvents = append(authEvents, queryRes.Events[i].Unwrap()) + authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) } + queryRes.Events = nil } return &gomatrixserverlib.RespState{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(res.StateEvents), + StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), AuthEvents: authEvents, } } @@ -860,8 +875,6 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event func (t *txnReq) lookupStateBeforeEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, roomID, eventID string) ( *gomatrixserverlib.RespState, error) { - util.GetLogger(ctx).Infof("lookupStateBeforeEvent %s", eventID) - // Attempt to fetch the missing state using /state_ids and /events return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion) } @@ -992,7 +1005,6 @@ Event: } } - // we processed everything! return newEvents, nil } @@ -1011,7 +1023,7 @@ func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion) ( *gomatrixserverlib.RespState, error) { - util.GetLogger(ctx).Infof("lookupMissingStateViaStateIDs %s", eventID) + util.GetLogger(ctx).WithField("room_id", roomID).Infof("lookupMissingStateViaStateIDs %s", eventID) // fetch the state event IDs at the time of the event stateIDs, err := t.federation.LookupStateIDs(ctx, t.Origin, roomID, eventID) if err != nil {