mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Fix race in incomplete transactions
This commit is contained in:
parent
cfbdff3c32
commit
d6c94064af
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationsender/storage"
|
"github.com/matrix-org/dendrite/federationsender/storage"
|
||||||
|
|
@ -48,6 +49,7 @@ type destinationQueue struct {
|
||||||
statistics *types.ServerStatistics // statistics about this remote server
|
statistics *types.ServerStatistics // statistics about this remote server
|
||||||
incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send
|
incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send
|
||||||
incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send
|
incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send
|
||||||
|
transactionMutex sync.Mutex // protects transactionID and transactionCount
|
||||||
transactionID gomatrixserverlib.TransactionID // last transaction ID
|
transactionID gomatrixserverlib.TransactionID // last transaction ID
|
||||||
transactionCount int // how many events in this transaction so far
|
transactionCount int // how many events in this transaction so far
|
||||||
pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent
|
pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent
|
||||||
|
|
@ -93,11 +95,13 @@ func (oq *destinationQueue) sendEvent(nid int64) {
|
||||||
// one made up yet, or if we've exceeded the number of maximum
|
// one made up yet, or if we've exceeded the number of maximum
|
||||||
// events allowed in a single tranaction. We'll reset the counter
|
// events allowed in a single tranaction. We'll reset the counter
|
||||||
// when we do.
|
// when we do.
|
||||||
|
oq.transactionMutex.Lock()
|
||||||
if oq.transactionID == "" || oq.transactionCount >= maxPDUsPerTransaction {
|
if oq.transactionID == "" || oq.transactionCount >= maxPDUsPerTransaction {
|
||||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
now := gomatrixserverlib.AsTimestamp(time.Now())
|
||||||
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
||||||
oq.transactionCount = 0
|
oq.transactionCount = 0
|
||||||
}
|
}
|
||||||
|
oq.transactionMutex.Unlock()
|
||||||
// Create a database entry that associates the given PDU NID with
|
// Create a database entry that associates the given PDU NID with
|
||||||
// this destination queue. We'll then be able to retrieve the PDU
|
// this destination queue. We'll then be able to retrieve the PDU
|
||||||
// later.
|
// later.
|
||||||
|
|
@ -164,12 +168,6 @@ func (oq *destinationQueue) backgroundSend() {
|
||||||
defer oq.running.Store(false)
|
defer oq.running.Store(false)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
// For now we don't know the next transaction ID. Set it to an
|
|
||||||
// empty one. The next step will populate it if we have pending
|
|
||||||
// PDUs in the database. Otherwise we'll generate one later on,
|
|
||||||
// e.g. in response to EDUs.
|
|
||||||
transactionID := gomatrixserverlib.TransactionID("")
|
|
||||||
|
|
||||||
// If we have nothing to do then wait either for incoming events, or
|
// If we have nothing to do then wait either for incoming events, or
|
||||||
// until we hit an idle timeout.
|
// until we hit an idle timeout.
|
||||||
if oq.pendingPDUs.Load() == 0 && len(oq.pendingEDUs) == 0 && len(oq.pendingInvites) == 0 {
|
if oq.pendingPDUs.Load() == 0 && len(oq.pendingEDUs) == 0 && len(oq.pendingInvites) == 0 {
|
||||||
|
|
@ -228,17 +226,8 @@ func (oq *destinationQueue) backgroundSend() {
|
||||||
|
|
||||||
// If we have pending PDUs or EDUs then construct a transaction.
|
// If we have pending PDUs or EDUs then construct a transaction.
|
||||||
if oq.pendingPDUs.Load() > 0 || len(oq.pendingEDUs) > 0 {
|
if oq.pendingPDUs.Load() > 0 || len(oq.pendingEDUs) > 0 {
|
||||||
// If we haven't got a transaction ID then we should generate
|
|
||||||
// one. Ideally we'd know this already because something queued
|
|
||||||
// in the database would give us one, but if we're dealing with
|
|
||||||
// EDUs alone, we won't go via the database so we'll make one.
|
|
||||||
if transactionID == "" {
|
|
||||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
|
||||||
transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try sending the next transaction and see what happens.
|
// Try sending the next transaction and see what happens.
|
||||||
transaction, terr := oq.nextTransaction(transactionID, oq.pendingEDUs)
|
transaction, terr := oq.nextTransaction(oq.pendingEDUs)
|
||||||
if terr != nil {
|
if terr != nil {
|
||||||
// We failed to send the transaction.
|
// We failed to send the transaction.
|
||||||
if giveUp := oq.statistics.Failure(); giveUp {
|
if giveUp := oq.statistics.Failure(); giveUp {
|
||||||
|
|
@ -253,7 +242,6 @@ func (oq *destinationQueue) backgroundSend() {
|
||||||
// If we successfully sent the transaction then clear out
|
// If we successfully sent the transaction then clear out
|
||||||
// the pending events and EDUs, and wipe our transaction ID.
|
// the pending events and EDUs, and wipe our transaction ID.
|
||||||
oq.statistics.Success()
|
oq.statistics.Success()
|
||||||
oq.transactionID = ""
|
|
||||||
// Clean up the in-memory buffers.
|
// Clean up the in-memory buffers.
|
||||||
oq.cleanPendingEDUs()
|
oq.cleanPendingEDUs()
|
||||||
}
|
}
|
||||||
|
|
@ -305,9 +293,19 @@ func (oq *destinationQueue) cleanPendingInvites() {
|
||||||
// queue and sends it. Returns true if a transaction was sent or
|
// queue and sends it. Returns true if a transaction was sent or
|
||||||
// false otherwise.
|
// false otherwise.
|
||||||
func (oq *destinationQueue) nextTransaction(
|
func (oq *destinationQueue) nextTransaction(
|
||||||
transactionID gomatrixserverlib.TransactionID,
|
|
||||||
pendingEDUs []*gomatrixserverlib.EDU,
|
pendingEDUs []*gomatrixserverlib.EDU,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
|
// Before we do anything, we need to roll over the transaction
|
||||||
|
// ID that is being used to coalesce events into the next TX.
|
||||||
|
// Otherwise it's possible that we'll pick up an incomplete
|
||||||
|
// transaction and end up nuking the rest of the events at the
|
||||||
|
// cleanup stage.
|
||||||
|
oq.transactionMutex.Lock()
|
||||||
|
oq.transactionID = ""
|
||||||
|
oq.transactionCount = 0
|
||||||
|
oq.transactionMutex.Unlock()
|
||||||
|
|
||||||
|
// Create the transaction.
|
||||||
t := gomatrixserverlib.Transaction{
|
t := gomatrixserverlib.Transaction{
|
||||||
PDUs: []json.RawMessage{},
|
PDUs: []json.RawMessage{},
|
||||||
EDUs: []gomatrixserverlib.EDU{},
|
EDUs: []gomatrixserverlib.EDU{},
|
||||||
|
|
@ -316,42 +314,50 @@ func (oq *destinationQueue) nextTransaction(
|
||||||
t.Destination = oq.destination
|
t.Destination = oq.destination
|
||||||
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
|
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
|
||||||
|
|
||||||
|
// Ask the database for any pending PDUs from the next transaction.
|
||||||
|
// maxPDUsPerTransaction is an upper limit but we probably won't
|
||||||
|
// actually retrieve that many events.
|
||||||
txid, pdus, err := oq.db.GetNextTransactionPDUs(
|
txid, pdus, err := oq.db.GetNextTransactionPDUs(
|
||||||
context.TODO(), // context
|
context.TODO(), // context
|
||||||
oq.destination, // server name
|
oq.destination, // server name
|
||||||
maxPDUsPerTransaction, // how many events to retrieve
|
maxPDUsPerTransaction, // max events to retrieve
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination)
|
log.WithError(err).Errorf("failed to get next transaction PDUs for server %q", oq.destination)
|
||||||
return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err)
|
return false, fmt.Errorf("oq.db.GetNextTransactionPDUs: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we didn't get anything from the database and there are no
|
||||||
|
// pending EDUs then there's nothing to do - stop here.
|
||||||
if len(pdus) == 0 && len(pendingEDUs) == 0 {
|
if len(pdus) == 0 && len(pendingEDUs) == 0 {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if txid != "" {
|
// Pick out the transaction ID from the database. If we didn't
|
||||||
// The database supplied us with a transaction ID to use
|
// get a transaction ID (i.e. because there are no PDUs but only
|
||||||
// from a failed PDU so use that.
|
// EDUs) then generate a transaction ID.
|
||||||
t.TransactionID = txid
|
t.TransactionID = txid
|
||||||
} else {
|
if t.TransactionID == "" {
|
||||||
// Otherwise, use the one that the function call gave us.
|
now := gomatrixserverlib.AsTimestamp(time.Now())
|
||||||
// This would happen if it's EDUs only.
|
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
||||||
t.TransactionID = transactionID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Go through PDUs that we retrieved from the database, if any,
|
||||||
|
// and add them into the transaction.
|
||||||
for _, pdu := range pdus {
|
for _, pdu := range pdus {
|
||||||
// Append the JSON of the event, since this is a json.RawMessage type in the
|
// Append the JSON of the event, since this is a json.RawMessage type in the
|
||||||
// gomatrixserverlib.Transaction struct
|
// gomatrixserverlib.Transaction struct
|
||||||
t.PDUs = append(t.PDUs, (*pdu).JSON())
|
t.PDUs = append(t.PDUs, (*pdu).JSON())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do the same for pending EDUS in the queue.
|
||||||
for _, edu := range pendingEDUs {
|
for _, edu := range pendingEDUs {
|
||||||
t.EDUs = append(t.EDUs, *edu)
|
t.EDUs = append(t.EDUs, *edu)
|
||||||
}
|
}
|
||||||
|
|
||||||
logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
|
logrus.WithField("server_name", oq.destination).Infof("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
|
||||||
|
|
||||||
|
// Try to send the transaction to the destination server.
|
||||||
// TODO: we should check for 500-ish fails vs 400-ish here,
|
// TODO: we should check for 500-ish fails vs 400-ish here,
|
||||||
// since we shouldn't queue things indefinitely in response
|
// since we shouldn't queue things indefinitely in response
|
||||||
// to a 400-ish error
|
// to a 400-ish error
|
||||||
|
|
@ -367,7 +373,7 @@ func (oq *destinationQueue) nextTransaction(
|
||||||
t.Destination,
|
t.Destination,
|
||||||
t.TransactionID,
|
t.TransactionID,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
log.WithError(err).Errorf("failed to clean transaction %q for server %q", transactionID, oq.destination)
|
log.WithError(err).Errorf("failed to clean transaction %q for server %q", t.TransactionID, t.Destination)
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
case gomatrix.HTTPError:
|
case gomatrix.HTTPError:
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue