mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Fix race in incomplete transactions
This commit is contained in:
parent
cfbdff3c32
commit
d6c94064af
|
|
@ -18,6 +18,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationsender/storage"
|
||||
|
|
@ -48,6 +49,7 @@ type destinationQueue struct {
|
|||
statistics *types.ServerStatistics // statistics about this remote server
|
||||
incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send
|
||||
incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send
|
||||
transactionMutex sync.Mutex // protects transactionID and transactionCount
|
||||
transactionID gomatrixserverlib.TransactionID // last transaction ID
|
||||
transactionCount int // how many events in this transaction so far
|
||||
pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent
|
||||
|
|
@ -93,11 +95,13 @@ func (oq *destinationQueue) sendEvent(nid int64) {
|
|||
// one made up yet, or if we've exceeded the number of maximum
|
||||
// events allowed in a single tranaction. We'll reset the counter
|
||||
// when we do.
|
||||
oq.transactionMutex.Lock()
|
||||
if oq.transactionID == "" || oq.transactionCount >= maxPDUsPerTransaction {
|
||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
||||
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
||||
oq.transactionCount = 0
|
||||
}
|
||||
oq.transactionMutex.Unlock()
|
||||
// Create a database entry that associates the given PDU NID with
|
||||
// this destination queue. We'll then be able to retrieve the PDU
|
||||
// later.
|
||||
|
|
@ -164,12 +168,6 @@ func (oq *destinationQueue) backgroundSend() {
|
|||
defer oq.running.Store(false)
|
||||
|
||||
for {
|
||||
// For now we don't know the next transaction ID. Set it to an
|
||||
// empty one. The next step will populate it if we have pending
|
||||
// PDUs in the database. Otherwise we'll generate one later on,
|
||||
// e.g. in response to EDUs.
|
||||
transactionID := gomatrixserverlib.TransactionID("")
|
||||
|
||||
// If we have nothing to do then wait either for incoming events, or
|
||||
// until we hit an idle timeout.
|
||||
if oq.pendingPDUs.Load() == 0 && len(oq.pendingEDUs) == 0 && len(oq.pendingInvites) == 0 {
|
||||
|
|
@ -228,17 +226,8 @@ func (oq *destinationQueue) backgroundSend() {
|
|||
|
||||
// If we have pending PDUs or EDUs then construct a transaction.
|
||||
if oq.pendingPDUs.Load() > 0 || len(oq.pendingEDUs) > 0 {
|
||||
// If we haven't got a transaction ID then we should generate
|
||||
// one. Ideally we'd know this already because something queued
|
||||
// in the database would give us one, but if we're dealing with
|
||||
// EDUs alone, we won't go via the database so we'll make one.
|
||||
if transactionID == "" {
|
||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
||||
transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
||||
}
|
||||
|
||||
// Try sending the next transaction and see what happens.
|
||||
transaction, terr := oq.nextTransaction(transactionID, oq.pendingEDUs)
|
||||
transaction, terr := oq.nextTransaction(oq.pendingEDUs)
|
||||
if terr != nil {
|
||||
// We failed to send the transaction.
|
||||
if giveUp := oq.statistics.Failure(); giveUp {
|
||||
|
|
@ -253,7 +242,6 @@ func (oq *destinationQueue) backgroundSend() {
|
|||
// If we successfully sent the transaction then clear out
|
||||
// the pending events and EDUs, and wipe our transaction ID.
|
||||
oq.statistics.Success()
|
||||
oq.transactionID = ""
|
||||
// Clean up the in-memory buffers.
|
||||
oq.cleanPendingEDUs()
|
||||
}
|
||||
|
|
@ -305,9 +293,19 @@ func (oq *destinationQueue) cleanPendingInvites() {
|
|||
// queue and sends it. Returns true if a transaction was sent or
|
||||
// false otherwise.
|
||||
func (oq *destinationQueue) nextTransaction(
|
||||
transactionID gomatrixserverlib.TransactionID,
|
||||
pendingEDUs []*gomatrixserverlib.EDU,
|
||||
) (bool, error) {
|
||||
// Before we do anything, we need to roll over the transaction
|
||||
// ID that is being used to coalesce events into the next TX.
|
||||
// Otherwise it's possible that we'll pick up an incomplete
|
||||
// transaction and end up nuking the rest of the events at the
|
||||
// cleanup stage.
|
||||
oq.transactionMutex.Lock()
|
||||
oq.transactionID = ""
|
||||
oq.transactionCount = 0
|
||||
oq.transactionMutex.Unlock()
|
||||
|
||||
// Create the transaction.
|
||||
t := gomatrixserverlib.Transaction{
|
||||
PDUs: []json.RawMessage{},
|
||||
EDUs: []gomatrixserverlib.EDU{},
|
||||
|
|
@ -316,42 +314,50 @@ func (oq *destinationQueue) nextTransaction(
|
|||
t.Destination = oq.destination
|
||||
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
|
||||
// Ask the database for any pending PDUs from the next transaction.
|
||||
// maxPDUsPerTransaction is an upper limit but we probably won't
|
||||
// actually retrieve that many events.
|
||||
txid, pdus, err := oq.db.GetNextTransactionPDUs(
|
||||
context.TODO(), // context
|
||||
oq.destination, // server name
|
||||
maxPDUsPerTransaction, // how many events to retrieve
|
||||
maxPDUsPerTransaction, // max events to retrieve
|
||||
)
|
||||
if err != nil {
|
||||
log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination)
|
||||
return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err)
|
||||
}
|
||||
|
||||
// If we didn't get anything from the database and there are no
|
||||
// pending EDUs then there's nothing to do - stop here.
|
||||
if len(pdus) == 0 && len(pendingEDUs) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if txid != "" {
|
||||
// The database supplied us with a transaction ID to use
|
||||
// from a failed PDU so use that.
|
||||
t.TransactionID = txid
|
||||
} else {
|
||||
// Otherwise, use the one that the function call gave us.
|
||||
// This would happen if it's EDUs only.
|
||||
t.TransactionID = transactionID
|
||||
// Pick out the transaction ID from the database. If we didn't
|
||||
// get a transaction ID (i.e. because there are no PDUs but only
|
||||
// EDUs) then generate a transaction ID.
|
||||
t.TransactionID = txid
|
||||
if t.TransactionID == "" {
|
||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
||||
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
||||
}
|
||||
|
||||
// Go through PDUs that we retrieved from the database, if any,
|
||||
// and add them into the transaction.
|
||||
for _, pdu := range pdus {
|
||||
// Append the JSON of the event, since this is a json.RawMessage type in the
|
||||
// gomatrixserverlib.Transaction struct
|
||||
t.PDUs = append(t.PDUs, (*pdu).JSON())
|
||||
}
|
||||
|
||||
// Do the same for pending EDUS in the queue.
|
||||
for _, edu := range pendingEDUs {
|
||||
t.EDUs = append(t.EDUs, *edu)
|
||||
}
|
||||
|
||||
logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
|
||||
|
||||
// Try to send the transaction to the destination server.
|
||||
// TODO: we should check for 500-ish fails vs 400-ish here,
|
||||
// since we shouldn't queue things indefinitely in response
|
||||
// to a 400-ish error
|
||||
|
|
@ -367,7 +373,7 @@ func (oq *destinationQueue) nextTransaction(
|
|||
t.Destination,
|
||||
t.TransactionID,
|
||||
); err != nil {
|
||||
log.WithError(err).Errorf("failed to clean transaction %q for server %q", transactionID, oq.destination)
|
||||
log.WithError(err).Errorf("failed to clean transaction %q for server %q", t.TransactionID, t.Destination)
|
||||
}
|
||||
return true, nil
|
||||
case gomatrix.HTTPError:
|
||||
|
|
|
|||
Loading…
Reference in a new issue