Fix race in incomplete transactions

This commit is contained in:
Neil Alexander 2020-07-01 11:14:52 +01:00
parent cfbdff3c32
commit d6c94064af

View file

@ -18,6 +18,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sync"
"time" "time"
"github.com/matrix-org/dendrite/federationsender/storage" "github.com/matrix-org/dendrite/federationsender/storage"
@ -48,6 +49,7 @@ type destinationQueue struct {
statistics *types.ServerStatistics // statistics about this remote server statistics *types.ServerStatistics // statistics about this remote server
incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send
incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send
transactionMutex sync.Mutex // protects transactionID and transactionCount
transactionID gomatrixserverlib.TransactionID // last transaction ID transactionID gomatrixserverlib.TransactionID // last transaction ID
transactionCount int // how many events in this transaction so far transactionCount int // how many events in this transaction so far
pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent
@ -93,11 +95,13 @@ func (oq *destinationQueue) sendEvent(nid int64) {
// one made up yet, or if we've exceeded the number of maximum // one made up yet, or if we've exceeded the number of maximum
// events allowed in a single tranaction. We'll reset the counter // events allowed in a single tranaction. We'll reset the counter
// when we do. // when we do.
oq.transactionMutex.Lock()
if oq.transactionID == "" || oq.transactionCount >= maxPDUsPerTransaction { if oq.transactionID == "" || oq.transactionCount >= maxPDUsPerTransaction {
now := gomatrixserverlib.AsTimestamp(time.Now()) now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
oq.transactionCount = 0 oq.transactionCount = 0
} }
oq.transactionMutex.Unlock()
// Create a database entry that associates the given PDU NID with // Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
@ -164,12 +168,6 @@ func (oq *destinationQueue) backgroundSend() {
defer oq.running.Store(false) defer oq.running.Store(false)
for { for {
// For now we don't know the next transaction ID. Set it to an
// empty one. The next step will populate it if we have pending
// PDUs in the database. Otherwise we'll generate one later on,
// e.g. in response to EDUs.
transactionID := gomatrixserverlib.TransactionID("")
// If we have nothing to do then wait either for incoming events, or // If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout. // until we hit an idle timeout.
if oq.pendingPDUs.Load() == 0 && len(oq.pendingEDUs) == 0 && len(oq.pendingInvites) == 0 { if oq.pendingPDUs.Load() == 0 && len(oq.pendingEDUs) == 0 && len(oq.pendingInvites) == 0 {
@ -228,17 +226,8 @@ func (oq *destinationQueue) backgroundSend() {
// If we have pending PDUs or EDUs then construct a transaction. // If we have pending PDUs or EDUs then construct a transaction.
if oq.pendingPDUs.Load() > 0 || len(oq.pendingEDUs) > 0 { if oq.pendingPDUs.Load() > 0 || len(oq.pendingEDUs) > 0 {
// If we haven't got a transaction ID then we should generate
// one. Ideally we'd know this already because something queued
// in the database would give us one, but if we're dealing with
// EDUs alone, we won't go via the database so we'll make one.
if transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
// Try sending the next transaction and see what happens. // Try sending the next transaction and see what happens.
transaction, terr := oq.nextTransaction(transactionID, oq.pendingEDUs) transaction, terr := oq.nextTransaction(oq.pendingEDUs)
if terr != nil { if terr != nil {
// We failed to send the transaction. // We failed to send the transaction.
if giveUp := oq.statistics.Failure(); giveUp { if giveUp := oq.statistics.Failure(); giveUp {
@ -253,7 +242,6 @@ func (oq *destinationQueue) backgroundSend() {
// If we successfully sent the transaction then clear out // If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID. // the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success() oq.statistics.Success()
oq.transactionID = ""
// Clean up the in-memory buffers. // Clean up the in-memory buffers.
oq.cleanPendingEDUs() oq.cleanPendingEDUs()
} }
@ -305,9 +293,19 @@ func (oq *destinationQueue) cleanPendingInvites() {
// queue and sends it. Returns true if a transaction was sent or // queue and sends it. Returns true if a transaction was sent or
// false otherwise. // false otherwise.
func (oq *destinationQueue) nextTransaction( func (oq *destinationQueue) nextTransaction(
transactionID gomatrixserverlib.TransactionID,
pendingEDUs []*gomatrixserverlib.EDU, pendingEDUs []*gomatrixserverlib.EDU,
) (bool, error) { ) (bool, error) {
// Before we do anything, we need to roll over the transaction
// ID that is being used to coalesce events into the next TX.
// Otherwise it's possible that we'll pick up an incomplete
// transaction and end up nuking the rest of the events at the
// cleanup stage.
oq.transactionMutex.Lock()
oq.transactionID = ""
oq.transactionCount = 0
oq.transactionMutex.Unlock()
// Create the transaction.
t := gomatrixserverlib.Transaction{ t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{}, PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{}, EDUs: []gomatrixserverlib.EDU{},
@ -316,42 +314,50 @@ func (oq *destinationQueue) nextTransaction(
t.Destination = oq.destination t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now()) t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
// Ask the database for any pending PDUs from the next transaction.
// maxPDUsPerTransaction is an upper limit but we probably won't
// actually retrieve that many events.
txid, pdus, err := oq.db.GetNextTransactionPDUs( txid, pdus, err := oq.db.GetNextTransactionPDUs(
context.TODO(), // context context.TODO(), // context
oq.destination, // server name oq.destination, // server name
maxPDUsPerTransaction, // how many events to retrieve maxPDUsPerTransaction, // max events to retrieve
) )
if err != nil { if err != nil {
log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination) log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination)
return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err) return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err)
} }
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(pendingEDUs) == 0 { if len(pdus) == 0 && len(pendingEDUs) == 0 {
return false, nil return false, nil
} }
if txid != "" { // Pick out the transaction ID from the database. If we didn't
// The database supplied us with a transaction ID to use // get a transaction ID (i.e. because there are no PDUs but only
// from a failed PDU so use that. // EDUs) then generate a transaction ID.
t.TransactionID = txid t.TransactionID = txid
} else { if t.TransactionID == "" {
// Otherwise, use the one that the function call gave us. now := gomatrixserverlib.AsTimestamp(time.Now())
// This would happen if it's EDUs only. t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
t.TransactionID = transactionID
} }
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus { for _, pdu := range pdus {
// Append the JSON of the event, since this is a json.RawMessage type in the // Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct // gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, (*pdu).JSON()) t.PDUs = append(t.PDUs, (*pdu).JSON())
} }
// Do the same for pending EDUS in the queue.
for _, edu := range pendingEDUs { for _, edu := range pendingEDUs {
t.EDUs = append(t.EDUs, *edu) t.EDUs = append(t.EDUs, *edu)
} }
logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs)) logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
// Try to send the transaction to the destination server.
// TODO: we should check for 500-ish fails vs 400-ish here, // TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response // since we shouldn't queue things indefinitely in response
// to a 400-ish error // to a 400-ish error
@ -367,7 +373,7 @@ func (oq *destinationQueue) nextTransaction(
t.Destination, t.Destination,
t.TransactionID, t.TransactionID,
); err != nil { ); err != nil {
log.WithError(err).Errorf("failed to clean transaction %q for server %q", transactionID, oq.destination) log.WithError(err).Errorf("failed to clean transaction %q for server %q", t.TransactionID, t.Destination)
} }
return true, nil return true, nil
case gomatrix.HTTPError: case gomatrix.HTTPError: