diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index c629e469a..b6c3aa23a 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "sync" "time" "github.com/matrix-org/dendrite/federationsender/storage" @@ -48,6 +49,7 @@ type destinationQueue struct { statistics *types.ServerStatistics // statistics about this remote server incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send + transactionMutex sync.Mutex // protects transactionID and transactionCount transactionID gomatrixserverlib.TransactionID // last transaction ID transactionCount int // how many events in this transaction so far pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent @@ -93,11 +95,13 @@ func (oq *destinationQueue) sendEvent(nid int64) { // one made up yet, or if we've exceeded the number of maximum // events allowed in a single tranaction. We'll reset the counter // when we do. + oq.transactionMutex.Lock() if oq.transactionID == "" || oq.transactionCount >= maxPDUsPerTransaction { now := gomatrixserverlib.AsTimestamp(time.Now()) oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) oq.transactionCount = 0 } + oq.transactionMutex.Unlock() // Create a database entry that associates the given PDU NID with // this destination queue. We'll then be able to retrieve the PDU // later. @@ -164,12 +168,6 @@ func (oq *destinationQueue) backgroundSend() { defer oq.running.Store(false) for { - // For now we don't know the next transaction ID. Set it to an - // empty one. The next step will populate it if we have pending - // PDUs in the database. Otherwise we'll generate one later on, - // e.g. in response to EDUs. - transactionID := gomatrixserverlib.TransactionID("") - // If we have nothing to do then wait either for incoming events, or // until we hit an idle timeout. if oq.pendingPDUs.Load() == 0 && len(oq.pendingEDUs) == 0 && len(oq.pendingInvites) == 0 { @@ -228,17 +226,8 @@ func (oq *destinationQueue) backgroundSend() { // If we have pending PDUs or EDUs then construct a transaction. if oq.pendingPDUs.Load() > 0 || len(oq.pendingEDUs) > 0 { - // If we haven't got a transaction ID then we should generate - // one. Ideally we'd know this already because something queued - // in the database would give us one, but if we're dealing with - // EDUs alone, we won't go via the database so we'll make one. - if transactionID == "" { - now := gomatrixserverlib.AsTimestamp(time.Now()) - transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) - } - // Try sending the next transaction and see what happens. - transaction, terr := oq.nextTransaction(transactionID, oq.pendingEDUs) + transaction, terr := oq.nextTransaction(oq.pendingEDUs) if terr != nil { // We failed to send the transaction. if giveUp := oq.statistics.Failure(); giveUp { @@ -253,7 +242,6 @@ func (oq *destinationQueue) backgroundSend() { // If we successfully sent the transaction then clear out // the pending events and EDUs, and wipe our transaction ID. oq.statistics.Success() - oq.transactionID = "" // Clean up the in-memory buffers. oq.cleanPendingEDUs() } @@ -305,9 +293,19 @@ func (oq *destinationQueue) cleanPendingInvites() { // queue and sends it. Returns true if a transaction was sent or // false otherwise. func (oq *destinationQueue) nextTransaction( - transactionID gomatrixserverlib.TransactionID, pendingEDUs []*gomatrixserverlib.EDU, ) (bool, error) { + // Before we do anything, we need to roll over the transaction + // ID that is being used to coalesce events into the next TX. + // Otherwise it's possible that we'll pick up an incomplete + // transaction and end up nuking the rest of the events at the + // cleanup stage. + oq.transactionMutex.Lock() + oq.transactionID = "" + oq.transactionCount = 0 + oq.transactionMutex.Unlock() + + // Create the transaction. t := gomatrixserverlib.Transaction{ PDUs: []json.RawMessage{}, EDUs: []gomatrixserverlib.EDU{}, @@ -316,42 +314,50 @@ func (oq *destinationQueue) nextTransaction( t.Destination = oq.destination t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) + // Ask the database for any pending PDUs from the next transaction. + // maxPDUsPerTransaction is an upper limit but we probably won't + // actually retrieve that many events. txid, pdus, err := oq.db.GetNextTransactionPDUs( context.TODO(), // context oq.destination, // server name - maxPDUsPerTransaction, // how many events to retrieve + maxPDUsPerTransaction, // max events to retrieve ) if err != nil { log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination) return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err) } + // If we didn't get anything from the database and there are no + // pending EDUs then there's nothing to do - stop here. if len(pdus) == 0 && len(pendingEDUs) == 0 { return false, nil } - if txid != "" { - // The database supplied us with a transaction ID to use - // from a failed PDU so use that. - t.TransactionID = txid - } else { - // Otherwise, use the one that the function call gave us. - // This would happen if it's EDUs only. - t.TransactionID = transactionID + // Pick out the transaction ID from the database. If we didn't + // get a transaction ID (i.e. because there are no PDUs but only + // EDUs) then generate a transaction ID. + t.TransactionID = txid + if t.TransactionID == "" { + now := gomatrixserverlib.AsTimestamp(time.Now()) + t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) } + // Go through PDUs that we retrieved from the database, if any, + // and add them into the transaction. for _, pdu := range pdus { // Append the JSON of the event, since this is a json.RawMessage type in the // gomatrixserverlib.Transaction struct t.PDUs = append(t.PDUs, (*pdu).JSON()) } + // Do the same for pending EDUS in the queue. for _, edu := range pendingEDUs { t.EDUs = append(t.EDUs, *edu) } logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) + // Try to send the transaction to the destination server. // TODO: we should check for 500-ish fails vs 400-ish here, // since we shouldn't queue things indefinitely in response // to a 400-ish error @@ -367,7 +373,7 @@ func (oq *destinationQueue) nextTransaction( t.Destination, t.TransactionID, ); err != nil { - log.WithError(err).Errorf("failed to clean transaction %q for server %q", transactionID, oq.destination) + log.WithError(err).Errorf("failed to clean transaction %q for server %q", t.TransactionID, t.Destination) } return true, nil case gomatrix.HTTPError: