Fix query, use transactions in postgres

This commit is contained in:
Neil Alexander 2020-06-30 15:11:51 +01:00
parent ebaa0cf5d4
commit 61ff558fef
2 changed files with 78 additions and 68 deletions

View file

@ -134,7 +134,7 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64, ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) { ) (int64, error) {
var count int64 var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return -1, nil return -1, nil

View file

@ -156,10 +156,11 @@ func (d *Database) AssociatePDUWithDestination(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nids []int64, nids []int64,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
for _, nid := range nids { for _, nid := range nids {
if err := d.insertQueuePDU( if err := d.insertQueuePDU(
ctx, // context ctx, // context
nil, // SQL transaction txn, // SQL transaction
transactionID, // transaction ID transactionID, // transaction ID
serverName, // destination server name serverName, // destination server name
nid, // NID from the federationsender_queue_json table nid, // NID from the federationsender_queue_json table
@ -168,6 +169,7 @@ func (d *Database) AssociatePDUWithDestination(
} }
} }
return nil return nil
})
} }
// GetNextTransactionPDUs retrieves events from the database for // GetNextTransactionPDUs retrieves events from the database for
@ -176,36 +178,42 @@ func (d *Database) GetNextTransactionPDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) { ) (
transactionID, err := d.selectQueueNextTransactionID(ctx, nil, serverName) transactionID gomatrixserverlib.TransactionID,
events []*gomatrixserverlib.HeaderedEvent,
err error,
) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
transactionID, err = d.selectQueueNextTransactionID(ctx, txn, serverName)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("d.selectQueueNextTransactionID: %w", err) return fmt.Errorf("d.selectQueueNextTransactionID: %w", err)
} }
if transactionID == "" { if transactionID == "" {
return "", nil, nil return nil
} }
nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, limit) nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, limit)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("d.selectQueuePDUs: %w", err) return fmt.Errorf("d.selectQueuePDUs: %w", err)
} }
blobs, err := d.selectQueueJSON(ctx, nil, nids) blobs, err := d.selectQueueJSON(ctx, txn, nids)
if err != nil { if err != nil {
return "", nil, fmt.Errorf("d.selectJSON: %w", err) return fmt.Errorf("d.selectJSON: %w", err)
} }
var events []*gomatrixserverlib.HeaderedEvent
for _, blob := range blobs { for _, blob := range blobs {
var event gomatrixserverlib.HeaderedEvent var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(blob, &event); err != nil { if err := json.Unmarshal(blob, &event); err != nil {
return "", nil, fmt.Errorf("json.Unmarshal: %w", err) return fmt.Errorf("json.Unmarshal: %w", err)
} }
events = append(events, &event) events = append(events, &event)
} }
return gomatrixserverlib.TransactionID(transactionID), events, nil return nil
})
return
} }
// CleanTransactionPDUs cleans up all associated events for a // CleanTransactionPDUs cleans up all associated events for a
@ -216,23 +224,24 @@ func (d *Database) CleanTransactionPDUs(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID, transactionID gomatrixserverlib.TransactionID,
) error { ) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
fmt.Println("Cleaning up after transaction", transactionID) fmt.Println("Cleaning up after transaction", transactionID)
nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, 50) nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50)
if err != nil { if err != nil {
return fmt.Errorf("d.selectQueuePDUs: %w", err) return fmt.Errorf("d.selectQueuePDUs: %w", err)
} }
fmt.Println("Transaction", transactionID, "has", len(nids), "NIDs") fmt.Println("Transaction", transactionID, "has", len(nids), "NIDs")
if err = d.deleteQueueTransaction(ctx, nil, serverName, transactionID); err != nil { if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil {
return fmt.Errorf("d.deleteQueueTransaction: %w", err) return fmt.Errorf("d.deleteQueueTransaction: %w", err)
} }
var count int64 var count int64
var deleteNIDs []int64 var deleteNIDs []int64
for _, nid := range nids { for _, nid := range nids {
count, err = d.selectQueueReferenceJSONCount(ctx, nil, nid) count, err = d.selectQueueReferenceJSONCount(ctx, txn, nid)
if err != nil { if err != nil {
return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err)
} }
@ -245,10 +254,11 @@ func (d *Database) CleanTransactionPDUs(
fmt.Println("There are", len(nids)-len(deleteNIDs), "NIDs still referenced") fmt.Println("There are", len(nids)-len(deleteNIDs), "NIDs still referenced")
if len(deleteNIDs) > 0 { if len(deleteNIDs) > 0 {
if err = d.deleteQueueJSON(ctx, nil, deleteNIDs); err != nil { if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil {
return fmt.Errorf("d.deleteQueueJSON: %w", err) return fmt.Errorf("d.deleteQueueJSON: %w", err)
} }
} }
return nil return nil
})
} }