Extend TransactionWriter to use optional existing transaction, use that for FS SQLite database writes

This commit is contained in:
Neil Alexander 2020-07-20 14:39:58 +01:00
parent dabb304d99
commit 7b078ed8e3
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 100 additions and 64 deletions

View file

@ -61,6 +61,7 @@ const selectAllJoinedHostsSQL = "" +
type joinedHostsStatements struct { type joinedHostsStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
@ -70,6 +71,7 @@ type joinedHostsStatements struct {
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{ s = &joinedHostsStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(joinedHostsSchema) _, err = db.Exec(joinedHostsSchema)
if err != nil { if err != nil {
@ -96,14 +98,17 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
roomID, eventID string, roomID, eventID string,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt) stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName) _, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err return err
})
} }
func (s *joinedHostsStatements) DeleteJoinedHosts( func (s *joinedHostsStatements) DeleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string, ctx context.Context, txn *sql.Tx, eventIDs []string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
for _, eventID := range eventIDs { for _, eventID := range eventIDs {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt) stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
if _, err := stmt.ExecContext(ctx, eventID); err != nil { if _, err := stmt.ExecContext(ctx, eventID); err != nil {
@ -111,6 +116,7 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
} }
} }
return nil return nil
})
} }
func (s *joinedHostsStatements) SelectJoinedHostsWithTx( func (s *joinedHostsStatements) SelectJoinedHostsWithTx(

View file

@ -59,6 +59,7 @@ const selectQueueServerNamesSQL = "" +
type queueEDUsStatements struct { type queueEDUsStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertQueueEDUStmt *sql.Stmt insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt selectQueueEDUReferenceJSONCountStmt *sql.Stmt
@ -69,6 +70,7 @@ type queueEDUsStatements struct {
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) { func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{ s = &queueEDUsStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(queueEDUsSchema) _, err = db.Exec(queueEDUsSchema)
if err != nil { if err != nil {
@ -99,6 +101,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nid int64, nid int64,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt) stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, ctx,
@ -108,6 +111,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
nid, // JSON blob NID nid, // JSON blob NID
) )
return err return err
})
} }
func (s *queueEDUsStatements) SelectQueueEDU( func (s *queueEDUsStatements) SelectQueueEDU(

View file

@ -50,6 +50,7 @@ const selectJSONSQL = "" +
type queueJSONStatements struct { type queueJSONStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertJSONStmt *sql.Stmt insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic //deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic //selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
@ -58,6 +59,7 @@ type queueJSONStatements struct {
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) { func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{ s = &queueJSONStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(queueJSONSchema) _, err = db.Exec(queueJSONSchema)
if err != nil { if err != nil {
@ -71,17 +73,20 @@ func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
func (s *queueJSONStatements) InsertQueueJSON( func (s *queueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string, ctx context.Context, txn *sql.Tx, json string,
) (int64, error) { ) (lastid int64, err error) {
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt) stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
res, err := stmt.ExecContext(ctx, json) res, err := stmt.ExecContext(ctx, json)
if err != nil { if err != nil {
return 0, fmt.Errorf("stmt.QueryContext: %w", err) return fmt.Errorf("stmt.QueryContext: %w", err)
} }
lastid, err := res.LastInsertId() lastid, err = res.LastInsertId()
if err != nil { if err != nil {
return 0, fmt.Errorf("res.LastInsertId: %w", err) return fmt.Errorf("res.LastInsertId: %w", err)
} }
return lastid, nil return nil
})
return
} }
func (s *queueJSONStatements) DeleteQueueJSON( func (s *queueJSONStatements) DeleteQueueJSON(
@ -98,9 +103,11 @@ func (s *queueJSONStatements) DeleteQueueJSON(
iNIDs[k] = v iNIDs[k] = v
} }
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, deleteStmt) stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, iNIDs...) _, err = stmt.ExecContext(ctx, iNIDs...)
return err return err
})
} }
func (s *queueJSONStatements) SelectQueueJSON( func (s *queueJSONStatements) SelectQueueJSON(

View file

@ -69,6 +69,7 @@ const selectQueuePDUsServerNamesSQL = "" +
type queuePDUsStatements struct { type queuePDUsStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertQueuePDUStmt *sql.Stmt insertQueuePDUStmt *sql.Stmt
deleteQueueTransactionPDUsStmt *sql.Stmt deleteQueueTransactionPDUsStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt selectQueueNextTransactionIDStmt *sql.Stmt
@ -81,6 +82,7 @@ type queuePDUsStatements struct {
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) { func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{ s = &queuePDUsStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(queuePDUsSchema) _, err = db.Exec(queuePDUsSchema)
if err != nil { if err != nil {
@ -117,6 +119,7 @@ func (s *queuePDUsStatements) InsertQueuePDU(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
nid int64, nid int64,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt) stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
_, err := stmt.ExecContext( _, err := stmt.ExecContext(
ctx, ctx,
@ -125,6 +128,7 @@ func (s *queuePDUsStatements) InsertQueuePDU(
nid, // JSON blob NID nid, // JSON blob NID
) )
return err return err
})
} }
func (s *queuePDUsStatements) DeleteQueuePDUTransaction( func (s *queuePDUsStatements) DeleteQueuePDUTransaction(
@ -132,9 +136,11 @@ func (s *queuePDUsStatements) DeleteQueuePDUTransaction(
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID, transactionID gomatrixserverlib.TransactionID,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt) stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt)
_, err := stmt.ExecContext(ctx, serverName, transactionID) _, err := stmt.ExecContext(ctx, serverName, transactionID)
return err return err
})
} }
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID( func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(

View file

@ -44,6 +44,7 @@ const updateRoomSQL = "" +
type roomStatements struct { type roomStatements struct {
db *sql.DB db *sql.DB
writer *sqlutil.TransactionWriter
insertRoomStmt *sql.Stmt insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt updateRoomStmt *sql.Stmt
@ -52,6 +53,7 @@ type roomStatements struct {
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) { func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
s = &roomStatements{ s = &roomStatements{
db: db, db: db,
writer: sqlutil.NewTransactionWriter(),
} }
_, err = db.Exec(roomSchema) _, err = db.Exec(roomSchema)
if err != nil { if err != nil {
@ -75,8 +77,10 @@ func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
func (s *roomStatements) InsertRoom( func (s *roomStatements) InsertRoom(
ctx context.Context, txn *sql.Tx, roomID string, ctx context.Context, txn *sql.Tx, roomID string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID) _, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
return err return err
})
} }
// selectRoomForUpdate locks the row for the room and returns the last_event_id. // selectRoomForUpdate locks the row for the room and returns the last_event_id.
@ -99,7 +103,9 @@ func (s *roomStatements) SelectRoomForUpdate(
func (s *roomStatements) UpdateRoom( func (s *roomStatements) UpdateRoom(
ctx context.Context, txn *sql.Tx, roomID, lastEventID string, ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error { ) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateRoomStmt) stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
_, err := stmt.ExecContext(ctx, roomID, lastEventID) _, err := stmt.ExecContext(ctx, roomID, lastEventID)
return err return err
})
} }

View file

@ -29,9 +29,6 @@ type Database struct {
shared.Database shared.Database
sqlutil.PartitionOffsetStatements sqlutil.PartitionOffsetStatements
db *sql.DB db *sql.DB
queuePDUsWriter *sqlutil.TransactionWriter
queueEDUsWriter *sqlutil.TransactionWriter
queueJSONWriter *sqlutil.TransactionWriter
} }
// NewDatabase opens a new database // NewDatabase opens a new database

View file

@ -131,14 +131,17 @@ func NewTransactionWriter() *TransactionWriter {
// transactionWriterTask represents a specific task. // transactionWriterTask represents a specific task.
type transactionWriterTask struct { type transactionWriterTask struct {
db *sql.DB db *sql.DB
txn *sql.Tx
f func(txn *sql.Tx) error f func(txn *sql.Tx) error
wait chan error wait chan error
} }
// Do queues a task to be run by a TransactionWriter. The function // Do queues a task to be run by a TransactionWriter. The function
// provided will be ran within a transaction as supplied by the // provided will be ran within a transaction as supplied by the
// database parameter. This will block until the task is finished. // txn parameter if one is supplied, and if not, will take out a
func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error { // new transaction from the database supplied in the database
// parameter. Either way, this will block until the task is done.
func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
if w.todo == nil { if w.todo == nil {
return errors.New("not initialised") return errors.New("not initialised")
} }
@ -147,6 +150,7 @@ func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error {
} }
task := transactionWriterTask{ task := transactionWriterTask{
db: db, db: db,
txn: txn,
f: f, f: f,
wait: make(chan error, 1), wait: make(chan error, 1),
} }
@ -164,9 +168,15 @@ func (w *TransactionWriter) run() {
} }
defer w.running.Store(false) defer w.running.Store(false)
for task := range w.todo { for task := range w.todo {
if task.txn != nil {
task.wait <- task.f(task.txn)
} else if task.db != nil {
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error { task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
return task.f(txn) return task.f(txn)
}) })
} else {
panic("expected database or transaction but got neither")
}
close(task.wait) close(task.wait)
} }
} }

View file

@ -1114,7 +1114,7 @@ func (d *Database) StoreNewSendForDeviceMessage(
} }
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee // Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place. // that we don't lock the table for writes in more than one place.
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AddSendToDeviceEvent( return d.AddSendToDeviceEvent(
ctx, txn, userID, deviceID, string(j), ctx, txn, userID, deviceID, string(j),
) )
@ -1179,7 +1179,7 @@ func (d *Database) CleanSendToDeviceUpdates(
// If we need to write to the database then we'll ask the SendToDeviceWriter to // If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in // do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place. // more than one place.
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error { err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion. // Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil { if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e) return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)