mirror of
https://github.com/matrix-org/dendrite.git
synced 2025-12-21 05:43:09 -06:00
Extend TransactionWriter to use optional existing transaction, use that for FS SQLite database writes
This commit is contained in:
parent
dabb304d99
commit
7b078ed8e3
|
|
@ -61,6 +61,7 @@ const selectAllJoinedHostsSQL = "" +
|
||||||
|
|
||||||
type joinedHostsStatements struct {
|
type joinedHostsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertJoinedHostsStmt *sql.Stmt
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
deleteJoinedHostsStmt *sql.Stmt
|
deleteJoinedHostsStmt *sql.Stmt
|
||||||
selectJoinedHostsStmt *sql.Stmt
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
|
|
@ -69,7 +70,8 @@ type joinedHostsStatements struct {
|
||||||
|
|
||||||
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
||||||
s = &joinedHostsStatements{
|
s = &joinedHostsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
}
|
}
|
||||||
_, err = db.Exec(joinedHostsSchema)
|
_, err = db.Exec(joinedHostsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -96,21 +98,25 @@ func (s *joinedHostsStatements) InsertJoinedHosts(
|
||||||
roomID, eventID string,
|
roomID, eventID string,
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
|
||||||
return err
|
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) DeleteJoinedHosts(
|
func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||||
) error {
|
) error {
|
||||||
for _, eventID := range eventIDs {
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
|
for _, eventID := range eventIDs {
|
||||||
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
|
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
|
||||||
return err
|
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
return nil
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,7 @@ const selectQueueServerNamesSQL = "" +
|
||||||
|
|
||||||
type queueEDUsStatements struct {
|
type queueEDUsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertQueueEDUStmt *sql.Stmt
|
insertQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUStmt *sql.Stmt
|
selectQueueEDUStmt *sql.Stmt
|
||||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||||
|
|
@ -68,7 +69,8 @@ type queueEDUsStatements struct {
|
||||||
|
|
||||||
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
|
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
|
||||||
s = &queueEDUsStatements{
|
s = &queueEDUsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
}
|
}
|
||||||
_, err = db.Exec(queueEDUsSchema)
|
_, err = db.Exec(queueEDUsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -99,15 +101,17 @@ func (s *queueEDUsStatements) InsertQueueEDU(
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
nid int64,
|
nid int64,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(
|
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
|
||||||
ctx,
|
_, err := stmt.ExecContext(
|
||||||
userID, // destination user ID
|
ctx,
|
||||||
deviceID, // destination device ID
|
userID, // destination user ID
|
||||||
serverName, // destination server name
|
deviceID, // destination device ID
|
||||||
nid, // JSON blob NID
|
serverName, // destination server name
|
||||||
)
|
nid, // JSON blob NID
|
||||||
return err
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueEDUsStatements) SelectQueueEDU(
|
func (s *queueEDUsStatements) SelectQueueEDU(
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ const selectJSONSQL = "" +
|
||||||
|
|
||||||
type queueJSONStatements struct {
|
type queueJSONStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertJSONStmt *sql.Stmt
|
insertJSONStmt *sql.Stmt
|
||||||
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
|
|
@ -57,7 +58,8 @@ type queueJSONStatements struct {
|
||||||
|
|
||||||
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
|
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
|
||||||
s = &queueJSONStatements{
|
s = &queueJSONStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
}
|
}
|
||||||
_, err = db.Exec(queueJSONSchema)
|
_, err = db.Exec(queueJSONSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -71,17 +73,20 @@ func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
|
||||||
|
|
||||||
func (s *queueJSONStatements) InsertQueueJSON(
|
func (s *queueJSONStatements) InsertQueueJSON(
|
||||||
ctx context.Context, txn *sql.Tx, json string,
|
ctx context.Context, txn *sql.Tx, json string,
|
||||||
) (int64, error) {
|
) (lastid int64, err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
|
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
res, err := stmt.ExecContext(ctx, json)
|
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
|
||||||
if err != nil {
|
res, err := stmt.ExecContext(ctx, json)
|
||||||
return 0, fmt.Errorf("stmt.QueryContext: %w", err)
|
if err != nil {
|
||||||
}
|
return fmt.Errorf("stmt.QueryContext: %w", err)
|
||||||
lastid, err := res.LastInsertId()
|
}
|
||||||
if err != nil {
|
lastid, err = res.LastInsertId()
|
||||||
return 0, fmt.Errorf("res.LastInsertId: %w", err)
|
if err != nil {
|
||||||
}
|
return fmt.Errorf("res.LastInsertId: %w", err)
|
||||||
return lastid, nil
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueJSONStatements) DeleteQueueJSON(
|
func (s *queueJSONStatements) DeleteQueueJSON(
|
||||||
|
|
@ -98,9 +103,11 @@ func (s *queueJSONStatements) DeleteQueueJSON(
|
||||||
iNIDs[k] = v
|
iNIDs[k] = v
|
||||||
}
|
}
|
||||||
|
|
||||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err = stmt.ExecContext(ctx, iNIDs...)
|
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||||
return err
|
_, err = stmt.ExecContext(ctx, iNIDs...)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queueJSONStatements) SelectQueueJSON(
|
func (s *queueJSONStatements) SelectQueueJSON(
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ const selectQueuePDUsServerNamesSQL = "" +
|
||||||
|
|
||||||
type queuePDUsStatements struct {
|
type queuePDUsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertQueuePDUStmt *sql.Stmt
|
insertQueuePDUStmt *sql.Stmt
|
||||||
deleteQueueTransactionPDUsStmt *sql.Stmt
|
deleteQueueTransactionPDUsStmt *sql.Stmt
|
||||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||||
|
|
@ -80,7 +81,8 @@ type queuePDUsStatements struct {
|
||||||
|
|
||||||
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
||||||
s = &queuePDUsStatements{
|
s = &queuePDUsStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
}
|
}
|
||||||
_, err = db.Exec(queuePDUsSchema)
|
_, err = db.Exec(queuePDUsSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -117,14 +119,16 @@ func (s *queuePDUsStatements) InsertQueuePDU(
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
nid int64,
|
nid int64,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(
|
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
|
||||||
ctx,
|
_, err := stmt.ExecContext(
|
||||||
transactionID, // the transaction ID that we initially attempted
|
ctx,
|
||||||
serverName, // destination server name
|
transactionID, // the transaction ID that we initially attempted
|
||||||
nid, // JSON blob NID
|
serverName, // destination server name
|
||||||
)
|
nid, // JSON blob NID
|
||||||
return err
|
)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) DeleteQueuePDUTransaction(
|
func (s *queuePDUsStatements) DeleteQueuePDUTransaction(
|
||||||
|
|
@ -132,9 +136,11 @@ func (s *queuePDUsStatements) DeleteQueuePDUTransaction(
|
||||||
serverName gomatrixserverlib.ServerName,
|
serverName gomatrixserverlib.ServerName,
|
||||||
transactionID gomatrixserverlib.TransactionID,
|
transactionID gomatrixserverlib.TransactionID,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(ctx, serverName, transactionID)
|
stmt := sqlutil.TxStmt(txn, s.deleteQueueTransactionPDUsStmt)
|
||||||
return err
|
_, err := stmt.ExecContext(ctx, serverName, transactionID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
|
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ const updateRoomSQL = "" +
|
||||||
|
|
||||||
type roomStatements struct {
|
type roomStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
writer *sqlutil.TransactionWriter
|
||||||
insertRoomStmt *sql.Stmt
|
insertRoomStmt *sql.Stmt
|
||||||
selectRoomForUpdateStmt *sql.Stmt
|
selectRoomForUpdateStmt *sql.Stmt
|
||||||
updateRoomStmt *sql.Stmt
|
updateRoomStmt *sql.Stmt
|
||||||
|
|
@ -51,7 +52,8 @@ type roomStatements struct {
|
||||||
|
|
||||||
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
|
func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
|
||||||
s = &roomStatements{
|
s = &roomStatements{
|
||||||
db: db,
|
db: db,
|
||||||
|
writer: sqlutil.NewTransactionWriter(),
|
||||||
}
|
}
|
||||||
_, err = db.Exec(roomSchema)
|
_, err = db.Exec(roomSchema)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -75,8 +77,10 @@ func NewSQLiteRoomsTable(db *sql.DB) (s *roomStatements, err error) {
|
||||||
func (s *roomStatements) InsertRoom(
|
func (s *roomStatements) InsertRoom(
|
||||||
ctx context.Context, txn *sql.Tx, roomID string,
|
ctx context.Context, txn *sql.Tx, roomID string,
|
||||||
) error {
|
) error {
|
||||||
_, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
return err
|
_, err := sqlutil.TxStmt(txn, s.insertRoomStmt).ExecContext(ctx, roomID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
// selectRoomForUpdate locks the row for the room and returns the last_event_id.
|
||||||
|
|
@ -99,7 +103,9 @@ func (s *roomStatements) SelectRoomForUpdate(
|
||||||
func (s *roomStatements) UpdateRoom(
|
func (s *roomStatements) UpdateRoom(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
|
ctx context.Context, txn *sql.Tx, roomID, lastEventID string,
|
||||||
) error {
|
) error {
|
||||||
stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
|
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||||
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
|
stmt := sqlutil.TxStmt(txn, s.updateRoomStmt)
|
||||||
return err
|
_, err := stmt.ExecContext(ctx, roomID, lastEventID)
|
||||||
|
return err
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,10 +28,7 @@ import (
|
||||||
type Database struct {
|
type Database struct {
|
||||||
shared.Database
|
shared.Database
|
||||||
sqlutil.PartitionOffsetStatements
|
sqlutil.PartitionOffsetStatements
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
queuePDUsWriter *sqlutil.TransactionWriter
|
|
||||||
queueEDUsWriter *sqlutil.TransactionWriter
|
|
||||||
queueJSONWriter *sqlutil.TransactionWriter
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDatabase opens a new database
|
// NewDatabase opens a new database
|
||||||
|
|
|
||||||
|
|
@ -131,14 +131,17 @@ func NewTransactionWriter() *TransactionWriter {
|
||||||
// transactionWriterTask represents a specific task.
|
// transactionWriterTask represents a specific task.
|
||||||
type transactionWriterTask struct {
|
type transactionWriterTask struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
txn *sql.Tx
|
||||||
f func(txn *sql.Tx) error
|
f func(txn *sql.Tx) error
|
||||||
wait chan error
|
wait chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do queues a task to be run by a TransactionWriter. The function
|
// Do queues a task to be run by a TransactionWriter. The function
|
||||||
// provided will be ran within a transaction as supplied by the
|
// provided will be ran within a transaction as supplied by the
|
||||||
// database parameter. This will block until the task is finished.
|
// txn parameter if one is supplied, and if not, will take out a
|
||||||
func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error {
|
// new transaction from the database supplied in the database
|
||||||
|
// parameter. Either way, this will block until the task is done.
|
||||||
|
func (w *TransactionWriter) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error {
|
||||||
if w.todo == nil {
|
if w.todo == nil {
|
||||||
return errors.New("not initialised")
|
return errors.New("not initialised")
|
||||||
}
|
}
|
||||||
|
|
@ -147,6 +150,7 @@ func (w *TransactionWriter) Do(db *sql.DB, f func(txn *sql.Tx) error) error {
|
||||||
}
|
}
|
||||||
task := transactionWriterTask{
|
task := transactionWriterTask{
|
||||||
db: db,
|
db: db,
|
||||||
|
txn: txn,
|
||||||
f: f,
|
f: f,
|
||||||
wait: make(chan error, 1),
|
wait: make(chan error, 1),
|
||||||
}
|
}
|
||||||
|
|
@ -164,9 +168,15 @@ func (w *TransactionWriter) run() {
|
||||||
}
|
}
|
||||||
defer w.running.Store(false)
|
defer w.running.Store(false)
|
||||||
for task := range w.todo {
|
for task := range w.todo {
|
||||||
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
|
if task.txn != nil {
|
||||||
return task.f(txn)
|
task.wait <- task.f(task.txn)
|
||||||
})
|
} else if task.db != nil {
|
||||||
|
task.wait <- WithTransaction(task.db, func(txn *sql.Tx) error {
|
||||||
|
return task.f(txn)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
panic("expected database or transaction but got neither")
|
||||||
|
}
|
||||||
close(task.wait)
|
close(task.wait)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1114,7 +1114,7 @@ func (d *Database) StoreNewSendForDeviceMessage(
|
||||||
}
|
}
|
||||||
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
|
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
|
||||||
// that we don't lock the table for writes in more than one place.
|
// that we don't lock the table for writes in more than one place.
|
||||||
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
|
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
return d.AddSendToDeviceEvent(
|
return d.AddSendToDeviceEvent(
|
||||||
ctx, txn, userID, deviceID, string(j),
|
ctx, txn, userID, deviceID, string(j),
|
||||||
)
|
)
|
||||||
|
|
@ -1179,7 +1179,7 @@ func (d *Database) CleanSendToDeviceUpdates(
|
||||||
// If we need to write to the database then we'll ask the SendToDeviceWriter to
|
// If we need to write to the database then we'll ask the SendToDeviceWriter to
|
||||||
// do that for us. It'll guarantee that we don't lock the table for writes in
|
// do that for us. It'll guarantee that we don't lock the table for writes in
|
||||||
// more than one place.
|
// more than one place.
|
||||||
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
|
err = d.SendToDeviceWriter.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
// Delete any send-to-device messages marked for deletion.
|
// Delete any send-to-device messages marked for deletion.
|
||||||
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
|
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
|
||||||
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
|
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue