diff --git a/roomserver/storage/shared/membership_updater.go b/roomserver/storage/shared/membership_updater.go index 5ddf6d84d..9184a1f40 100644 --- a/roomserver/storage/shared/membership_updater.go +++ b/roomserver/storage/shared/membership_updater.go @@ -18,21 +18,9 @@ type membershipUpdater struct { } func NewMembershipUpdater( - ctx context.Context, d *Database, roomID, targetUserID string, + ctx context.Context, d *Database, txn *sql.Tx, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, - useTxns bool, ) (types.MembershipUpdater, error) { - txn, err := d.DB.Begin() - if err != nil { - return nil, err - } - succeeded := false - defer func() { - if !succeeded { - txn.Rollback() // nolint: errcheck - } - }() - roomNID, err := d.assignRoomNID(ctx, txn, roomID, roomVersion) if err != nil { return nil, err @@ -43,17 +31,7 @@ func NewMembershipUpdater( return nil, err } - updater, err := d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) - if err != nil { - return nil, err - } - - succeeded = true - if !useTxns { - txn.Commit() // nolint: errcheck - updater.transaction.txn = nil - } - return updater, nil + return d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal) } func (d *Database) membershipUpdaterTxn( diff --git a/roomserver/storage/shared/room_recent_events_updater.go b/roomserver/storage/shared/room_recent_events_updater.go index 8131f712d..b4c2153d6 100644 --- a/roomserver/storage/shared/room_recent_events_updater.go +++ b/roomserver/storage/shared/room_recent_events_updater.go @@ -17,11 +17,7 @@ type roomRecentEventsUpdater struct { currentStateSnapshotNID types.StateSnapshotNID } -func NewRoomRecentEventsUpdater(d *Database, ctx context.Context, roomNID types.RoomNID, useTxns bool) (types.RoomRecentEventsUpdater, error) { - txn, err := d.DB.Begin() - if err != nil { - return nil, err - } +func NewRoomRecentEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomNID types.RoomNID) (types.RoomRecentEventsUpdater, error) { eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err := d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomNID) if err != nil { @@ -41,10 +37,6 @@ func NewRoomRecentEventsUpdater(d *Database, ctx context.Context, roomNID types. return nil, err } } - if !useTxns { - txn.Commit() // nolint: errcheck - txn = nil - } return &roomRecentEventsUpdater{ transaction{ctx, txn}, d, roomNID, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, }, nil diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 5494e4654..cd1ef3759 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -333,13 +333,21 @@ func (d *Database) MembershipUpdater( ctx context.Context, roomID, targetUserID string, targetLocal bool, roomVersion gomatrixserverlib.RoomVersion, ) (types.MembershipUpdater, error) { - return NewMembershipUpdater(ctx, d, roomID, targetUserID, targetLocal, roomVersion, true) + txn, err := d.DB.Begin() + if err != nil { + return nil, err + } + return NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion) } func (d *Database) GetLatestEventsForUpdate( ctx context.Context, roomNID types.RoomNID, ) (types.RoomRecentEventsUpdater, error) { - return NewRoomRecentEventsUpdater(d, ctx, roomNID, true) + txn, err := d.DB.Begin() + if err != nil { + return nil, err + } + return NewRoomRecentEventsUpdater(ctx, d, txn, roomNID) } func (d *Database) StoreEvent( diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index ae3140d7d..375a3f448 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -146,7 +146,7 @@ func (d *Database) GetLatestEventsForUpdate( // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewRoomRecentEventsUpdater(&d.Database, ctx, roomNID, false) + return shared.NewRoomRecentEventsUpdater(ctx, &d.Database, nil, roomNID) } func (d *Database) MembershipUpdater( @@ -159,5 +159,5 @@ func (d *Database) MembershipUpdater( // 'database is locked' errors. As sqlite doesn't support multi-process on the // same DB anyway, and we only execute updates sequentially, the only worries // are for rolling back when things go wrong. (atomicity) - return shared.NewMembershipUpdater(ctx, &d.Database, roomID, targetUserID, targetLocal, roomVersion, false) + return shared.NewMembershipUpdater(ctx, &d.Database, nil, roomID, targetUserID, targetLocal, roomVersion) }