diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 45020d551..766d4f205 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -179,22 +179,19 @@ func (d *Database) RoomNIDExcludingStubs(ctx context.Context, roomID string) (ro func (d *Database) LatestEventIDs( ctx context.Context, roomNID types.RoomNID, ) (references []gomatrixserverlib.EventReference, currentStateSnapshotNID types.StateSnapshotNID, depth int64, err error) { - err = sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error { - var eventNIDs []types.EventNID - eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, txn, roomNID) - if err != nil { - return err - } - references, err = d.EventsTable.BulkSelectEventReference(ctx, txn, eventNIDs) - if err != nil { - return err - } - depth, err = d.EventsTable.SelectMaxEventDepth(ctx, txn, eventNIDs) - if err != nil { - return err - } - return nil - }) + var eventNIDs []types.EventNID + eventNIDs, currentStateSnapshotNID, err = d.RoomsTable.SelectLatestEventNIDs(ctx, nil, roomNID) + if err != nil { + return + } + references, err = d.EventsTable.BulkSelectEventReference(ctx, nil, eventNIDs) + if err != nil { + return + } + depth, err = d.EventsTable.SelectMaxEventDepth(ctx, nil, eventNIDs) + if err != nil { + return + } return } @@ -351,7 +348,12 @@ func (d *Database) MembershipUpdater( if err != nil { return nil, err } - return NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion) + var updater *MembershipUpdater + _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + updater, err = NewMembershipUpdater(ctx, d, txn, roomID, targetUserID, targetLocal, roomVersion) + return nil + }) + return updater, err } func (d *Database) GetLatestEventsForUpdate( @@ -361,7 +363,12 @@ func (d *Database) GetLatestEventsForUpdate( if err != nil { return nil, err } - return NewLatestEventsUpdater(ctx, d, txn, roomNID) + var updater *LatestEventsUpdater + _ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { + updater, err = NewLatestEventsUpdater(ctx, d, txn, roomNID) + return nil + }) + return updater, err } // nolint:gocyclo @@ -383,7 +390,7 @@ func (d *Database) StoreEvent( err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if txnAndSessionID != nil { if err = d.TransactionsTable.InsertTransaction( - ctx, nil, txnAndSessionID.TransactionID, + ctx, txn, txnAndSessionID.TransactionID, txnAndSessionID.SessionID, event.Sender(), event.EventID(), ); err != nil { return fmt.Errorf("d.TransactionsTable.InsertTransaction: %w", err)