diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index bdbf5e7cb..e66efb097 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -276,9 +276,10 @@ func (s *eventStatements) BulkSelectStateAtEventByID( } func (s *eventStatements) UpdateEventState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - _, err := s.updateEventStateStmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) + stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt) + _, err := stmt.ExecContext(ctx, int64(eventNID), int64(stateNID)) return err } diff --git a/roomserver/storage/postgres/published_table.go b/roomserver/storage/postgres/published_table.go index 23a9b067e..440ae7842 100644 --- a/roomserver/storage/postgres/published_table.go +++ b/roomserver/storage/postgres/published_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -63,9 +64,10 @@ func NewPostgresPublishedTable(db *sql.DB) (tables.Published, error) { } func (s *publishedStatements) UpsertRoomPublished( - ctx context.Context, roomID string, published bool, + ctx context.Context, txn *sql.Tx, roomID string, published bool, ) (err error) { - _, err = s.upsertPublishedStmt.ExecContext(ctx, roomID, published) + stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) + _, err = stmt.ExecContext(ctx, roomID, published) return } diff --git a/roomserver/storage/postgres/room_aliases_table.go b/roomserver/storage/postgres/room_aliases_table.go index 85042c54f..b603a673c 100644 --- a/roomserver/storage/postgres/room_aliases_table.go +++ b/roomserver/storage/postgres/room_aliases_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -77,9 +78,10 @@ func NewPostgresRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { } func (s *roomAliasesStatements) InsertRoomAlias( - ctx context.Context, alias string, roomID string, creatorUserID string, + ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, ) (err error) { - _, err = s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) + stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt) + _, err = stmt.ExecContext(ctx, alias, roomID, creatorUserID) return } @@ -125,8 +127,9 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias( } func (s *roomAliasesStatements) DeleteRoomAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) (err error) { - _, err = s.deleteRoomAliasStmt.ExecContext(ctx, alias) + stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt) + _, err = stmt.ExecContext(ctx, alias) return } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 9e36cba6c..aee3873a5 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -114,8 +114,8 @@ func (d *Database) EventNIDs( func (d *Database) SetState( ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { - return d.EventsTable.UpdateEventState(ctx, eventNID, stateNID) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID) }) } @@ -224,8 +224,8 @@ func (d *Database) GetRoomVersionForRoomNID( } func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { - return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { - return d.RoomAliasesTable.InsertRoomAlias(ctx, alias, roomID, creatorUserID) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.RoomAliasesTable.InsertRoomAlias(ctx, txn, alias, roomID, creatorUserID) }) } @@ -244,8 +244,8 @@ func (d *Database) GetCreatorIDForAlias( } func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error { - return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { - return d.RoomAliasesTable.DeleteRoomAlias(ctx, alias) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.RoomAliasesTable.DeleteRoomAlias(ctx, txn, alias) }) } @@ -471,8 +471,8 @@ func (d *Database) StoreEvent( } func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool) error { - return d.Writer.Do(d.DB, nil, func(_ *sql.Tx) error { - return d.PublishedTable.UpsertRoomPublished(ctx, roomID, publish) + return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + return d.PublishedTable.UpsertRoomPublished(ctx, txn, roomID, publish) }) } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 26ea1d415..a866c85d0 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -279,9 +279,10 @@ func (s *eventStatements) BulkSelectStateAtEventByID( } func (s *eventStatements) UpdateEventState( - ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID, + ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID, ) error { - _, err := s.updateEventStateStmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) + stmt := sqlutil.TxStmt(txn, s.updateEventStateStmt) + _, err := stmt.ExecContext(ctx, int64(stateNID), int64(eventNID)) return err } diff --git a/roomserver/storage/sqlite3/published_table.go b/roomserver/storage/sqlite3/published_table.go index 1d6ccd561..dcf6f697a 100644 --- a/roomserver/storage/sqlite3/published_table.go +++ b/roomserver/storage/sqlite3/published_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -65,9 +66,10 @@ func NewSqlitePublishedTable(db *sql.DB) (tables.Published, error) { } func (s *publishedStatements) UpsertRoomPublished( - ctx context.Context, roomID string, published bool, + ctx context.Context, txn *sql.Tx, roomID string, published bool, ) error { - _, err := s.upsertPublishedStmt.ExecContext(ctx, roomID, published) + stmt := sqlutil.TxStmt(txn, s.upsertPublishedStmt) + _, err := stmt.ExecContext(ctx, roomID, published) return err } diff --git a/roomserver/storage/sqlite3/room_aliases_table.go b/roomserver/storage/sqlite3/room_aliases_table.go index a16e97aa5..f053e3981 100644 --- a/roomserver/storage/sqlite3/room_aliases_table.go +++ b/roomserver/storage/sqlite3/room_aliases_table.go @@ -20,6 +20,7 @@ import ( "database/sql" "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" ) @@ -81,9 +82,10 @@ func NewSqliteRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) { } func (s *roomAliasesStatements) InsertRoomAlias( - ctx context.Context, alias string, roomID string, creatorUserID string, + ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string, ) error { - _, err := s.insertRoomAliasStmt.ExecContext(ctx, alias, roomID, creatorUserID) + stmt := sqlutil.TxStmt(txn, s.insertRoomAliasStmt) + _, err := stmt.ExecContext(ctx, alias, roomID, creatorUserID) return err } @@ -131,8 +133,9 @@ func (s *roomAliasesStatements) SelectCreatorIDFromAlias( } func (s *roomAliasesStatements) DeleteRoomAlias( - ctx context.Context, alias string, + ctx context.Context, txn *sql.Tx, alias string, ) error { - _, err := s.deleteRoomAliasStmt.ExecContext(ctx, alias) + stmt := sqlutil.TxStmt(txn, s.deleteRoomAliasStmt) + _, err := stmt.ExecContext(ctx, alias) return err } diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 78273b3cc..47c12c2ca 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -42,7 +42,7 @@ type Events interface { // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error) - UpdateEventState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error + UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error) UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error SelectEventID(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (eventID string, err error) @@ -84,11 +84,11 @@ type StateBlock interface { } type RoomAliases interface { - InsertRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) (err error) + InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error) SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error) SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error) - DeleteRoomAlias(ctx context.Context, alias string) (err error) + DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error) } type PreviousEvents interface { @@ -123,7 +123,7 @@ type Membership interface { } type Published interface { - UpsertRoomPublished(ctx context.Context, roomID string, published bool) (err error) + UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error) SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error) SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error) }