mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Make sure that the federation sender knows how many pending events are in the database when the worker starts
This commit is contained in:
parent
7089050733
commit
9ea9c20307
|
|
@ -52,7 +52,7 @@ type destinationQueue struct {
|
||||||
transactionIDMutex sync.Mutex // protects transactionID
|
transactionIDMutex sync.Mutex // protects transactionID
|
||||||
transactionID gomatrixserverlib.TransactionID // last transaction ID
|
transactionID gomatrixserverlib.TransactionID // last transaction ID
|
||||||
transactionCount atomic.Int32 // how many events in this transaction so far
|
transactionCount atomic.Int32 // how many events in this transaction so far
|
||||||
pendingPDUs atomic.Int32 // how many PDUs are waiting to be sent
|
pendingPDUs atomic.Int64 // how many PDUs are waiting to be sent
|
||||||
pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend
|
pendingEDUs []*gomatrixserverlib.EDU // owned by backgroundSend
|
||||||
pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend
|
pendingInvites []*gomatrixserverlib.InviteV2Request // owned by backgroundSend
|
||||||
wakeServerCh chan bool // interrupts idle wait
|
wakeServerCh chan bool // interrupts idle wait
|
||||||
|
|
@ -91,6 +91,7 @@ func (oq *destinationQueue) sendEvent(nid int64) {
|
||||||
// If the destination is blacklisted then drop the event.
|
// If the destination is blacklisted then drop the event.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
oq.wakeQueueIfNeeded()
|
||||||
// Create a transaction ID. We'll either do this if we don't have
|
// Create a transaction ID. We'll either do this if we don't have
|
||||||
// one made up yet, or if we've exceeded the number of maximum
|
// one made up yet, or if we've exceeded the number of maximum
|
||||||
// events allowed in a single tranaction. We'll reset the counter
|
// events allowed in a single tranaction. We'll reset the counter
|
||||||
|
|
@ -117,10 +118,6 @@ func (oq *destinationQueue) sendEvent(nid int64) {
|
||||||
// We've successfully added a PDU to the transaction so increase
|
// We've successfully added a PDU to the transaction so increase
|
||||||
// the counter.
|
// the counter.
|
||||||
oq.transactionCount.Add(1)
|
oq.transactionCount.Add(1)
|
||||||
// If the queue isn't running at this point then start it.
|
|
||||||
if !oq.running.Load() {
|
|
||||||
go oq.backgroundSend()
|
|
||||||
}
|
|
||||||
// Signal that we've sent a new PDU. This will cause the queue to
|
// Signal that we've sent a new PDU. This will cause the queue to
|
||||||
// wake up if it's asleep. The return to the Add function will only
|
// wake up if it's asleep. The return to the Add function will only
|
||||||
// be 1 if the previous value was 0, e.g. nothing was waiting before.
|
// be 1 if the previous value was 0, e.g. nothing was waiting before.
|
||||||
|
|
@ -137,9 +134,7 @@ func (oq *destinationQueue) sendEDU(ev *gomatrixserverlib.EDU) {
|
||||||
// If the destination is blacklisted then drop the event.
|
// If the destination is blacklisted then drop the event.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if !oq.running.Load() {
|
oq.wakeQueueIfNeeded()
|
||||||
go oq.backgroundSend()
|
|
||||||
}
|
|
||||||
oq.incomingEDUs <- ev
|
oq.incomingEDUs <- ev
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -151,10 +146,27 @@ func (oq *destinationQueue) sendInvite(ev *gomatrixserverlib.InviteV2Request) {
|
||||||
// If the destination is blacklisted then drop the event.
|
// If the destination is blacklisted then drop the event.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
oq.wakeQueueIfNeeded()
|
||||||
|
oq.incomingInvites <- ev
|
||||||
|
}
|
||||||
|
|
||||||
|
func (oq *destinationQueue) wakeQueueIfNeeded() {
|
||||||
if !oq.running.Load() {
|
if !oq.running.Load() {
|
||||||
|
// Look up how many events are pending in this queue. We need
|
||||||
|
// to do this so that the queue thinks it has work to do.
|
||||||
|
count, err := oq.db.GetPendingPDUCount(
|
||||||
|
context.TODO(),
|
||||||
|
oq.destination,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
oq.pendingPDUs.Store(count)
|
||||||
|
log.Printf("Destination queue %q has %d pending PDUs", oq.destination, count)
|
||||||
|
} else {
|
||||||
|
log.WithError(err).Errorf("Can't get pending PDU count for %q destination queue", oq.destination)
|
||||||
|
}
|
||||||
|
// Then start the queue.
|
||||||
go oq.backgroundSend()
|
go oq.backgroundSend()
|
||||||
}
|
}
|
||||||
oq.incomingInvites <- ev
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// backgroundSend is the worker goroutine for sending events.
|
// backgroundSend is the worker goroutine for sending events.
|
||||||
|
|
@ -366,7 +378,7 @@ func (oq *destinationQueue) nextTransaction(
|
||||||
case nil:
|
case nil:
|
||||||
// No error was returned so the transaction looks to have
|
// No error was returned so the transaction looks to have
|
||||||
// been successfully sent.
|
// been successfully sent.
|
||||||
oq.pendingPDUs.Sub(int32(len(t.PDUs)))
|
oq.pendingPDUs.Sub(int64(len(t.PDUs)))
|
||||||
// Clean up the transaction in the database.
|
// Clean up the transaction in the database.
|
||||||
if err = oq.db.CleanTransactionPDUs(
|
if err = oq.db.CleanTransactionPDUs(
|
||||||
context.TODO(),
|
context.TODO(),
|
||||||
|
|
|
||||||
|
|
@ -30,4 +30,5 @@ type Database interface {
|
||||||
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nids []int64) error
|
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nids []int64) error
|
||||||
GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error)
|
GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error)
|
||||||
CleanTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID) error
|
CleanTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID) error
|
||||||
|
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -59,12 +59,17 @@ const selectQueueReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
|
const selectQueuePDUsCountSQL = "" +
|
||||||
|
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||||
|
" WHERE server_name = $1"
|
||||||
|
|
||||||
type queuePDUsStatements struct {
|
type queuePDUsStatements struct {
|
||||||
insertQueuePDUStmt *sql.Stmt
|
insertQueuePDUStmt *sql.Stmt
|
||||||
deleteQueueTransactionPDUsStmt *sql.Stmt
|
deleteQueueTransactionPDUsStmt *sql.Stmt
|
||||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||||
selectQueuePDUsByTransactionStmt *sql.Stmt
|
selectQueuePDUsByTransactionStmt *sql.Stmt
|
||||||
selectQueueReferenceJSONCountStmt *sql.Stmt
|
selectQueueReferenceJSONCountStmt *sql.Stmt
|
||||||
|
selectQueuePDUsCountStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) {
|
func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
@ -87,6 +92,9 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil {
|
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -144,6 +152,21 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *queuePDUsStatements) selectQueuePDUCount(
|
||||||
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
|
) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
// It's acceptable for there to be no rows referencing a given
|
||||||
|
// JSON NID but it's not an error condition. Just return as if
|
||||||
|
// there's a zero count.
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) selectQueuePDUs(
|
func (s *queuePDUsStatements) selectQueuePDUs(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
|
|
||||||
|
|
@ -255,3 +255,12 @@ func (d *Database) CleanTransactionPDUs(
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPendingPDUCount returns the number of PDUs waiting to be
|
||||||
|
// sent for a given servername.
|
||||||
|
func (d *Database) GetPendingPDUCount(
|
||||||
|
ctx context.Context,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
) (int64, error) {
|
||||||
|
return d.selectQueuePDUCount(ctx, nil, serverName)
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -60,12 +60,17 @@ const selectQueueReferenceJSONCountSQL = "" +
|
||||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||||
" WHERE json_nid = $1"
|
" WHERE json_nid = $1"
|
||||||
|
|
||||||
|
const selectQueuePDUsCountSQL = "" +
|
||||||
|
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||||
|
" WHERE server_name = $1"
|
||||||
|
|
||||||
type queuePDUsStatements struct {
|
type queuePDUsStatements struct {
|
||||||
insertQueuePDUStmt *sql.Stmt
|
insertQueuePDUStmt *sql.Stmt
|
||||||
deleteQueueTransactionPDUsStmt *sql.Stmt
|
deleteQueueTransactionPDUsStmt *sql.Stmt
|
||||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||||
selectQueuePDUsByTransactionStmt *sql.Stmt
|
selectQueuePDUsByTransactionStmt *sql.Stmt
|
||||||
selectQueueReferenceJSONCountStmt *sql.Stmt
|
selectQueueReferenceJSONCountStmt *sql.Stmt
|
||||||
|
selectQueuePDUsCountStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) {
|
func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) {
|
||||||
|
|
@ -88,6 +93,9 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) {
|
||||||
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil {
|
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -142,6 +150,21 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount(
|
||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *queuePDUsStatements) selectQueuePDUCount(
|
||||||
|
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||||
|
) (int64, error) {
|
||||||
|
var count int64
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
||||||
|
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
// It's acceptable for there to be no rows referencing a given
|
||||||
|
// JSON NID but it's not an error condition. Just return as if
|
||||||
|
// there's a zero count.
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return count, err
|
||||||
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) selectQueuePDUs(
|
func (s *queuePDUsStatements) selectQueuePDUs(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
|
|
||||||
|
|
@ -261,3 +261,12 @@ func (d *Database) CleanTransactionPDUs(
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPendingPDUCount returns the number of PDUs waiting to be
|
||||||
|
// sent for a given servername.
|
||||||
|
func (d *Database) GetPendingPDUCount(
|
||||||
|
ctx context.Context,
|
||||||
|
serverName gomatrixserverlib.ServerName,
|
||||||
|
) (int64, error) {
|
||||||
|
return d.selectQueuePDUCount(ctx, nil, serverName)
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue