Thread safety on transaction ID/count

This commit is contained in:
Neil Alexander 2020-07-01 11:19:48 +01:00
parent d6c94064af
commit ea1e8397fc

View file

@ -49,9 +49,9 @@ type destinationQueue struct {
statistics *types.ServerStatistics // statistics about this remote server statistics *types.ServerStatistics // statistics about this remote server
incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send incomingInvites chan *gomatrixserverlib.InviteV2Request // invites to send
incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send incomingEDUs chan *gomatrixserverlib.EDU // EDUs to send
transactionMutex sync.Mutex // protects transactionID and transactionCount transactionIDMutex sync.Mutex // protects transactionID
transactionID gomatrixserverlib.TransactionID // last transaction ID transactionID gomatrixserverlib.TransactionID // last transaction ID
transactionCount int // how many events in this transaction so far transactionCount atomic.Int32 // how many events in this transaction so far
pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent
pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend
pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend
@ -95,13 +95,13 @@ func (oq *destinationQueue) sendEvent(nid int64) {
// one made up yet, or if we've exceeded the number of maximum // one made up yet, or if we've exceeded the number of maximum
// events allowed in a single tranaction. We'll reset the counter // events allowed in a single tranaction. We'll reset the counter
// when we do. // when we do.
oq.transactionMutex.Lock() oq.transactionIDMutex.Lock()
if oq.transactionID == "" || oq.transactionCount >= maxPDUsPerTransaction { if oq.transactionID == "" || oq.transactionCount.Load() >= maxPDUsPerTransaction {
now := gomatrixserverlib.AsTimestamp(time.Now()) now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount())) oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
oq.transactionCount = 0 oq.transactionCount.Store(0)
} }
oq.transactionMutex.Unlock() oq.transactionIDMutex.Unlock()
// Create a database entry that associates the given PDU NID with // Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU // this destination queue. We'll then be able to retrieve the PDU
// later. // later.
@ -116,7 +116,7 @@ func (oq *destinationQueue) sendEvent(nid int64) {
} }
// We've successfully added a PDU to the transaction so increase // We've successfully added a PDU to the transaction so increase
// the counter. // the counter.
oq.transactionCount++ oq.transactionCount.Add(1)
// If the queue isn't running at this point then start it. // If the queue isn't running at this point then start it.
if !oq.running.Load() { if !oq.running.Load() {
go oq.backgroundSend() go oq.backgroundSend()
@ -300,10 +300,10 @@ func (oq *destinationQueue) nextTransaction(
// Otherwise it's possible that we'll pick up an incomplete // Otherwise it's possible that we'll pick up an incomplete
// transaction and end up nuking the rest of the events at the // transaction and end up nuking the rest of the events at the
// cleanup stage. // cleanup stage.
oq.transactionMutex.Lock() oq.transactionIDMutex.Lock()
oq.transactionID = "" oq.transactionID = ""
oq.transactionCount = 0 oq.transactionIDMutex.Unlock()
oq.transactionMutex.Unlock() oq.transactionCount.Store(0)
// Create the transaction. // Create the transaction.
t := gomatrixserverlib.Transaction{ t := gomatrixserverlib.Transaction{