diff --git a/federationsender/storage/postgres/queue_pdus_table.go b/federationsender/storage/postgres/queue_pdus_table.go index 4e8b9f7e9..13b3b49c2 100644 --- a/federationsender/storage/postgres/queue_pdus_table.go +++ b/federationsender/storage/postgres/queue_pdus_table.go @@ -134,7 +134,7 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount( ctx context.Context, txn *sql.Tx, jsonNID int64, ) (int64, error) { var count int64 - stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) + stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt) err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) if err == sql.ErrNoRows { return -1, nil diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index 619904919..153dec049 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -156,18 +156,20 @@ func (d *Database) AssociatePDUWithDestination( serverName gomatrixserverlib.ServerName, nids []int64, ) error { - for _, nid := range nids { - if err := d.insertQueuePDU( - ctx, // context - nil, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - nid, // NID from the federationsender_queue_json table - ); err != nil { - return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + for _, nid := range nids { + if err := d.insertQueuePDU( + ctx, // context + txn, // SQL transaction + transactionID, // transaction ID + serverName, // destination server name + nid, // NID from the federationsender_queue_json table + ); err != nil { + return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) + } } - } - return nil + return nil + }) } // GetNextTransactionPDUs retrieves events from the database for @@ -176,36 +178,42 @@ func (d *Database) GetNextTransactionPDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, limit int, -) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) { - transactionID, err := d.selectQueueNextTransactionID(ctx, nil, serverName) - if err != nil { - return "", nil, fmt.Errorf("d.selectQueueNextTransactionID: %w", err) - } - - if transactionID == "" { - return "", nil, nil - } - - nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, limit) - if err != nil { - return "", nil, fmt.Errorf("d.selectQueuePDUs: %w", err) - } - - blobs, err := d.selectQueueJSON(ctx, nil, nids) - if err != nil { - return "", nil, fmt.Errorf("d.selectJSON: %w", err) - } - - var events []*gomatrixserverlib.HeaderedEvent - for _, blob := range blobs { - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal(blob, &event); err != nil { - return "", nil, fmt.Errorf("json.Unmarshal: %w", err) +) ( + transactionID gomatrixserverlib.TransactionID, + events []*gomatrixserverlib.HeaderedEvent, + err error, +) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + transactionID, err = d.selectQueueNextTransactionID(ctx, txn, serverName) + if err != nil { + return fmt.Errorf("d.selectQueueNextTransactionID: %w", err) } - events = append(events, &event) - } - return gomatrixserverlib.TransactionID(transactionID), events, nil + if transactionID == "" { + return nil + } + + nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, limit) + if err != nil { + return fmt.Errorf("d.selectQueuePDUs: %w", err) + } + + blobs, err := d.selectQueueJSON(ctx, txn, nids) + if err != nil { + return fmt.Errorf("d.selectJSON: %w", err) + } + + for _, blob := range blobs { + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal(blob, &event); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + events = append(events, &event) + } + + return nil + }) + return } // CleanTransactionPDUs cleans up all associated events for a @@ -216,39 +224,41 @@ func (d *Database) CleanTransactionPDUs( serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, ) error { - fmt.Println("Cleaning up after transaction", transactionID) + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + fmt.Println("Cleaning up after transaction", transactionID) - nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, 50) - if err != nil { - return fmt.Errorf("d.selectQueuePDUs: %w", err) - } - - fmt.Println("Transaction", transactionID, "has", len(nids), "NIDs") - - if err = d.deleteQueueTransaction(ctx, nil, serverName, transactionID); err != nil { - return fmt.Errorf("d.deleteQueueTransaction: %w", err) - } - - var count int64 - var deleteNIDs []int64 - for _, nid := range nids { - count, err = d.selectQueueReferenceJSONCount(ctx, nil, nid) + nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50) if err != nil { - return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) + return fmt.Errorf("d.selectQueuePDUs: %w", err) } - if count == 0 { - deleteNIDs = append(deleteNIDs, nid) + + fmt.Println("Transaction", transactionID, "has", len(nids), "NIDs") + + if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil { + return fmt.Errorf("d.deleteQueueTransaction: %w", err) } - } - fmt.Println("There are", len(deleteNIDs), "unreferenced NIDs ready for deletion") - fmt.Println("There are", len(nids)-len(deleteNIDs), "NIDs still referenced") - - if len(deleteNIDs) > 0 { - if err = d.deleteQueueJSON(ctx, nil, deleteNIDs); err != nil { - return fmt.Errorf("d.deleteQueueJSON: %w", err) + var count int64 + var deleteNIDs []int64 + for _, nid := range nids { + count, err = d.selectQueueReferenceJSONCount(ctx, txn, nid) + if err != nil { + return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) + } + if count == 0 { + deleteNIDs = append(deleteNIDs, nid) + } } - } - return nil + fmt.Println("There are", len(deleteNIDs), "unreferenced NIDs ready for deletion") + fmt.Println("There are", len(nids)-len(deleteNIDs), "NIDs still referenced") + + if len(deleteNIDs) > 0 { + if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil { + return fmt.Errorf("d.deleteQueueJSON: %w", err) + } + } + + return nil + }) }