Extend TransactionWriter to use optional existing transaction, use that for FS SQLite database writes

This commit is contained in:
Neil Alexander 2020-07-20 14:39:58 +01:00
parent dabb304d99
commit 7b078ed8e3
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
8 changed files with 100 additions and 64 deletions

View file

@ -61,6 +61,7 @@ const selectAllJoinedHostsSQL = "" +
type joinedHostsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
@ -70,6 +71,7 @@ type joinedHostsStatements struct {
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err = db.Exec(joinedHostsSchema)
if err != nil {
@ -96,14 +98,17 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
roomID, eventID string,
serverName gomatrixserverlib.ServerName,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err
})
}
func (s *joinedHostsStatements) DeleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
for _, eventID := range eventIDs {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
@ -111,6 +116,7 @@ func (s *joinedHostsStatements) DeleteJoinedHosts(
}
}
return nil
})
}
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(

View file

@ -59,6 +59,7 @@ const selectQueueServerNamesSQL = "" +
type queueEDUsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
@ -69,6 +70,7 @@ type queueEDUsStatements struct {
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err = db.Exec(queueEDUsSchema)
if err != nil {
@ -99,6 +101,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext(
ctx,
@ -108,6 +111,7 @@ func (s *queueEDUsStatements) InsertQueueEDU(
nid, // JSON blob NID
)
return err
})
}
func (s *queueEDUsStatements) SelectQueueEDU(

View file

@ -50,6 +50,7 @@ const selectJSONSQL = "" +
type queueJSONStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
@ -58,6 +59,7 @@ type queueJSONStatements struct {
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err = db.Exec(queueJSONSchema)
if err != nil {
@ -71,17 +73,20 @@ func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
func (s *queueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (int64, error) {
) (lastid int64, err error) {
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
res, err := stmt.ExecContext(ctx, json)
if err != nil {
return 0, fmt.Errorf("stmt.QueryContext: %w", err)
return fmt.Errorf("stmt.QueryContext: %w", err)
}
lastid, err := res.LastInsertId()
lastid, err = res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("res.LastInsertId: %w", err)
return fmt.Errorf("res.LastInsertId: %w", err)
}
return lastid, nil
return nil
})
return
}
func (s *queueJSONStatements) DeleteQueueJSON(
@ -98,9 +103,11 @@ func (s *queueJSONStatements) DeleteQueueJSON(
iNIDs[k] = v
}
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, iNIDs...)
return err
})
}
func (s *queueJSONStatements) SelectQueueJSON(

View file

@ -69,6 +69,7 @@ const selectQueuePDUsServerNamesSQL = "" +
type queuePDUsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertQueuePDUStmt *sql.Stmt
deleteQueueTransactionPDUsStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt
@ -81,6 +82,7 @@ type queuePDUsStatements struct {
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err = db.Exec(queuePDUsSchema)
if err != nil {
@ -117,6 +119,7 @@ func (s *queuePDUsStatements) InsertQueuePDU(
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
_, err := stmt.ExecContext(
ctx,
@ -125,6 +128,7 @@ func (s *queuePDUsStatements) InsertQueuePDU(
nid, // JSON blob NID
)
return err
})
}
func (s *queuePDUsStatements) DeleteQueuePDUTransaction(
@ -132,9 +136,11 @@ func (s *queuePDUsStatements) DeleteQueuePDUTransaction(
serverName gomatrixserverlib.ServerName,
transactionID gomatrixserverlib.TransactionID,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt)
_, err := stmt.ExecContext(ctx, serverName, transactionID)
return err
})
}
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(

View file

@ -44,6 +44,7 @@ const updateRoomSQL = "" +
type roomStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertRoomStmt *sql.Stmt
selectRoomForUpdateStmt *sql.Stmt
updateRoomStmt *sql.Stmt
@ -52,6 +53,7 @@ type roomStatements struct {
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
s = &roomStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
}
_, err = db.Exec(roomSchema)
if err != nil {
@ -75,8 +77,10 @@ func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
func (s *roomStatements) InsertRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
return err
})
}
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
@ -99,7 +103,9 @@ func (s *roomStatements) SelectRoomForUpdate(
func (s *roomStatements) UpdateRoom(
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
return err
})
}

View file

@ -29,9 +29,6 @@ type Database struct {
shared.Database
sqlutil.PartitionOffsetStatements
db *sql.DB
queuePDUsWriter *sqlutil.TransactionWriter
queueEDUsWriter *sqlutil.TransactionWriter
queueJSONWriter *sqlutil.TransactionWriter
}
// NewDatabase opens a new database

View file

@ -131,14 +131,17 @@ func NewTransactionWriter() *TransactionWriter {
// transactionWriterTask represents a specific task.
type transactionWriterTask struct {
db *sql.DB
txn *sql.Tx
f func(txn *sql.Tx) error
wait chan error
}
// Do queues a task to be run by a TransactionWriter. The function
// provided will be ran within a transaction as supplied by the
// database parameter. This will block until the task is finished.
func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error {
// txn parameter if one is supplied, and if not, will take out a
// new transaction from the database supplied in the database
// parameter. Either way, this will block until the task is done.
func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
if w.todo == nil {
return errors.New("not initialised")
}
@ -147,6 +150,7 @@ func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error {
}
task := transactionWriterTask{
db: db,
txn: txn,
f: f,
wait: make(chan error, 1),
}
@ -164,9 +168,15 @@ func (w *TransactionWriter) run() {
}
defer w.running.Store(false)
for task := range w.todo {
if task.txn != nil {
task.wait <- task.f(task.txn)
} else if task.db != nil {
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
return task.f(txn)
})
} else {
panic("expected database or transaction but got neither")
}
close(task.wait)
}
}

View file

@ -1114,7 +1114,7 @@ func (d *Database) StoreNewSendForDeviceMessage(
}
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place.
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AddSendToDeviceEvent(
ctx, txn, userID, deviceID, string(j),
)
@ -1179,7 +1179,7 @@ func (d *Database) CleanSendToDeviceUpdates(
// If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place.
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)