Merge branch 'master' into kegan/device-mgmt

This commit is contained in:
Kegsay 2020-07-09 17:55:56 +01:00 committed by GitHub
commit 433b585e38
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 69 additions and 66 deletions

View file

@ -32,6 +32,7 @@ import (
) )
const maxPDUsPerTransaction = 50 const maxPDUsPerTransaction = 50
const queueIdleTimeout = time.Second * 30
// destinationQueue is a queue of events for a single destination. // destinationQueue is a queue of events for a single destination.
// It is responsible for sending the events to the destination and // It is responsible for sending the events to the destination and
@ -52,7 +53,6 @@ type destinationQueue struct {
transactionIDMutex sync.Mutex // protects transactionID transactionIDMutex sync.Mutex // protects transactionID
transactionID gomatrixserverlib.TransactionID // last transaction ID transactionID gomatrixserverlib.TransactionID // last transaction ID
transactionCount atomic.Int32 // how many events in this transaction so far transactionCount atomic.Int32 // how many events in this transaction so far
pendingPDUs atomic.Int64 // how many PDUs are waiting to be sent
pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend
pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend
notifyPDUs chan bool // interrupts idle wait for PDUs notifyPDUs chan bool // interrupts idle wait for PDUs
@ -68,7 +68,6 @@ func (oq *destinationQueue) sendEvent(nid int64) {
log.Infof("%s is blacklisted; dropping event", oq.destination) log.Infof("%s is blacklisted; dropping event", oq.destination)
return return
} }
oq.wakeQueueIfNeeded()
// Create a transaction ID. We'll either do this if we don't have // Create a transaction ID. We'll either do this if we don't have
// one made up yet, or if we've exceeded the number of maximum // one made up yet, or if we've exceeded the number of maximum
// events allowed in a single tranaction. We'll reset the counter // events allowed in a single tranaction. We'll reset the counter
@ -95,11 +94,13 @@ func (oq *destinationQueue) sendEvent(nid int64) {
// We've successfully added a PDU to the transaction so increase // We've successfully added a PDU to the transaction so increase
// the counter. // the counter.
oq.transactionCount.Add(1) oq.transactionCount.Add(1)
// Signal that we've sent a new PDU. This will cause the queue to // Wake up the queue if it's asleep.
// wake up if it's asleep. The return to the Add function will only oq.wakeQueueIfNeeded()
// be 1 if the previous value was 0, e.g. nothing was waiting before. // If we're blocking on waiting PDUs then tell the queue that we
if oq.pendingPDUs.Add(1) == 1 { // have work to do.
oq.notifyPDUs <- true select {
case oq.notifyPDUs <- true:
default:
} }
} }
@ -138,26 +139,33 @@ func (oq *destinationQueue) wakeQueueIfNeeded() {
} }
// If we aren't running then wake up the queue. // If we aren't running then wake up the queue.
if !oq.running.Load() { if !oq.running.Load() {
// Look up how many events are pending in this queue. We need // Start the queue.
// to do this so that the queue thinks it has work to do.
count, err := oq.db.GetPendingPDUCount(
context.TODO(),
oq.destination,
)
if err == nil {
oq.pendingPDUs.Store(count)
log.Printf("Destination queue %q has %d pending PDUs", oq.destination, count)
} else {
log.WithError(err).Errorf("Can't get pending PDU count for %q destination queue", oq.destination)
}
if count > 0 {
oq.notifyPDUs <- true
}
// Then start the queue.
go oq.backgroundSend() go oq.backgroundSend()
} }
} }
// waitForPDUs returns a channel for pending PDUs, which will be
// used in backgroundSend select. It returns a closed channel if
// there is something pending right now, or an open channel if
// we're waiting for something.
func (oq *destinationQueue) waitForPDUs() chan bool {
pendingPDUs, err := oq.db.GetPendingPDUCount(context.TODO(), oq.destination)
if err != nil {
log.WithError(err).Errorf("Failed to get pending PDU count on queue %q", oq.destination)
}
// If there are PDUs pending right now then we'll return a closed
// channel. This will mean that the backgroundSend will not block.
if pendingPDUs > 0 {
ch := make(chan bool, 1)
close(ch)
return ch
}
// If there are no PDUs pending right now then instead we'll return
// the notify channel, so that backgroundSend can pick up normal
// notifications from sendEvent.
return oq.notifyPDUs
}
// backgroundSend is the worker goroutine for sending events. // backgroundSend is the worker goroutine for sending events.
// nolint:gocyclo // nolint:gocyclo
func (oq *destinationQueue) backgroundSend() { func (oq *destinationQueue) backgroundSend() {
@ -169,12 +177,15 @@ func (oq *destinationQueue) backgroundSend() {
defer oq.running.Store(false) defer oq.running.Store(false)
for { for {
pendingPDUs := false
// If we have nothing to do then wait either for incoming events, or // If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout. // until we hit an idle timeout.
select { select {
case <-oq.notifyPDUs: case <-oq.waitForPDUs():
// We were woken up because there are new PDUs waiting in the // We were woken up because there are new PDUs waiting in the
// database. // database.
pendingPDUs = true
case edu := <-oq.incomingEDUs: case edu := <-oq.incomingEDUs:
// EDUs are handled in-memory for now. We will try to keep // EDUs are handled in-memory for now. We will try to keep
// the ordering intact. // the ordering intact.
@ -204,10 +215,11 @@ func (oq *destinationQueue) backgroundSend() {
for len(oq.incomingInvites) > 0 { for len(oq.incomingInvites) > 0 {
oq.pendingInvites = append(oq.pendingInvites, <-oq.incomingInvites) oq.pendingInvites = append(oq.pendingInvites, <-oq.incomingInvites)
} }
case <-time.After(time.Second * 30): case <-time.After(queueIdleTimeout):
// The worker is idle so stop the goroutine. It'll get // The worker is idle so stop the goroutine. It'll get
// restarted automatically the next time we have an event to // restarted automatically the next time we have an event to
// send. // send.
log.Infof("Queue %q has been idle for %s, going to sleep", oq.destination, queueIdleTimeout)
return return
} }
@ -220,12 +232,13 @@ func (oq *destinationQueue) backgroundSend() {
select { select {
case <-time.After(duration): case <-time.After(duration):
case <-oq.interruptBackoff: case <-oq.interruptBackoff:
log.Infof("Interrupting backoff for %q", oq.destination)
} }
oq.backingOff.Store(false) oq.backingOff.Store(false)
} }
// If we have pending PDUs or EDUs then construct a transaction. // If we have pending PDUs or EDUs then construct a transaction.
if oq.pendingPDUs.Load() > 0 || len(oq.pendingEDUs) > 0 { if pendingPDUs || len(oq.pendingEDUs) > 0 {
// Try sending the next transaction and see what happens. // Try sending the next transaction and see what happens.
transaction, terr := oq.nextTransaction(oq.pendingEDUs) transaction, terr := oq.nextTransaction(oq.pendingEDUs)
if terr != nil { if terr != nil {
@ -236,6 +249,7 @@ func (oq *destinationQueue) backgroundSend() {
// buffers at this point. The PDU clean-up is already on a defer. // buffers at this point. The PDU clean-up is already on a defer.
oq.cleanPendingEDUs() oq.cleanPendingEDUs()
oq.cleanPendingInvites() oq.cleanPendingInvites()
log.Infof("Blacklisting %q due to errors", oq.destination)
return return
} else { } else {
// We haven't been told to give up terminally yet but we still have // We haven't been told to give up terminally yet but we still have
@ -262,6 +276,7 @@ func (oq *destinationQueue) backgroundSend() {
if giveUp := oq.statistics.Failure(); giveUp { if giveUp := oq.statistics.Failure(); giveUp {
// It's been suggested that we should give up because // It's been suggested that we should give up because
// the backoff has exceeded a maximum allowable value. // the backoff has exceeded a maximum allowable value.
log.Infof("Blacklisting %q due to errors", oq.destination)
return return
} }
} else if sent > 0 { } else if sent > 0 {
@ -273,17 +288,6 @@ func (oq *destinationQueue) backgroundSend() {
oq.cleanPendingInvites() oq.cleanPendingInvites()
} }
} }
// If something else has come along while we were busy sending
// the previous transaction then we don't want the next loop
// iteration to sleep. Send a message if someone else hasn't
// already sent a wake-up.
if oq.pendingPDUs.Load() > 0 {
select {
case oq.notifyPDUs <- true:
default:
}
}
} }
} }
@ -349,17 +353,6 @@ func (oq *destinationQueue) nextTransaction(
// If we didn't get anything from the database and there are no // If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here. // pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(pendingEDUs) == 0 { if len(pdus) == 0 && len(pendingEDUs) == 0 {
log.Warnf("Expected PDUs/EDUs for destination %q but got none", oq.destination)
// This shouldn't really happen but since it has, let's check
// how many events are *really* in the database that are waiting.
if count, cerr := oq.db.GetPendingPDUCount(
context.TODO(),
oq.destination,
); cerr == nil {
oq.pendingPDUs.Store(count)
} else {
log.Warnf("Failed to retrieve pending PDU count for %q", oq.destination)
}
return false, nil return false, nil
} }
@ -396,9 +389,6 @@ func (oq *destinationQueue) nextTransaction(
_, err = oq.client.SendTransaction(ctx, t) _, err = oq.client.SendTransaction(ctx, t)
switch err.(type) { switch err.(type) {
case nil: case nil:
// No error was returned so the transaction looks to have
// been successfully sent.
oq.pendingPDUs.Sub(int64(len(t.PDUs)))
// Clean up the transaction in the database. // Clean up the transaction in the database.
if err = oq.db.CleanTransactionPDUs( if err = oq.db.CleanTransactionPDUs(
context.Background(), context.Background(),

View file

@ -35,7 +35,9 @@ type Database struct {
queuePDUsStatements queuePDUsStatements
queueJSONStatements queueJSONStatements
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
db *sql.DB db *sql.DB
queuePDUsWriter *sqlutil.TransactionWriter
queueJSONWriter *sqlutil.TransactionWriter
} }
// NewDatabase opens a new database // NewDatabase opens a new database
@ -74,6 +76,9 @@ func (d *Database) prepare() error {
return err return err
} }
d.queuePDUsWriter = sqlutil.NewTransactionWriter()
d.queueJSONWriter = sqlutil.NewTransactionWriter()
return d.PartitionOffsetStatements.Prepare(d.db, "federationsender") return d.PartitionOffsetStatements.Prepare(d.db, "federationsender")
} }
@ -145,12 +150,16 @@ func (d *Database) GetJoinedHosts(
// metadata entries. // metadata entries.
func (d *Database) StoreJSON( func (d *Database) StoreJSON(
ctx context.Context, js string, ctx context.Context, js string,
) (int64, error) { ) (nid int64, err error) {
nid, err := d.insertQueueJSON(ctx, nil, js) err = d.queueJSONWriter.Do(d.db, func(txn *sql.Tx) error {
if err != nil { n, e := d.insertQueueJSON(ctx, nil, js)
return 0, fmt.Errorf("d.insertQueueJSON: %w", err) if e != nil {
} return fmt.Errorf("d.insertQueueJSON: %w", e)
return nid, nil }
nid = n
return nil
})
return
} }
// AssociatePDUWithDestination creates an association that the // AssociatePDUWithDestination creates an association that the
@ -162,7 +171,7 @@ func (d *Database) AssociatePDUWithDestination(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nids []int64, nids []int64,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { return d.queuePDUsWriter.Do(d.db, func(txn *sql.Tx) error {
for _, nid := range nids { for _, nid := range nids {
if err := d.insertQueuePDU( if err := d.insertQueuePDU(
ctx, // context ctx, // context
@ -230,18 +239,18 @@ func (d *Database) CleanTransactionPDUs(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID, transactionID gomatrixserverlib.TransactionID,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { var err error
nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50) var nids []int64
var deleteNIDs []int64
if err = d.queuePDUsWriter.Do(d.db, func(txn *sql.Tx) error {
nids, err = d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50)
if err != nil { if err != nil {
return fmt.Errorf("d.selectQueuePDUs: %w", err) return fmt.Errorf("d.selectQueuePDUs: %w", err)
} }
if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil { if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil {
return fmt.Errorf("d.deleteQueueTransaction: %w", err) return fmt.Errorf("d.deleteQueueTransaction: %w", err)
} }
var count int64 var count int64
var deleteNIDs []int64
for _, nid := range nids { for _, nid := range nids {
count, err = d.selectQueueReferenceJSONCount(ctx, txn, nid) count, err = d.selectQueueReferenceJSONCount(ctx, txn, nid)
if err != nil { if err != nil {
@ -251,15 +260,19 @@ func (d *Database) CleanTransactionPDUs(
deleteNIDs = append(deleteNIDs, nid) deleteNIDs = append(deleteNIDs, nid)
} }
} }
return nil
}); err != nil {
return err
}
err = d.queueJSONWriter.Do(d.db, func(txn *sql.Tx) error {
if len(deleteNIDs) > 0 { if len(deleteNIDs) > 0 {
if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil { if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil {
return fmt.Errorf("d.deleteQueueJSON: %w", err) return fmt.Errorf("d.deleteQueueJSON: %w", err)
} }
} }
return nil return nil
}) })
return err
} }
// GetPendingPDUCount returns the number of PDUs waiting to be // GetPendingPDUCount returns the number of PDUs waiting to be