diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index 7651f7a7f..d6fee0f8d 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -57,11 +57,10 @@ type destinationQueue struct { statistics *statistics.ServerStatistics // statistics about this remote server transactionIDMutex sync.Mutex // protects transactionID transactionID gomatrixserverlib.TransactionID // last transaction ID - notifyPDUs chan *queuedPDU // interrupts idle wait for PDUs that have just been queued - notifyEDUs chan *queuedEDU // interrupts idle wait for EDUs that have just been queued - notifyOverflow chan struct{} // interrupts idle wait for overflowed PDUs/EDUs from the database - pendingPDUs []*queuedPDU // PDUs waiting to be sent, owned by backgroundSender goroutine once started - pendingEDUs []*queuedEDU // EDUs waiting to be sent, owned by backgroundSender goroutine once started + notify chan struct{} // interrupts idle wait for overflowed PDUs/EDUs from the database + pendingPDUs []*queuedPDU // PDUs waiting to be sent + pendingEDUs []*queuedEDU // EDUs waiting to be sent + pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs interruptBackoff chan bool // interrupts backoff } @@ -85,17 +84,23 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re log.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination) return } + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingPDUs) < maxPDUsInMemory { + oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{ + pdu: event, + receipt: receipt, + }) + } + oq.pendingMutex.Unlock() // Check if the destination is blacklisted. If it isn't then wake // up the queue. if !oq.statistics.Blacklisted() { // Wake up the queue if it's asleep. oq.wakeQueueIfNeeded() - // Queue the PDU. select { - case oq.notifyPDUs <- &queuedPDU{ - receipt: receipt, - pdu: event, - }: + case oq.notify <- struct{}{}: default: } } @@ -120,6 +125,16 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share log.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination) return } + // If there's room in memory to hold the event then add it to the + // list. + oq.pendingMutex.Lock() + if len(oq.pendingEDUs) < maxEDUsInMemory { + oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{ + edu: event, + receipt: receipt, + }) + } + oq.pendingMutex.Unlock() // Check if the destination is blacklisted. If it isn't then wake // up the queue. if !oq.statistics.Blacklisted() { @@ -127,10 +142,7 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share oq.wakeQueueIfNeeded() // Queue the EDU. select { - case oq.notifyEDUs <- &queuedEDU{ - receipt: receipt, - edu: event, - }: + case oq.notify <- struct{}{}: default: } } @@ -154,15 +166,14 @@ func (oq *destinationQueue) wakeQueueIfNeeded() { // getPendingFromDatabase will look at the database and see if // there are any persisted events that haven't been sent to this -// destination yet. If so, they will be queued up. This function -// MUST be called from backgroundSend() goroutine ONLY because -// it modifies oq.pendingPDUs/oq.pendingEDUs and they aren't -// mutexed. +// destination yet. If so, they will be queued up. func (oq *destinationQueue) getPendingFromDatabase() { // Check to see if there's anything to do for this server // in the database. retrieved := false ctx := context.Background() + oq.pendingMutex.Lock() + defer oq.pendingMutex.Unlock() if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 { // We have room in memory for some PDUs - let's request no more than that. if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil { @@ -191,7 +202,7 @@ func (oq *destinationQueue) getPendingFromDatabase() { oq.overflowed.Store(false) } if retrieved { - oq.notifyOverflow <- struct{}{} + oq.notify <- struct{}{} } } @@ -214,47 +225,11 @@ func (oq *destinationQueue) backgroundSend() { // If we have nothing to do then wait either for incoming events, or // until we hit an idle timeout. - awaitSelect: select { - case <-oq.notifyOverflow: - // getPendingFromDatabase has woken us up because of pending - // work. - case pdu := <-oq.notifyPDUs: - // We were woken up because there are new PDUs waiting in the - // database. - if len(oq.pendingPDUs) > maxPDUsInMemory { - oq.overflowed.Store(true) - break awaitSelect - } - oq.pendingPDUs = append(oq.pendingPDUs, pdu) - pendingPDULoop: - for i := 1; i < maxPDUsInMemory-len(oq.pendingPDUs); i++ { - select { - case pdu := <-oq.notifyPDUs: - oq.pendingPDUs = append(oq.pendingPDUs, pdu) - default: - break pendingPDULoop - } - } - - case edu := <-oq.notifyEDUs: - // We were woken up because there are new PDUs waiting in the - // database. - if len(oq.pendingEDUs) > maxEDUsInMemory { - oq.overflowed.Store(true) - break awaitSelect - } - oq.pendingEDUs = append(oq.pendingEDUs, edu) - pendingEDULoop: - for i := 1; i < maxEDUsInMemory-len(oq.pendingEDUs); i++ { - select { - case edu := <-oq.notifyEDUs: - oq.pendingEDUs = append(oq.pendingEDUs, edu) - default: - break pendingEDULoop - } - } - + case <-oq.notify: + // There's work to do, either because getPendingFromDatabase + // told us there is, or because a new event has come in via + // sendEvent/sendEDU. case <-time.After(queueIdleTimeout): // The worker is idle so stop the goroutine. It'll get // restarted automatically the next time we have an event to @@ -272,6 +247,7 @@ func (oq *destinationQueue) backgroundSend() { // has exceeded a maximum allowable value. Clean up the in-memory // buffers at this point. The PDU clean-up is already on a defer. log.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination) + oq.pendingMutex.Lock() for i := range oq.pendingPDUs { oq.pendingPDUs[i] = nil } @@ -280,6 +256,7 @@ func (oq *destinationQueue) backgroundSend() { } oq.pendingPDUs = nil oq.pendingEDUs = nil + oq.pendingMutex.Lock() return } if until != nil && until.After(time.Now()) { @@ -293,6 +270,7 @@ func (oq *destinationQueue) backgroundSend() { } } + oq.pendingMutex.RLock() pduCount := len(oq.pendingPDUs) eduCount := len(oq.pendingEDUs) if pduCount > maxPDUsPerTransaction { @@ -305,13 +283,16 @@ func (oq *destinationQueue) backgroundSend() { // If we have pending PDUs or EDUs then construct a transaction. // Try sending the next transaction and see what happens. transaction, pc, ec, terr := oq.nextTransaction(oq.pendingPDUs[:pduCount], oq.pendingEDUs[:eduCount]) + oq.pendingMutex.RUnlock() if terr != nil { // We failed to send the transaction. Mark it as a failure. oq.statistics.Failure() + } else if transaction { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. oq.statistics.Success() + oq.pendingMutex.Lock() for i := range oq.pendingPDUs { oq.pendingPDUs[i] = nil } @@ -320,6 +301,7 @@ func (oq *destinationQueue) backgroundSend() { } oq.pendingPDUs = oq.pendingPDUs[pc:] oq.pendingEDUs = oq.pendingEDUs[ec:] + oq.pendingMutex.Unlock() } } } diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 5ec2ccd29..048bba301 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -123,9 +123,7 @@ func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *d destination: destination, client: oqs.client, statistics: oqs.statistics.ForServer(destination), - notifyPDUs: make(chan *queuedPDU, 16), - notifyEDUs: make(chan *queuedEDU, 16), - notifyOverflow: make(chan struct{}, 1), + notify: make(chan struct{}, 1), interruptBackoff: make(chan bool), signing: oqs.signing, }