Fix transaction coalescing

This commit is contained in:
Neil Alexander 2020-12-07 09:54:51 +00:00
parent 91ac6e9880
commit 26cc2d480b
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
2 changed files with 27 additions and 36 deletions

View file

@ -57,9 +57,9 @@ type destinationQueue struct {
statistics *statistics.ServerStatistics // statistics about this remote server statistics *statistics.ServerStatistics // statistics about this remote server
transactionIDMutex sync.Mutex // protects transactionID transactionIDMutex sync.Mutex // protects transactionID
transactionID gomatrixserverlib.TransactionID // last transaction ID transactionID gomatrixserverlib.TransactionID // last transaction ID
transactionCount atomic.Int32 // how many events in this transaction so far
notifyPDUs chan *queuedPDU // interrupts idle wait for PDUs notifyPDUs chan *queuedPDU // interrupts idle wait for PDUs
notifyEDUs chan *queuedEDU // interrupts idle wait for EDUs notifyEDUs chan *queuedEDU // interrupts idle wait for EDUs
notifyOverflow chan struct{} // interrupts idle wait for overflowed PDUs/EDUs
pendingPDUs []*queuedPDU // owned by backgroundSender goroutine once started pendingPDUs []*queuedPDU // owned by backgroundSender goroutine once started
pendingEDUs []*queuedEDU // owned by backgroundSender goroutine once started pendingEDUs []*queuedEDU // owned by backgroundSender goroutine once started
interruptBackoff chan bool // interrupts backoff interruptBackoff chan bool // interrupts backoff
@ -73,32 +73,18 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
log.Errorf("attempt to send nil PDU with destination %q", oq.destination) log.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return return
} }
// Create a transaction ID. We'll either do this if we don't have
// one made up yet, or if we've exceeded the number of maximum
// events allowed in a single tranaction. We'll reset the counter
// when we do.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" || oq.transactionCount.Load() >= maxPDUsPerTransaction {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
oq.transactionCount.Store(0)
}
oq.transactionIDMutex.Unlock()
// Create a database entry that associates the given PDU NID with // Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
if err := oq.db.AssociatePDUWithDestination( if err := oq.db.AssociatePDUWithDestination(
context.TODO(), context.TODO(),
oq.transactionID, // the current transaction ID "", // the current transaction ID, TODO: do something about this
oq.destination, // the destination server name oq.destination, // the destination server name
receipt, // NIDs from federationsender_queue_json table receipt, // NIDs from federationsender_queue_json table
); err != nil { ); err != nil {
log.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination) log.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination)
return return
} }
// We've successfully added a PDU to the transaction so increase
// the counter.
oq.transactionCount.Add(1)
// Check if the destination is blacklisted. If it isn't then wake // Check if the destination is blacklisted. If it isn't then wake
// up the queue. // up the queue.
if !oq.statistics.Blacklisted() { if !oq.statistics.Blacklisted() {
@ -135,9 +121,6 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
log.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination) log.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return return
} }
// We've successfully added an EDU to the transaction so increase
// the counter.
oq.transactionCount.Add(1)
// Check if the destination is blacklisted. If it isn't then wake // Check if the destination is blacklisted. If it isn't then wake
// up the queue. // up the queue.
if !oq.statistics.Blacklisted() { if !oq.statistics.Blacklisted() {
@ -177,12 +160,14 @@ func (oq *destinationQueue) wakeQueueIfNeeded() {
func (oq *destinationQueue) getPendingFromDatabase() { func (oq *destinationQueue) getPendingFromDatabase() {
// Check to see if there's anything to do for this server // Check to see if there's anything to do for this server
// in the database. // in the database.
retrieved := false
ctx := context.Background() ctx := context.Background()
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 { if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
// We have room in memory for some PDUs - let's request no more than that. // We have room in memory for some PDUs - let's request no more than that.
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil { if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil {
for receipt, pdu := range pdus { for receipt, pdu := range pdus {
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu}) oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
retrieved = true
} }
} else { } else {
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination) logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
@ -193,6 +178,7 @@ func (oq *destinationQueue) getPendingFromDatabase() {
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil { if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil {
for receipt, edu := range edus { for receipt, edu := range edus {
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu}) oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
retrieved = true
} }
} else { } else {
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination) logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
@ -203,6 +189,9 @@ func (oq *destinationQueue) getPendingFromDatabase() {
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory { if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory {
oq.overflowed.Store(false) oq.overflowed.Store(false)
} }
if retrieved {
oq.notifyOverflow <- struct{}{}
}
} }
// backgroundSend is the worker goroutine for sending events. // backgroundSend is the worker goroutine for sending events.
@ -226,6 +215,9 @@ func (oq *destinationQueue) backgroundSend() {
// until we hit an idle timeout. // until we hit an idle timeout.
awaitSelect: awaitSelect:
select { select {
case <-oq.notifyOverflow:
// getPendingFromDatabase has woken us up because of pending
// work.
case pdu := <-oq.notifyPDUs: case pdu := <-oq.notifyPDUs:
// We were woken up because there are new PDUs waiting in the // We were woken up because there are new PDUs waiting in the
// database. // database.
@ -339,15 +331,16 @@ func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU, pdus []*queuedPDU,
edus []*queuedEDU, edus []*queuedEDU,
) (bool, int, int, error) { ) (bool, int, int, error) {
// Before we do anything, we need to roll over the transaction // If there's no projected transaction ID then generate one. If
// ID that is being used to coalesce events into the next TX. // the transaction succeeds then we'll set it back to "" so that
// Otherwise it's possible that we'll pick up an incomplete // we generate a new one next time. If it fails, we'll preserve
// transaction and end up nuking the rest of the events at the // it so that we retry with the same transaction ID.
// cleanup stage.
oq.transactionIDMutex.Lock() oq.transactionIDMutex.Lock()
oq.transactionID = "" if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock() oq.transactionIDMutex.Unlock()
oq.transactionCount.Store(0)
// Create the transaction. // Create the transaction.
t := gomatrixserverlib.Transaction{ t := gomatrixserverlib.Transaction{
@ -357,6 +350,7 @@ func (oq *destinationQueue) nextTransaction(
t.Origin = oq.origin t.Origin = oq.origin
t.Destination = oq.destination t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
// If we didn't get anything from the database and there are no // If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here. // pending EDUs then there's nothing to do - stop here.
@ -364,14 +358,6 @@ func (oq *destinationQueue) nextTransaction(
return false, 0, 0, nil return false, 0, 0, nil
} }
// Pick out the transaction ID from the database. If we didn't
// get a transaction ID (i.e. because there are no PDUs but only
// EDUs) then generate a transaction ID.
if t.TransactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
var pduReceipts []*shared.Receipt var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt var eduReceipts []*shared.Receipt
@ -420,6 +406,10 @@ func (oq *destinationQueue) nextTransaction(
log.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination) log.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
} }
} }
// Reset the transaction ID.
oq.transactionIDMutex.Lock()
oq.transactionID = ""
oq.transactionIDMutex.Unlock()
return true, len(t.PDUs), len(t.EDUs), nil return true, len(t.PDUs), len(t.EDUs), nil
case gomatrix.HTTPError: case gomatrix.HTTPError:
// Report that we failed to send the transaction and we // Report that we failed to send the transaction and we

View file

@ -125,6 +125,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d
statistics: oqs.statistics.ForServer(destination), statistics: oqs.statistics.ForServer(destination),
notifyPDUs: make(chan *queuedPDU, 16), notifyPDUs: make(chan *queuedPDU, 16),
notifyEDUs: make(chan *queuedEDU, 16), notifyEDUs: make(chan *queuedEDU, 16),
notifyOverflow: make(chan struct{}, 1),
interruptBackoff: make(chan bool), interruptBackoff: make(chan bool),
signing: oqs.signing, signing: oqs.signing,
} }