From 7b078ed8e3e690e85145791db762da53108d2743 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 20 Jul 2020 14:39:58 +0100 Subject: [PATCH] Extend TransactionWriter to use optional existing transaction, use that for FS SQLite database writes --- .../storage/sqlite3/joined_hosts_table.go | 26 ++++++++----- .../storage/sqlite3/queue_edus_table.go | 24 +++++++----- .../storage/sqlite3/queue_json_table.go | 37 +++++++++++-------- .../storage/sqlite3/queue_pdus_table.go | 30 +++++++++------ .../storage/sqlite3/room_table.go | 18 ++++++--- federationsender/storage/sqlite3/storage.go | 5 +-- internal/sqlutil/sql.go | 20 +++++++--- syncapi/storage/shared/syncserver.go | 4 +- 8 files changed, 100 insertions(+), 64 deletions(-) diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 4338e8182..bd917c61a 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -61,6 +61,7 @@ const selectAllJoinedHostsSQL = "" + type joinedHostsStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt @@ -69,7 +70,8 @@ type joinedHostsStatements struct { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { s = &joinedHostsStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err = db.Exec(joinedHostsSchema) if err != nil { @@ -96,21 +98,25 @@ func (s *joinedHostsStatements) InsertJoinedHosts( roomID, eventID string, serverName gomatrixserverlib.ServerName, ) error { - stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) - _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) + _, err := stmt.ExecContext(ctx, roomID, eventID, serverName) + return err + }) } func (s *joinedHostsStatements) DeleteJoinedHosts( ctx context.Context, txn *sql.Tx, eventIDs []string, ) error { - for _, eventID := range eventIDs { - stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) - if _, err := stmt.ExecContext(ctx, eventID); err != nil { - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + for _, eventID := range eventIDs { + stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) + if _, err := stmt.ExecContext(ctx, eventID); err != nil { + return err + } } - } - return nil + return nil + }) } func (s *joinedHostsStatements) SelectJoinedHostsWithTx( diff --git a/federationsender/storage/sqlite3/queue_edus_table.go b/federationsender/storage/sqlite3/queue_edus_table.go index 46b44c047..aefe36e85 100644 --- a/federationsender/storage/sqlite3/queue_edus_table.go +++ b/federationsender/storage/sqlite3/queue_edus_table.go @@ -59,6 +59,7 @@ const selectQueueServerNamesSQL = "" + type queueEDUsStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt @@ -68,7 +69,8 @@ type queueEDUsStatements struct { func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { s = &queueEDUsStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err = db.Exec(queueEDUsSchema) if err != nil { @@ -99,15 +101,17 @@ func (s *queueEDUsStatements) InsertQueueEDU( serverName gomatrixserverlib.ServerName, nid int64, ) error { - stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) - _, err := stmt.ExecContext( - ctx, - userID, // destination user ID - deviceID, // destination device ID - serverName, // destination server name - nid, // JSON blob NID - ) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) + _, err := stmt.ExecContext( + ctx, + userID, // destination user ID + deviceID, // destination device ID + serverName, // destination server name + nid, // JSON blob NID + ) + return err + }) } func (s *queueEDUsStatements) SelectQueueEDU( diff --git a/federationsender/storage/sqlite3/queue_json_table.go b/federationsender/storage/sqlite3/queue_json_table.go index 95e6cd206..46dfd9ab1 100644 --- a/federationsender/storage/sqlite3/queue_json_table.go +++ b/federationsender/storage/sqlite3/queue_json_table.go @@ -50,6 +50,7 @@ const selectJSONSQL = "" + type queueJSONStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertJSONStmt *sql.Stmt //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic @@ -57,7 +58,8 @@ type queueJSONStatements struct { func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { s = &queueJSONStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err = db.Exec(queueJSONSchema) if err != nil { @@ -71,17 +73,20 @@ func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { func (s *queueJSONStatements) InsertQueueJSON( ctx context.Context, txn *sql.Tx, json string, -) (int64, error) { - stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) - res, err := stmt.ExecContext(ctx, json) - if err != nil { - return 0, fmt.Errorf("stmt.QueryContext: %w", err) - } - lastid, err := res.LastInsertId() - if err != nil { - return 0, fmt.Errorf("res.LastInsertId: %w", err) - } - return lastid, nil +) (lastid int64, err error) { + err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) + res, err := stmt.ExecContext(ctx, json) + if err != nil { + return fmt.Errorf("stmt.QueryContext: %w", err) + } + lastid, err = res.LastInsertId() + if err != nil { + return fmt.Errorf("res.LastInsertId: %w", err) + } + return nil + }) + return } func (s *queueJSONStatements) DeleteQueueJSON( @@ -98,9 +103,11 @@ func (s *queueJSONStatements) DeleteQueueJSON( iNIDs[k] = v } - stmt := sqlutil.TxStmt(txn, deleteStmt) - _, err = stmt.ExecContext(ctx, iNIDs...) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, deleteStmt) + _, err = stmt.ExecContext(ctx, iNIDs...) + return err + }) } func (s *queueJSONStatements) SelectQueueJSON( diff --git a/federationsender/storage/sqlite3/queue_pdus_table.go b/federationsender/storage/sqlite3/queue_pdus_table.go index de278c4ef..1fc5680c3 100644 --- a/federationsender/storage/sqlite3/queue_pdus_table.go +++ b/federationsender/storage/sqlite3/queue_pdus_table.go @@ -69,6 +69,7 @@ const selectQueuePDUsServerNamesSQL = "" + type queuePDUsStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertQueuePDUStmt *sql.Stmt deleteQueueTransactionPDUsStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt @@ -80,7 +81,8 @@ type queuePDUsStatements struct { func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { s = &queuePDUsStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err = db.Exec(queuePDUsSchema) if err != nil { @@ -117,14 +119,16 @@ func (s *queuePDUsStatements) InsertQueuePDU( serverName gomatrixserverlib.ServerName, nid int64, ) error { - stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) - _, err := stmt.ExecContext( - ctx, - transactionID, // the transaction ID that we initially attempted - serverName, // destination server name - nid, // JSON blob NID - ) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) + _, err := stmt.ExecContext( + ctx, + transactionID, // the transaction ID that we initially attempted + serverName, // destination server name + nid, // JSON blob NID + ) + return err + }) } func (s *queuePDUsStatements) DeleteQueuePDUTransaction( @@ -132,9 +136,11 @@ func (s *queuePDUsStatements) DeleteQueuePDUTransaction( serverName gomatrixserverlib.ServerName, transactionID gomatrixserverlib.TransactionID, ) error { - stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt) - _, err := stmt.ExecContext(ctx, serverName, transactionID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt) + _, err := stmt.ExecContext(ctx, serverName, transactionID) + return err + }) } func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( diff --git a/federationsender/storage/sqlite3/room_table.go b/federationsender/storage/sqlite3/room_table.go index 0710ccca3..517938745 100644 --- a/federationsender/storage/sqlite3/room_table.go +++ b/federationsender/storage/sqlite3/room_table.go @@ -44,6 +44,7 @@ const updateRoomSQL = "" + type roomStatements struct { db *sql.DB + writer *sqlutil.TransactionWriter insertRoomStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt updateRoomStmt *sql.Stmt @@ -51,7 +52,8 @@ type roomStatements struct { func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) { s = &roomStatements{ - db: db, + db: db, + writer: sqlutil.NewTransactionWriter(), } _, err = db.Exec(roomSchema) if err != nil { @@ -75,8 +77,10 @@ func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) { func (s *roomStatements) InsertRoom( ctx context.Context, txn *sql.Tx, roomID string, ) error { - _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) + return err + }) } // selectRoomForUpdate locks the row for the room and returns the last_event_id. @@ -99,7 +103,9 @@ func (s *roomStatements) SelectRoomForUpdate( func (s *roomStatements) UpdateRoom( ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ) error { - stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) - _, err := stmt.ExecContext(ctx, roomID, lastEventID) - return err + return s.writer.Do(s.db, txn, func(txn *sql.Tx) error { + stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) + _, err := stmt.ExecContext(ctx, roomID, lastEventID) + return err + }) } diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index a24b6c354..545a229c6 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -28,10 +28,7 @@ import ( type Database struct { shared.Database sqlutil.PartitionOffsetStatements - db *sql.DB - queuePDUsWriter *sqlutil.TransactionWriter - queueEDUsWriter *sqlutil.TransactionWriter - queueJSONWriter *sqlutil.TransactionWriter + db *sql.DB } // NewDatabase opens a new database diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index a25a4a5b6..2ec6ce291 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -131,14 +131,17 @@ func NewTransactionWriter() *TransactionWriter { // transactionWriterTask represents a specific task. type transactionWriterTask struct { db *sql.DB + txn *sql.Tx f func(txn *sql.Tx) error wait chan error } // Do queues a task to be run by a TransactionWriter. The function // provided will be ran within a transaction as supplied by the -// database parameter. This will block until the task is finished. -func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error { +// txn parameter if one is supplied, and if not, will take out a +// new transaction from the database supplied in the database +// parameter. Either way, this will block until the task is done. +func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { if w.todo == nil { return errors.New("not initialised") } @@ -147,6 +150,7 @@ func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error { } task := transactionWriterTask{ db: db, + txn: txn, f: f, wait: make(chan error, 1), } @@ -164,9 +168,15 @@ func (w *TransactionWriter) run() { } defer w.running.Store(false) for task := range w.todo { - task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { - return task.f(txn) - }) + if task.txn != nil { + task.wait <- task.f(task.txn) + } else if task.db != nil { + task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { + return task.f(txn) + }) + } else { + panic("expected database or transaction but got neither") + } close(task.wait) } } diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 38b503cd0..320792914 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1114,7 +1114,7 @@ func (d *Database) StoreNewSendForDeviceMessage( } // Delegate the database write task to the SendToDeviceWriter. It'll guarantee // that we don't lock the table for writes in more than one place. - err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { return d.AddSendToDeviceEvent( ctx, txn, userID, deviceID, string(j), ) @@ -1179,7 +1179,7 @@ func (d *Database) CleanSendToDeviceUpdates( // If we need to write to the database then we'll ask the SendToDeviceWriter to // do that for us. It'll guarantee that we don't lock the table for writes in // more than one place. - err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { + err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error { // Delete any send-to-device messages marked for deletion. if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)