Fix race in incomplete transactions

This commit is contained in:
Neil Alexander 2020-07-01 11:14:52 +01:00
parent cfbdff3c32
commit d6c94064af

View file

@ -18,6 +18,7 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/matrix-org/dendrite/federationsender/storage"
@ -48,6 +49,7 @@ type destinationQueue struct {
statistics *types.ServerStatistics // statistics about this remote server
incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send
incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send
transactionMutex sync.Mutex // protects transactionID and transactionCount
transactionID gomatrixserverlib.TransactionID // last transaction ID
transactionCount int // how many events in this transaction so far
pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent
@ -93,11 +95,13 @@ func (oq *destinationQueue) sendEvent(nid int64) {
// one made up yet, or if we've exceeded the number of maximum
// events allowed in a single tranaction. We'll reset the counter
// when we do.
oq.transactionMutex.Lock()
if oq.transactionID == "" || oq.transactionCount >= maxPDUsPerTransaction {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
oq.transactionCount = 0
}
oq.transactionMutex.Unlock()
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
@ -164,12 +168,6 @@ func (oq *destinationQueue) backgroundSend() {
defer oq.running.Store(false)
for {
// For now we don't know the next transaction ID. Set it to an
// empty one. The next step will populate it if we have pending
// PDUs in the database. Otherwise we'll generate one later on,
// e.g. in response to EDUs.
transactionID := gomatrixserverlib.TransactionID("")
// If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout.
if oq.pendingPDUs.Load() == 0 && len(oq.pendingEDUs) == 0 && len(oq.pendingInvites) == 0 {
@ -228,17 +226,8 @@ func (oq *destinationQueue) backgroundSend() {
// If we have pending PDUs or EDUs then construct a transaction.
if oq.pendingPDUs.Load() > 0 || len(oq.pendingEDUs) > 0 {
// If we haven't got a transaction ID then we should generate
// one. Ideally we'd know this already because something queued
// in the database would give us one, but if we're dealing with
// EDUs alone, we won't go via the database so we'll make one.
if transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
// Try sending the next transaction and see what happens.
transaction, terr := oq.nextTransaction(transactionID, oq.pendingEDUs)
transaction, terr := oq.nextTransaction(oq.pendingEDUs)
if terr != nil {
// We failed to send the transaction.
if giveUp := oq.statistics.Failure(); giveUp {
@ -253,7 +242,6 @@ func (oq *destinationQueue) backgroundSend() {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.transactionID = ""
// Clean up the in-memory buffers.
oq.cleanPendingEDUs()
}
@ -305,9 +293,19 @@ func (oq *destinationQueue) cleanPendingInvites() {
// queue and sends it. Returns true if a transaction was sent or
// false otherwise.
func (oq *destinationQueue) nextTransaction(
transactionID gomatrixserverlib.TransactionID,
pendingEDUs []*gomatrixserverlib.EDU,
) (bool, error) {
// Before we do anything, we need to roll over the transaction
// ID that is being used to coalesce events into the next TX.
// Otherwise it's possible that we'll pick up an incomplete
// transaction and end up nuking the rest of the events at the
// cleanup stage.
oq.transactionMutex.Lock()
oq.transactionID = ""
oq.transactionCount = 0
oq.transactionMutex.Unlock()
// Create the transaction.
t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
@ -316,42 +314,50 @@ func (oq *destinationQueue) nextTransaction(
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
// Ask the database for any pending PDUs from the next transaction.
// maxPDUsPerTransaction is an upper limit but we probably won't
// actually retrieve that many events.
txid, pdus, err := oq.db.GetNextTransactionPDUs(
context.TODO(), // context
oq.destination, // server name
maxPDUsPerTransaction, // how many events to retrieve
maxPDUsPerTransaction, // max events to retrieve
)
if err != nil {
log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination)
return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err)
}
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(pendingEDUs) == 0 {
return false, nil
}
if txid != "" {
// The database supplied us with a transaction ID to use
// from a failed PDU so use that.
// Pick out the transaction ID from the database. If we didn't
// get a transaction ID (i.e. because there are no PDUs but only
// EDUs) then generate a transaction ID.
t.TransactionID = txid
} else {
// Otherwise, use the one that the function call gave us.
// This would happen if it's EDUs only.
t.TransactionID = transactionID
if t.TransactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, (*pdu).JSON())
}
// Do the same for pending EDUS in the queue.
for _, edu := range pendingEDUs {
t.EDUs = append(t.EDUs, *edu)
}
logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
// Try to send the transaction to the destination server.
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
@ -367,7 +373,7 @@ func (oq *destinationQueue) nextTransaction(
t.Destination,
t.TransactionID,
); err != nil {
log.WithError(err).Errorf("failed to clean transaction %q for server %q", transactionID, oq.destination)
log.WithError(err).Errorf("failed to clean transaction %q for server %q", t.TransactionID, t.Destination)
}
return true, nil
case gomatrix.HTTPError: