Fix query, use transactions in postgres

This commit is contained in:
Neil Alexander 2020-06-30 15:11:51 +01:00
parent ebaa0cf5d4
commit 61ff558fef
2 changed files with 78 additions and 68 deletions

View file

@ -134,7 +134,7 @@ func (s *queuePDUsStatements) selectQueueReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64, ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) { ) (int64, error) {
var count int64 var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return -1, nil return -1, nil

View file

@ -156,18 +156,20 @@ func (d *Database) AssociatePDUWithDestination(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nids []int64, nids []int64,
) error { ) error {
for _, nid := range nids { return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.insertQueuePDU( for _, nid := range nids {
ctx, // context if err := d.insertQueuePDU(
nil, // SQL transaction ctx, // context
transactionID, // transaction ID txn, // SQL transaction
serverName, // destination server name transactionID, // transaction ID
nid, // NID from the federationsender_queue_json table serverName, // destination server name
); err != nil { nid, // NID from the federationsender_queue_json table
return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) ); err != nil {
return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err)
}
} }
} return nil
return nil })
} }
// GetNextTransactionPDUs retrieves events from the database for // GetNextTransactionPDUs retrieves events from the database for
@ -176,36 +178,42 @@ func (d *Database) GetNextTransactionPDUs(
ctx context.Context, ctx context.Context,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
limit int, limit int,
) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) { ) (
transactionID, err := d.selectQueueNextTransactionID(ctx, nil, serverName) transactionID gomatrixserverlib.TransactionID,
if err != nil { events []*gomatrixserverlib.HeaderedEvent,
return "", nil, fmt.Errorf("d.selectQueueNextTransactionID: %w", err) err error,
} ) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if transactionID == "" { transactionID, err = d.selectQueueNextTransactionID(ctx, txn, serverName)
return "", nil, nil if err != nil {
} return fmt.Errorf("d.selectQueueNextTransactionID: %w", err)
nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, limit)
if err != nil {
return "", nil, fmt.Errorf("d.selectQueuePDUs: %w", err)
}
blobs, err := d.selectQueueJSON(ctx, nil, nids)
if err != nil {
return "", nil, fmt.Errorf("d.selectJSON: %w", err)
}
var events []*gomatrixserverlib.HeaderedEvent
for _, blob := range blobs {
var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(blob, &event); err != nil {
return "", nil, fmt.Errorf("json.Unmarshal: %w", err)
} }
events = append(events, &event)
}
return gomatrixserverlib.TransactionID(transactionID), events, nil if transactionID == "" {
return nil
}
nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, limit)
if err != nil {
return fmt.Errorf("d.selectQueuePDUs: %w", err)
}
blobs, err := d.selectQueueJSON(ctx, txn, nids)
if err != nil {
return fmt.Errorf("d.selectJSON: %w", err)
}
for _, blob := range blobs {
var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
events = append(events, &event)
}
return nil
})
return
} }
// CleanTransactionPDUs cleans up all associated events for a // CleanTransactionPDUs cleans up all associated events for a
@ -216,39 +224,41 @@ func (d *Database) CleanTransactionPDUs(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID, transactionID gomatrixserverlib.TransactionID,
) error { ) error {
fmt.Println("Cleaning up after transaction", transactionID) return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
fmt.Println("Cleaning up after transaction", transactionID)
nids, err := d.selectQueuePDUs(ctx, nil, serverName, transactionID, 50) nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50)
if err != nil {
return fmt.Errorf("d.selectQueuePDUs: %w", err)
}
fmt.Println("Transaction", transactionID, "has", len(nids), "NIDs")
if err = d.deleteQueueTransaction(ctx, nil, serverName, transactionID); err != nil {
return fmt.Errorf("d.deleteQueueTransaction: %w", err)
}
var count int64
var deleteNIDs []int64
for _, nid := range nids {
count, err = d.selectQueueReferenceJSONCount(ctx, nil, nid)
if err != nil { if err != nil {
return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) return fmt.Errorf("d.selectQueuePDUs: %w", err)
} }
if count == 0 {
deleteNIDs = append(deleteNIDs, nid) fmt.Println("Transaction", transactionID, "has", len(nids), "NIDs")
if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil {
return fmt.Errorf("d.deleteQueueTransaction: %w", err)
} }
}
fmt.Println("There are", len(deleteNIDs), "unreferenced NIDs ready for deletion") var count int64
fmt.Println("There are", len(nids)-len(deleteNIDs), "NIDs still referenced") var deleteNIDs []int64
for _, nid := range nids {
if len(deleteNIDs) > 0 { count, err = d.selectQueueReferenceJSONCount(ctx, txn, nid)
if err = d.deleteQueueJSON(ctx, nil, deleteNIDs); err != nil { if err != nil {
return fmt.Errorf("d.deleteQueueJSON: %w", err) return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err)
}
if count == 0 {
deleteNIDs = append(deleteNIDs, nid)
}
} }
}
return nil fmt.Println("There are", len(deleteNIDs), "unreferenced NIDs ready for deletion")
fmt.Println("There are", len(nids)-len(deleteNIDs), "NIDs still referenced")
if len(deleteNIDs) > 0 {
if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil {
return fmt.Errorf("d.deleteQueueJSON: %w", err)
}
}
return nil
})
} }