diff --git a/federationapi/routing/send.go b/federationapi/routing/send.go index ae8457dfb..18339bdae 100644 --- a/federationapi/routing/send.go +++ b/federationapi/routing/send.go @@ -33,7 +33,6 @@ import ( "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" - "go.uber.org/atomic" ) const ( @@ -88,67 +87,7 @@ func init() { ) } -type sendFIFOQueue struct { - tasks []*inputTask - count int - mutex sync.Mutex - notifs chan struct{} -} - -func newSendFIFOQueue() *sendFIFOQueue { - q := &sendFIFOQueue{ - notifs: make(chan struct{}, 1), - } - return q -} - -func (q *sendFIFOQueue) push(frame *inputTask) { - q.mutex.Lock() - defer q.mutex.Unlock() - q.tasks = append(q.tasks, frame) - q.count++ - select { - case q.notifs <- struct{}{}: - default: - } -} - -// pop returns the first item of the queue, if there is one. -// The second return value will indicate if a task was returned. -func (q *sendFIFOQueue) pop() (*inputTask, bool) { - q.mutex.Lock() - defer q.mutex.Unlock() - if q.count == 0 { - return nil, false - } - frame := q.tasks[0] - q.tasks[0] = nil - q.tasks = q.tasks[1:] - q.count-- - if q.count == 0 { - // Force a GC of the underlying array, since it might have - // grown significantly if the queue was hammered for some reason - q.tasks = nil - } - return frame, true -} - -type inputTask struct { - ctx context.Context - t *txnReq - event *gomatrixserverlib.HeaderedEvent - wg *sync.WaitGroup - err error // written back by worker, only safe to read when all tasks are done - duration time.Duration // written back by worker, only safe to read when all tasks are done -} - -type inputWorker struct { - running atomic.Bool - input *sendFIFOQueue -} - var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse -var inputWorkers sync.Map // room ID -> *inputWorker // Send implements /_matrix/federation/v1/send/{txnID} func Send( @@ -261,7 +200,6 @@ type txnReq struct { federation txnFederationClient roomsMu *internal.MutexByRoom servers federationAPI.ServersInRoomProvider - work string } // A subset of FederationClient functionality that txn requires. Useful for testing. @@ -276,9 +214,28 @@ type txnFederationClient interface { } func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { - results := make(map[string]gomatrixserverlib.PDUResult) var wg sync.WaitGroup - var tasks []*inputTask + wg.Add(1) + go func() { + defer wg.Done() + t.processEDUs(ctx) + }() + + results := make(map[string]gomatrixserverlib.PDUResult) + roomVersions := make(map[string]gomatrixserverlib.RoomVersion) + getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion { + if v, ok := roomVersions[roomID]; ok { + return v + } + verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID} + verRes := api.QueryRoomVersionForRoomResponse{} + if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { + util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) + return "" + } + roomVersions[roomID] = verRes.RoomVersion + return verRes.RoomVersion + } for _, pdu := range t.PDUs { pduCountTotal.WithLabelValues("total").Inc() @@ -291,15 +248,8 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res // failure in the PDU results continue } - verReq := api.QueryRoomVersionForRoomRequest{RoomID: header.RoomID} - verRes := api.QueryRoomVersionForRoomResponse{} - if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil { - util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID) - // We don't know the event ID at this point so we can't return the - // failure in the PDU results - continue - } - event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, verRes.RoomVersion) + roomVersion := getRoomVersion(header.RoomID) + event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion) if err != nil { if _, ok := err.(gomatrixserverlib.BadJSONError); ok { // Room version 6 states that homeservers should strictly enforce canonical JSON @@ -330,96 +280,30 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res } continue } - v, _ := inputWorkers.LoadOrStore(event.RoomID(), &inputWorker{ - input: newSendFIFOQueue(), - }) - worker := v.(*inputWorker) - wg.Add(1) - task := &inputTask{ - ctx: ctx, - t: t, - event: event.Headered(verRes.RoomVersion), - wg: &wg, - } - tasks = append(tasks, task) - worker.input.push(task) - if worker.running.CAS(false, true) { - go worker.run() - } + + // pass the event to the roomserver which will do auth checks + // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently + // discarded by the caller of this function + _ = api.SendEvents( + context.Background(), + t.rsAPI, + api.KindNew, + []*gomatrixserverlib.HeaderedEvent{ + event.Headered(roomVersion), + }, + t.Origin, + api.DoNotSendToOtherServers, + nil, + false, + ) + + results[event.EventID()] = gomatrixserverlib.PDUResult{} } - t.processEDUs(ctx) wg.Wait() - - for _, task := range tasks { - if task.err != nil { - results[task.event.EventID()] = gomatrixserverlib.PDUResult{ - // Error: task.err.Error(), TODO: this upsets tests if uncommented - } - } else { - results[task.event.EventID()] = gomatrixserverlib.PDUResult{} - } - } - - if c := len(results); c > 0 { - util.GetLogger(ctx).Debugf("Processed %d PDUs from %v in transaction %q", c, t.Origin, t.TransactionID) - } return &gomatrixserverlib.RespSend{PDUs: results}, nil } -func (t *inputWorker) run() { - defer t.running.Store(false) - for { - task, ok := t.input.pop() - if !ok { - return - } - if task == nil { - continue - } - func() { - defer task.wg.Done() - select { - case <-task.ctx.Done(): - task.err = context.DeadlineExceeded - pduCountTotal.WithLabelValues("expired").Inc() - return - default: - evStart := time.Now() - // TODO: Is 5 minutes too long? - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - task.err = task.t.processEvent(ctx, task.event) - cancel() - task.duration = time.Since(evStart) - if err := task.err; err != nil { - switch err.(type) { - case *gomatrixserverlib.NotAllowed: - processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeRejected).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", true).Warn( - "Failed to process incoming federation event, skipping", - ) - task.err = nil // make "rejected" failures silent - default: - processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeFail).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", false).Warn( - "Failed to process incoming federation event, skipping", - ) - } - } else { - pduCountTotal.WithLabelValues("success").Inc() - processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeOK).Observe( - float64(time.Since(evStart).Nanoseconds()) / 1000., - ) - } - } - }() - } -} - func (t *txnReq) processEDUs(ctx context.Context) { for _, e := range t.EDUs { eduCountTotal.Inc() @@ -561,19 +445,3 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli util.GetLogger(ctx).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate") } } - -func (t *txnReq) processEvent(_ context.Context, e *gomatrixserverlib.HeaderedEvent) error { - // pass the event to the roomserver which will do auth checks - // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently - // discarded by the caller of this function - return api.SendEvents( - context.Background(), - t.rsAPI, - api.KindNew, - []*gomatrixserverlib.HeaderedEvent{e}, - t.Origin, - api.DoNotSendToOtherServers, - nil, - false, - ) -}