diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index 42c0c8f2d..7bf0d521f 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -105,14 +105,12 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { // StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { - return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { - for _, ref := range previousEventReferences { - if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { - return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) - } + for _, ref := range previousEventReferences { + if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil { + return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err) } - return nil - }) + } + return nil } func (u *RoomUpdater) Events( diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index ff71fc9f1..c6884c5a8 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -680,13 +680,18 @@ func (d *Database) StoreEvent( if roomInfo == nil && len(prevEvents) > 0 { return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID()) } - updater, err = d.GetRoomUpdater(ctx, roomInfo) - if err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err) - } - defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) - if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { - return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err) + if err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + updater, err = d.GetRoomUpdater(ctx, roomInfo) + if err != nil { + return fmt.Errorf("GetRoomUpdater: %w", err) + } + defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err) + if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil { + return fmt.Errorf("updater.StorePreviousEvents: %w", err) + } + return nil + }); err != nil { + return 0, 0, types.StateAtEvent{}, nil, "", err } succeeded = true }