Simplify /send endpoint significantly

This commit is contained in:
Neil Alexander 2022-01-17 11:40:42 +00:00
parent 0f5049279c
commit 8236478dc3
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -33,7 +33,6 @@ import (
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"go.uber.org/atomic"
) )
const ( const (
@ -88,67 +87,7 @@ func init() {
) )
} }
type sendFIFOQueue struct {
tasks []*inputTask
count int
mutex sync.Mutex
notifs chan struct{}
}
func newSendFIFOQueue() *sendFIFOQueue {
q := &sendFIFOQueue{
notifs: make(chan struct{}, 1),
}
return q
}
func (q *sendFIFOQueue) push(frame *inputTask) {
q.mutex.Lock()
defer q.mutex.Unlock()
q.tasks = append(q.tasks, frame)
q.count++
select {
case q.notifs <- struct{}{}:
default:
}
}
// pop returns the first item of the queue, if there is one.
// The second return value will indicate if a task was returned.
func (q *sendFIFOQueue) pop() (*inputTask, bool) {
q.mutex.Lock()
defer q.mutex.Unlock()
if q.count == 0 {
return nil, false
}
frame := q.tasks[0]
q.tasks[0] = nil
q.tasks = q.tasks[1:]
q.count--
if q.count == 0 {
// Force a GC of the underlying array, since it might have
// grown significantly if the queue was hammered for some reason
q.tasks = nil
}
return frame, true
}
type inputTask struct {
ctx context.Context
t *txnReq
event *gomatrixserverlib.HeaderedEvent
wg *sync.WaitGroup
err error // written back by worker, only safe to read when all tasks are done
duration time.Duration // written back by worker, only safe to read when all tasks are done
}
type inputWorker struct {
running atomic.Bool
input *sendFIFOQueue
}
var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse
var inputWorkers sync.Map // room ID -> *inputWorker
// Send implements /_matrix/federation/v1/send/{txnID} // Send implements /_matrix/federation/v1/send/{txnID}
func Send( func Send(
@ -261,7 +200,6 @@ type txnReq struct {
federation txnFederationClient federation txnFederationClient
roomsMu *internal.MutexByRoom roomsMu *internal.MutexByRoom
servers federationAPI.ServersInRoomProvider servers federationAPI.ServersInRoomProvider
work string
} }
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
@ -276,9 +214,28 @@ type txnFederationClient interface {
} }
func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) { func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) {
results := make(map[string]gomatrixserverlib.PDUResult)
var wg sync.WaitGroup var wg sync.WaitGroup
var tasks []*inputTask wg.Add(1)
go func() {
defer wg.Done()
t.processEDUs(ctx)
}()
results := make(map[string]gomatrixserverlib.PDUResult)
roomVersions := make(map[string]gomatrixserverlib.RoomVersion)
getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion {
if v, ok := roomVersions[roomID]; ok {
return v
}
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID)
return ""
}
roomVersions[roomID] = verRes.RoomVersion
return verRes.RoomVersion
}
for _, pdu := range t.PDUs { for _, pdu := range t.PDUs {
pduCountTotal.WithLabelValues("total").Inc() pduCountTotal.WithLabelValues("total").Inc()
@ -291,15 +248,8 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
// failure in the PDU results // failure in the PDU results
continue continue
} }
verReq := api.QueryRoomVersionForRoomRequest{RoomID: header.RoomID} roomVersion := getRoomVersion(header.RoomID)
verRes := api.QueryRoomVersionForRoomResponse{} event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID)
// We don't know the event ID at this point so we can't return the
// failure in the PDU results
continue
}
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, verRes.RoomVersion)
if err != nil { if err != nil {
if _, ok := err.(gomatrixserverlib.BadJSONError); ok { if _, ok := err.(gomatrixserverlib.BadJSONError); ok {
// Room version 6 states that homeservers should strictly enforce canonical JSON // Room version 6 states that homeservers should strictly enforce canonical JSON
@ -330,96 +280,30 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
} }
continue continue
} }
v, _ := inputWorkers.LoadOrStore(event.RoomID(), &inputWorker{
input: newSendFIFOQueue(), // pass the event to the roomserver which will do auth checks
}) // If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently
worker := v.(*inputWorker) // discarded by the caller of this function
wg.Add(1) _ = api.SendEvents(
task := &inputTask{ context.Background(),
ctx: ctx, t.rsAPI,
t: t, api.KindNew,
event: event.Headered(verRes.RoomVersion), []*gomatrixserverlib.HeaderedEvent{
wg: &wg, event.Headered(roomVersion),
} },
tasks = append(tasks, task) t.Origin,
worker.input.push(task) api.DoNotSendToOtherServers,
if worker.running.CAS(false, true) { nil,
go worker.run() false,
} )
results[event.EventID()] = gomatrixserverlib.PDUResult{}
} }
t.processEDUs(ctx)
wg.Wait() wg.Wait()
for _, task := range tasks {
if task.err != nil {
results[task.event.EventID()] = gomatrixserverlib.PDUResult{
// Error: task.err.Error(), TODO: this upsets tests if uncommented
}
} else {
results[task.event.EventID()] = gomatrixserverlib.PDUResult{}
}
}
if c := len(results); c > 0 {
util.GetLogger(ctx).Debugf("Processed %d PDUs from %v in transaction %q", c, t.Origin, t.TransactionID)
}
return &gomatrixserverlib.RespSend{PDUs: results}, nil return &gomatrixserverlib.RespSend{PDUs: results}, nil
} }
func (t *inputWorker) run() {
defer t.running.Store(false)
for {
task, ok := t.input.pop()
if !ok {
return
}
if task == nil {
continue
}
func() {
defer task.wg.Done()
select {
case <-task.ctx.Done():
task.err = context.DeadlineExceeded
pduCountTotal.WithLabelValues("expired").Inc()
return
default:
evStart := time.Now()
// TODO: Is 5 minutes too long?
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
task.err = task.t.processEvent(ctx, task.event)
cancel()
task.duration = time.Since(evStart)
if err := task.err; err != nil {
switch err.(type) {
case *gomatrixserverlib.NotAllowed:
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeRejected).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", true).Warn(
"Failed to process incoming federation event, skipping",
)
task.err = nil // make "rejected" failures silent
default:
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeFail).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", false).Warn(
"Failed to process incoming federation event, skipping",
)
}
} else {
pduCountTotal.WithLabelValues("success").Inc()
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeOK).Observe(
float64(time.Since(evStart).Nanoseconds()) / 1000.,
)
}
}
}()
}
}
func (t *txnReq) processEDUs(ctx context.Context) { func (t *txnReq) processEDUs(ctx context.Context) {
for _, e := range t.EDUs { for _, e := range t.EDUs {
eduCountTotal.Inc() eduCountTotal.Inc()
@ -561,19 +445,3 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
util.GetLogger(ctx).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate") util.GetLogger(ctx).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate")
} }
} }
func (t *txnReq) processEvent(_ context.Context, e *gomatrixserverlib.HeaderedEvent) error {
// pass the event to the roomserver which will do auth checks
// If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently
// discarded by the caller of this function
return api.SendEvents(
context.Background(),
t.rsAPI,
api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{e},
t.Origin,
api.DoNotSendToOtherServers,
nil,
false,
)
}