diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index ce706768e..e2314ebbe 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -250,6 +250,12 @@ func (oq *destinationQueue) backgroundSend() { oq.cleanPendingEDUs() oq.cleanPendingInvites() return + } else { + // We haven't been told to give up terminally yet but we still have + // PDUs waiting to be sent. By sending a message into the wake chan, + // the next loop iteration will try processing these PDUs again, + // subject to the backoff. + oq.wakeServerCh <- true } } else if transaction { // If we successfully sent the transaction then clear out diff --git a/federationsender/queue/queue.go b/federationsender/queue/queue.go index 492d5f553..bc7ec0f93 100644 --- a/federationsender/queue/queue.go +++ b/federationsender/queue/queue.go @@ -51,7 +51,7 @@ func NewOutgoingQueues( statistics *types.Statistics, signing *SigningInfo, ) *OutgoingQueues { - return &OutgoingQueues{ + queues := &OutgoingQueues{ db: db, rsAPI: rsAPI, origin: origin, @@ -60,6 +60,15 @@ func NewOutgoingQueues( signing: signing, queues: map[gomatrixserverlib.ServerName]*destinationQueue{}, } + // Look up which servers we have pending items for and then rehydrate those queues. + if serverNames, err := db.GetPendingServerNames(context.Background()); err == nil { + for _, serverName := range serverNames { + queues.getQueue(serverName).wakeQueueIfNeeded() + } + } else { + log.WithError(err).Error("Failed to get server names for destination queue hydration") + } + return queues } // TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index 09d74ed7e..4bf36c247 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -31,4 +31,5 @@ type Database interface { GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) CleanTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID) error GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error) + GetPendingServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) } diff --git a/federationsender/storage/postgres/queue_pdus_table.go b/federationsender/storage/postgres/queue_pdus_table.go index bc22825d8..dab6003e9 100644 --- a/federationsender/storage/postgres/queue_pdus_table.go +++ b/federationsender/storage/postgres/queue_pdus_table.go @@ -63,6 +63,9 @@ const selectQueuePDUsCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE server_name = $1" +const selectQueueServerNamesSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" + type queuePDUsStatements struct { insertQueuePDUStmt *sql.Stmt deleteQueueTransactionPDUsStmt *sql.Stmt @@ -70,6 +73,7 @@ type queuePDUsStatements struct { selectQueuePDUsByTransactionStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt selectQueuePDUsCountStmt *sql.Stmt + selectQueueServerNamesStmt *sql.Stmt } func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { @@ -95,6 +99,9 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { return } + if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + return + } return } @@ -190,3 +197,24 @@ func (s *queuePDUsStatements) selectQueuePDUs( return result, rows.Err() } + +func (s *queuePDUsStatements) selectQueueServerNames( + ctx context.Context, txn *sql.Tx, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName gomatrixserverlib.ServerName + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, serverName) + } + + return result, rows.Err() +} diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index be28c15dc..80686e090 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -264,3 +264,11 @@ func (d *Database) GetPendingPDUCount( ) (int64, error) { return d.selectQueuePDUCount(ctx, nil, serverName) } + +// GetPendingServerNames returns the server names that have PDUs +// waiting to be sent. +func (d *Database) GetPendingServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + return d.selectQueueServerNames(ctx, nil) +} diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go index 955ff507d..33eef91ed 100644 --- a/federationsender/storage/sqlite3/queue_pdus_table.go +++ b/federationsender/storage/sqlite3/queue_pdus_table.go @@ -64,6 +64,9 @@ const selectQueuePDUsCountSQL = "" + "SELECT COUNT(*) FROM federationsender_queue_pdus" + " WHERE server_name = $1" +const selectQueueServerNamesSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_queue_pdus" + type queuePDUsStatements struct { insertQueuePDUStmt *sql.Stmt deleteQueueTransactionPDUsStmt *sql.Stmt @@ -71,6 +74,7 @@ type queuePDUsStatements struct { selectQueuePDUsByTransactionStmt *sql.Stmt selectQueueReferenceJSONCountStmt *sql.Stmt selectQueuePDUsCountStmt *sql.Stmt + selectQueueServerNamesStmt *sql.Stmt } func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { @@ -96,6 +100,9 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil { return } + if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil { + return + } return } @@ -188,3 +195,24 @@ func (s *queuePDUsStatements) selectQueuePDUs( return result, rows.Err() } + +func (s *queuePDUsStatements) selectQueueServerNames( + ctx context.Context, txn *sql.Tx, +) ([]gomatrixserverlib.ServerName, error) { + stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt) + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed") + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName gomatrixserverlib.ServerName + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, serverName) + } + + return result, rows.Err() +} diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 30ac81bfd..7ba51fb52 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -270,3 +270,11 @@ func (d *Database) GetPendingPDUCount( ) (int64, error) { return d.selectQueuePDUCount(ctx, nil, serverName) } + +// GetPendingServerNames returns the server names that have PDUs +// waiting to be sent. +func (d *Database) GetPendingServerNames( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + return d.selectQueueServerNames(ctx, nil) +}