diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index c3869f3de..59cabf5da 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -32,12 +32,8 @@ func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *typ // we will just run with a normal database transaction. It'll either // succeed, processing a create event which creates the room, or it won't. if roomInfo == nil { - tx, err := d.DB.Begin() - if err != nil { - return nil, fmt.Errorf("d.DB.Begin: %w", err) - } return &RoomUpdater{ - transaction{ctx, tx}, d, nil, nil, "", 0, + transaction{ctx, txn}, d, nil, nil, "", 0, }, nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index b82f5984d..de78dd9f4 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -588,25 +588,27 @@ func (d *Database) storeEvent( // events updater because it somewhat works as a mutex, ensuring // that there's a row-level lock on the latest room events (well, // on Postgres at least). - var roomInfo *types.RoomInfo if prevEvents := event.PrevEvents(); len(prevEvents) > 0 { - roomInfo, err = d.RoomInfo(ctx, event.RoomID()) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) - } - if roomInfo == nil && len(prevEvents) > 0 { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) - } // Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of // GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This // function only does SELECTs though so the created txn (at this point) is just a read txn like // any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater // to do writes however then this will need to go inside `Writer.Do`. + succeeded := false if updater == nil { + var roomInfo *types.RoomInfo + roomInfo, err = d.RoomInfo(ctx, event.RoomID()) + if err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err) + } + if roomInfo == nil && len(prevEvents) > 0 { + return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) + } updater, err = d.GetRoomUpdater(ctx, roomInfo) if err != nil { return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) } + defer sqlutil.EndTransactionWithCheck(updater.txn, &succeeded, &err) } // Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents // and EndTransaction in a writer then it's possible for a new write txn to be made between the two @@ -618,9 +620,8 @@ func (d *Database) storeEvent( if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { return fmt.Errorf("updater.StorePreviousEvents: %w", err) } - succeeded := true - err = sqlutil.EndTransaction(updater, &succeeded) - return err + succeeded = true + return nil }) if err != nil { return 0, 0, types.StateAtEvent{}, nil, "", err