diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index 3ab939c68..c30567250 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -255,7 +255,7 @@ func (oq *destinationQueue) backgroundSend() { } // Try sending the next transaction and see what happens. - transaction, terr := oq.nextTransaction(transactionID, oq.pendingPDUs, oq.pendingEDUs, oq.statistics.SuccessCount()) + transaction, terr := oq.nextTransaction(transactionID, oq.pendingPDUs, oq.pendingEDUs) if terr != nil { // We failed to send the transaction. if giveUp := oq.statistics.Failure(); giveUp { @@ -328,7 +328,6 @@ func (oq *destinationQueue) nextTransaction( transactionID gomatrixserverlib.TransactionID, pendingPDUs []*gomatrixserverlib.HeaderedEvent, pendingEDUs []*gomatrixserverlib.EDU, - sentCounter uint32, ) (bool, error) { t := gomatrixserverlib.Transaction{ PDUs: []json.RawMessage{}, diff --git a/federationsender/storage/postgres/queue_json_table.go b/federationsender/storage/postgres/queue_json_table.go index 0f6ebe2ae..1095c6ad4 100644 --- a/federationsender/storage/postgres/queue_json_table.go +++ b/federationsender/storage/postgres/queue_json_table.go @@ -83,10 +83,10 @@ func (s *queueJSONStatements) insertQueueJSON( } func (s *queueJSONStatements) deleteQueueJSON( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, nids []int64, ) error { stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) - _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) + _, err := stmt.ExecContext(ctx, pq.Int64Array(nids)) return err } diff --git a/federationsender/storage/postgres/queue_pdus_table.go b/federationsender/storage/postgres/queue_pdus_table.go index 4ba22bc20..4e8b9f7e9 100644 --- a/federationsender/storage/postgres/queue_pdus_table.go +++ b/federationsender/storage/postgres/queue_pdus_table.go @@ -56,11 +56,16 @@ const selectQueuePDUsByTransactionSQL = "" + " WHERE server_name = $1 AND transaction_id = $2" + " LIMIT $3" +const selectQueueReferenceJSONCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_pdus" + + " WHERE json_nid = $1" + type queuePDUsStatements struct { - insertQueuePDUStmt *sql.Stmt - deleteQueueTransactionPDUsStmt *sql.Stmt - selectQueueNextTransactionIDStmt *sql.Stmt - selectQueuePDUsByTransactionStmt *sql.Stmt + insertQueuePDUStmt *sql.Stmt + deleteQueueTransactionPDUsStmt *sql.Stmt + selectQueueNextTransactionIDStmt *sql.Stmt + selectQueuePDUsByTransactionStmt *sql.Stmt + selectQueueReferenceJSONCountStmt *sql.Stmt } func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { @@ -80,6 +85,9 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { return } + if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil { + return + } return } @@ -112,8 +120,8 @@ func (s *queuePDUsStatements) deleteQueueTransaction( func (s *queuePDUsStatements) selectQueueNextTransactionID( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (string, error) { - var transactionID string +) (gomatrixserverlib.TransactionID, error) { + var transactionID gomatrixserverlib.TransactionID stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) if err == sql.ErrNoRows { @@ -122,8 +130,23 @@ func (s *queuePDUsStatements) selectQueueNextTransactionID( return transactionID, err } +func (s *queuePDUsStatements) selectQueueReferenceJSONCount( + ctx context.Context, txn *sql.Tx, jsonNID int64, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) + err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + if err == sql.ErrNoRows { + return -1, nil + } + return count, err +} + func (s *queuePDUsStatements) selectQueuePDUs( - ctx context.Context, txn *sql.Tx, serverName string, transactionID string, limit int, + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, + limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsByTransactionStmt) rows, err := stmt.QueryContext(ctx, serverName, transactionID, limit) diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index fb69a2f0f..b90ac874a 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -186,7 +186,7 @@ func (d *Database) GetNextTransactionPDUs( return "", nil, nil } - nids, err := d.selectQueuePDUs(ctx, nil, string(serverName), transactionID, limit) + nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, limit) if err != nil { return "", nil, fmt.Errorf("d.selectQueuePDUs: %w", err) } @@ -216,5 +216,32 @@ func (d *Database) CleanTransactionPDUs( serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, ) error { - return d.deleteQueueTransaction(ctx, nil, serverName, transactionID) + nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, 50) + if err != nil { + return fmt.Errorf("d.selectQueuePDUs: %w", err) + } + + if err = d.deleteQueueTransaction(ctx, nil, serverName, transactionID); err != nil { + return fmt.Errorf("d.deleteQueueTransaction: %w", err) + } + + var count int64 + var deleteNIDs []int64 + for _, nid := range nids { + count, err = d.selectQueueReferenceJSONCount(ctx, nil, nid) + if err != nil { + return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) + } + if count == 0 { + deleteNIDs = append(deleteNIDs, nid) + } + } + + if len(deleteNIDs) > 0 { + if err = d.deleteQueueJSON(ctx, nil, deleteNIDs); err != nil { + return fmt.Errorf("d.deleteQueueJSON: %w", err) + } + } + + return nil }