From 18245f224a810d9c0f80654e2f5cf12c5bf320b9 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Tue, 30 Jun 2020 16:14:57 +0100 Subject: [PATCH] Fix SQLite --- cmd/dendrite-demo-libp2p/main.go | 3 +- cmd/dendrite-demo-yggdrasil/main.go | 3 +- .../storage/sqlite3/queue_json_table.go | 72 +++++++---- .../storage/sqlite3/queue_pdus_table.go | 37 ++++-- federationsender/storage/sqlite3/storage.go | 117 ++++++++++++------ 5 files changed, 157 insertions(+), 75 deletions(-) diff --git a/cmd/dendrite-demo-libp2p/main.go b/cmd/dendrite-demo-libp2p/main.go index b7e86b77c..4bb7a96c2 100644 --- a/cmd/dendrite-demo-libp2p/main.go +++ b/cmd/dendrite-demo-libp2p/main.go @@ -130,8 +130,9 @@ func main() { cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName)) cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName)) cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) - cfg.Database.PublicRoomsAPI = config.DataSource(fmt.Sprintf("file:%s-publicroomsa.db", *instanceName)) + cfg.Database.PublicRoomsAPI = config.DataSource(fmt.Sprintf("file:%s-publicrooms.db", *instanceName)) cfg.Database.Naffka = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) + cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s-currentstate.db", *instanceName)) if err = cfg.Derive(); err != nil { panic(err) } diff --git a/cmd/dendrite-demo-yggdrasil/main.go b/cmd/dendrite-demo-yggdrasil/main.go index 5de674021..cef34c7ed 100644 --- a/cmd/dendrite-demo-yggdrasil/main.go +++ b/cmd/dendrite-demo-yggdrasil/main.go @@ -74,7 +74,8 @@ func main() { cfg.Database.ServerKey = config.DataSource(fmt.Sprintf("file:%s-serverkey.db", *instanceName)) cfg.Database.FederationSender = config.DataSource(fmt.Sprintf("file:%s-federationsender.db", *instanceName)) cfg.Database.AppService = config.DataSource(fmt.Sprintf("file:%s-appservice.db", *instanceName)) - cfg.Database.PublicRoomsAPI = config.DataSource(fmt.Sprintf("file:%s-publicroomsa.db", *instanceName)) + cfg.Database.PublicRoomsAPI = config.DataSource(fmt.Sprintf("file:%s-publicrooms.db", *instanceName)) + cfg.Database.CurrentState = config.DataSource(fmt.Sprintf("file:%s-currentstate.db", *instanceName)) cfg.Database.Naffka = config.DataSource(fmt.Sprintf("file:%s-naffka.db", *instanceName)) if err = cfg.Derive(); err != nil { panic(err) diff --git a/federationsender/storage/sqlite3/queue_json_table.go b/federationsender/storage/sqlite3/queue_json_table.go index 0267159d6..01b7160db 100644 --- a/federationsender/storage/sqlite3/queue_json_table.go +++ b/federationsender/storage/sqlite3/queue_json_table.go @@ -18,8 +18,9 @@ package sqlite3 import ( "context" "database/sql" + "fmt" + "strings" - "github.com/lib/pq" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" ) @@ -30,7 +31,7 @@ const queueJSONSchema = ` CREATE TABLE IF NOT EXISTS federationsender_queue_json ( -- The JSON NID. This allows the federationsender_queue_retry table to -- cross-reference to find the JSON blob. - json_nid BIGSERIAL, + json_nid INTEGER PRIMARY KEY AUTOINCREMENT, -- The JSON body. Text so that we preserve UTF-8. json_body TEXT NOT NULL ); @@ -38,20 +39,19 @@ CREATE TABLE IF NOT EXISTS federationsender_queue_json ( const insertJSONSQL = "" + "INSERT INTO federationsender_queue_json (json_body)" + - " VALUES ($1)" + - " RETURNING json_nid" + " VALUES ($1)" const deleteJSONSQL = "" + - "DELETE FROM federationsender_queue_json WHERE json_nid = ANY($1)" + "DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)" const selectJSONSQL = "" + "SELECT json_nid, json_body FROM federationsender_queue_json" + - " WHERE json_nid = ANY($1)" + " WHERE json_nid IN ($1)" type queueJSONStatements struct { insertJSONStmt *sql.Stmt - deleteJSONStmt *sql.Stmt - selectJSONStmt *sql.Stmt + //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic + //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic } func (s *queueJSONStatements) prepare(db *sql.DB) (err error) { @@ -62,12 +62,6 @@ func (s *queueJSONStatements) prepare(db *sql.DB) (err error) { if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil { return } - if s.deleteJSONStmt, err = db.Prepare(deleteJSONSQL); err != nil { - return - } - if s.selectJSONStmt, err = db.Prepare(selectJSONSQL); err != nil { - return - } return } @@ -75,36 +69,62 @@ func (s *queueJSONStatements) insertQueueJSON( ctx context.Context, txn *sql.Tx, json string, ) (int64, error) { stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) - var lastid int64 - if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil { - return 0, err + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return 0, fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err := res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("res.LastInsertId: %w", err) } return lastid, nil } func (s *queueJSONStatements) deleteQueueJSON( - ctx context.Context, txn *sql.Tx, eventIDs []string, + ctx context.Context, txn *sql.Tx, nids []int64, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt) - _, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs)) + deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + deleteStmt, err := txn.Prepare(deleteSQL) + if err != nil { + return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(nids)) + for k, v := range nids { + iNIDs[k] = v + } + + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) return err } -func (s *queueJSONStatements) selectJSON( +func (s *queueJSONStatements) selectQueueJSON( ctx context.Context, txn *sql.Tx, jsonNIDs []int64, ) (map[int64][]byte, error) { - blobs := map[int64][]byte{} - stmt := sqlutil.TxStmt(txn, s.selectJSONStmt) - rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs)) + selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1) + selectStmt, err := txn.Prepare(selectSQL) if err != nil { - return nil, err + return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err) + } + + iNIDs := make([]interface{}, len(jsonNIDs)) + for k, v := range jsonNIDs { + iNIDs[k] = v + } + + blobs := map[int64][]byte{} + stmt := sqlutil.TxStmt(txn, selectStmt) + rows, err := stmt.QueryContext(ctx, iNIDs...) + if err != nil { + return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err) } defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed") for rows.Next() { var nid int64 var blob []byte if err = rows.Scan(&nid, &blob); err != nil { - return nil, err + return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err) } blobs[nid] = blob } diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go index 5bfa528e1..dc08fd707 100644 --- a/federationsender/storage/sqlite3/queue_pdus_table.go +++ b/federationsender/storage/sqlite3/queue_pdus_table.go @@ -56,11 +56,16 @@ const selectQueuePDUsByTransactionSQL = "" + " WHERE server_name = $1 AND transaction_id = $2" + " LIMIT $3" +const selectQueueReferenceJSONCountSQL = "" + + "SELECT COUNT(*) FROM federationsender_queue_pdus" + + " WHERE json_nid = $1" + type queuePDUsStatements struct { - insertQueuePDUStmt *sql.Stmt - deleteQueueTransactionPDUsStmt *sql.Stmt - selectQueueNextTransactionIDStmt *sql.Stmt - selectQueuePDUsByTransactionStmt *sql.Stmt + insertQueuePDUStmt *sql.Stmt + deleteQueueTransactionPDUsStmt *sql.Stmt + selectQueueNextTransactionIDStmt *sql.Stmt + selectQueuePDUsByTransactionStmt *sql.Stmt + selectQueueReferenceJSONCountStmt *sql.Stmt } func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { @@ -80,6 +85,9 @@ func (s *queuePDUsStatements) prepare(db *sql.DB) (err error) { if s.selectQueuePDUsByTransactionStmt, err = db.Prepare(selectQueuePDUsByTransactionSQL); err != nil { return } + if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueueReferenceJSONCountSQL); err != nil { + return + } return } @@ -112,8 +120,8 @@ func (s *queuePDUsStatements) deleteQueueTransaction( func (s *queuePDUsStatements) selectQueueNextTransactionID( ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, -) (string, error) { - var transactionID string +) (gomatrixserverlib.TransactionID, error) { + var transactionID gomatrixserverlib.TransactionID stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt) err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID) if err == sql.ErrNoRows { @@ -122,8 +130,23 @@ func (s *queuePDUsStatements) selectQueueNextTransactionID( return transactionID, err } +func (s *queuePDUsStatements) selectQueueReferenceJSONCount( + ctx context.Context, txn *sql.Tx, jsonNID int64, +) (int64, error) { + var count int64 + stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt) + err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count) + if err == sql.ErrNoRows { + return -1, nil + } + return count, err +} + func (s *queuePDUsStatements) selectQueuePDUs( - ctx context.Context, txn *sql.Tx, serverName string, transactionID string, limit int, + ctx context.Context, txn *sql.Tx, + serverName gomatrixserverlib.ServerName, + transactionID gomatrixserverlib.TransactionID, + limit int, ) ([]int64, error) { stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsByTransactionStmt) rows, err := stmt.QueryContext(ctx, serverName, transactionID, limit) diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index f5adaa10b..7629ecd21 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -162,18 +162,20 @@ func (d *Database) AssociatePDUWithDestination( serverName gomatrixserverlib.ServerName, nids []int64, ) error { - for _, nid := range nids { - if err := d.insertQueuePDU( - ctx, // context - nil, // SQL transaction - transactionID, // transaction ID - serverName, // destination server name - nid, // NID from the federationsender_queue_json table - ); err != nil { - return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + for _, nid := range nids { + if err := d.insertQueuePDU( + ctx, // context + txn, // SQL transaction + transactionID, // transaction ID + serverName, // destination server name + nid, // NID from the federationsender_queue_json table + ); err != nil { + return fmt.Errorf("d.insertQueueRetryStmt.ExecContext: %w", err) + } } - } - return nil + return nil + }) } // GetNextTransactionPDUs retrieves events from the database for @@ -182,36 +184,42 @@ func (d *Database) GetNextTransactionPDUs( ctx context.Context, serverName gomatrixserverlib.ServerName, limit int, -) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) { - transactionID, err := d.selectQueueNextTransactionID(ctx, nil, serverName) - if err != nil { - return "", nil, fmt.Errorf("d.selectQueueNextTransactionID: %w", err) - } - - if transactionID == "" { - return "", nil, nil - } - - nids, err := d.selectQueuePDUs(ctx, nil, string(serverName), transactionID, limit) - if err != nil { - return "", nil, fmt.Errorf("d.selectQueuePDUs: %w", err) - } - - blobs, err := d.selectJSON(ctx, nil, nids) - if err != nil { - return "", nil, fmt.Errorf("d.selectJSON: %w", err) - } - - var events []*gomatrixserverlib.HeaderedEvent - for _, blob := range blobs { - var event gomatrixserverlib.HeaderedEvent - if err := json.Unmarshal(blob, &event); err != nil { - return "", nil, fmt.Errorf("json.Unmarshal: %w", err) +) ( + transactionID gomatrixserverlib.TransactionID, + events []*gomatrixserverlib.HeaderedEvent, + err error, +) { + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + transactionID, err = d.selectQueueNextTransactionID(ctx, txn, serverName) + if err != nil { + return fmt.Errorf("d.selectQueueNextTransactionID: %w", err) } - events = append(events, &event) - } - return gomatrixserverlib.TransactionID(transactionID), events, nil + if transactionID == "" { + return nil + } + + nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, limit) + if err != nil { + return fmt.Errorf("d.selectQueuePDUs: %w", err) + } + + blobs, err := d.selectQueueJSON(ctx, txn, nids) + if err != nil { + return fmt.Errorf("d.selectJSON: %w", err) + } + + for _, blob := range blobs { + var event gomatrixserverlib.HeaderedEvent + if err := json.Unmarshal(blob, &event); err != nil { + return fmt.Errorf("json.Unmarshal: %w", err) + } + events = append(events, &event) + } + + return nil + }) + return } // CleanTransactionPDUs cleans up all associated events for a @@ -222,5 +230,34 @@ func (d *Database) CleanTransactionPDUs( serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, ) error { - return d.deleteQueueTransaction(ctx, nil, serverName, transactionID) + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + nids, err := d.selectQueuePDUs(ctx, txn, serverName, transactionID, 50) + if err != nil { + return fmt.Errorf("d.selectQueuePDUs: %w", err) + } + + if err = d.deleteQueueTransaction(ctx, txn, serverName, transactionID); err != nil { + return fmt.Errorf("d.deleteQueueTransaction: %w", err) + } + + var count int64 + var deleteNIDs []int64 + for _, nid := range nids { + count, err = d.selectQueueReferenceJSONCount(ctx, txn, nid) + if err != nil { + return fmt.Errorf("d.selectQueueReferenceJSONCount: %w", err) + } + if count == 0 { + deleteNIDs = append(deleteNIDs, nid) + } + } + + if len(deleteNIDs) > 0 { + if err = d.deleteQueueJSON(ctx, txn, deleteNIDs); err != nil { + return fmt.Errorf("d.deleteQueueJSON: %w", err) + } + } + + return nil + }) }