From f008173a5ade76ac753040068851ab1ab0de7dfd Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 25 Jan 2022 14:13:52 +0000 Subject: [PATCH] Some context refactoring --- roomserver/internal/input/input.go | 7 +------ roomserver/internal/input/input_events.go | 13 ++++++++++++- roomserver/internal/input/input_missing.go | 22 +++++++++++++++++----- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index c6abbae1e..c29fa930c 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -43,9 +43,6 @@ var keyContentFields = map[string]string{ "m.room.member": "membership", } -// TODO: Does this value make sense? -const MaximumProcessingTime = time.Minute * 2 - type Inputer struct { DB storage.Database JetStream nats.JetStreamContext @@ -85,10 +82,8 @@ func (r *Inputer) Start() error { roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Inc() r.workerForRoom(roomID).Act(nil, func() { _ = msg.InProgress() // resets the acknowledgement wait timer - ctx, cancel := context.WithTimeout(context.Background(), MaximumProcessingTime) - defer cancel() defer roomserverInputBackpressure.With(prometheus.Labels{"room_id": roomID}).Dec() - if err := r.processRoomEvent(ctx, &inputRoomEvent); err != nil { + if err := r.processRoomEvent(context.Background(), &inputRoomEvent); err != nil { sentry.CaptureException(err) } else { hooks.Run(hooks.KindNewEventPersisted, inputRoomEvent.Event) diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 7fcc9cd59..adde852b7 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -39,6 +39,9 @@ func init() { prometheus.MustRegister(processRoomEventDuration) } +// TODO: Does this value make sense? +const MaximumProcessingTime = time.Minute * 2 + var processRoomEventDuration = prometheus.NewHistogramVec( prometheus.HistogramOpts{ Namespace: "dendrite", @@ -65,11 +68,19 @@ func (r *Inputer) processRoomEvent( ctx context.Context, input *api.InputRoomEvent, ) (err error) { - // Before we do anything, make sure the context hasn't expired for this pending task. select { case <-ctx.Done(): + // Before we do anything, make sure the context hasn't expired for this pending task. + // If it has then we'll give up straight away — it's probably a synchronous input + // request and the caller has already given up, but the inbox task was still queued. return context.DeadlineExceeded default: + // Otherwise we're going to wrap the context with a time limit. We'll allow no more + // than MaximumProcessingTime for everything that we need to do for this event, or + // it's possible that we could end up wedging the roomserver for a very long time. + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, MaximumProcessingTime) + defer cancel() } // Measure how long it takes to process this event. diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index aff11e3d3..50cd841c4 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -373,9 +373,9 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve var missingResp *gomatrixserverlib.RespMissingEvents for server := range t.servers { var m gomatrixserverlib.RespMissingEvents - rctx, cancel := context.WithTimeout(ctx, time.Second*30) + reqctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - if m, err = t.federation.LookupMissingEvents(rctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ + if m, err = t.federation.LookupMissingEvents(reqctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, // The latest event IDs that the sender already has. These are skipped when retrieving the previous events of latest_events. EarliestEvents: latestEvents, @@ -387,7 +387,12 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve } else { logger.WithError(err).Errorf("%s pushed us an event but %q did not respond to /get_missing_events", t.origin, server) if errors.Is(err, context.DeadlineExceeded) { - break + select { + case <-reqctx.Done(): // this server took too long + continue + case <-ctx.Done(): // the input request timed out + return nil, context.DeadlineExceeded + } } } } @@ -638,11 +643,18 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs var event *gomatrixserverlib.Event found := false for serverName := range t.servers { - txn, err := t.federation.GetEvent(ctx, serverName, missingEventID) + reqctx, cancel := context.WithTimeout(ctx, time.Second*30) + defer cancel() + txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) if err != nil || len(txn.PDUs) == 0 { util.GetLogger(ctx).WithError(err).WithField("event_id", missingEventID).Warn("Failed to get missing /event for event ID") if errors.Is(err, context.DeadlineExceeded) { - break + select { + case <-reqctx.Done(): // this server took too long + continue + case <-ctx.Done(): // the input request timed out + return nil, context.DeadlineExceeded + } } continue }