From 9ea9c20307caff2c0a6c308277c0c0656c62dfa5 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 2 Jul 2020 15:40:17 +0100 Subject: [PATCH] Make sure that the federation sender knows how many pending events are in the database when the worker starts --- federationsender/queue/destinationqueue.go | 32 +++++++++++++------ federationsender/storage/interface.go | 1 + .../storage/postgres/queue_pdus_table.go | 23 +++++++++++++ federationsender/storage/postgres/storage.go | 9 ++++++ .../storage/sqlite3/queue_pdus_table.go | 23 +++++++++++++ federationsender/storage/sqlite3/storage.go | 9 ++++++ 6 files changed, 87 insertions(+), 10 deletions(-) diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index a736b3852..a6e41b8d4 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -52,7 +52,7 @@ type destinationQueue struct { transactionIDMutex sync.Mutex // protects transactionID transactionID gomatrixserverlib.TransactionID // last transaction ID transactionCount atomic.Int32 // how many events in this transaction so far - pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent + pendingPDUs atomic.Int64 // how many PDUs are waiting to be sent pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend wakeServerCh chan bool // interrupts idle wait @@ -91,6 +91,7 @@ func (oq *destinationQueue) sendEvent(nid int64) { // If the destination is blacklisted then drop the event. return } + oq.wakeQueueIfNeeded() // Create a transaction ID. We'll either do this if we don't have // one made up yet, or if we've exceeded the number of maximum // events allowed in a single tranaction. We'll reset the counter @@ -117,10 +118,6 @@ func (oq *destinationQueue) sendEvent(nid int64) { // We've successfully added a PDU to the transaction so increase // the counter. oq.transactionCount.Add(1) - // If the queue isn't running at this point then start it. - if !oq.running.Load() { - go oq.backgroundSend() - } // Signal that we've sent a new PDU. This will cause the queue to // wake up if it's asleep. The return to the Add function will only // be 1 if the previous value was 0, e.g. nothing was waiting before. @@ -137,9 +134,7 @@ func (oq *destinationQueue) sendEDU(ev *gomatrixserverlib.EDU) { // If the destination is blacklisted then drop the event. return } - if !oq.running.Load() { - go oq.backgroundSend() - } + oq.wakeQueueIfNeeded() oq.incomingEDUs <- ev } @@ -151,10 +146,27 @@ func (oq *destinationQueue) sendInvite(ev *gomatrixserverlib.InviteV2Request) { // If the destination is blacklisted then drop the event. return } + oq.wakeQueueIfNeeded() + oq.incomingInvites <- ev +} + +func (oq *destinationQueue) wakeQueueIfNeeded() { if !oq.running.Load() { + // Look up how many events are pending in this queue. We need + // to do this so that the queue thinks it has work to do. + count, err := oq.db.GetPendingPDUCount( + context.TODO(), + oq.destination, + ) + if err == nil { + oq.pendingPDUs.Store(count) + log.Printf("Destination queue %q has %d pending PDUs", oq.destination, count) + } else { + log.WithError(err).Errorf("Can't get pending PDU count for %q destination queue", oq.destination) + } + // Then start the queue. go oq.backgroundSend() } - oq.incomingInvites <- ev } // backgroundSend is the worker goroutine for sending events. @@ -366,7 +378,7 @@ func (oq *destinationQueue) nextTransaction( case nil: // No error was returned so the transaction looks to have // been successfully sent. - oq.pendingPDUs.Sub(int32(len(t.PDUs))) + oq.pendingPDUs.Sub(int64(len(t.PDUs))) // Clean up the transaction in the database. if err = oq.db.CleanTransactionPDUs( context.TODO(), diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index f4df93fa4..09d74ed7e 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -30,4 +30,5 @@ type Database interface { AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nids []int64) error GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) CleanTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID) error + GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) } diff --git a/federationsender/storage/postgres/queue_pdus_table.go b/federationsender/storage/postgres/queue_pdus_table.go index ef7a9f41e..bc22825d8 100644 --- a/federationsender/storage/postgres/queue_pdus_table.go +++ b/federationsender/storage/postgres/queue_pdus_table.go @@ -59,12 +59,17 @@ const selectQueueReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" +const selectQueuePDUsCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_pdus" + + " WHERE server_name = $1" + type queuePDUsStatements struct { insertQueuePDUStmt *sql.Stmt deleteQueueTransactionPDUsStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsByTransactionStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt + selectQueuePDUsCountStmt *sql.Stmt } func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { @@ -87,6 +92,9 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil { return } + if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { + return + } return } @@ -144,6 +152,21 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount( return count, err } +func (s *queuePDUsStatements) selectQueuePDUCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} + func (s *queuePDUsStatements) selectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index 18d1532a4..be28c15dc 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -255,3 +255,12 @@ func (d *Database) CleanTransactionPDUs( return nil }) } + +// GetPendingPDUCount returns the number of PDUs waiting to be +// sent for a given servername. +func (d *Database) GetPendingPDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return d.selectQueuePDUCount(ctx, nil, serverName) +} diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go index dc08fd707..955ff507d 100644 --- a/federationsender/storage/sqlite3/queue_pdus_table.go +++ b/federationsender/storage/sqlite3/queue_pdus_table.go @@ -60,12 +60,17 @@ const selectQueueReferenceJSONCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE json_nid = $1" +const selectQueuePDUsCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_pdus" + + " WHERE server_name = $1" + type queuePDUsStatements struct { insertQueuePDUStmt *sql.Stmt deleteQueueTransactionPDUsStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt selectQueuePDUsByTransactionStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt + selectQueuePDUsCountStmt *sql.Stmt } func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { @@ -88,6 +93,9 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil { return } + if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { + return + } return } @@ -142,6 +150,21 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount( return count, err } +func (s *queuePDUsStatements) selectQueuePDUCount( + ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt) + err := stmt.QueryRowContext(ctx, serverName).Scan(&count) + if err == sql.ErrNoRows { + // It's acceptable for there to be no rows referencing a given + // JSON NID but it's not an error condition. Just return as if + // there's a zero count. + return 0, nil + } + return count, err +} + func (s *queuePDUsStatements) selectQueuePDUs( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 7629ecd21..30ac81bfd 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -261,3 +261,12 @@ func (d *Database) CleanTransactionPDUs( return nil }) } + +// GetPendingPDUCount returns the number of PDUs waiting to be +// sent for a given servername. +func (d *Database) GetPendingPDUCount( + ctx context.Context, + serverName gomatrixserverlib.ServerName, +) (int64, error) { + return d.selectQueuePDUCount(ctx, nil, serverName) +}