diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index b0dc54bf4..c1ec40ba9 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -57,9 +57,9 @@ type destinationQueue struct { statistics *statistics.ServerStatistics // statistics about this remote server transactionIDMutex sync.Mutex // protects transactionID transactionID gomatrixserverlib.TransactionID // last transaction ID - transactionCount atomic.Int32 // how many events in this transaction so far notifyPDUs chan *queuedPDU // interrupts idle wait for PDUs notifyEDUs chan *queuedEDU // interrupts idle wait for EDUs + notifyOverflow chan struct{} // interrupts idle wait for overflowed PDUs/EDUs pendingPDUs []*queuedPDU // owned by backgroundSender goroutine once started pendingEDUs []*queuedEDU // owned by backgroundSender goroutine once started interruptBackoff chan bool // interrupts backoff @@ -73,32 +73,18 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re log.Errorf("attempt to send nil PDU with destination %q", oq.destination) return } - // Create a transaction ID. We'll either do this if we don't have - // one made up yet, or if we've exceeded the number of maximum - // events allowed in a single tranaction. We'll reset the counter - // when we do. - oq.transactionIDMutex.Lock() - if oq.transactionID == "" || oq.transactionCount.Load() >= maxPDUsPerTransaction { - now := gomatrixserverlib.AsTimestamp(time.Now()) - oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) - oq.transactionCount.Store(0) - } - oq.transactionIDMutex.Unlock() // Create a database entry that associates the given PDU NID with // this destination queue. We'll then be able to retrieve the PDU // later. if err := oq.db.AssociatePDUWithDestination( context.TODO(), - oq.transactionID, // the current transaction ID - oq.destination, // the destination server name - receipt, // NIDs from federationsender_queue_json table + "", // the current transaction ID, TODO: do something about this + oq.destination, // the destination server name + receipt, // NIDs from federationsender_queue_json table ); err != nil { log.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination) return } - // We've successfully added a PDU to the transaction so increase - // the counter. - oq.transactionCount.Add(1) // Check if the destination is blacklisted. If it isn't then wake // up the queue. if !oq.statistics.Blacklisted() { @@ -135,9 +121,6 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share log.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination) return } - // We've successfully added an EDU to the transaction so increase - // the counter. - oq.transactionCount.Add(1) // Check if the destination is blacklisted. If it isn't then wake // up the queue. if !oq.statistics.Blacklisted() { @@ -177,12 +160,14 @@ func (oq *destinationQueue) wakeQueueIfNeeded() { func (oq *destinationQueue) getPendingFromDatabase() { // Check to see if there's anything to do for this server // in the database. + retrieved := false ctx := context.Background() if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 { // We have room in memory for some PDUs - let's request no more than that. if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil { for receipt, pdu := range pdus { oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu}) + retrieved = true } } else { logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination) @@ -193,6 +178,7 @@ func (oq *destinationQueue) getPendingFromDatabase() { if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil { for receipt, edu := range edus { oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu}) + retrieved = true } } else { logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination) @@ -203,6 +189,9 @@ func (oq *destinationQueue) getPendingFromDatabase() { if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory { oq.overflowed.Store(false) } + if retrieved { + oq.notifyOverflow <- struct{}{} + } } // backgroundSend is the worker goroutine for sending events. @@ -226,6 +215,9 @@ func (oq *destinationQueue) backgroundSend() { // until we hit an idle timeout. awaitSelect: select { + case <-oq.notifyOverflow: + // getPendingFromDatabase has woken us up because of pending + // work. case pdu := <-oq.notifyPDUs: // We were woken up because there are new PDUs waiting in the // database. @@ -339,15 +331,16 @@ func (oq *destinationQueue) nextTransaction( pdus []*queuedPDU, edus []*queuedEDU, ) (bool, int, int, error) { - // Before we do anything, we need to roll over the transaction - // ID that is being used to coalesce events into the next TX. - // Otherwise it's possible that we'll pick up an incomplete - // transaction and end up nuking the rest of the events at the - // cleanup stage. + // If there's no projected transaction ID then generate one. If + // the transaction succeeds then we'll set it back to "" so that + // we generate a new one next time. If it fails, we'll preserve + // it so that we retry with the same transaction ID. oq.transactionIDMutex.Lock() - oq.transactionID = "" + if oq.transactionID == "" { + now := gomatrixserverlib.AsTimestamp(time.Now()) + oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) + } oq.transactionIDMutex.Unlock() - oq.transactionCount.Store(0) // Create the transaction. t := gomatrixserverlib.Transaction{ @@ -357,6 +350,7 @@ func (oq *destinationQueue) nextTransaction( t.Origin = oq.origin t.Destination = oq.destination t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) + t.TransactionID = oq.transactionID // If we didn't get anything from the database and there are no // pending EDUs then there's nothing to do - stop here. @@ -364,14 +358,6 @@ func (oq *destinationQueue) nextTransaction( return false, 0, 0, nil } - // Pick out the transaction ID from the database. If we didn't - // get a transaction ID (i.e. because there are no PDUs but only - // EDUs) then generate a transaction ID. - if t.TransactionID == "" { - now := gomatrixserverlib.AsTimestamp(time.Now()) - t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) - } - var pduReceipts []*shared.Receipt var eduReceipts []*shared.Receipt @@ -420,6 +406,10 @@ func (oq *destinationQueue) nextTransaction( log.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination) } } + // Reset the transaction ID. + oq.transactionIDMutex.Lock() + oq.transactionID = "" + oq.transactionIDMutex.Unlock() return true, len(t.PDUs), len(t.EDUs), nil case gomatrix.HTTPError: // Report that we failed to send the transaction and we diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 96c173687..5ec2ccd29 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -125,6 +125,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d statistics: oqs.statistics.ForServer(destination), notifyPDUs: make(chan *queuedPDU, 16), notifyEDUs: make(chan *queuedEDU, 16), + notifyOverflow: make(chan struct{}, 1), interruptBackoff: make(chan bool), signing: oqs.signing, }