Simplify /send
endpoint significantly
This commit is contained in:
parent
0f5049279c
commit
8236478dc3
|
@ -33,7 +33,6 @@ import (
|
|||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -88,67 +87,7 @@ func init() {
|
|||
)
|
||||
}
|
||||
|
||||
type sendFIFOQueue struct {
|
||||
tasks []*inputTask
|
||||
count int
|
||||
mutex sync.Mutex
|
||||
notifs chan struct{}
|
||||
}
|
||||
|
||||
func newSendFIFOQueue() *sendFIFOQueue {
|
||||
q := &sendFIFOQueue{
|
||||
notifs: make(chan struct{}, 1),
|
||||
}
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *sendFIFOQueue) push(frame *inputTask) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
q.tasks = append(q.tasks, frame)
|
||||
q.count++
|
||||
select {
|
||||
case q.notifs <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// pop returns the first item of the queue, if there is one.
|
||||
// The second return value will indicate if a task was returned.
|
||||
func (q *sendFIFOQueue) pop() (*inputTask, bool) {
|
||||
q.mutex.Lock()
|
||||
defer q.mutex.Unlock()
|
||||
if q.count == 0 {
|
||||
return nil, false
|
||||
}
|
||||
frame := q.tasks[0]
|
||||
q.tasks[0] = nil
|
||||
q.tasks = q.tasks[1:]
|
||||
q.count--
|
||||
if q.count == 0 {
|
||||
// Force a GC of the underlying array, since it might have
|
||||
// grown significantly if the queue was hammered for some reason
|
||||
q.tasks = nil
|
||||
}
|
||||
return frame, true
|
||||
}
|
||||
|
||||
type inputTask struct {
|
||||
ctx context.Context
|
||||
t *txnReq
|
||||
event *gomatrixserverlib.HeaderedEvent
|
||||
wg *sync.WaitGroup
|
||||
err error // written back by worker, only safe to read when all tasks are done
|
||||
duration time.Duration // written back by worker, only safe to read when all tasks are done
|
||||
}
|
||||
|
||||
type inputWorker struct {
|
||||
running atomic.Bool
|
||||
input *sendFIFOQueue
|
||||
}
|
||||
|
||||
var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse
|
||||
var inputWorkers sync.Map // room ID -> *inputWorker
|
||||
|
||||
// Send implements /_matrix/federation/v1/send/{txnID}
|
||||
func Send(
|
||||
|
@ -261,7 +200,6 @@ type txnReq struct {
|
|||
federation txnFederationClient
|
||||
roomsMu *internal.MutexByRoom
|
||||
servers federationAPI.ServersInRoomProvider
|
||||
work string
|
||||
}
|
||||
|
||||
// A subset of FederationClient functionality that txn requires. Useful for testing.
|
||||
|
@ -276,9 +214,28 @@ type txnFederationClient interface {
|
|||
}
|
||||
|
||||
func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) {
|
||||
results := make(map[string]gomatrixserverlib.PDUResult)
|
||||
var wg sync.WaitGroup
|
||||
var tasks []*inputTask
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
t.processEDUs(ctx)
|
||||
}()
|
||||
|
||||
results := make(map[string]gomatrixserverlib.PDUResult)
|
||||
roomVersions := make(map[string]gomatrixserverlib.RoomVersion)
|
||||
getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion {
|
||||
if v, ok := roomVersions[roomID]; ok {
|
||||
return v
|
||||
}
|
||||
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
|
||||
verRes := api.QueryRoomVersionForRoomResponse{}
|
||||
if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID)
|
||||
return ""
|
||||
}
|
||||
roomVersions[roomID] = verRes.RoomVersion
|
||||
return verRes.RoomVersion
|
||||
}
|
||||
|
||||
for _, pdu := range t.PDUs {
|
||||
pduCountTotal.WithLabelValues("total").Inc()
|
||||
|
@ -291,15 +248,8 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
|
|||
// failure in the PDU results
|
||||
continue
|
||||
}
|
||||
verReq := api.QueryRoomVersionForRoomRequest{RoomID: header.RoomID}
|
||||
verRes := api.QueryRoomVersionForRoomResponse{}
|
||||
if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Warn("Transaction: Failed to query room version for room", verReq.RoomID)
|
||||
// We don't know the event ID at this point so we can't return the
|
||||
// failure in the PDU results
|
||||
continue
|
||||
}
|
||||
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, verRes.RoomVersion)
|
||||
roomVersion := getRoomVersion(header.RoomID)
|
||||
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
|
||||
if err != nil {
|
||||
if _, ok := err.(gomatrixserverlib.BadJSONError); ok {
|
||||
// Room version 6 states that homeservers should strictly enforce canonical JSON
|
||||
|
@ -330,96 +280,30 @@ func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.Res
|
|||
}
|
||||
continue
|
||||
}
|
||||
v, _ := inputWorkers.LoadOrStore(event.RoomID(), &inputWorker{
|
||||
input: newSendFIFOQueue(),
|
||||
})
|
||||
worker := v.(*inputWorker)
|
||||
wg.Add(1)
|
||||
task := &inputTask{
|
||||
ctx: ctx,
|
||||
t: t,
|
||||
event: event.Headered(verRes.RoomVersion),
|
||||
wg: &wg,
|
||||
}
|
||||
tasks = append(tasks, task)
|
||||
worker.input.push(task)
|
||||
if worker.running.CAS(false, true) {
|
||||
go worker.run()
|
||||
}
|
||||
|
||||
// pass the event to the roomserver which will do auth checks
|
||||
// If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently
|
||||
// discarded by the caller of this function
|
||||
_ = api.SendEvents(
|
||||
context.Background(),
|
||||
t.rsAPI,
|
||||
api.KindNew,
|
||||
[]*gomatrixserverlib.HeaderedEvent{
|
||||
event.Headered(roomVersion),
|
||||
},
|
||||
t.Origin,
|
||||
api.DoNotSendToOtherServers,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
|
||||
results[event.EventID()] = gomatrixserverlib.PDUResult{}
|
||||
}
|
||||
|
||||
t.processEDUs(ctx)
|
||||
wg.Wait()
|
||||
|
||||
for _, task := range tasks {
|
||||
if task.err != nil {
|
||||
results[task.event.EventID()] = gomatrixserverlib.PDUResult{
|
||||
// Error: task.err.Error(), TODO: this upsets tests if uncommented
|
||||
}
|
||||
} else {
|
||||
results[task.event.EventID()] = gomatrixserverlib.PDUResult{}
|
||||
}
|
||||
}
|
||||
|
||||
if c := len(results); c > 0 {
|
||||
util.GetLogger(ctx).Debugf("Processed %d PDUs from %v in transaction %q", c, t.Origin, t.TransactionID)
|
||||
}
|
||||
return &gomatrixserverlib.RespSend{PDUs: results}, nil
|
||||
}
|
||||
|
||||
func (t *inputWorker) run() {
|
||||
defer t.running.Store(false)
|
||||
for {
|
||||
task, ok := t.input.pop()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if task == nil {
|
||||
continue
|
||||
}
|
||||
func() {
|
||||
defer task.wg.Done()
|
||||
select {
|
||||
case <-task.ctx.Done():
|
||||
task.err = context.DeadlineExceeded
|
||||
pduCountTotal.WithLabelValues("expired").Inc()
|
||||
return
|
||||
default:
|
||||
evStart := time.Now()
|
||||
// TODO: Is 5 minutes too long?
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
|
||||
task.err = task.t.processEvent(ctx, task.event)
|
||||
cancel()
|
||||
task.duration = time.Since(evStart)
|
||||
if err := task.err; err != nil {
|
||||
switch err.(type) {
|
||||
case *gomatrixserverlib.NotAllowed:
|
||||
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeRejected).Observe(
|
||||
float64(time.Since(evStart).Nanoseconds()) / 1000.,
|
||||
)
|
||||
util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", true).Warn(
|
||||
"Failed to process incoming federation event, skipping",
|
||||
)
|
||||
task.err = nil // make "rejected" failures silent
|
||||
default:
|
||||
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeFail).Observe(
|
||||
float64(time.Since(evStart).Nanoseconds()) / 1000.,
|
||||
)
|
||||
util.GetLogger(task.ctx).WithError(err).WithField("event_id", task.event.EventID()).WithField("rejected", false).Warn(
|
||||
"Failed to process incoming federation event, skipping",
|
||||
)
|
||||
}
|
||||
} else {
|
||||
pduCountTotal.WithLabelValues("success").Inc()
|
||||
processEventSummary.WithLabelValues(task.t.work, MetricsOutcomeOK).Observe(
|
||||
float64(time.Since(evStart).Nanoseconds()) / 1000.,
|
||||
)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func (t *txnReq) processEDUs(ctx context.Context) {
|
||||
for _, e := range t.EDUs {
|
||||
eduCountTotal.Inc()
|
||||
|
@ -561,19 +445,3 @@ func (t *txnReq) processDeviceListUpdate(ctx context.Context, e gomatrixserverli
|
|||
util.GetLogger(ctx).WithError(inputRes.Error).WithField("user_id", payload.UserID).Error("failed to InputDeviceListUpdate")
|
||||
}
|
||||
}
|
||||
|
||||
func (t *txnReq) processEvent(_ context.Context, e *gomatrixserverlib.HeaderedEvent) error {
|
||||
// pass the event to the roomserver which will do auth checks
|
||||
// If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently
|
||||
// discarded by the caller of this function
|
||||
return api.SendEvents(
|
||||
context.Background(),
|
||||
t.rsAPI,
|
||||
api.KindNew,
|
||||
[]*gomatrixserverlib.HeaderedEvent{e},
|
||||
t.Origin,
|
||||
api.DoNotSendToOtherServers,
|
||||
nil,
|
||||
false,
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue