diff --git a/internal/sqlutil/sql.go b/internal/sqlutil/sql.go index e026f452c..8d0d2dfa5 100644 --- a/internal/sqlutil/sql.go +++ b/internal/sqlutil/sql.go @@ -40,11 +40,6 @@ type Transaction interface { // You MUST check the error returned from this function to be sure that the transaction // was applied correctly. For example, 'database is locked' errors in sqlite will happen here. func EndTransaction(txn Transaction, succeeded *bool) error { - if txn == nil { - // Sometimes in SQLite mode we have nil transactions. If that's the case - // then we are working outside of a transaction and should do nothing here. - return nil - } if *succeeded { return txn.Commit() } else { diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 883df0d0f..86c02d07e 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -61,6 +61,22 @@ func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *typ }, nil } +// Implements sqlutil.Transaction +func (u *RoomUpdater) Commit() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Commit() +} + +// Implements sqlutil.Transaction +func (u *RoomUpdater) Rollback() error { + if u.txn == nil { // SQLite mode probably + return nil + } + return u.txn.Commit() +} + // RoomVersion implements types.RoomRecentEventsUpdater func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) { return u.roomInfo.RoomVersion diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index ad2d39964..2df88534d 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -649,7 +649,7 @@ func (d *Database) storeEvent( if err != nil { return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) } - defer sqlutil.EndTransactionWithCheck(updater.txn, &succeeded, &err) + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) } if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)