From 64f26be89e9260028ddb05001ec408709b76dfa6 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 22 Jun 2026 17:41:53 +0000 Subject: [PATCH 01/10] Datapipes producer/consumer refactor + stream overlap Refactor the datapipe prefetch path into a thread-safe host producer (_load_host) plus a main-thread consumer (_consume) with a FIFO submit/consume primitive (io_pump.IOPump), so all device/Warp kernels launch on the consuming thread. Build deferred-sync stream overlap on top: _consume records the preprocessing CUDA event into _events_pending and the DataLoader does one-batch lookahead, inserting compute_stream.wait_event just before each yield so batch N+1 preprocessing overlaps batch N compute. - New io_pump.py (FIFO pump); producer/consumer protocols; _rng fork_generator; core.function_spec.warp_stream_from_torch; refactored readers (base/numpy/zarr/tensorstore_zarr) and datapipes __init__. - MeshDataset parallel disk read + pin (serialize_load_consume=False); DomainMeshReader drop_interior_cells / drop_in_file_boundaries; volume configs enable both. - radius_search pinned non_blocking H2D; recipe train loop pinned async loss D2H. Opt-in timing + torch.profiler labels; streaming + reader tests, docs, and the iterable-dataset tutorial. No Warp keepalive machinery: with the Warp-free SDF (parent branch) the datapipe no longer launches Warp kernels in _consume. --- docs/api/datapipes/physicsnemo.datapipes.rst | 60 ++ .../datasets/drivaer_ml_volume.yaml | 10 + .../datasets/highlift_volume.yaml | 6 + .../unified_external_aero_recipe/src/train.py | 103 ++- examples/minimal/datapipes/README.md | 38 ++ .../tutorial_5_iterable_online_simulation.py | 194 ++++++ physicsnemo/core/function_spec.py | 46 +- physicsnemo/datapipes/RNG.md | 79 ++- physicsnemo/datapipes/__init__.py | 3 +- physicsnemo/datapipes/_rng.py | 59 ++ physicsnemo/datapipes/dataloader.py | 394 +++++++++-- physicsnemo/datapipes/datapipes.md | 203 +++--- physicsnemo/datapipes/dataset.py | 187 +++--- physicsnemo/datapipes/io_pump.py | 220 +++++++ physicsnemo/datapipes/mesh_dataset.py | 224 ++++--- physicsnemo/datapipes/multi_dataset.py | 70 ++ physicsnemo/datapipes/protocols.py | 527 +++++++++++++-- physicsnemo/datapipes/readers/base.py | 54 +- physicsnemo/datapipes/readers/mesh.py | 149 ++++- physicsnemo/datapipes/readers/numpy.py | 32 +- .../datapipes/readers/tensorstore_zarr.py | 22 +- physicsnemo/datapipes/readers/zarr.py | 22 +- .../datapipes/transforms/mesh/transforms.py | 24 +- .../neighbors/radius_search/_warp_impl.py | 63 +- test/datapipes/core/test_dataset.py | 57 ++ test/datapipes/core/test_streaming.py | 617 ++++++++++++++++++ .../readers/test_numpy_consolidated.py | 104 +++ .../transforms/test_mesh_augmentations.py | 9 +- 28 files changed, 3030 insertions(+), 546 deletions(-) create mode 100644 examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py create mode 100644 physicsnemo/datapipes/io_pump.py create mode 100644 test/datapipes/core/test_streaming.py diff --git a/docs/api/datapipes/physicsnemo.datapipes.rst b/docs/api/datapipes/physicsnemo.datapipes.rst index 82b438f25b..047360e65b 100644 --- a/docs/api/datapipes/physicsnemo.datapipes.rst +++ b/docs/api/datapipes/physicsnemo.datapipes.rst @@ -139,6 +139,29 @@ of ``pin_memory`` from the ``DataLoader`` class to the ``Reader`` classes. This is because of the much earlier GPU data transfer in the PhysicsNeMo datapipe compared to PyTorch. +The ``DataLoader`` drives one of two mutually-exclusive paths, selected by +dataset type: + +- **Map-style preload path** (:class:`~physicsnemo.datapipes.DatasetBase`, + e.g. ``Dataset``, ``MeshDataset``): a dedicated dispatcher thread keeps a + *bounded* number of samples in flight by pulling the index stream + **lazily** under backpressure and submitting host-only loads to a worker + pool. The main thread consumes the resulting samples in order + (host-to-device transfer plus transforms on a preprocessing stream) and + reassembles batches from boundary markers, so the full epoch is never + materialized up front and irregular batch sizes are supported. +- **Iterable generator path** (:class:`~physicsnemo.datapipes.IterableDatasetBase`): + a generator dataset driven entirely on the main thread (no sampler, no + worker pool); see `Iterable Datasets`_ below. + +In both paths, **all device-kernel launches happen on the single main +thread**. This is the real constraint for Warp-based transforms: Warp may +launch on any CUDA stream as long as the launch comes from the main thread +and Warp's current stream is bound to the torch stream in use. +Preprocessing therefore runs on a side (preprocessing) stream and is +ordered against the compute stream with a CUDA event, so it overlaps +training without ever blocking the host. + .. autoclass:: physicsnemo.datapipes.dataloader.DataLoader :members: :show-inheritance: @@ -179,6 +202,43 @@ consistent keys. Because the exact collation details differ by dataset, the :show-inheritance: +Iterable Datasets +^^^^^^^^^^^^^^^^^ + +Map-style datasets (``Dataset``, ``MeshDataset``) assume a fixed length and a +sampler that hands out indices. Some workloads have neither: an online +simulation, a procedural generator, or any source that produces samples on +the fly with no meaningful ``__len__``. For those, subclass +:class:`~physicsnemo.datapipes.IterableDatasetBase` and yield samples from +``__iter__``. The ``DataLoader`` detects an iterable dataset automatically +and switches to the main-thread-only generator path; ``shuffle`` and +``sampler`` are ignored (a warning is issued) and ``len(loader)`` raises, +since the length is unknown. + +An iterable dataset chooses one of two emission modes via the +``yields_batches`` attribute: + +- ``yields_batches = False`` (default): ``__iter__`` yields individual + samples and the ``DataLoader`` collates them into batches of + ``batch_size`` (honoring ``drop_last``). +- ``yields_batches = True``: ``__iter__`` yields fully-formed batches and the + ``DataLoader`` passes them through without further collation, which is the + natural fit for a generator that already produces a batch per step. + +Reproducibility follows a per-``(epoch, position)`` scheme rather than the +map-style per-``(epoch, index)`` scheme: implement ``set_epoch`` and/or +``set_generator`` to seed deterministically from the iteration position. +Because the generator runs on the main thread, Warp kernels inside it are +safe on any preprocessing stream the ``DataLoader`` binds. See the online +simulation tutorial in the +`examples directory `_ +for a runnable Warp ``Darcy2D`` generator wired through this path. + +.. autoclass:: physicsnemo.datapipes.IterableDatasetBase + :members: + :show-inheritance: + + Readers ^^^^^^^ diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/drivaer_ml_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/drivaer_ml_volume.yaml index d6a27d5335..244899cea9 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/drivaer_ml_volume.yaml +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/drivaer_ml_volume.yaml @@ -32,6 +32,16 @@ pipeline: pattern: "run_*/domain_*.pdmsh" subsample_n_points: ${sampling_resolution} subsample_n_cells: ${sampling_resolution} + # The volume model consumes the interior as a point cloud (interior.points + + # interior.point_data); dropping interior tet topology makes the point + # subsample a cheap contiguous block read instead of a full slice_points + # remap (the dominant per-sample IO cost otherwise). + drop_interior_cells: true + # The in-file `vehicle` surface boundary is unused by the volume pipeline + # (SDF comes from the injected stl_geometry below; model/collate use only + # the interior). Drop it so we don't subsample (expensive, GIL-held) or pin + # it every sample. + drop_in_file_boundaries: true extra_boundaries: stl_geometry: pattern: "*_single_solid.stl.pmsh" diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/highlift_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/highlift_volume.yaml index b0c6f8f602..c4026ec01f 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/highlift_volume.yaml +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/highlift_volume.yaml @@ -34,6 +34,12 @@ pipeline: pattern: "geo_LHC*/*.pdmsh" subsample_n_points: ${sampling_resolution} subsample_n_cells: ${sampling_resolution} + # Interior consumed as a point cloud; drop tet topology so the point + # subsample is a cheap contiguous block read (see drivaer_ml_volume.yaml). + drop_interior_cells: true + # In-file boundaries unused by the volume pipeline (SDF uses stl_geometry); + # drop them to skip per-sample subsample + pin. + drop_in_file_boundaries: true extra_boundaries: stl_geometry: pattern: "*.stl.pmsh" diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py index 7a48ea8957..00888ac4a7 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py @@ -271,18 +271,23 @@ def forward_pass( targets: TensorDict = batch["targets"] ### Inputs keep their native dtype; autocast handles model-internal precision. - with get_autocast_context(precision): - output = model(**forward_kwargs) + with torch.profiler.record_function("forward_pass: model forward"): + with get_autocast_context(precision): + output = model(**forward_kwargs) - pred_td = normalize_output_to_tensordict(output, target_config, output_type) + with torch.profiler.record_function("forward_pass: normalize_output_to_tensordict"): + pred_td = normalize_output_to_tensordict(output, target_config, output_type) ### Loss runs in float32 to avoid bf16 precision loss in the reduction. - pred_f32 = pred_td.float() - target_f32 = targets.float() + with torch.profiler.record_function("forward_pass: .float() cast"): + pred_f32 = pred_td.float() + target_f32 = targets.float() - loss, loss_td = loss_calculator(pred_f32, target_f32) + with torch.profiler.record_function("forward_pass: loss_calculator"): + loss, loss_td = loss_calculator(pred_f32, target_f32) with torch.no_grad(): - metric_td = metric_calculator(pred_f32, target_f32) + with torch.profiler.record_function("forward_pass: metric_calculator"): + metric_td = metric_calculator(pred_f32, target_f32) ### Detach (don't sync) the per-field TDs so the caller controls when ### a D2H copy happens; running ``.item()`` here would serialise the ### forward kernels against the host. ``TensorDict.detach()`` walks @@ -365,31 +370,55 @@ def _run_epoch( n_local = 0 num_steps = len(dataloader) epoch_t0 = time.perf_counter() + ### Single pinned scalar buffer reused every step so the loss D2H + ### transfer is async (non_blocking=True from device to pinned host + ### memory). The copy is issued right after forward_pass and read + ### just before the logger line; by then backward + optimizer.step + ### have run, giving the GPU time to complete the copy without + ### blocking the host. + _loss_pinned = ( + torch.zeros(1, pin_memory=True) if torch.cuda.is_available() else None + ) with grad_ctx: step_t0 = time.perf_counter() for i, batch in enumerate(dataloader): - batch = recursive_to_device(batch, dist_manager.device) + with torch.profiler.record_function("_run_epoch: recursive_to_device"): + batch = recursive_to_device(batch, dist_manager.device) + + with torch.profiler.record_function("_run_epoch: forward_pass"): + loss, losses, metrics = forward_pass( + batch, + model, + precision, + loss_calculator, + metric_calculator, + output_type=output_type, + target_config=target_config, + ) - loss, losses, metrics = forward_pass( - batch, - model, - precision, - loss_calculator, - metric_calculator, - output_type=output_type, - target_config=target_config, - ) + ### Kick off the async D2H copy of the scalar loss value into the + ### pinned buffer. Backward + optimizer.step run while the copy is + ### in flight, so by the time we call .item() below the transfer + ### is already done and there is no host stall. + if _loss_pinned is not None: + _loss_pinned.copy_(loss.detach(), non_blocking=True) if is_train: - optimizer.zero_grad() + with torch.profiler.record_function("_run_epoch: optimizer.zero_grad"): + optimizer.zero_grad() if precision == "float16" and scaler is not None: - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() + with torch.profiler.record_function("_run_epoch: scaler.backward"): + scaler.scale(loss).backward() + with torch.profiler.record_function("_run_epoch: scaler.step"): + scaler.step(optimizer) + with torch.profiler.record_function("_run_epoch: scaler.update"): + scaler.update() else: - loss.backward() - optimizer.step() + with torch.profiler.record_function("_run_epoch: loss.backward"): + loss.backward() + with torch.profiler.record_function("_run_epoch: optimizer.step"): + optimizer.step() if cfg.training.get("scheduler_update_mode", "epoch") == "step": scheduler.step() @@ -407,9 +436,14 @@ def _run_epoch( total_metrics_td.add_(metrics) n_local += 1 - ### Per-step sync for the print line; lands after backward + - ### optimizer.step so it overlaps with queued GPU work. - this_loss = loss.detach().item() + ### Read the loss scalar from the pinned buffer; the async copy + ### was issued before backward so it has had the full backward + + ### optimizer.step to complete without stalling the host. + with torch.profiler.record_function("_run_epoch: loss.item() D2H readback"): + if _loss_pinned is not None: + this_loss = _loss_pinned.item() + else: + this_loss = loss.detach().item() total_loss += this_loss step_dt = time.perf_counter() - step_t0 @@ -719,13 +753,16 @@ def benchmark_io_epoch( ) for name, t in named_tensors: v_flat = t.float() if t.is_floating_point() else t.to(torch.float32) - logger.info( - f" {name:30s} " - f"min={v_flat.min().item(): .6e} " - f"mean={v_flat.mean().item(): .6e} " - f"std={v_flat.std().item(): .6e} " - f"max={v_flat.max().item(): .6e}" - ) + with torch.profiler.record_function( + "benchmark_io: tensor stats .item() D2H" + ): + logger.info( + f" {name:30s} " + f"min={v_flat.min().item(): .6e} " + f"mean={v_flat.mean().item(): .6e} " + f"std={v_flat.std().item(): .6e} " + f"max={v_flat.max().item(): .6e}" + ) if max_steps is not None and i + 1 >= max_steps: break diff --git a/examples/minimal/datapipes/README.md b/examples/minimal/datapipes/README.md index 6934467232..949d615e44 100644 --- a/examples/minimal/datapipes/README.md +++ b/examples/minimal/datapipes/README.md @@ -215,3 +215,41 @@ python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud \ subsample.n_points=5000 ``` + +### Tutorial 5: Iterable Datasets for Online Simulation + +**File:** `tutorial_5_iterable_online_simulation.py` + +Not every dataset is map-style. When data is *generated* on the fly -- +an online physics simulation, a procedural sampler, an unbounded stream -- +there is no fixed length and no index to address. PhysicsNeMo models these +with `IterableDatasetBase`: + +- Iterable datasets only support iteration: no `__len__`, no `__getitem__`, + no sampler, and no prefetch worker pool. +- They run entirely on the **main thread**, so they may freely launch Warp + kernels and use CUDA streams. Warp's only requirement is a single + launching thread, which the main thread satisfies -- this is what makes + an online GPU simulation safe on this path. +- The `DataLoader` still drives generation on a preprocessing stream (when + `use_streams=True`) and hands each item to the compute stream via a CUDA + event, so generating the next batch can overlap training on the current + one. + +This tutorial wraps the built-in Warp `Darcy2D` flow generator (a +multigrid Jacobi solver) as an iterable dataset and iterates it through the +`DataLoader`. Because `Darcy2D` produces a full batch per step, the wrapper +is *self-batching* (`yields_batches = True`) and the loader passes each +batch through unchanged. + +**When to use the iterable path vs the map/descriptor path:** use the +map-style `Dataset` when you have a fixed corpus on disk addressable by +index (storage-backed, benefits from threaded prefetch). Use an +`IterableDatasetBase` when samples are produced by a generator/simulation, +the length is unbounded or unknown, or the producer itself must launch +device kernels. + +```bash +# Requires a CUDA device (the Darcy solver runs Warp kernels on the GPU) +python tutorial_5_iterable_online_simulation.py +``` diff --git a/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py b/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py new file mode 100644 index 0000000000..09c14d9e90 --- /dev/null +++ b/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py @@ -0,0 +1,194 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tutorial 5: Iterable datasets for online simulation. + +Most datasets are *map-style*: a fixed number of samples addressed by +index, read from storage. Some workloads instead *generate* data on the +fly -- an online physics simulation, a procedural sampler, a streaming +source with no fixed length. These are *iterable* datasets. + +PhysicsNeMo models this with :class:`IterableDatasetBase`. Unlike a +map-style dataset, an iterable dataset: + +- has no length and no indexing -- it only supports iteration; +- is driven entirely on the **main thread** (no worker pool), so it may + freely launch Warp kernels and use CUDA streams. This is exactly the + property that makes an online GPU simulation safe here: Warp's + constraint is a single launching thread, which the main thread + satisfies. + +This tutorial wraps the built-in Warp ``Darcy2D`` flow generator -- which +solves the 2D Darcy equation with a multigrid Jacobi solver and yields a +ready-made batch each step -- as an iterable dataset and drives it through +the PhysicsNeMo :class:`DataLoader`. + +Run with:: + + python tutorial_5_iterable_online_simulation.py + +Requires a CUDA device (the Darcy solver runs Warp kernels on the GPU). +""" + +from __future__ import annotations + +import time + +import numpy as np +import torch + +from physicsnemo.datapipes import DataLoader, IterableDatasetBase +from physicsnemo.datapipes.benchmarks.darcy import Darcy2D + + +class DarcyOnlineDataset(IterableDatasetBase): + """Online 2D Darcy-flow simulation as an iterable dataset. + + Wraps :class:`~physicsnemo.datapipes.benchmarks.darcy.Darcy2D`, whose + iterator runs the solver and yields a full ``{"permeability", "darcy"}`` + batch per step. The underlying generator is infinite, so this wrapper + caps it at ``num_batches`` per epoch to give the loader a finite stream. + + Because ``Darcy2D`` already produces a complete batch, this is a + *self-batching* dataset: we set :attr:`yields_batches` so the loader + passes each batch through unchanged instead of re-collating. + + Parameters + ---------- + num_batches : int + Number of batches to emit per epoch. + resolution : int, default=64 + Simulation grid resolution. + batch_size : int, default=8 + Number of simulations per batch. + device : str, default="cuda" + Device the Warp solver runs on. + base_seed : int, default=0 + Base seed for reproducible permeability sampling. + """ + + # Darcy2D emits a full batch per step; do not re-collate. + yields_batches = True + + def __init__( + self, + num_batches: int, + *, + resolution: int = 64, + batch_size: int = 8, + device: str = "cuda", + base_seed: int = 0, + ) -> None: + self._sim = Darcy2D( + resolution=resolution, + batch_size=batch_size, + device=device, + normaliser={"permeability": (1.25, 0.75), "darcy": (4.52e-2, 2.79e-2)}, + ) + self._num_batches = num_batches + self._base_seed = base_seed + self._epoch = 0 + + def set_epoch(self, epoch: int) -> None: + """Select the epoch so each epoch draws a distinct, reproducible stream.""" + self._epoch = epoch + + def __iter__(self): + # One solver iterator drives the simulation; we pull a bounded + # number of steps from the otherwise-infinite generator. + sim_iter = iter(self._sim) + for position in range(self._num_batches): + # Per-(epoch, position) seeding: the stream is reproducible + # across runs and distinct across epochs and positions. There is + # no stable sample index for a generator, so we key on the + # monotonic emission position instead. + seed = np.random.SeedSequence( + [self._base_seed, self._epoch, position] + ).generate_state(1)[0] + np.random.seed(int(seed)) + yield next(sim_iter) + + +def main() -> None: + """Run the online Darcy simulation through the PhysicsNeMo ``DataLoader``. + + Builds the iterable :class:`DarcyOnlineDataset`, drives it for a few + epochs with stream overlap enabled, and prints per-batch and per-epoch + host/CUDA timing so the generation/compute overlap is visible. No-ops + on machines without CUDA, since the Warp Darcy solver requires a GPU. + """ + if not torch.cuda.is_available(): + print("This tutorial requires a CUDA device (Warp Darcy solver). Skipping.") + return + + num_epochs = 5 + num_batches = 16 + dataset = DarcyOnlineDataset(num_batches=num_batches, resolution=64, batch_size=8) + + # use_streams=True runs each simulation step on a preprocessing stream + # and hands the result to the compute stream via a CUDA event, so + # generation of the next batch can overlap training on the current one. + loader = DataLoader(dataset, use_streams=True, seed=0) + + # Iterable datasets have no length: this will take the exception path.oOOh, + try: + len(loader) + except TypeError as exc: + print(f"len(loader) is undefined for iterable datasets: {exc}") + + for epoch in range(num_epochs): + loader.set_epoch(epoch) + print(f"\nEpoch {epoch}") + host_times = [] + cuda_events = [] + epoch_start = time.perf_counter() + prev_host = epoch_start + cuda_start = torch.cuda.Event(enable_timing=True) + cuda_start.record(torch.cuda.current_stream()) + for i, batch in enumerate(loader): + host_now = time.perf_counter() + permeability = batch["permeability"] + darcy = batch["darcy"] + cuda_end = torch.cuda.Event(enable_timing=True) + cuda_end.record(torch.cuda.current_stream()) + cuda_events.append((cuda_start, cuda_end)) + cuda_start = cuda_end + + host_times.append(host_now - prev_host) + prev_host = host_now + print( + f" batch {i}: permeability {tuple(permeability.shape)} " + f"on {permeability.device}, darcy {tuple(darcy.shape)}, " + f"host_dt={host_times[-1]:.4f}s" + ) + + torch.cuda.synchronize() + cuda_times_ms = [start.elapsed_time(end) for start, end in cuda_events] + epoch_wall = time.perf_counter() - epoch_start + mean_host = sum(host_times) / len(host_times) + mean_cuda = sum(cuda_times_ms) / len(cuda_times_ms) + print( + f" epoch summary: batches={len(host_times)}, wall={epoch_wall:.3f}s, " + f"host_mean={mean_host:.4f}s, cuda_mean={mean_cuda:.2f}ms, " + f"cuda_min={min(cuda_times_ms):.2f}ms, cuda_max={max(cuda_times_ms):.2f}ms" + ) + + # Train as usual; the batches are ordinary device tensors. + + +if __name__ == "__main__": + main() diff --git a/physicsnemo/core/function_spec.py b/physicsnemo/core/function_spec.py index a6bc1a54f8..d076406a4a 100644 --- a/physicsnemo/core/function_spec.py +++ b/physicsnemo/core/function_spec.py @@ -29,6 +29,44 @@ from physicsnemo.core.version_check import check_version_spec +# Cache of Warp stream wrappers keyed by the underlying CUDA stream handle. +# +# ``warp.stream_from_torch`` wraps a torch-owned (external) CUDA stream, and the +# resulting ``warp.Stream`` unregisters that handle from Warp on ``__del__``. +# Creating a fresh wrapper on every launch therefore churns register/unregister +# on a shared stream; unregistering while another wrapper -- or an in-flight +# kernel -- still uses the stream corrupts it (illegal memory access). Keeping +# one long-lived wrapper per handle registers each stream exactly once. +_WARP_STREAM_CACHE: Dict[int, Any] = {} + + +def warp_stream_from_torch(torch_stream: "torch.cuda.Stream") -> Any: + """Return a cached Warp stream wrapping *torch_stream*. + + Wrapping a torch stream registers it with Warp; the wrapper unregisters it + on garbage collection. Caching one wrapper per CUDA stream handle keeps the + registration stable for the process lifetime, which is required when the + same torch stream is bound by nested Warp scopes (e.g. an outer + preprocessing scope and an inner functional launch). + + Parameters + ---------- + torch_stream : torch.cuda.Stream + Torch CUDA stream to wrap. + + Returns + ------- + warp.Stream + Cached Warp stream sharing ``torch_stream``'s underlying CUDA handle. + """ + wp = importlib.import_module("warp") + handle = torch_stream.cuda_stream + cached = _WARP_STREAM_CACHE.get(handle) + if cached is None: + cached = wp.stream_from_torch(torch_stream) + _WARP_STREAM_CACHE[handle] = cached + return cached + @dataclass(frozen=True) class Implementation: @@ -123,7 +161,7 @@ class FunctionSpec: from physicsnemo.core.function_spec import FunctionSpec wp.init() - wp.config.log_level = wp.LOG_WARNING + wp.config.quiet = True @wp.kernel def _identity_kernel( @@ -687,11 +725,13 @@ def warp_launch_context(tensor: torch.Tensor): Warp device and stream. """ try: - wp = importlib.import_module("warp") + importlib.import_module("warp") except ImportError as exc: raise ImportError("warp is not available") from exc if tensor.device.type == "cuda": - stream = wp.stream_from_torch(torch.cuda.current_stream(tensor.device)) + # Reuse a cached wrapper so binding the current torch stream does + # not churn Warp's stream registration (see warp_stream_from_torch). + stream = warp_stream_from_torch(torch.cuda.current_stream(tensor.device)) device = None else: stream = None diff --git a/physicsnemo/datapipes/RNG.md b/physicsnemo/datapipes/RNG.md index f390d544de..da23c2dbea 100644 --- a/physicsnemo/datapipes/RNG.md +++ b/physicsnemo/datapipes/RNG.md @@ -32,6 +32,21 @@ master seed using `fork_generator(parent, n)`. Each child is seeded with and stable across runs. Children are created on the **same device** as the parent. +For RNG that must be reproducible regardless of *execution order* (e.g. +reader subsampling, which runs on a pool of worker threads), `_rng.py` +also provides coordinate-based seeding: + +- **`derive_seed(base_seed, *coords)`** — mixes a base seed with integer + coordinates (typically `epoch` and sample `index`) into a single + well-mixed 64-bit seed via `numpy.random.SeedSequence`. The result + depends only on the inputs, not on call order or thread. +- **`spawn_generator(base_seed, *coords, device=...)`** — returns a fresh + `torch.Generator` seeded with `derive_seed(base_seed, *coords)`. + +Because each call returns an independent generator seeded purely from its +coordinates, draws are reproducible irrespective of order and safe to +compute concurrently from multiple threads (no shared mutable state). + ### DataLoader When `seed` is set the DataLoader: @@ -70,9 +85,12 @@ sub-dataset. ### Epoch reseeding `DataLoader.set_epoch(epoch)` propagates to the sampler and dataset. -Each component with a generator reseeds it with +The sampler and stochastic transforms reseed their generators with `initial_seed() + epoch`, producing a different but deterministic -random sequence every epoch. +random sequence every epoch. Readers instead store the epoch and fold +it into each sample's derived seed (see [Readers](#readers)), so their +per-sample RNG also varies deterministically per epoch without relying +on a shared, sequentially-drawn generator. ## Generator tree @@ -152,18 +170,63 @@ standalone use. ## Readers -The `Reader` base class defines no-op `set_generator` / `set_epoch`. -Readers that use randomness override them: - -| Reader | Randomness | Generator support | +Reader subsampling runs on the dataset's worker-thread pool (the threaded +`prefetch` producer path; `Dataset` defaults to `num_workers=2`), so the +*order* in which samples are drawn is non-deterministic. A single shared, +sequentially-drawn generator would therefore not be reproducible with +`num_workers > 1`. To avoid this, readers derive RNG **per +`(base_seed, epoch, index)`** instead of from one shared stream: + +- **`set_generator(g)`** stores `g.initial_seed()` as the reader's base + seed (it does *not* keep the generator itself). +- **`set_epoch(e)`** stores the epoch. +- Each `reader[index]` then calls + `spawn_generator(base_seed, epoch, index)` to obtain a fresh generator + for that sample's draws (the `Reader` base class exposes this as + `_index_generator(index)`). + +The draw for a given sample depends only on `(base_seed, epoch, index)`, +so it is **identical regardless of read order or worker thread** — +reproducible for any `num_workers` — while still differing across indices +and across epochs. When no seed has been set, the per-sample generator is +`None` and draws fall back to the global default RNG. + +Transforms remain reproducible because they run on the main thread in +sampler order (via the consume stage), so their sequentially-drawn +generators are unaffected by the threaded producer. + +| Reader | Randomness | Per-`(seed, epoch, index)` RNG | |---|---|---| | `MeshReader` | `torch.randint` (contiguous block selection) | Yes | | `DomainMeshReader` | `torch.randint` | Yes | | `NumpyReader` | `torch.randint` (coordinated subsampling) | Yes | | `ZarrReader` | `torch.randint` | Yes | | `TensorStoreZarrReader` | `torch.randint` | Yes | -| `HDF5Reader` | None | No-op (inherited) | -| `VTKReader` | None | No-op (inherited) | +| `HDF5Reader` | None | n/a (inherited base) | +| `VTKReader` | None | n/a (inherited base) | + +## Iterable & descriptor paths: per-`(epoch, position)` seeding + +Map-style datasets have a stable sample `index`, so readers key their +per-sample RNG on `(base_seed, epoch, index)` (see [Readers](#readers)). +Generator-style (`IterableDatasetBase`) and future descriptor-keyed +sources have **no stable index**: samples are produced in sequence with no +addressable position in a corpus. They therefore key on the **monotonic +emission position** within the epoch instead: + +- **map-style:** `derive_seed(base_seed, epoch, index)` — reproducible for + any read order / `num_workers`, since the index is intrinsic to the + sample. +- **iterable / descriptor:** `derive_seed(base_seed, epoch, position)` — + where `position` is a 0-based counter of emissions in the current epoch. + Reproducible across runs and distinct across epochs and positions. + +Both schemes use the same `derive_seed`/`spawn_generator` primitives; only +the coordinate that stands in for "which sample" differs. The iterable +path runs entirely on the main thread in emission order, so the position +counter is unambiguous (there is no worker-thread reordering to defend +against). See `tutorial_5_iterable_online_simulation.py` for a worked +example seeding an online Darcy-flow simulation per `(epoch, position)`. ## Current limitations diff --git a/physicsnemo/datapipes/__init__.py b/physicsnemo/datapipes/__init__.py index d3902c5459..50e4e2aa4a 100644 --- a/physicsnemo/datapipes/__init__.py +++ b/physicsnemo/datapipes/__init__.py @@ -42,7 +42,7 @@ from physicsnemo.datapipes.dataset import Dataset from physicsnemo.datapipes.mesh_dataset import MeshDataset from physicsnemo.datapipes.multi_dataset import MultiDataset -from physicsnemo.datapipes.protocols import DatasetBase +from physicsnemo.datapipes.protocols import DatasetBase, IterableDatasetBase from physicsnemo.datapipes.readers import ( DomainMeshReader, HDF5Reader, @@ -105,6 +105,7 @@ # "TensorDict", # Re-export from tensordict "DatasetBase", + "IterableDatasetBase", "Dataset", "MeshDataset", "DataLoader", diff --git a/physicsnemo/datapipes/_rng.py b/physicsnemo/datapipes/_rng.py index d76a374df0..5e20c2715a 100644 --- a/physicsnemo/datapipes/_rng.py +++ b/physicsnemo/datapipes/_rng.py @@ -23,9 +23,68 @@ from __future__ import annotations +import numpy as np import torch +def derive_seed(base_seed: int, *coords: int) -> int: + """Deterministically mix a base seed with integer coordinates. + + Combines ``base_seed`` with arbitrary integer ``coords`` (typically + ``epoch`` and sample ``index``) into a single well-mixed 64-bit seed + using :class:`numpy.random.SeedSequence`. The result depends only on + the inputs, not on call order or thread, so per-sample RNG derived + from it is reproducible and safe to compute concurrently. + + Parameters + ---------- + base_seed : int + Base seed (e.g. a generator's ``initial_seed()``). + *coords : int + Additional non-negative integer coordinates to fold in, such as + ``(epoch, index)``. + + Returns + ------- + int + A deterministic 64-bit seed. + """ + seq = np.random.SeedSequence([int(base_seed), *(int(c) for c in coords)]) + return int(seq.generate_state(1, dtype=np.uint64)[0]) + + +def spawn_generator( + base_seed: int, + *coords: int, + device: torch.device | str = "cpu", +) -> torch.Generator: + """Create a fresh :class:`torch.Generator` seeded from mixed coordinates. + + Returns an independent generator whose seed is + :func:`derive_seed(base_seed, *coords) `. Because each + call returns a new generator seeded purely from its inputs, draws are + reproducible regardless of execution order and can be made + concurrently from multiple threads without sharing mutable state. + + Parameters + ---------- + base_seed : int + Base seed (e.g. a generator's ``initial_seed()``). + *coords : int + Additional integer coordinates to fold in, such as ``(epoch, index)``. + device : torch.device or str, default="cpu" + Device the generator is created on. + + Returns + ------- + torch.Generator + A new generator seeded deterministically from the inputs. + """ + generator = torch.Generator(device=device) + generator.manual_seed(derive_seed(base_seed, *coords)) + return generator + + def fork_generator( parent: torch.Generator, n: int, diff --git a/physicsnemo/datapipes/dataloader.py b/physicsnemo/datapipes/dataloader.py index 4e6b7bc61f..914e3f34b1 100644 --- a/physicsnemo/datapipes/dataloader.py +++ b/physicsnemo/datapipes/dataloader.py @@ -25,6 +25,8 @@ from __future__ import annotations +import itertools +import warnings from typing import Any, Callable, Iterator, Optional, Sequence import torch @@ -33,7 +35,13 @@ from physicsnemo.datapipes._rng import fork_generator from physicsnemo.datapipes.collate import Collator, get_collator -from physicsnemo.datapipes.protocols import DatasetBase +from physicsnemo.datapipes.io_pump import BATCH_BOUNDARY, IOPump +from physicsnemo.datapipes.protocols import ( + DatasetBase, + IterableDatasetBase, + preprocessing_stream, + record_stream, +) from physicsnemo.datapipes.registry import register @@ -57,6 +65,38 @@ class DataLoader: - Compatible with PyTorch samplers (DistributedSampler, etc.) - Familiar torch DataLoader interface + Two data paths + -------------- + The path is selected by dataset type: + + - **Map-style** (:class:`~physicsnemo.datapipes.protocols.DatasetBase`): + a dispatcher thread (:class:`~physicsnemo.datapipes.io_pump.IOPump`) + lazily submits sample loads to a worker pool and forwards batch + boundaries, while the main thread consumes handles (host-to-device + transfer + transforms on a preprocessing stream). + - **Iterable** (:class:`~physicsnemo.datapipes.protocols.IterableDatasetBase`): + a generator dataset driven main-thread-only (no sampler, no pump, no + worker pool). ``len()`` is undefined and ``shuffle``/``sampler`` are + ignored; generation runs on a preprocessing stream with the same + event handoff so it overlaps training. + + Concurrency model + ----------------- + A dedicated dispatcher thread keeps the I/O pipeline primed by + submitting sample loads ahead of consumption, bounded by + ``prefetch_factor`` batches worth of in-flight samples. The main + thread is the sole consumer: it performs all host-to-device transfers + and GPU transforms (including Warp kernels) on the prefetch streams. + Warp's invariant is the single launching thread, not a single stream, + so transforms run on the assigned preprocessing stream and overlap the + compute stream. + + For the pipeline to stay primed, the main thread must not block: keep + reader output (optionally) pinned (so host-to-device copies are asynchronous) and + avoid host readbacks (``.item()``, ``wp.synchronize()``), data- + dependent shapes, and GIL-bound pure-Python transforms on the launch + path. + Examples -------- >>> from physicsnemo.datapipes import DataLoader, Dataset, HDF5Reader, Normalize @@ -80,7 +120,7 @@ class DataLoader: def __init__( self, - dataset: DatasetBase, + dataset: DatasetBase | IterableDatasetBase, *, batch_size: int = 1, shuffle: bool = False, @@ -104,9 +144,10 @@ def __init__( Parameters ---------- - dataset : DatasetBase - Dataset to load from. Any subclass of :class:`DatasetBase` - (e.g. :class:`Dataset`, :class:`MeshDataset`). + dataset : DatasetBase or IterableDatasetBase + Dataset to load from. A map-style :class:`DatasetBase` + (e.g. :class:`Dataset`, :class:`MeshDataset`) or an + :class:`IterableDatasetBase` generator dataset. batch_size : int, default=1 Number of samples per batch. shuffle : bool, default=False @@ -151,6 +192,9 @@ def __init__( self.num_streams = num_streams self.use_streams = use_streams and torch.cuda.is_available() self._seed = seed + # Iterable (generator) datasets are driven main-thread-only: no + # sampler, no worker-pool prefetch (see _iter_iterable). + self._iterable = isinstance(dataset, IterableDatasetBase) # Build master generator and fork for sampler + dataset sampler_generator: torch.Generator | None = None @@ -163,14 +207,18 @@ def __init__( if hasattr(dataset, "set_generator"): dataset.set_generator(forks[1]) - # Handle sampler - if sampler is not None: + # Handle sampler. Iterable datasets have no indices, so they carry + # no sampler and ignore shuffle. + if self._iterable: + if sampler is not None or shuffle: + warnings.warn( + "shuffle/sampler are ignored for iterable datasets; " + "the generator controls sample order.", + stacklevel=2, + ) + self.sampler = None + elif sampler is not None: self.sampler = sampler - # For DistributedSampler, propagate seed if available - if seed is not None and hasattr(sampler, "seed"): - # DistributedSampler exposes seed as a constructor arg - # but it's read-only; users should pass seed at construction. - pass elif shuffle: self.sampler = RandomSampler(dataset, generator=sampler_generator) else: @@ -179,7 +227,8 @@ def __init__( # Handle collation self.collate_fn = get_collator(collate_fn, collate_metadata=collate_metadata) - # Create CUDA streams for prefetching + # Create CUDA streams: prefetch uses several round-robin streams; the + # iterable path uses the first as its preprocessing stream. self._streams: list[torch.cuda.Stream] = [] if self.use_streams: for _ in range(num_streams): @@ -193,7 +242,15 @@ def __len__(self) -> int: ------- int Number of batches in the dataloader. + + Raises + ------ + TypeError + If the dataset is iterable (generator-style), which has no + defined length. """ + if self._iterable: + raise TypeError("len() is undefined for an iterable (generator) dataset") n_samples = ( len(self.sampler) if hasattr(self.sampler, "__len__") else len(self.dataset) ) @@ -226,8 +283,15 @@ def __iter__( """ Iterate over batches. - Uses stream-based prefetching when enabled to overlap IO, - GPU transfers, and computation. + Uses the self-priming :class:`IOPump` to overlap host-side I/O + (on the dataset's worker threads) with main-thread consumption + whenever ``prefetch_factor > 0``. This threaded producer path is + independent of CUDA streams: when streams are enabled (and CUDA + is available) each prefetched sample is also assigned a stream so + the host-to-device copy and GPU transforms overlap; otherwise the + same path runs with ``stream=None`` (still overlapping disk I/O + with the main thread). Set ``prefetch_factor=0`` for fully + synchronous iteration. Yields ------ @@ -236,7 +300,9 @@ def __iter__( or tuple of (batched TensorDict, list of metadata dicts) if collate_metadata=True. """ - if self.prefetch_factor > 0 and self.use_streams: + if self._iterable: + yield from self._iter_iterable() + elif self.prefetch_factor > 0: yield from self._iter_prefetch() else: yield from self._iter_simple() @@ -256,59 +322,271 @@ def _iter_simple( samples = [self.dataset[idx] for idx in batch_indices] yield self.collate_fn(samples) + def _work_stream(self) -> Iterator[Any]: + """Lazily yield sampler indices delimited by :data:`BATCH_BOUNDARY`. + + Buffers at most one batch of indices (never the whole epoch), so + arbitrarily long samplers stream without up-front materialization. + A boundary is emitted after each full batch; a trailing partial + batch is emitted only when ``drop_last`` is False. + + Yields + ------ + int or object + Sample indices, with :data:`BATCH_BOUNDARY` after each batch. + """ + batch: list[int] = [] + for index in self.sampler: + batch.append(index) + if len(batch) == self.batch_size: + yield from batch + yield BATCH_BOUNDARY + batch = [] + if batch and not self.drop_last: + yield from batch + yield BATCH_BOUNDARY + def _iter_prefetch( self, ) -> Iterator[TensorDict | tuple[TensorDict, list[dict[str, Any]]]]: """ - Iteration with stream-based prefetching. - - Strategy: - - 1. Prefetch `prefetch_factor` batches worth of samples - 2. As we yield batches, prefetch more to keep the pipeline full - 3. Each sample in a batch uses a different stream for overlap + Iteration driven by a self-priming prefetch pump with one-batch lookahead. + + A dedicated dispatcher thread (the :class:`IOPump`) lazily pulls + the index stream and submits sample loads to the dataset's worker + pool, keeping a bounded number of samples in flight regardless of + the consumer's cadence. The main thread is a pure drain loop: it + pulls ready handles in order, runs the per-sample consume step + (host-to-device transfer plus GPU transforms, including Warp, on + the assigned stream), and reassembles batches from the boundary + markers the pump forwards. + + Stream assignment is optional and decoupled from the threaded + producer: when CUDA streams are enabled a stream is round-robined + per sample (so preprocessing overlaps the previous batch's compute); + otherwise dispatch passes ``stream=None`` and the path still + overlaps host-side I/O with main-thread consumption. + + Because dispatch lives off the main thread, the pipeline stays + primed even while the main thread is blocked launching kernels or + running the model. All device-kernel launches happen here, on the + single main thread. + + One-batch lookahead for preprocessing stream overlap + ---------------------------------------------------- + Before yielding batch N, this method eagerly drains the pump for + batch N+1's items and calls ``consume(..., defer_sync=True)`` on + each. ``consume`` enqueues the host-to-device transfer and GPU + transforms on a preprocessing stream asynchronously and *records* a + CUDA event, but -- with ``defer_sync=True`` -- does **not** make the + compute stream wait on it yet. The wait is inserted here, by + :func:`gate_compute_stream`, immediately before the owning batch is + yielded. + + This ordering is the whole point of the lookahead. If ``_consume`` + made the compute stream wait on batch N+1's event during the + lookahead drain (i.e. before yielding batch N), that wait would be + enqueued on the compute stream *ahead* of batch N's model kernels, + so batch N's model would block on batch N+1's preprocessing -- the + opposite of overlap. By deferring the wait, the compute-stream order + becomes ``..., model_{N-1}, wait(prep_N), model_N, ...``: batch N's + preprocessing (already in flight on its own stream) overlaps batch + N-1's compute, and each model only blocks on its own batch's + preprocessing. + + With ``prefetch_factor >= 2`` the IOPump has already dispatched + batch N+1's disk-I/O ahead of time, so the ``future.result()`` + inside ``consume()`` returns immediately and the lookahead drain + adds negligible host-side latency. Yields ------ TensorDict or tuple[TensorDict, list[dict[str, Any]]] Collated batch. """ - # Collect all batches upfront for prefetch planning - all_batches = list(self._generate_batches()) - if not all_batches: - return - - num_prefetch_batches = min(self.prefetch_factor, len(all_batches)) - stream_idx = 0 - - # Start initial prefetch - prefetched_up_to = 0 - for batch_idx in range(num_prefetch_batches): - for sample_idx in all_batches[batch_idx]: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 - prefetched_up_to = batch_idx + 1 - - # Yield batches and prefetch more - for batch_idx, batch_indices in enumerate(all_batches): - # Collect samples (uses prefetched if available) - samples = [self.dataset[idx] for idx in batch_indices] - batch = self.collate_fn(samples) - - # Prefetch next batch if available - next_prefetch_idx = prefetched_up_to - if next_prefetch_idx < len(all_batches): - for sample_idx in all_batches[next_prefetch_idx]: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 - prefetched_up_to += 1 - - yield batch + # Streams are an optional accelerator on top of the threaded + # producer; only assign them when actually available. + use_streams = self.use_streams and len(self._streams) > 0 + + # The compute stream the training loop runs on. Captured once on the + # main thread; the deferred per-batch preprocessing waits are + # enqueued onto it right before each batch is yielded. + compute_stream = torch.cuda.current_stream() if use_streams else None + + # Round-robin a stream per sample at dispatch time (when enabled); + # submit returns a handle the consumer resolves in order. + stream_counter = itertools.count() + + def dispatch(index: int) -> Any: + stream = ( + self._streams[next(stream_counter) % self.num_streams] + if use_streams + else None + ) + return self.dataset.submit(index, stream=stream) + + # Depth = prefetch_factor batches worth of samples kept in flight + # (at least the stream count when streams drive the overlap). + depth = max(self.prefetch_factor * self.batch_size, 1) + if use_streams: + depth = max(depth, self.num_streams) + + pump = IOPump(self._work_stream(), dispatch, depth=depth) + pump_iter = iter(pump) + + def drain_one_batch() -> tuple[list[Any], bool]: + """Consume pump items up to the next BATCH_BOUNDARY. + + Each non-boundary item is passed through + :meth:`~physicsnemo.datapipes.protocols.DatasetBase.consume` + with ``defer_sync=True``, which enqueues the host-to-device copy + and GPU transforms on a preprocessing stream (asynchronously) + and records a CUDA event *without* making the compute stream wait + on it. The wait is inserted later by :func:`gate_compute_stream`. + + Returns + ------- + tuple[list, bool] + ``(samples, has_batch)`` where *has_batch* is ``True`` when + a ``BATCH_BOUNDARY`` was found and ``False`` when the pump + was exhausted without one (only possible for an empty or + fully-consumed source). + """ + samples: list[Any] = [] + for item in pump_iter: + if item is BATCH_BOUNDARY: + return samples, True + samples.append(self.dataset.consume(item, defer_sync=True)) + return samples, False + + def gate_compute_stream(events: list) -> None: + """Make the compute stream wait on a batch's preprocessing events. + + Called immediately before a batch is yielded, so the wait is + ordered *after* the previous batch's model kernels (already + enqueued by the prior iteration's ``yield``). The batch's + preprocessing -- launched on its side stream during the lookahead + drain -- thus overlaps the previous batch's compute, and the + model only blocks on its own batch's preprocessing. + + A no-op when streams are disabled (``events`` is empty). + """ + if compute_stream is not None: + for event in events: + compute_stream.wait_event(event) + + try: + # Prime: consume the first batch's items, enqueueing their + # preprocessing stream work (event recorded, compute-stream wait + # deferred to gate_compute_stream below). + current_samples, has_first = drain_one_batch() + # Per-batch preprocessing events whose compute-stream wait was + # deferred; gate_compute_stream consumes these right before yield. + current_events = self.dataset._pop_events() + if not has_first: + # Source was empty or had no boundary; yield whatever arrived. + if current_samples: + gate_compute_stream(current_events) + yield self.collate_fn(current_samples) + return + + while True: + # Eagerly drain the NEXT batch before yielding the current + # one. This enqueues batch N+1's H2D transfer and GPU + # transforms on the preprocessing streams so they run + # concurrently with the training loop's compute-stream work + # for batch N (forward / backward / optimizer). + next_samples, has_next = drain_one_batch() + next_events = self.dataset._pop_events() + + # Gate the compute stream on the CURRENT batch's preprocessing + # *now* -- after the previous iteration's yield enqueued the + # previous batch's model work, and after the NEXT batch's + # preprocessing was launched above on its own stream. This is + # what lets the next batch's preprocessing overlap this + # batch's compute instead of blocking it. + gate_compute_stream(current_events) + + # Yield the current batch. While the training loop runs, + # the preprocessing streams are already working on the next + # batch. + yield self.collate_fn(current_samples) + + if not has_next: + # Pump exhausted after the lookahead drain; yield the + # final partial batch (empty when drop_last trimmed it) + # then stop. + if next_samples: + gate_compute_stream(next_events) + yield self.collate_fn(next_samples) + break + + current_samples = next_samples + current_events = next_events + finally: + # Stop the dispatcher (handles early break / exhaustion) and + # drop any prefetched-but-unconsumed handles. + pump.stop() + self.dataset.cancel_prefetch() + + def _iter_iterable( + self, + ) -> Iterator[Any]: + """ + Main-thread-only iteration for generator (iterable) datasets. + + There is no worker pool: the dataset's generator runs on the main + thread, so it may freely launch Warp kernels / use streams. Each + item is generated on a preprocessing stream (when streams are + enabled) and handed to the compute stream via a CUDA event, so + generation of the next item can overlap training on the current + one. A generator that forces a host readback simply serializes + itself. + + Two emission modes are supported (see :class:`IterableDatasetBase`): + per-sample items are collated into ``batch_size`` batches (with + ``drop_last`` trimming the trailing partial batch); a self-batching + generator (``yields_batches = True``) has each batch passed through + unchanged. - # Clean up any remaining prefetch state - self.dataset.cancel_prefetch() + Yields + ------ + Any + Collated batches, or generator-produced batches when the + dataset is self-batching. + """ + use_stream = self.use_streams and len(self._streams) > 0 + prep_stream = self._streams[0] if use_stream else None + compute_stream = torch.cuda.current_stream() if use_stream else None + self_batching = getattr(self.dataset, "yields_batches", False) + + iterator = iter(self.dataset) + samples: list[Any] = [] + while True: + # Generate the next item on the preprocessing stream, then order + # it before the compute stream without blocking the host. + with preprocessing_stream(prep_stream): + try: + item = next(iterator) + except StopIteration: + break + if use_stream: + record_stream(item, compute_stream) + event = torch.cuda.Event() + event.record(prep_stream) + compute_stream.wait_event(event) + + if self_batching: + yield item + continue + + samples.append(item) + if len(samples) == self.batch_size: + yield self.collate_fn(samples) + samples = [] + + if not self_batching and samples and not self.drop_last: + yield self.collate_fn(samples) def set_epoch(self, epoch: int) -> None: """ diff --git a/physicsnemo/datapipes/datapipes.md b/physicsnemo/datapipes/datapipes.md index b41ca1d846..0aa1dc9929 100644 --- a/physicsnemo/datapipes/datapipes.md +++ b/physicsnemo/datapipes/datapipes.md @@ -141,106 +141,144 @@ preprocessing. Threads are a natural fit: duplication overhead. - **I/O concurrency** -- the GIL is released during disk reads and CUDA kernel launches, so multiple threads usefully overlap I/O with GPU work. -- **Stream parallelism** -- each prefetched sample is assigned its own - CUDA stream, allowing host-to-device transfers and GPU transforms to - run concurrently with the main training computation. +- **Stream parallelism** -- when enabled, each prefetched sample is + assigned a CUDA stream so its host-to-device transfer can overlap with + the main training computation. -### Thread-pool prefetch +### Producer / consumer split -`DatasetBase` owns a `ThreadPoolExecutor` (configurable via -`num_workers`, default 2). Calling `prefetch(index)` submits the -load-and-transform pipeline to the pool and stashes the `Future`: +Prefetching is split into two stages so that **no device kernels are +launched off the main thread** -- a hard requirement for Warp-based +transforms, which must share the model's single launching thread: -```python -def prefetch(self, index, stream=None): - if index in self._prefetch_futures: - return - executor = self._ensure_executor() - self._prefetch_futures[index] = executor.submit(self._load, index) -``` +- `_load_host` is the **producer**. It runs on a worker thread and does + only thread-safe work: reading, decoding, and staging into pinned host + memory. It returns a `HostPayload`. +- `_consume` is the **consumer**. It runs on whatever thread calls + `__getitem__` (the main thread, in practice) and performs the + host-to-device transfer and device transforms (including Warp kernels). -`__getitem__` pops the `Future` if one exists, otherwise loads -synchronously: +`DatasetBase` owns a `ThreadPoolExecutor` (configurable via +`num_workers`) and exposes a FIFO prefetch primitive. `submit(work_item, +stream=...)` runs only the producer on the pool and returns a +`PrefetchHandle` bundling the future with the stream the consumer should +use; `consume(handle)` resolves it on the calling thread: ```python -def __getitem__(self, index): - future = self._prefetch_futures.pop(index, None) - if future is not None: - return future.result() - return self._load(index) -``` - -This means the DataLoader can keep the next batch loading in background -threads while the current batch is being consumed by the model. +def submit(self, work_item, stream=None): + future = self._executor.submit(self._load_host, work_item) + return PrefetchHandle(future=future, stream=stream) -### CUDA stream overlap - -When GPU execution is available, `Dataset` (and `MeshDataset`) override -`prefetch` to run device transfer and transforms on a caller-supplied -CUDA stream, then record an event for later synchronization: - -```python -def _load_and_transform(self, index, stream=None): - result = _PrefetchResult(index=index) - data, metadata = self.reader[index] # CPU I/O in worker thread - - if stream is not None: - with torch.cuda.stream(stream): - data = data.to(device, non_blocking=True) # H2D on stream - data = self.transforms(data) # GPU transforms on stream - result.event = torch.cuda.Event() - result.event.record(stream) # mark completion - - result.data, result.metadata = data, metadata - return result +def consume(self, handle): + payload = handle.future.result() # re-raises producer errors + return self._consume(payload, handle.stream) # H2D + transforms here ``` -On retrieval, `__getitem__` synchronizes the event before returning: +Correlation is purely by handle identity (FIFO), so work items need not +be hashable, unique, or even integers -- an `int` index is just the +common case. The index-keyed `prefetch(index)` / `__getitem__(index)` +convenience API is a thin layer over `submit`/`consume` for random +access, and is what map-style tests and `MultiDataset` use. + +### Self-priming dispatch (IOPump) + +The threaded producer is driven by `IOPump`, a dedicated dispatcher +thread that keeps a *bounded* number of samples in flight regardless of +the consumer's cadence. It pulls a work-item stream **lazily** (one item +per free backpressure slot, so an arbitrarily long or unbounded source +never materializes up front), calls `submit` for each, and hands the +returned handles back to the main thread in FIFO order. The source +interleaves `BATCH_BOUNDARY` markers between work items; the pump forwards +them in place without consuming a slot, so the consumer reassembles +dynamically-sized batches from the boundaries -- the DataLoader never +builds the epoch's batch list in advance. Because dispatch lives off the +main thread, the pipeline stays primed even while the main thread is busy +launching kernels or running the model. This path is active whenever +`prefetch_factor > 0`; set `prefetch_factor=0` for fully synchronous +iteration. + +### CUDA stream handoff + +CUDA streams are an *optional* accelerator layered on top of the threaded +producer. When `use_streams=True` (and CUDA is available), each sample is +round-robined a **preprocessing stream**. The consumer runs *both* the +host-to-device copy and the transforms on that stream, then hands the +result to the compute stream via a CUDA **event** (never a host +`synchronize()`): ```python -if result.event is not None: - result.event.synchronize() -return result.data, result.metadata +def _consume(self, payload, stream=None): + data = payload.data + if device is not None and stream is not None: + compute_stream = torch.cuda.current_stream() + # Bind torch AND Warp to the preprocessing stream. + with preprocessing_stream(stream): # torch + wp.ScopedStream + data = data.to(device, non_blocking=True) # H2D on prep stream + data = self.transforms(data) # transforms on SAME stream + data.record_stream(compute_stream) # keep memory alive + event = torch.cuda.Event() + event.record(stream) + compute_stream.wait_event(event) # order, no host block + else: + data = self.transforms(data) + return data, payload.metadata ``` -The `DataLoader` owns a pool of `num_streams` CUDA streams (default 4) -and round-robins them across samples. It also maintains a sliding -prefetch window of `prefetch_factor` batches (default 2) ahead of the -current yield position: - -```python -# Prefetch the next batch as we yield the current one -for sample_idx in all_batches[next_prefetch_idx]: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 -``` +**The single launching thread -- not a single stream -- is Warp's real +invariant.** Warp kernels may run on any CUDA stream provided they are +launched from the main thread *and* Warp's current stream matches torch's. +`preprocessing_stream` (in `protocols.py`) binds both via +`wp.ScopedStream(wp.stream_from_torch(stream))`, so transforms (including +Warp mesh-query / BVH kernels) run correctly on the side stream. A +previous `cudaErrorIllegalAddress` here was a torch/Warp stream +*divergence* (data on a side stream, the Warp kernel on Warp's own +stream), not a prohibition on non-default streams; binding both fixes it +and lets GPU preprocessing genuinely overlap training. `record_stream` +keeps the device tensors from being recycled while the compute stream +reads them; the pinned host source is held by the caching host allocator +until the copy completes. ### Concurrency timeline -The diagram below shows how threads and streams overlap for a two-sample -batch with `prefetch_factor=1`: +With everything launched from the main thread, the worker pool, the +preprocessing stream, and the compute stream form a triple buffer: ```text -Main thread Worker 1 Worker 2 Stream 1 Stream 2 - │ │ │ │ │ - ├─prefetch(0,S1)─►│ │ │ │ - ├─prefetch(1,S2)─────────────────────►│ │ │ - │ ├─ Read (I/O) │ │ │ - │ │ ├─ Read (I/O) │ │ - │ ├─ to(device) ─────────────────────────►│ │ - │ ├─ transforms ─────────────────────────►│ │ - │ ├─ event.record() ─────────────────────►│ │ - │ │ ├─ to(device) ─────────────────►│ - │ │ ├─ transforms ─────────────────►│ - │ │ ├─ event.record() ─────────────►│ - ├─ event.synchronize() ×2 │ │ │ - ├─ collate + yield batch │ │ │ - │ │ │ │ │ +Worker pool │ load N+2 ─ load N+1 ... (host I/O + thread-safe CPU work) +Preprocess stream │ H2D + Warp transforms for N+1 +Compute stream │ train N ``` -While the main thread consumes batch N, worker threads are already -loading batch N+1 on different streams. +GPU preprocessing of batch N+1 genuinely overlaps training of batch N on +a separate stream; the two are ordered by a CUDA event, never a host-side +`synchronize`. A transform (or generator) that forces a host readback +simply serializes itself -- a property of that code, not of the pipeline. + +### Two data paths: map/descriptor vs iterable + +The DataLoader selects one of two mutually-exclusive paths by dataset +type: + +- **Preload path (`DatasetBase`)** -- map-style and descriptor-keyed + datasets. Uses the worker pool + `IOPump` described above: workers do + thread-safe host I/O, the main thread consumes handles (H2D + transforms + on the preprocessing stream). This is the path for storage-backed data + addressable by index. +- **Generator path (`IterableDatasetBase`)** -- iterable datasets that + *produce* data (online simulation, procedural samplers, unbounded + streams). Driven **main-thread-only**: no sampler, no pump, no worker + pool. `__iter__` may freely launch Warp kernels and use CUDA streams + (the single-launching-thread invariant holds), and the loader still + drives generation on a preprocessing stream with the same event handoff, + so generation of batch N+1 overlaps training of batch N. + +An iterable dataset yields either per-sample `(data, metadata)` (the +loader collates `batch_size` of them, `drop_last` trims the tail) or, when +`yields_batches = True`, ready-made batches that the loader passes through +unchanged. Iterable datasets have no length: `len(loader)` raises +`TypeError`, and `shuffle`/`sampler` are ignored. See +`examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py` for +a Warp `Darcy2D` online simulation wired through this path. ### Pinned memory @@ -258,8 +296,9 @@ loader.disable_prefetch() # synchronous, single-stream -- easy to debug loader.enable_prefetch() # re-enable after debugging ``` -Setting `use_streams=False` or `prefetch_factor=0` at construction time -also forces synchronous execution. +`use_streams=False` keeps the threaded producer but drops the CUDA +stream handoff (the consumer copies and transforms on the default +stream); `prefetch_factor=0` forces fully synchronous execution. ## RNG and reproducibility diff --git a/physicsnemo/datapipes/dataset.py b/physicsnemo/datapipes/dataset.py index 727503e800..c56a767053 100644 --- a/physicsnemo/datapipes/dataset.py +++ b/physicsnemo/datapipes/dataset.py @@ -31,7 +31,12 @@ from tensordict import TensorDict from physicsnemo.datapipes._rng import fork_generator -from physicsnemo.datapipes.protocols import DatasetBase, _PrefetchResult +from physicsnemo.datapipes.protocols import ( + DatasetBase, + HostPayload, + preprocessing_stream, + record_stream, +) from physicsnemo.datapipes.readers.base import Reader from physicsnemo.datapipes.registry import register from physicsnemo.datapipes.transforms.base import Transform @@ -55,10 +60,20 @@ class Dataset(DatasetBase): Prefetching Model ----------------- - The dataset supports prefetching samples using a thread pool. - When a CUDA stream is provided, GPU operations (device transfer, - GPU transforms) happen on that stream, allowing overlap with - other computation. + The dataset supports prefetching samples using a thread pool. The + work is split into a thread-safe *producer* stage and a main-thread + *consumer* stage: + + - :meth:`_load_host` (producer, worker thread) reads the sample into + host memory (pinned when the reader is configured with + ``pin_memory=True``). It launches no device kernels. + - :meth:`_consume` (consumer, calling thread) performs the + host-to-device transfer and the GPU transforms on the assigned + CUDA stream. + + This keeps all device-kernel launches (notably Warp transforms) on + the consuming thread, which must be the same single thread the model + launches from. >>> # Start prefetching >>> dataset.prefetch(0, stream=stream0) # doctest: +SKIP @@ -242,94 +257,71 @@ def set_epoch(self, epoch: int) -> None: t.set_epoch(epoch) # ------------------------------------------------------------------ - # Stream-aware prefetch (overrides DatasetBase defaults) + # Producer / consumer split (overrides DatasetBase defaults) # ------------------------------------------------------------------ - def _load_and_transform( - self, - index: int, - stream: Optional[torch.cuda.Stream] = None, - ) -> _PrefetchResult: + def _load_host(self, work_item: int) -> HostPayload: """ - Load a sample and apply transforms with optional CUDA stream. + Producer stage: read a sample into host memory. + + Runs on a worker thread and launches no device kernels. Pinning + is the reader's responsibility: construct the reader with + ``pin_memory=True`` to stage tensors in pinned memory so the + consumer's host-to-device copy can be asynchronous. Parameters ---------- - index : int - Sample index. - stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. + work_item : int + Sample index to read from the reader. Returns ------- - _PrefetchResult - PrefetchResult with data, metadata, or error. + HostPayload + Payload carrying the host data and metadata, or a captured + error. """ - result = _PrefetchResult(index=index) - try: - data, metadata = self.reader[index] - - if self.target_device is not None: - if stream is not None: - with torch.cuda.stream(stream): - data = data.to(self.target_device, non_blocking=True) - else: - data = data.to(self.target_device, non_blocking=True) - - if self.transforms is not None: - if stream is not None: - with torch.cuda.stream(stream): - data = self.transforms(data) - result.event = torch.cuda.Event() - result.event.record(stream) - else: - data = self.transforms(data) - - result.data = data - result.metadata = metadata - - except Exception as e: - result.error = e - - return result + data, metadata = self.reader[work_item] + return HostPayload(work_item=work_item, data=data, metadata=metadata) + except Exception as e: # noqa: BLE001 + return HostPayload(work_item=work_item, error=e) - def prefetch( + def _consume( self, - index: int, + payload: HostPayload, stream: Optional[torch.cuda.Stream] = None, - ) -> None: + *, + defer_sync: bool = False, + ) -> tuple[TensorDict, dict[str, Any]]: """ - Start prefetching a sample asynchronously. - - When a CUDA stream is provided, GPU operations (device transfer - and transforms) run on that stream for overlap with computation. + Consumer stage: device transfer + transforms on the calling thread. + + Runs on whatever thread calls this (the main thread, so any Warp + kernels in the transforms share the model's launching thread). When + a CUDA ``stream`` is assigned, the host-to-device copy *and* the + transforms run on that preprocessing stream -- Warp bound to it via + :func:`preprocessing_stream` -- so this sample's preprocessing + overlaps the previous batch's training on the compute stream. The + result is tagged via ``record_stream`` so the caching allocator does + not recycle it while training reads it, and a CUDA event orders the + preprocessing before the compute stream (never a host-side + synchronize). Parameters ---------- - index : int - Sample index to prefetch. + payload : HostPayload + Producer payload from :meth:`_load_host`. stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. - """ - if index in self._prefetch_futures: - return - - executor = self._ensure_executor() - future = executor.submit(self._load_and_transform, index, stream) - self._prefetch_futures[index] = future - - def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: - """ - Get a transformed sample by index. - - If the index was prefetched, returns the prefetched result - (waiting for completion if necessary). Otherwise loads synchronously. - - Parameters - ---------- - index : int - Sample index. + Preprocessing stream for the host-to-device transfer and + transforms. ``None`` runs on the current stream. + defer_sync : bool, default=False + When False, ``compute_stream.wait_event`` is enqueued here so the + result is immediately safe to use on the current stream. When + True, the recorded event is appended to :attr:`_events_pending` + and the DataLoader enqueues the wait just before the batch is + yielded -- after the previous batch's model work -- so this + batch's preprocessing overlaps the previous batch's compute + instead of blocking it. Returns ------- @@ -338,26 +330,45 @@ def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: Raises ------ - IndexError - If index is out of range. Exception - If prefetch failed, re-raises the error. + If the producer captured an error, re-raises it. """ - future = self._prefetch_futures.pop(index, None) + if payload.error is not None: + raise payload.error - if future is not None: - result = future.result() + data = payload.data + metadata = payload.metadata - if isinstance(result, _PrefetchResult): - if result.error is not None: - raise result.error - if result.event is not None: - torch.cuda.current_stream().wait_event(result.event) - return result.data, result.metadata + device_is_cuda = ( + self.target_device is not None + and torch.device(self.target_device).type == "cuda" + ) + use_stream = stream is not None and device_is_cuda + compute_stream = torch.cuda.current_stream() if use_stream else None - return result + with preprocessing_stream(stream if use_stream else None): + if self.target_device is not None: + data = data.to(self.target_device, non_blocking=True) + if self.transforms is not None: + data = self.transforms(data) + + if use_stream: + # Tag the memory so the allocator keeps it alive for the compute + # stream, then record an event marking the preprocessing's + # completion on the prep stream. + record_stream(data, compute_stream) + event = torch.cuda.Event() + event.record(stream) + if defer_sync: + # Defer the compute-stream wait to the DataLoader so it lands + # after the previous batch's model work (real overlap). + self._events_pending.append(event) + else: + # Inline ordering for standalone callers (no DataLoader to + # insert the wait at the right point). + compute_stream.wait_event(event) - return self._load(index) + return data, metadata @property def field_names(self) -> list[str]: diff --git a/physicsnemo/datapipes/io_pump.py b/physicsnemo/datapipes/io_pump.py new file mode 100644 index 0000000000..516ea3c8c8 --- /dev/null +++ b/physicsnemo/datapipes/io_pump.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +IOPump - A self-driving I/O producer that keeps the pipeline primed. + +The pump owns a dedicated dispatcher thread that pulls work items from a +source iterator *lazily* and submits each for background loading, keeping +a *bounded* number of samples in flight at all times. Pulling lazily means +the source may be arbitrarily large (or effectively unbounded): the +dispatcher only advances it when a backpressure slot is free, so memory +stays bounded regardless of source length. + +Dispatch is decoupled from the consumer's cadence: the dispatcher keeps +topping the pipeline off while the main thread is busy launching kernels or +running the model, so a ready sample is (almost) always waiting when the +consumer asks for the next one. + +The pump is agnostic to *how* a sample is produced and consumed: it drives +opaque work items through a user-provided ``dispatch_fn`` (which starts the +background load and returns a handle) and hands those handles back to the +consumer in the exact order they were pulled. Correlation is purely +positional (FIFO), so work items need not be hashable or unique. + +A source may interleave :data:`BATCH_BOUNDARY` markers between work items; +the pump forwards them to the consumer in order without consuming a +backpressure slot, letting the consumer reassemble dynamically-sized +batches without knowing the batch layout up front. +""" + +from __future__ import annotations + +import queue +import threading +from typing import Any, Callable, Iterable, Iterator + +# Public marker a source yields to delimit the end of one batch. A distinct +# sentinel object so it can never collide with a real work item. +BATCH_BOUNDARY = object() + +# Internal marker the dispatcher pushes once the source is exhausted (or the +# pump is stopped), telling the consumer to finish iterating. +_DONE = object() + + +class _PumpError: + """Wraps an exception raised on the dispatcher thread. + + Forwarded through the ready queue so the consumer re-raises it on the + main thread instead of blocking forever waiting for items that will + never arrive. + """ + + __slots__ = ("exc",) + + def __init__(self, exc: BaseException) -> None: + self.exc = exc + + +class IOPump: + """Bounded, self-driving prefetch dispatcher. + + A dedicated dispatcher thread pulls work items from ``source``, + acquires a backpressure slot, calls ``dispatch_fn(work_item)`` to start + the background load, and makes the returned handle available to the + consumer in FIFO order via iteration. Slots are released as the + consumer advances, keeping at most ``depth`` samples in flight. + + Parameters + ---------- + source : Iterable + Work items to load, optionally interleaved with + :data:`BATCH_BOUNDARY` markers. Consumed lazily, one item at a + time, only as backpressure slots free up. + dispatch_fn : Callable[[Any], Any] + Called on the dispatcher thread to start loading a work item (for + example ``dataset.submit(work_item, stream=...)``). It must be + non-blocking and thread-safe and must not launch device kernels; + it returns an opaque handle that the consumer later turns into a + sample. + depth : int + Maximum number of samples dispatched but not yet consumed. Acts as + both the backpressure valve and the jitter buffer that hides + consumer stalls. Clamped to at least 1. + + Notes + ----- + A pump instance is single-consumer. Iterate it with a single thread + (the main/launcher thread). Call :meth:`stop` (or use it as a context + manager) to tear down the dispatcher thread; already-submitted loads + are left to complete and are reaped by the owning dataset. + """ + + def __init__( + self, + source: Iterable[Any], + dispatch_fn: Callable[[Any], Any], + depth: int, + ) -> None: + self._source = source + self._dispatch_fn = dispatch_fn + self._depth = max(1, int(depth)) + self._slots = threading.Semaphore(self._depth) + self._ready_queue: queue.Queue = queue.Queue() + self._stop = threading.Event() + self._thread = threading.Thread( + target=self._run, + name="datapipe_pump", + daemon=True, + ) + self._thread.start() + + # ------------------------------------------------------------------ + # Dispatcher thread + # ------------------------------------------------------------------ + + def _run(self) -> None: + """Dispatcher loop: keep ``depth`` samples in flight, in order.""" + source = iter(self._source) + while not self._stop.is_set(): + try: + item = next(source) + except StopIteration: + break + except BaseException as exc: # noqa: BLE001 + # A failing source must surface to the consumer, not hang it. + self._ready_queue.put(_PumpError(exc)) + return + + if item is BATCH_BOUNDARY: + # Boundaries are bookkeeping, not work: forward without + # consuming a slot. + self._ready_queue.put(BATCH_BOUNDARY) + continue + + # Backpressure: block until the consumer frees a slot. This is + # also where lazy pulling is enforced -- the source is not + # advanced again until there is room in flight. + self._slots.acquire() + if self._stop.is_set(): + break + try: + handle = self._dispatch_fn(item) + except BaseException as exc: # noqa: BLE001 + # A failing dispatch must surface to the consumer, not hang it. + self._ready_queue.put(_PumpError(exc)) + return + self._ready_queue.put(handle) + + self._ready_queue.put(_DONE) + + # ------------------------------------------------------------------ + # Consumer side (single consumer / the main thread) + # ------------------------------------------------------------------ + + def __iter__(self) -> Iterator[Any]: + """Yield ready handles (and batch boundaries) in FIFO order. + + Yields each loaded sample's handle in the order its work item was + pulled, and forwards :data:`BATCH_BOUNDARY` markers in place. + Releases a backpressure slot after each handle is consumed (i.e. + on the next iteration), so the dispatcher can refill the pipeline + as the consumer advances. Returns once the source is exhausted. + + Yields + ------ + object + Either a handle returned by ``dispatch_fn`` or + :data:`BATCH_BOUNDARY`. + """ + while True: + item = self._ready_queue.get() + if item is _DONE: + return + if isinstance(item, _PumpError): + raise item.exc + if item is BATCH_BOUNDARY: + yield BATCH_BOUNDARY + continue + yield item + # Consumer has finished with this sample; free a slot. + self._slots.release() + + # ------------------------------------------------------------------ + # Teardown + # ------------------------------------------------------------------ + + def stop(self) -> None: + """Stop the dispatcher thread and release its resources. + + Idempotent. Unblocks the dispatcher if it is waiting on a slot, + then joins it briefly. In-flight background loads already submitted + via ``dispatch_fn`` are not cancelled; the owning dataset reaps + them. + """ + if self._stop.is_set(): + return + self._stop.set() + # Unblock the dispatcher if it is parked acquiring a slot. + self._slots.release() + self._thread.join(timeout=5.0) + + def __enter__(self) -> "IOPump": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.stop() diff --git a/physicsnemo/datapipes/mesh_dataset.py b/physicsnemo/datapipes/mesh_dataset.py index 6ca740bb4b..2d2e471f7e 100644 --- a/physicsnemo/datapipes/mesh_dataset.py +++ b/physicsnemo/datapipes/mesh_dataset.py @@ -28,8 +28,14 @@ import torch from tensordict import TensorDict +from physicsnemo.datapipes import _timing from physicsnemo.datapipes._rng import fork_generator -from physicsnemo.datapipes.protocols import DatasetBase, _PrefetchResult +from physicsnemo.datapipes.protocols import ( + DatasetBase, + HostPayload, + preprocessing_stream, + record_stream, +) from physicsnemo.datapipes.readers.mesh import DomainMeshReader, MeshReader from physicsnemo.datapipes.registry import register from physicsnemo.datapipes.transforms.mesh.base import MeshTransform @@ -88,12 +94,19 @@ def __init__( device : str or torch.device, optional If set, move mesh data to this device after loading (before transforms). num_workers : int, default=1 - Number of worker threads for prefetching. Defaults to 1 - because mesh transforms construct new Mesh objects internally - and tensordict's ``_device_recorder`` is not safe for - concurrent TensorDict construction across threads. + Number of worker threads for the prefetch pool. Worker threads + run :meth:`_load_host` (disk read + pin_memory) concurrently; + GPU operations (H2D transfer, transforms, Warp kernels) always + run on the main thread in :meth:`_consume`. """ - super().__init__(num_workers=num_workers) + # _load_host (disk read + pin_memory) is host-side only and launches no + # device kernels; all GPU work (H2D transfer, transforms, Triton SDF) + # runs on the main thread in _consume and is ordered via CUDA stream + # events. With the Warp-free (torch/Triton) SDF there is no wp.Mesh + # lifetime race to guard against, so load and consume need not be + # serialized: disabling serialization lets all num_workers threads read + # in parallel and makes prefetch_factor actually scale I/O throughput. + super().__init__(num_workers=num_workers, serialize_load_consume=False) self.reader = reader self.transforms = list(transforms) if transforms else [] self._device = torch.device(device) if isinstance(device, str) else device @@ -165,16 +178,21 @@ def _load( self, index: int ) -> tuple[Union[Mesh, DomainMesh, TensorDict], dict[str, Any]]: """Synchronous load: reader -> device transfer -> transforms.""" - data, metadata = self.reader[index] + with torch.profiler.record_function("MeshDataset._load: reader[index]"): + data, metadata = self.reader[index] if self._device is not None: - data = data.to(self._device) + with torch.profiler.record_function("MeshDataset._load: data.to(device)"): + data = data.to(self._device) for t in self.transforms: - if isinstance(data, DomainMesh): - data = t.apply_to_domain(data) - else: - data = t(data) + with torch.profiler.record_function( + f"MeshDataset._load: transform {type(t).__name__}" + ): + if isinstance(data, DomainMesh): + data = t.apply_to_domain(data) + else: + data = t(data) return data, metadata @@ -182,101 +200,67 @@ def __len__(self) -> int: return len(self.reader) # ------------------------------------------------------------------ - # Stream-aware prefetch (overrides DatasetBase defaults) + # Producer / consumer split (overrides DatasetBase defaults) # ------------------------------------------------------------------ - def _load_and_transform( - self, - index: int, - stream: Optional[torch.cuda.Stream] = None, - ) -> _PrefetchResult: - """Load a sample and apply transforms with optional CUDA stream. + def _load_host(self, work_item: int) -> HostPayload: + """Producer stage: read a mesh sample on a worker thread. + + Launches no device kernels: it only reads the raw sample. Device + transfer and mesh transforms happen later in :meth:`_consume` on + the consuming thread. Parameters ---------- - index : int - Sample index. - stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. + work_item : int + Sample index to read from the reader. Returns ------- - _PrefetchResult - Result with data, metadata, or error. + HostPayload + Payload carrying the host data and metadata, or a captured + error. """ - result = _PrefetchResult(index=index) - try: - data, metadata = self.reader[index] - - if self._device is not None: - if stream is not None: - with torch.cuda.stream(stream): - data = data.to(self._device, non_blocking=True) - else: - data = data.to(self._device, non_blocking=True) - - for t in self.transforms: - if stream is not None: - with torch.cuda.stream(stream): - if isinstance(data, DomainMesh): - data = t.apply_to_domain(data) - else: - data = t(data) - else: - if isinstance(data, DomainMesh): - data = t.apply_to_domain(data) - else: - data = t(data) - - if stream is not None: - result.event = torch.cuda.Event() - result.event.record(stream) - - result.data = data - result.metadata = metadata - - except Exception as e: - result.error = e - - return result + data, metadata = self.reader[work_item] + return HostPayload(work_item=work_item, data=data, metadata=metadata) + except Exception as e: # noqa: BLE001 + return HostPayload(work_item=work_item, error=e) - def prefetch( + def _consume( self, - index: int, + payload: HostPayload, stream: Optional[torch.cuda.Stream] = None, - ) -> None: - """Start prefetching a sample asynchronously. - - When a CUDA stream is provided, GPU operations (device transfer - and transforms) run on that stream for overlap with computation. - - Parameters - ---------- - index : int - Sample index to prefetch. - stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. - """ - if index in self._prefetch_futures: - return - - executor = self._ensure_executor() - future = executor.submit(self._load_and_transform, index, stream) - self._prefetch_futures[index] = future - - def __getitem__( - self, index: int + *, + defer_sync: bool = False, ) -> tuple[Union[Mesh, DomainMesh, TensorDict], dict[str, Any]]: - """Get a transformed sample by index. - - If the index was prefetched, returns the prefetched result - (waiting for completion if necessary). Otherwise loads synchronously. + """Consumer stage: device transfer + transforms on the calling thread. + + Runs on whatever thread calls this (the main thread, so any Warp + mesh-query kernels in the transforms share the model's launching + thread). When a CUDA ``stream`` is assigned, the host-to-device + copy *and* the transforms run on that preprocessing stream -- Warp + bound to it via :func:`preprocessing_stream` -- so this sample's + preprocessing overlaps the previous batch's training on the compute + stream. The result is tagged via ``record_stream`` and a CUDA event + orders the preprocessing before the compute stream (not a host-side + synchronize). Parameters ---------- - index : int - Sample index. + payload : HostPayload + Producer payload from :meth:`_load_host`. + stream : torch.cuda.Stream, optional + Preprocessing stream for the host-to-device transfer and + transforms. ``None`` runs on the current stream. + defer_sync : bool, default=False + When False, ``compute_stream.wait_event`` is enqueued here so the + result is immediately safe to use on the current stream. When + True, the recorded event is appended to :attr:`_events_pending` + and the DataLoader enqueues the wait just before the batch is + yielded -- after the previous batch's model work -- so this + batch's preprocessing overlaps the previous batch's compute + instead of blocking it. Returns ------- @@ -286,23 +270,59 @@ def __getitem__( Raises ------ Exception - If prefetch failed, re-raises the error. + If the producer captured an error, re-raises it. """ - future = self._prefetch_futures.pop(index, None) + if payload.error is not None: + raise payload.error - if future is not None: - result = future.result() + data = payload.data + metadata = payload.metadata + + def _apply_transforms(d: Any) -> Any: + for t in self.transforms: + if isinstance(d, DomainMesh): + d = t.apply_to_domain(d) + else: + d = t(d) + return d - if isinstance(result, _PrefetchResult): - if result.error is not None: - raise result.error - if result.event is not None: - torch.cuda.current_stream().wait_event(result.event) - return result.data, result.metadata + device_is_cuda = ( + self._device is not None and torch.device(self._device).type == "cuda" + ) + use_stream = stream is not None and device_is_cuda + compute_stream = torch.cuda.current_stream() if use_stream else None - return result + with preprocessing_stream(stream if use_stream else None): + if self._device is not None: + with torch.profiler.record_function( + "MeshDataset._consume: data.to(device)" + ): + with _timing.record("consume/h2d"): + data = data.to(self._device, non_blocking=True) + with torch.profiler.record_function( + "MeshDataset._consume: _apply_transforms" + ): + with _timing.record("consume/transforms"): + data = _apply_transforms(data) + + if use_stream: + # Tag the memory so the allocator keeps it alive for the compute + # stream, then record an event marking the preprocessing's + # completion on the prep stream. + record_stream(data, compute_stream) + event = torch.cuda.Event() + event.record(stream) + if defer_sync: + # Defer the compute-stream wait to the DataLoader so it lands + # after the previous batch's model work (real overlap). + self._events_pending.append(event) + else: + # Inline ordering for standalone callers (no DataLoader to + # insert the wait at the right point). + compute_stream.wait_event(event) - return self._load(index) + _timing.tick() + return data, metadata def close(self) -> None: """Close the dataset and stop prefetching. diff --git a/physicsnemo/datapipes/multi_dataset.py b/physicsnemo/datapipes/multi_dataset.py index 2547bf5f95..6fc6a7470a 100644 --- a/physicsnemo/datapipes/multi_dataset.py +++ b/physicsnemo/datapipes/multi_dataset.py @@ -304,6 +304,76 @@ def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: metadata[DATASET_INDEX_METADATA_KEY] = ds_id return data, metadata + def submit(self, index: int, stream: Optional[Any] = None) -> tuple[int, Any]: + """ + Submit a global index for background loading (FIFO prefetch primitive). + + Maps the global index to its owning sub-dataset and delegates to + that dataset's :meth:`~DatasetBase.submit`. The returned handle is + wrapped with the owning dataset id so :meth:`consume` can restore + the ``dataset_index`` metadata. + + Parameters + ---------- + index : int + Global sample index to load. + stream : object, optional + CUDA stream for the consume step. + + Returns + ------- + tuple[int, PrefetchHandle] + ``(dataset_index, handle)`` to pass to :meth:`consume`. + """ + ds_id, local_i = self._index_to_dataset_and_local(index) + handle = self._datasets[ds_id].submit(local_i, stream=stream) + return ds_id, handle + + def consume( + self, handle: tuple[int, Any], *, defer_sync: bool = False + ) -> tuple[TensorDict, dict[str, Any]]: + """ + Resolve a :meth:`submit` handle into ``(data, metadata)``. + + Parameters + ---------- + handle : tuple[int, PrefetchHandle] + The ``(dataset_index, handle)`` returned by :meth:`submit`. + defer_sync : bool, default=False + Forwarded to the owning sub-dataset's + :meth:`~DatasetBase.consume`. When True, the sub-dataset records + its preprocessing event into its own ``_events_pending`` instead + of making the compute stream wait on it; :meth:`_pop_events` + collects those events for the DataLoader. + + Returns + ------- + tuple[TensorDict, dict[str, Any]] + Sample and metadata, enriched with ``dataset_index``. + """ + ds_id, inner = handle + data, metadata = self._datasets[ds_id].consume(inner, defer_sync=defer_sync) + metadata = dict(metadata) + metadata[DATASET_INDEX_METADATA_KEY] = ds_id + return data, metadata + + def _pop_events(self) -> list: + """Aggregate and clear pending preprocessing events across sub-datasets. + + Each sub-dataset records its deferred preprocessing CUDA events on + itself, so the DataLoader retrieves them through the + :class:`MultiDataset` by gathering from every constituent. + + Returns + ------- + list + CUDA events recorded by the sub-datasets since the last pop. + """ + collected: list = [] + for ds in self._datasets: + collected.extend(ds._pop_events()) + return collected + def prefetch( self, index: int, diff --git a/physicsnemo/datapipes/protocols.py b/physicsnemo/datapipes/protocols.py index 69d4363f3b..d5dc61b49e 100644 --- a/physicsnemo/datapipes/protocols.py +++ b/physicsnemo/datapipes/protocols.py @@ -15,16 +15,26 @@ # limitations under the License. """ -Base class for dataset components. +Base classes for dataset components. -Provides :class:`DatasetBase`, an ABC that owns the thread-based prefetch -infrastructure shared by :class:`Dataset`, :class:`MeshDataset`, and any -future dataset implementations. The user-facing extension points are -**Readers** and **Transforms**, not dataset subclasses. +Provides two abstractions consumed by :class:`~physicsnemo.datapipes.DataLoader`: + +- :class:`DatasetBase` -- map-style datasets. Owns the thread-based + prefetch infrastructure (a producer/consumer split plus a FIFO + ``submit``/``consume`` primitive) shared by :class:`Dataset`, + :class:`MeshDataset`, and any future implementation. +- :class:`IterableDatasetBase` -- generator-style datasets that produce + data directly on the main thread (online simulation and other + stream-sensitive workloads). No prefetch, no length, no indexing. + +The user-facing extension points are **Readers** and **Transforms**, not +dataset subclasses. """ from __future__ import annotations +import contextlib +import threading from abc import ABC, abstractmethod from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass, field @@ -32,62 +42,393 @@ import torch +from physicsnemo.core.function_spec import warp_stream_from_torch + +try: + import warp as wp + + _HAS_WARP = True +except ImportError: # pragma: no cover - warp is normally installed + wp = None + _HAS_WARP = False + + +@contextlib.contextmanager +def preprocessing_stream(stream: Optional["torch.cuda.Stream"]): + """Bind torch (and Warp) to *stream* for the host-to-device + transforms. + + Within the block both torch's current stream and -- when Warp is + installed -- Warp's current stream are set to *stream*, so Warp kernels + launched by transforms run on the same stream the data was copied on. + The single launching thread is the real Warp invariant; the stream is + free, but torch and Warp must agree on which one. A ``None`` stream is + a no-op (run on the current stream). + + Parameters + ---------- + stream : torch.cuda.Stream, optional + Stream to bind, or ``None`` to run on the current stream. + """ + if stream is None: + yield + return + with torch.cuda.stream(stream): + if _HAS_WARP: + # Use the cached wrapper: a fresh stream_from_torch per call would + # register/unregister the same CUDA handle on every consume and + # collide with the inner wrapper a Warp functional launch creates + # for the same stream, corrupting it (illegal memory access). + with wp.ScopedStream(warp_stream_from_torch(stream)): + yield + else: + yield + + +def record_stream(obj: Any, stream: "torch.cuda.Stream") -> None: + """Tag *obj*'s device tensors with *stream* for the caching allocator. + + Recurses into ``TensorDict``/tensor/mesh objects (which expose + ``record_stream``) and plain containers, so memory allocated on a + preprocessing stream is not recycled while the compute stream reads + it. Only CUDA tensors are tagged; CPU tensors (``record_stream`` is + unimplemented on the CPU backend) and objects without device memory + are skipped. + + Parameters + ---------- + obj : Any + Item to tag (tensor, TensorDict, mesh, dict, or sequence). + stream : torch.cuda.Stream + Stream that will consume the memory. + """ + if isinstance(obj, torch.Tensor): + # record_stream is a CUDA-only caching-allocator hint; it is + # unimplemented on the CPU backend, so only tag CUDA tensors. + if obj.is_cuda: + obj.record_stream(stream) + return + record = getattr(obj, "record_stream", None) + if callable(record): + device = getattr(obj, "device", None) + device_type = getattr(device, "type", None) + if device_type == "cuda": + obj.record_stream(stream) + elif device_type is None: + # Device-less container (e.g. a mixed-device TensorDict): recurse + # so CUDA leaves are tagged and CPU leaves are skipped. + values = getattr(obj, "values", None) + if callable(values): + for value in values(): + record_stream(value, stream) + # device_type == "cpu": no-op (nothing for the allocator to track) + elif isinstance(obj, dict): + for value in obj.values(): + record_stream(value, stream) + elif isinstance(obj, (list, tuple)): + for value in obj: + record_stream(value, stream) + @dataclass -class _PrefetchResult: - """Result of a stream-aware prefetch operation. +class HostPayload: + """A sample produced by the (thread-safe) I/O stage, staged on the host. - Used by :class:`Dataset` and :class:`MeshDataset` to carry data, - metadata, and an optional CUDA event through the prefetch pipeline. + A ``HostPayload`` is the boundary object between the I/O producer and + the main-thread consumer. It carries a CPU ``TensorDict`` (ideally + pinned, so the subsequent host-to-device copy can be asynchronous) + plus metadata. It is produced by a worker thread, which must not + launch device kernels (in particular Warp kernels). + + Parameters + ---------- + work_item : Any + The work item this payload was produced from -- an ``int`` index + for map-style datasets, or any opaque descriptor for + descriptor-driven sources. + data : Any, optional + Host ``TensorDict`` (or mesh) payload. ``None`` on error. + metadata : dict, optional + Per-sample metadata produced by the reader. + error : Exception, optional + Exception captured during production, re-raised on consumption. """ - index: int + work_item: Any data: Any = None metadata: Optional[dict[str, Any]] = field(default=None) error: Optional[Exception] = field(default=None) - event: Optional[torch.cuda.Event] = field(default=None) + + +@dataclass +class PrefetchHandle: + """Handle to one in-flight prefetch returned by :meth:`DatasetBase.submit`. + + Bundles the producer ``Future`` with the CUDA stream the consumer + should use for the host-to-device copy and transforms. The handle is + correlated to its sample purely by identity/order, so opaque work + items need not be hashable or unique. + + Parameters + ---------- + future : concurrent.futures.Future + Future resolving to the producer's :class:`HostPayload`. + stream : torch.cuda.Stream, optional + Stream assigned for the consume step, or ``None`` for the current + stream. + """ + + future: Future + stream: Optional[torch.cuda.Stream] = None class DatasetBase(ABC): - """Abstract base for datasets compatible with :class:`DataLoader`. + """Abstract base for map-style datasets compatible with :class:`DataLoader`. + + Subclasses implement :meth:`_load` (the synchronous data-loading + pipeline) and :meth:`__len__`. Everything else -- indexing, the + ``submit``/``consume`` prefetch primitive, the index-keyed + ``prefetch``/``__getitem__`` convenience API, cancellation, and + cleanup -- is provided here. - Subclasses implement :meth:`_load` (the actual data-loading pipeline) - and :meth:`__len__`. Everything else — ``__getitem__`` with prefetch - cache lookup, thread-pool prefetching, cancellation, cleanup — is - provided here. + Producer / consumer split + -------------------------- + Prefetching is split into two stages so that **no device kernels are + launched off the main thread** (a hard requirement for Warp + transforms, which must share the model's single launching thread): - Both :class:`Dataset` and :class:`MeshDataset` override - :meth:`prefetch` and :meth:`__getitem__` to add CUDA-stream - support via :class:`_PrefetchResult`. + - :meth:`_load_host` is the **producer**. It runs on a worker thread + and performs only thread-safe work: reading, decoding, and staging + into host memory. It returns a :class:`HostPayload`. + - :meth:`_consume` is the **consumer**. It runs on the thread that + calls :meth:`consume` / :meth:`__getitem__` (the main thread, in + practice) and performs the host-to-device transfer and device + transforms (including Warp kernels) on the assigned CUDA stream. + + :class:`Dataset` and :class:`MeshDataset` override both hooks to + perform the real split. The default implementations fall back to + running the full :meth:`_load` on the worker for any subclass that + does not override them. """ - def __init__(self, *, num_workers: int = 2) -> None: - self._prefetch_futures: dict[int, Future] = {} + def __init__( + self, + *, + num_workers: int = 2, + serialize_load_consume: bool = False, + ) -> None: self._executor: Optional[ThreadPoolExecutor] = None self._num_workers = num_workers + self._lock = threading.Lock() + self._stage_lock = threading.Lock() + self._serialize_load_consume = serialize_load_consume + # Futures still in flight, tracked so close() can drain them. + self._inflight: set[Future] = set() + # Index-keyed handles backing the prefetch()/__getitem__ compat API. + self._prefetch_handles: dict[int, PrefetchHandle] = {} + # Per-sample preprocessing CUDA events recorded by _consume when + # invoked with defer_sync=True. The compute-stream wait on these is + # deferred to the DataLoader, which inserts it right before the batch + # is yielded (i.e. after the previous batch's model work is enqueued) + # so preprocessing genuinely overlaps the prior batch's compute. + # Always accessed on the main thread so no locking is needed. + self._events_pending: list = [] + + def _pop_events(self) -> list: + """Return and clear the pending preprocessing-event list. + + Called by the DataLoader after each ``consume()`` (or batch of + consumes) made with ``defer_sync=True`` to retrieve the CUDA events + that ``_consume`` recorded but did not yet wait on. The DataLoader + inserts ``compute_stream.wait_event`` for each returned event right + before the corresponding batch is yielded. + + Returns + ------- + list + CUDA events recorded during the most recent deferred + ``consume()`` call(s). Empty when none were recorded (e.g. no + stream, CPU target, or ``defer_sync=False``). + """ + lst = self._events_pending + self._events_pending = [] + return lst @abstractmethod def _load(self, index: int) -> tuple[Any, dict[str, Any]]: """Load and return a single sample ``(data, metadata)``. - This is the hook that subclasses must implement. It is called - both synchronously (from ``__getitem__``) and asynchronously - (from the prefetch thread pool). + This is the synchronous, full-pipeline hook that subclasses must + implement. It is called directly from :meth:`__getitem__` when the + index was not prefetched. """ ... @abstractmethod def __len__(self) -> int: ... + # ------------------------------------------------------------------ + # Producer / consumer hooks (overridden by Dataset / MeshDataset) + # ------------------------------------------------------------------ + + def _load_host(self, work_item: Any) -> HostPayload: + """Producer stage: load *work_item* into a :class:`HostPayload`. + + Runs on a worker thread and must not launch device kernels. The + default implementation runs the full :meth:`_load` pipeline for + backward compatibility; subclasses that use device transforms + override this to stop at host staging. + + Parameters + ---------- + work_item : Any + Work item to load (an ``int`` index by default). + + Returns + ------- + HostPayload + Payload carrying the host data and metadata, or a captured + error. + """ + try: + data, metadata = self._load(work_item) + return HostPayload(work_item=work_item, data=data, metadata=metadata) + except Exception as e: # noqa: BLE001 + return HostPayload(work_item=work_item, error=e) + + def _load_host_guarded(self, work_item: Any) -> HostPayload: + if not self._serialize_load_consume: + return self._load_host(work_item) + with self._stage_lock: + return self._load_host(work_item) + + def _consume( + self, + payload: HostPayload, + stream: Optional[torch.cuda.Stream] = None, + *, + defer_sync: bool = False, + ) -> tuple[Any, dict[str, Any]]: + """Consumer stage: turn a :class:`HostPayload` into ``(data, metadata)``. + + Runs on the calling (main) thread. The default implementation + unwraps the payload and re-raises any captured error; the + ``stream`` and ``defer_sync`` arguments are ignored. Subclasses + override this to perform the host-to-device transfer and device + transforms. + + Parameters + ---------- + payload : HostPayload + Producer payload from :meth:`_load_host`. + stream : torch.cuda.Stream, optional + Stream for the consume step. + defer_sync : bool, default=False + When True (and a stream is used), record the preprocessing CUDA + event into :attr:`_events_pending` instead of making the compute + stream wait on it here. The DataLoader then inserts the wait + just before the batch is yielded so preprocessing overlaps the + previous batch's compute. Ignored by this default implementation. + + Returns + ------- + tuple[Any, dict[str, Any]] + The sample data and its metadata. + """ + if payload.error is not None: + raise payload.error + return payload.data, payload.metadata + + def _consume_guarded( + self, + payload: HostPayload, + stream: Optional[torch.cuda.Stream] = None, + *, + defer_sync: bool = False, + ) -> tuple[Any, dict[str, Any]]: + if not self._serialize_load_consume: + return self._consume(payload, stream, defer_sync=defer_sync) + with self._stage_lock: + return self._consume(payload, stream, defer_sync=defer_sync) + + # ------------------------------------------------------------------ + # FIFO prefetch primitive (used by the DataLoader's pump) + # ------------------------------------------------------------------ + + def submit( + self, + work_item: Any, + stream: Optional[torch.cuda.Stream] = None, + ) -> PrefetchHandle: + """Submit *work_item* for background loading and return its handle. + + Only the (thread-safe) producer stage runs on the worker pool. The + returned :class:`PrefetchHandle` is later passed to :meth:`consume` + on the main thread. Safe to call from a dispatcher thread distinct + from the consumer. + + Parameters + ---------- + work_item : Any + Work item to load. + stream : torch.cuda.Stream, optional + Stream the consumer should use for this sample. + + Returns + ------- + PrefetchHandle + Handle bundling the producer future and the assigned stream. + """ + executor = self._ensure_executor() + future = executor.submit(self._load_host_guarded, work_item) + with self._lock: + self._inflight.add(future) + future.add_done_callback(self._discard_inflight) + return PrefetchHandle(future=future, stream=stream) + + def consume( + self, handle: PrefetchHandle, *, defer_sync: bool = False + ) -> tuple[Any, dict[str, Any]]: + """Resolve a :meth:`submit` handle into ``(data, metadata)``. + + Blocks until the producer future is ready, then runs the consumer + stage on the calling thread (host-to-device transfer and device + transforms on the handle's stream). + + Parameters + ---------- + handle : PrefetchHandle + Handle returned by :meth:`submit`. + defer_sync : bool, default=False + Forwarded to :meth:`_consume`. When True, the compute-stream + wait on this sample's preprocessing event is deferred to the + caller (the DataLoader) via :attr:`_events_pending`. Standalone + callers leave this False so the wait is enqueued inline and the + result is immediately safe to use on the current stream. + + Returns + ------- + tuple[Any, dict[str, Any]] + The sample data and its metadata. + """ + payload = handle.future.result() # re-raises producer errors via _consume + return self._consume_guarded(payload, handle.stream, defer_sync=defer_sync) + # ------------------------------------------------------------------ # Concrete interface # ------------------------------------------------------------------ def __getitem__(self, index: int) -> tuple[Any, dict[str, Any]]: - """Return sample *index*, using a prefetched result when available.""" - future = self._prefetch_futures.pop(index, None) - if future is not None: - return future.result() # re-raises on error + """Return sample *index*, using a prefetched result when available. + + When the index was prefetched via :meth:`prefetch`, the pending + handle is consumed on the calling thread (so device transforms run + here, not on the worker). Otherwise the sample is loaded + synchronously. + """ + with self._lock: + handle = self._prefetch_handles.pop(index, None) + if handle is not None: + return self.consume(handle) return self._load(index) def prefetch( @@ -95,32 +436,47 @@ def prefetch( index: int, stream: Optional[torch.cuda.Stream] = None, ) -> None: - """Submit *index* for background loading in a worker thread. + """Start prefetching *index*, retrievable later via :meth:`__getitem__`. + + Index-keyed convenience wrapper around :meth:`submit`. A repeated + prefetch of an in-flight index is a no-op. - The ``stream`` parameter is accepted for interface compatibility - but ignored by the default implementation. :class:`Dataset` - overrides this to run GPU transfers on the given stream. + Parameters + ---------- + index : int + Sample index to prefetch. + stream : torch.cuda.Stream, optional + Stream the consumer should use for this sample. """ - if index in self._prefetch_futures: - return - executor = self._ensure_executor() - self._prefetch_futures[index] = executor.submit(self._load, index) + with self._lock: + if index in self._prefetch_handles: + return + handle = self.submit(index, stream) + with self._lock: + if index in self._prefetch_handles: + return # raced; the extra handle's future is reaped on completion + self._prefetch_handles[index] = handle def cancel_prefetch(self, index: Optional[int] = None) -> None: - """Discard prefetch results (already-running tasks still complete).""" - if index is None: - self._prefetch_futures.clear() - else: - self._prefetch_futures.pop(index, None) + """Discard prefetch handles (already-running tasks still complete).""" + with self._lock: + if index is None: + self._prefetch_handles.clear() + else: + self._prefetch_handles.pop(index, None) def close(self) -> None: """Drain in-flight prefetches and shut down the thread pool.""" - for future in self._prefetch_futures.values(): + with self._lock: + futures = list(self._inflight) + self._prefetch_handles.clear() + for future in futures: try: future.result(timeout=30.0) except Exception: # noqa: BLE001, S110 pass - self._prefetch_futures.clear() + with self._lock: + self._inflight.clear() if self._executor is not None: self._executor.shutdown(wait=True) @@ -130,13 +486,19 @@ def close(self) -> None: # Helpers # ------------------------------------------------------------------ + def _discard_inflight(self, future: Future) -> None: + """Done-callback: drop a finished future from the in-flight set.""" + with self._lock: + self._inflight.discard(future) + def _ensure_executor(self) -> ThreadPoolExecutor: - if self._executor is None: - self._executor = ThreadPoolExecutor( - max_workers=self._num_workers, - thread_name_prefix="datapipe_prefetch", - ) - return self._executor + with self._lock: + if self._executor is None: + self._executor = ThreadPoolExecutor( + max_workers=self._num_workers, + thread_name_prefix="datapipe_prefetch", + ) + return self._executor def __iter__(self) -> Iterator[tuple[Any, dict[str, Any]]]: for i in range(len(self)): @@ -147,3 +509,70 @@ def __enter__(self) -> "DatasetBase": def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() + + +class IterableDatasetBase(ABC): + """Abstract base for generator-style datasets driven on the main thread. + + Unlike :class:`DatasetBase`, an iterable dataset has no length and no + indexing: it produces data by iteration only. The + :class:`~physicsnemo.datapipes.DataLoader` drives it entirely on the + main thread (no worker pool), so :meth:`__iter__` may freely launch + Warp kernels and use CUDA streams -- the property that makes online + simulation safe here but unsafe on the worker-pool preload path. + + Emission modes + -------------- + - **Per-sample** (default, ``yields_batches = False``): :meth:`__iter__` + yields ``(data, metadata)`` for one sample at a time and the loader + collates ``batch_size`` of them. + - **Self-batching** (``yields_batches = True``): :meth:`__iter__` yields + a fully-formed batch already and the loader passes it through + unchanged (``batch_size``/``drop_last`` do not apply). + + Subclasses may optionally implement :meth:`set_epoch` and + :meth:`set_generator` for reproducible seeding. + """ + + # When True, __iter__ yields ready-made batches and the loader does not + # re-collate (e.g. an online simulator that produces a batch per step). + yields_batches: bool = False + + @abstractmethod + def __iter__(self) -> Iterator[Any]: + """Yield samples ``(data, metadata)`` or ready-made batches. + + Returns + ------- + Iterator + Per-sample ``(data, metadata)`` tuples, or full batches when + :attr:`yields_batches` is True. + """ + ... + + def set_epoch(self, epoch: int) -> None: + """Reseed for *epoch* (no-op by default). + + Parameters + ---------- + epoch : int + Current epoch number. + """ + + def set_generator(self, generator: torch.Generator) -> None: + """Seed the dataset's randomness from *generator* (no-op by default). + + Parameters + ---------- + generator : torch.Generator + Parent generator supplied by the DataLoader. + """ + + def __enter__(self) -> "IterableDatasetBase": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def close(self) -> None: + """Release any resources held by the dataset (no-op by default).""" diff --git a/physicsnemo/datapipes/readers/base.py b/physicsnemo/datapipes/readers/base.py index 17ce9b0b86..9ec8f2abe3 100644 --- a/physicsnemo/datapipes/readers/base.py +++ b/physicsnemo/datapipes/readers/base.py @@ -31,6 +31,8 @@ import torch from tensordict import TensorDict +from physicsnemo.datapipes._rng import spawn_generator + logger = logging.getLogger(__name__) @@ -110,6 +112,11 @@ def __init__( self.pin_memory = pin_memory self.include_index_in_metadata = include_index_in_metadata self._coordinated_subsampling_config = coordinated_subsampling + # Base seed + epoch for deterministic per-index RNG. See + # :meth:`_index_generator`. ``None`` means no seed was provided + # (random draws fall back to the global default RNG). + self._seed_base: int | None = None + self._epoch: int = 0 @abstractmethod def _load_sample(self, index: int) -> dict[str, torch.Tensor]: @@ -279,28 +286,61 @@ def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: raise RuntimeError(error_msg) from e def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible random sampling. + """Assign a base seed for reproducible, order-independent sampling. + + Stores ``generator.initial_seed()`` as the base seed. Per-sample + generators are then derived deterministically from + ``(base_seed, epoch, index)`` via :meth:`_index_generator`, so + random draws are reproducible regardless of the order in which + samples are read or which worker thread reads them. - Override in subclasses that use randomness (e.g. subsampling). - The default implementation is a no-op. + Subclasses that use randomness should draw from + :meth:`_index_generator` rather than a shared generator. Readers + with no randomness inherit this harmlessly. Parameters ---------- generator : torch.Generator - Generator to use for random draws. + Generator whose ``initial_seed()`` seeds all per-sample RNG. """ + self._seed_base = generator.initial_seed() def set_epoch(self, epoch: int) -> None: - """Reseed the reader's RNG for a new epoch. + """Set the epoch used to vary per-sample RNG deterministically. - Override in subclasses that use randomness. - The default implementation is a no-op. + The epoch is folded into each sample's derived seed (see + :meth:`_index_generator`), so each epoch produces a different but + reproducible sequence. Parameters ---------- epoch : int Current epoch number. """ + self._epoch = epoch + + def _index_generator(self, index: int) -> torch.Generator | None: + """Return a fresh generator seeded for sample *index* this epoch. + + Derives an independent :class:`torch.Generator` from + ``(base_seed, epoch, index)``. Because the seed depends only on + those values, the draw for a given sample is identical regardless + of read order or worker thread. Returns ``None`` when no base + seed has been set (the unseeded fallback). + + Parameters + ---------- + index : int + Sample index. + + Returns + ------- + torch.Generator or None + A per-sample generator, or ``None`` if no seed was provided. + """ + if self._seed_base is None: + return None + return spawn_generator(self._seed_base, self._epoch, index) def close(self) -> None: """ diff --git a/physicsnemo/datapipes/readers/mesh.py b/physicsnemo/datapipes/readers/mesh.py index 603dcf13bd..9746b19031 100644 --- a/physicsnemo/datapipes/readers/mesh.py +++ b/physicsnemo/datapipes/readers/mesh.py @@ -30,6 +30,8 @@ import torch +from physicsnemo.datapipes import _timing +from physicsnemo.datapipes._rng import spawn_generator from physicsnemo.datapipes.registry import register from physicsnemo.mesh import DomainMesh, Mesh @@ -55,7 +57,8 @@ def _contiguous_block_slice( """ if total <= k: return slice(0, total) - start = torch.randint(0, total - k + 1, (1,), generator=generator).item() + with torch.profiler.record_function("mesh_reader: randint.item() scalar readback"): + start = torch.randint(0, total - k + 1, (1,), generator=generator).item() return slice(start, start + k) @@ -188,7 +191,10 @@ def __init__( self.include_index_in_metadata = include_index_in_metadata self.subsample_n_points = subsample_n_points self.subsample_n_cells = subsample_n_cells - self._subsample_generator: torch.Generator | None = None + # Base seed + epoch for deterministic per-index RNG (see + # :meth:`set_generator`). ``None`` means unseeded. + self._seed_base: int | None = None + self._epoch: int = 0 if not self._root.exists(): raise FileNotFoundError(f"Path not found: {self._root}") @@ -212,38 +218,43 @@ def __len__(self) -> int: return len(self._paths) def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling. + """Assign a base seed for reproducible, order-independent subsampling. Called by :class:`MeshDataset` when the DataLoader provides a - seed. Replaces any previously assigned generator. + seed. Stores ``generator.initial_seed()`` as the base seed; each + sample then derives its own generator from + ``(base_seed, epoch, index)``, so subsampling is reproducible + regardless of read order or worker thread. Parameters ---------- generator : torch.Generator - Generator to use for contiguous block selection. + Generator whose ``initial_seed()`` seeds all per-sample RNG. """ - self._subsample_generator = generator + self._seed_base = generator.initial_seed() def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch. + """Set the epoch used to vary per-sample RNG deterministically. - Produces a different (but deterministic) sequence of contiguous - blocks each epoch when a generator has been assigned via - :meth:`set_generator`. + The epoch is folded into each sample's derived seed, producing a + different (but deterministic) sequence of contiguous blocks each + epoch when a base seed has been assigned via :meth:`set_generator`. """ - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) + self._epoch = epoch def __getitem__(self, index: int) -> tuple[Mesh, dict[str, Any]]: mesh = self._load_sample(index) + generator = ( + None + if self._seed_base is None + else spawn_generator(self._seed_base, self._epoch, index) + ) mesh = _subsample_mesh( mesh, self.subsample_n_cells, self.subsample_n_points, - generator=self._subsample_generator, + generator=generator, ) if self.pin_memory: @@ -286,6 +297,8 @@ def __init__( subsample_n_points: int | None = None, subsample_n_cells: int | None = None, extra_boundaries: dict[str, dict] | None = None, + drop_interior_cells: bool = False, + drop_in_file_boundaries: bool = False, ) -> None: """ Initialize the domain mesh reader. @@ -332,14 +345,41 @@ def __init__( extra_boundaries: stl_geometry: pattern: "*_single_solid.stl.pmsh" + drop_interior_cells : bool, default=False + If True, discard the interior mesh's cell connectivity (and + cell_data) immediately after load, turning it into a point + cloud. This makes ``subsample_n_points`` take the cheap + contiguous-block path instead of the expensive + ``slice_points`` remap (which allocates an ``n_points`` map + and scatter-reads the full cell array from the memmap). Use + for point-based models that consume only ``interior.points`` + and ``interior.point_data`` (e.g. GeoTransolver volume) and + never the interior tet/cell topology. Boundaries are + unaffected, so surface normals etc. still work. + drop_in_file_boundaries : bool, default=False + If True, discard the boundaries stored *in* the DomainMesh + file immediately after load (before subsampling and pinning). + ``extra_boundaries`` are added afterwards and are therefore + unaffected. Use when the model consumes only the interior + (plus any ``extra_boundaries``) and never the in-file + boundaries -- e.g. a volume pipeline whose SDF comes from an + injected STL, where the in-file car-surface boundary would + otherwise be subsampled (an expensive ``slice_points`` remap, + GIL-held, that blocks worker-thread overlap) and pinned every + sample for nothing. """ self._root = Path(path) self._pattern = pattern self.pin_memory = pin_memory self.include_index_in_metadata = include_index_in_metadata + self.drop_interior_cells = drop_interior_cells + self.drop_in_file_boundaries = drop_in_file_boundaries self.subsample_n_points = subsample_n_points self.subsample_n_cells = subsample_n_cells - self._subsample_generator: torch.Generator | None = None + # Base seed + epoch for deterministic per-index RNG (see + # :meth:`set_generator`). ``None`` means unseeded. + self._seed_base: int | None = None + self._epoch: int = 0 self._extra_boundaries = extra_boundaries or {} if not self._root.exists(): @@ -359,54 +399,91 @@ def __len__(self) -> int: return len(self._paths) def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling. + """Assign a base seed for reproducible, order-independent subsampling. Called by :class:`MeshDataset` when the DataLoader provides a - seed. Replaces any previously assigned generator. + seed. Stores ``generator.initial_seed()`` as the base seed; each + sample then derives its own generator from + ``(base_seed, epoch, index)``, so subsampling is reproducible + regardless of read order or worker thread. Parameters ---------- generator : torch.Generator - Generator to use for contiguous block selection. + Generator whose ``initial_seed()`` seeds all per-sample RNG. """ - self._subsample_generator = generator + self._seed_base = generator.initial_seed() def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch. + """Set the epoch used to vary per-sample RNG deterministically. - Produces a different (but deterministic) sequence of contiguous - blocks each epoch when a generator has been assigned via - :meth:`set_generator`. + The epoch is folded into each sample's derived seed, producing a + different (but deterministic) sequence of contiguous blocks each + epoch when a base seed has been assigned via :meth:`set_generator`. """ - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) + self._epoch = epoch def __getitem__(self, index: int) -> tuple[DomainMesh, dict[str, Any]]: - dm = self._load_sample(index) + with _timing.record("producer/domain_load"): + dm = self._load_sample(index) + + # Trim unused data before subsample/pin. Both references are lazy (no + # memmap materialization here): + # - drop_interior_cells: turn the interior into a point cloud so its + # point subsample takes the cheap contiguous-block path instead of a + # full slice_points remap + scattered reads. + # - drop_in_file_boundaries: skip the in-file boundaries entirely so we + # don't subsample (an expensive, GIL-held slice_points remap that + # starves worker-thread overlap) or pin a surface the model ignores. + if (self.drop_interior_cells and dm.interior.n_cells > 0) or ( + self.drop_in_file_boundaries and len(dm.boundary_names) > 0 + ): + interior = dm.interior + if self.drop_interior_cells and interior.n_cells > 0: + interior = Mesh( + points=interior.points, + point_data=interior.point_data, + global_data=interior.global_data, + ) + boundaries = {} if self.drop_in_file_boundaries else dm.boundaries + dm = DomainMesh( + interior=interior, + boundaries=boundaries, + global_data=dm.global_data, + ) if self.subsample_n_cells is not None or self.subsample_n_points is not None: + generator = ( + None + if self._seed_base is None + else spawn_generator(self._seed_base, self._epoch, index) + ) sub_kw = dict( n_cells=self.subsample_n_cells, n_points=self.subsample_n_points, - generator=self._subsample_generator, + generator=generator, ) - dm = DomainMesh( - interior=_subsample_mesh(dm.interior, **sub_kw), - boundaries={ + with _timing.record("producer/interior_subsample"): + interior = _subsample_mesh(dm.interior, **sub_kw) + with _timing.record("producer/boundary_subsample"): + boundaries = { name: _subsample_mesh(dm.boundaries[name], **sub_kw) for name in dm.boundary_names - }, + } + dm = DomainMesh( + interior=interior, + boundaries=boundaries, global_data=dm.global_data, ) # Load extra boundary meshes (full resolution, no subsampling). if self._extra_boundaries: - dm = self._load_extra_boundaries(dm, index) + with _timing.record("producer/stl_load"): + dm = self._load_extra_boundaries(dm, index) if self.pin_memory: - dm = dm.pin_memory() + with _timing.record("producer/pin_memory"): + dm = dm.pin_memory() metadata: dict[str, Any] = { "source_path": str(self._paths[index]), diff --git a/physicsnemo/datapipes/readers/numpy.py b/physicsnemo/datapipes/readers/numpy.py index 2f546d5916..a82e347a7e 100644 --- a/physicsnemo/datapipes/readers/numpy.py +++ b/physicsnemo/datapipes/readers/numpy.py @@ -112,7 +112,6 @@ def __init__( self.default_values = default_values or {} self.file_pattern = file_pattern self.index_key = index_key - self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -168,22 +167,12 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields - def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling.""" - self._subsample_generator = generator - - def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch.""" - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) - def _select_random_sections_from_slice( self, slice_start: int, slice_stop: int, n_points: int, + generator: Optional[torch.Generator] = None, ) -> slice: """ Select a random contiguous slice from a range. @@ -196,6 +185,9 @@ def _select_random_sections_from_slice( Stop index of the available range (exclusive). n_points : int Number of points to sample. + generator : torch.Generator, optional + Per-sample generator for reproducible, order-independent + selection. ``None`` uses the global default RNG. Returns ------- @@ -219,7 +211,7 @@ def _select_random_sections_from_slice( slice_start, slice_stop - n_points + 1, (1,), - generator=self._subsample_generator, + generator=generator, ).item() return slice(start, start + n_points) @@ -228,6 +220,7 @@ def _load_from_npz( npz: np.lib.npyio.NpzFile, index: Optional[int] = None, file_path: Optional[Path] = None, + generator: Optional[torch.Generator] = None, ) -> dict[str, torch.Tensor]: """ Load data from an npz file. @@ -241,6 +234,8 @@ def _load_from_npz( None for directory mode (load entire arrays). file_path : Path, optional Path to the file (for error messages). + generator : torch.Generator, optional + Per-sample generator for reproducible coordinated subsampling. Returns ------- @@ -272,7 +267,7 @@ def _load_from_npz( if field in npz.files: array_shape = npz[field].shape[0] subsample_slice = self._select_random_sections_from_slice( - 0, array_shape, n_points + 0, array_shape, n_points, generator=generator ) break @@ -301,12 +296,17 @@ def _load_from_npz( def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample.""" + # Derive a per-sample generator so coordinated subsampling is + # reproducible regardless of read order or worker thread. + generator = self._index_generator(index) if self._mode == "directory": file_path = self._files[index] with np.load(file_path) as npz: - return self._load_from_npz(npz, index=None, file_path=file_path) + return self._load_from_npz( + npz, index=None, file_path=file_path, generator=generator + ) else: # single - return self._load_from_npz(self._data, index=index) + return self._load_from_npz(self._data, index=index, generator=generator) def __len__(self) -> int: """Return number of samples.""" diff --git a/physicsnemo/datapipes/readers/tensorstore_zarr.py b/physicsnemo/datapipes/readers/tensorstore_zarr.py index 9bc407aea5..5ce3e6021a 100644 --- a/physicsnemo/datapipes/readers/tensorstore_zarr.py +++ b/physicsnemo/datapipes/readers/tensorstore_zarr.py @@ -155,7 +155,6 @@ def __init__( self._user_fields = fields self.default_values = default_values or {} self.group_pattern = group_pattern - self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -235,17 +234,6 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields - def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling.""" - self._subsample_generator = generator - - def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch.""" - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) - def _read_attributes(self, group_path: Path) -> dict[str, Any]: """Read attributes from a Zarr group (v2 or v3).""" store_spec = {"driver": "file", "path": str(group_path)} @@ -274,6 +262,7 @@ def _select_random_sections_from_slice( slice_start: int, slice_stop: int, n_points: int, + generator: Optional[torch.Generator] = None, ) -> slice: """ Select a random contiguous slice from a range. @@ -286,6 +275,9 @@ def _select_random_sections_from_slice( Stop index of the available range (exclusive). n_points : int Number of points to sample. + generator : torch.Generator, optional + Per-sample generator for reproducible, order-independent + selection. ``None`` uses the global default RNG. Returns ------- @@ -309,13 +301,15 @@ def _select_random_sections_from_slice( slice_start, slice_stop - n_points + 1, (1,), - generator=self._subsample_generator, + generator=generator, ).item() return slice(start, start + n_points) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample from a Zarr group using TensorStore.""" group_path = self._groups[index] + # Per-sample generator: reproducible regardless of read order/thread. + generator = self._index_generator(index) # Read attributes (stored as tensors in sample) attributes = self._read_attributes(group_path) @@ -368,7 +362,7 @@ def _load_sample(self, index: int) -> dict[str, torch.Tensor]: if key in stores: array_shape = stores[key].shape[0] subsample_slice = self._select_random_sections_from_slice( - 0, array_shape, n_points + 0, array_shape, n_points, generator=generator ) break diff --git a/physicsnemo/datapipes/readers/zarr.py b/physicsnemo/datapipes/readers/zarr.py index a27cee5447..2d2a70cbd2 100644 --- a/physicsnemo/datapipes/readers/zarr.py +++ b/physicsnemo/datapipes/readers/zarr.py @@ -144,7 +144,6 @@ def __init__( self.group_pattern = group_pattern self._cache_stores = cache_stores self._cached_stores: dict[Path, Any] = {} # Cache for opened zarr stores - self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -206,17 +205,6 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields - def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling.""" - self._subsample_generator = generator - - def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch.""" - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) - def _open_zarr_store(self, path: Path) -> Any: """ Open a zarr store, using cache if enabled. @@ -255,6 +243,7 @@ def _select_random_sections_from_slice( slice_start: int, slice_stop: int, n_points: int, + generator: Optional[torch.Generator] = None, ) -> slice: """ Select a random contiguous slice from a range. @@ -267,6 +256,9 @@ def _select_random_sections_from_slice( Stop index of the available range (exclusive). n_points : int Number of points to sample. + generator : torch.Generator, optional + Per-sample generator for reproducible, order-independent + selection. ``None`` uses the global default RNG. Returns ------- @@ -290,12 +282,14 @@ def _select_random_sections_from_slice( slice_start, slice_stop - n_points + 1, (1,), - generator=self._subsample_generator, + generator=generator, ).item() return slice(start, start + n_points) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample from a Zarr group.""" + # Per-sample generator: reproducible regardless of read order/thread. + generator = self._index_generator(index) if self._single_group_mode: # Single group: index into first dimension of each array group_path = self._groups[0] @@ -339,7 +333,7 @@ def _load_sample(self, index: int) -> dict[str, torch.Tensor]: else: array_shape = root[field].shape[0] subsample_slice = self._select_random_sections_from_slice( - 0, array_shape, n_points + 0, array_shape, n_points, generator=generator ) break diff --git a/physicsnemo/datapipes/transforms/mesh/transforms.py b/physicsnemo/datapipes/transforms/mesh/transforms.py index e64ff609e2..b4479d715e 100644 --- a/physicsnemo/datapipes/transforms/mesh/transforms.py +++ b/physicsnemo/datapipes/transforms/mesh/transforms.py @@ -99,7 +99,8 @@ def __init__( self.vector = vector def __call__(self, mesh: Mesh) -> Mesh: - return mesh.translate(self.vector.to(mesh.points.device)) + with torch.profiler.record_function("TranslateMesh: vector.to(device)"): + return mesh.translate(self.vector.to(mesh.points.device)) def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: """Apply translation to a :class:`DomainMesh`. @@ -453,9 +454,10 @@ def __init__( def __call__(self, mesh: Mesh) -> Mesh: new_gd = mesh.global_data.clone() - new_gd.update( - self._fields.to(device=mesh.points.device, dtype=mesh.points.dtype) - ) + with torch.profiler.record_function("InjectGlobalFields: _fields.to(device)"): + new_gd.update( + self._fields.to(device=mesh.points.device, dtype=mesh.points.dtype) + ) return Mesh( points=mesh.points, cells=mesh.cells, @@ -542,8 +544,11 @@ def __call__(self, mesh: Mesh) -> Mesh: if field_name not in new_td.keys(): continue val = new_td[field_name].float() - mean = stats["mean"].to(dtype=val.dtype, device=val.device) - std = stats["std"].to(dtype=val.dtype, device=val.device) + with torch.profiler.record_function( + "NormalizeMeshFields: stats.to(device)" + ): + mean = stats["mean"].to(dtype=val.dtype, device=val.device) + std = stats["std"].to(dtype=val.dtype, device=val.device) new_td[field_name] = (val - mean) / (std + self._eps) ### `Mesh.copy` is a tensorclass-provided shallow copy: `points`, @@ -589,8 +594,11 @@ def inverse_tensor( dim = 1 if ftype == "scalar" else n_spatial_dims if name in self._stats: stats = self._stats[name] - mean = stats["mean"].to(dtype=tensor.dtype, device=tensor.device) - std = stats["std"].to(dtype=tensor.dtype, device=tensor.device) + with torch.profiler.record_function( + "NormalizeMeshFields.inverse_tensor: stats.to(device)" + ): + mean = stats["mean"].to(dtype=tensor.dtype, device=tensor.device) + std = stats["std"].to(dtype=tensor.dtype, device=tensor.device) out[..., idx : idx + dim] = ( out[..., idx : idx + dim] * (std + self._eps) + mean ) diff --git a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py index e3d3c8461c..3a61667136 100644 --- a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py +++ b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py @@ -215,10 +215,12 @@ def radius_search_impl( input_dtype = points.dtype # Warp supports only fp32, so we have to cast: - if points.dtype != torch.float32: - points = points.to(torch.float32) - if queries.dtype != torch.float32: - queries = queries.to(torch.float32) + with torch.profiler.record_function("radius_search: cast points to fp32"): + if points.dtype != torch.float32: + points = points.to(torch.float32) + with torch.profiler.record_function("radius_search: cast queries to fp32"): + if queries.dtype != torch.float32: + queries = queries.to(torch.float32) # Compute follows data. wp_launch_device, wp_launch_stream = FunctionSpec.warp_launch_context(points) @@ -229,13 +231,18 @@ def radius_search_impl( wp_points_per_b = [] wp_queries_per_b = [] for b in range(B): - pts_b = points[b].contiguous() - qrs_b = queries[b].contiguous() - wp_pts_b = wp.from_torch(pts_b, dtype=wp.vec3) - wp_qrs_b = wp.from_torch(qrs_b, dtype=wp.vec3, return_ctype=True) - grid = wp.HashGrid(dim_x=128, dim_y=128, dim_z=128, device=wp_pts_b.device) - grid.reserve(N_queries) - grid.build(points=wp_pts_b, radius=0.5 * radius) + with torch.profiler.record_function(f"radius_search: contiguous b={b}"): + pts_b = points[b].contiguous() + qrs_b = queries[b].contiguous() + with torch.profiler.record_function(f"radius_search: from_torch b={b}"): + wp_pts_b = wp.from_torch(pts_b, dtype=wp.vec3) + wp_qrs_b = wp.from_torch(qrs_b, dtype=wp.vec3, return_ctype=True) + with torch.profiler.record_function(f"radius_search: HashGrid build b={b}"): + grid = wp.HashGrid( + dim_x=128, dim_y=128, dim_z=128, device=wp_pts_b.device + ) + grid.reserve(N_queries) + grid.build(points=wp_pts_b, radius=0.5 * radius) grids.append(grid) wp_points_per_b.append(wp_pts_b) wp_queries_per_b.append(wp_qrs_b) @@ -323,21 +330,30 @@ def radius_search_impl( # Deterministic output path: always use batched 2D kernel launch # --------------------------------------------------------------- - # Build warp array of grid IDs - grid_ids_tensor = torch.tensor( - [g.id for g in grids], dtype=torch.int64, device=points.device - ) + # Build warp array of grid IDs. + # Construct in pinned host memory first so the H2D copy is + # stream-ordered (non_blocking) rather than a blocking cudaMemcpy. + with torch.profiler.record_function("radius_search: grid_ids host->device"): + _grid_ids_cpu = torch.tensor( + [g.id for g in grids], + dtype=torch.int64, + pin_memory=torch.cuda.is_available(), + ) + grid_ids_tensor = _grid_ids_cpu.to(points.device, non_blocking=True) wp_grid_ids = wp.from_torch( grid_ids_tensor, dtype=wp.uint64, return_ctype=True ) # Convert batched points/queries to warp 2D arrays - wp_points_2d = wp.from_torch( - points.contiguous(), dtype=wp.vec3, return_ctype=True - ) - wp_queries_2d = wp.from_torch( - queries.contiguous(), dtype=wp.vec3, return_ctype=True - ) + with torch.profiler.record_function( + "radius_search: contiguous points/queries 2D" + ): + wp_points_2d = wp.from_torch( + points.contiguous(), dtype=wp.vec3, return_ctype=True + ) + wp_queries_2d = wp.from_torch( + queries.contiguous(), dtype=wp.vec3, return_ctype=True + ) # Allocate outputs with batch dimension indices = torch.full( @@ -407,8 +423,9 @@ def radius_search_impl( pts_out = pts_out.squeeze(0) # Handle the matrix of return values: - pts_out = pts_out.to(input_dtype) - dists_out = dists_out.to(input_dtype) + with torch.profiler.record_function("radius_search: cast outputs to input_dtype"): + pts_out = pts_out.to(input_dtype) + dists_out = dists_out.to(input_dtype) return indices, pts_out, dists_out, num_neighbors diff --git a/test/datapipes/core/test_dataset.py b/test/datapipes/core/test_dataset.py index c6f382725c..f46424559d 100644 --- a/test/datapipes/core/test_dataset.py +++ b/test/datapipes/core/test_dataset.py @@ -742,3 +742,60 @@ def test_gpu_iteration_with_transforms(self, numpy_data_dir): assert data["positions"].device.type == "cuda" torch.cuda.synchronize() + + +class TestDatasetSubsamplingReproducibility: + """Reader subsampling RNG is reproducible across the threaded path. + + Reader subsampling derives its generator per ``(base_seed, epoch, + index)``, so the multi-worker ``prefetch`` path matches synchronous + loading and two seeded runs agree -- even with ``num_workers > 1``. + """ + + def _make_dataset(self, data_dir, seed: int) -> dp.Dataset: + reader = dp.NumpyReader( + data_dir, + file_pattern="sample_*.npz", + fields=["positions", "features"], + coordinated_subsampling={ + "n_points": 50, + "target_keys": ["positions", "features"], + }, + ) + dataset = dp.Dataset(reader, num_workers=2) + dataset.set_generator(torch.Generator().manual_seed(seed)) + return dataset + + def test_prefetch_matches_synchronous(self, numpy_data_dir): + """Threaded prefetch (num_workers=2) matches synchronous loading.""" + indices = list(range(10)) + + sync_ds = self._make_dataset(numpy_data_dir, seed=2024) + expected = {i: sync_ds._load(i)[0]["positions"] for i in indices} + + async_ds = self._make_dataset(numpy_data_dir, seed=2024) + for i in indices: + async_ds.prefetch(i) + actual = {i: async_ds[i][0]["positions"] for i in indices} + async_ds.close() + + for i in indices: + assert torch.equal(actual[i], expected[i]) + + def test_two_seeded_runs_identical(self, numpy_data_dir): + """Two datasets seeded identically yield identical prefetch results.""" + indices = list(range(10)) + + ds_a = self._make_dataset(numpy_data_dir, seed=99) + ds_b = self._make_dataset(numpy_data_dir, seed=99) + for i in indices: + ds_a.prefetch(i) + ds_b.prefetch(i) + + for i in indices: + a = ds_a[i][0]["positions"] + b = ds_b[i][0]["positions"] + assert torch.equal(a, b) + + ds_a.close() + ds_b.close() diff --git a/test/datapipes/core/test_streaming.py b/test/datapipes/core/test_streaming.py new file mode 100644 index 0000000000..0e211671ab --- /dev/null +++ b/test/datapipes/core/test_streaming.py @@ -0,0 +1,617 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the lazy preload path and the iterable (generator) dataset path. + +Stage 1 covers the lazy, FIFO-handle preload path (IOPump laziness, +``BATCH_BOUNDARY`` reassembly, opaque work items, the ``submit``/``consume`` +primitive). Stage 2 covers iterable datasets driven main-thread-only +(finite/capped-infinite generators, ``drop_last``, self-batching +pass-through, reproducibility, and no worker pool). CUDA-guarded tests +exercise stream-bound preprocessing and Warp-on-a-non-default-stream. +""" + +from __future__ import annotations + +import threading +import time + +import numpy as np +import pytest +import torch +from tensordict import TensorDict + +import physicsnemo.datapipes as dp +from physicsnemo.datapipes.io_pump import BATCH_BOUNDARY, IOPump +from physicsnemo.datapipes.protocols import DatasetBase, IterableDatasetBase + +# ============================================================================ +# Stage 1 -- IOPump (lazy, FIFO, batch boundaries) +# ============================================================================ + + +class TestIOPump: + """Tests for the lazy, self-driving prefetch pump.""" + + def test_lazy_bounded_pull_on_infinite_source(self): + """Pump pulls an unbounded source lazily, bounded by depth.""" + pulled: list[int] = [] + + def source(): + i = 0 + while True: + pulled.append(i) + yield i + i += 1 + + depth = 3 + pump = IOPump(source(), lambda x: x, depth=depth) + out = [] + for item in pump: + out.append(item) + if len(out) == 5: + break + pump.stop() + + assert out == [0, 1, 2, 3, 4] + # The dispatcher must not have run far ahead of what was consumed: + # at most consumed + depth + a small slack for the in-flight pull. + assert len(pulled) <= 5 + depth + 2 + + def test_batch_boundary_reassembly_irregular(self): + """Boundaries delimit dynamically-sized batches without slot use.""" + source = [0, 1, BATCH_BOUNDARY, 2, BATCH_BOUNDARY, 3, 4, 5, BATCH_BOUNDARY] + pump = IOPump(iter(source), lambda x: x, depth=2) + + batches: list[list[int]] = [] + current: list[int] = [] + for item in pump: + if item is BATCH_BOUNDARY: + batches.append(current) + current = [] + else: + current.append(item) + pump.stop() + + assert batches == [[0, 1], [2], [3, 4, 5]] + + def test_fifo_order_preserved(self): + """Handles are yielded in the order work items were pulled.""" + pump = IOPump(iter(range(20)), lambda x: x * 10, depth=4) + out = list(pump) + pump.stop() + assert out == [x * 10 for x in range(20)] + + def test_dispatch_error_surfaces_not_hangs(self): + """A dispatch exception is raised on the consumer, never a hang.""" + + def boom(x): + raise RuntimeError("dispatch failed") + + pump = IOPump(iter(range(5)), boom, depth=2) + with pytest.raises(RuntimeError, match="dispatch failed"): + list(pump) + pump.stop() + + def test_source_error_surfaces_not_hangs(self): + """A failing source is raised on the consumer, never a hang.""" + + def source(): + yield 0 + raise ValueError("source failed") + + pump = IOPump(source(), lambda x: x, depth=2) + with pytest.raises(ValueError, match="source failed"): + list(pump) + pump.stop() + + +# ============================================================================ +# Stage 1 -- submit / consume FIFO primitive with opaque work items +# ============================================================================ + + +class _DescriptorDataset(DatasetBase): + """Map-style dataset keyed by an opaque (non-int) descriptor.""" + + def __init__(self): + super().__init__(num_workers=2) + self._store = {"alpha": 1.0, "beta": 2.0, "gamma": 3.0} + + def _load(self, key): + if key == "explode": + raise KeyError("no such key") + return TensorDict({"x": torch.tensor([self._store[key]])}), {"key": key} + + def __len__(self): + return len(self._store) + + +class _StageLockedDataset(DatasetBase): + """Dataset that records whether worker load overlaps consume.""" + + def __init__(self): + super().__init__(num_workers=1, serialize_load_consume=True) + self.load_entries: list[int] = [] + self.consume_started = threading.Event() + self.release_consume = threading.Event() + + def _load(self, index): + return TensorDict({"x": torch.tensor([float(index)])}), {"index": index} + + def _load_host(self, work_item): + self.load_entries.append(work_item) + return super()._load_host(work_item) + + def _consume(self, payload, stream=None, *, defer_sync=False): + self.consume_started.set() + self.release_consume.wait(timeout=5.0) + return super()._consume(payload, stream, defer_sync=defer_sync) + + def __len__(self): + return 2 + + +class TestSubmitConsume: + """Tests for the FIFO submit/consume primitive.""" + + def test_opaque_descriptor_roundtrip(self): + """submit/consume works with non-int, string work items.""" + ds = _DescriptorDataset() + try: + handle = ds.submit("beta") + data, metadata = ds.consume(handle) + assert metadata["key"] == "beta" + assert data["x"].item() == 2.0 + finally: + ds.close() + + def test_submit_consume_fifo_independent_of_value(self): + """Multiple in-flight handles consume to their own results.""" + ds = _DescriptorDataset() + try: + handles = [ds.submit(k) for k in ("alpha", "beta", "gamma")] + keys = [ds.consume(h)[1]["key"] for h in handles] + assert keys == ["alpha", "beta", "gamma"] + finally: + ds.close() + + def test_producer_error_reraised_on_consume(self): + """An error raised in the producer surfaces on consume.""" + ds = _DescriptorDataset() + try: + handle = ds.submit("explode") + with pytest.raises(KeyError): + ds.consume(handle) + finally: + ds.close() + + def test_stage_lock_prevents_load_consume_overlap(self): + """Opt-in stage lock keeps worker loads out of active consume.""" + ds = _StageLockedDataset() + try: + first = ds.submit(0) + first.future.result(timeout=5.0) + + consumer = threading.Thread(target=ds.consume, args=(first,)) + consumer.start() + assert ds.consume_started.wait(timeout=5.0) + + second = ds.submit(1) + time.sleep(0.1) + assert ds.load_entries == [0] + + ds.release_consume.set() + consumer.join(timeout=5.0) + second.future.result(timeout=5.0) + assert ds.load_entries == [0, 1] + finally: + ds.release_consume.set() + ds.close() + + +# ============================================================================ +# Stage 1 -- DataLoader laziness over the sampler +# ============================================================================ + + +class _CountingSampler: + """Sequential sampler that records how many indices it has yielded.""" + + def __init__(self, n): + self.n = n + self.consumed = 0 + + def __iter__(self): + self.consumed = 0 + for i in range(self.n): + self.consumed += 1 + yield i + + def __len__(self): + return self.n + + +class TestDataLoaderLazyPreload: + """The preload path must not materialize the whole epoch up front.""" + + def test_sampler_not_fully_drained_on_early_break(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + sampler = _CountingSampler(10) + # Keep the prefetch window small (single stream, one batch ahead) so + # the bound is well under the epoch; with the default num_streams the + # in-flight depth would cover this tiny 10-sample dataset entirely. + loader = dp.DataLoader( + dataset, + batch_size=2, + sampler=sampler, + prefetch_factor=1, + num_streams=1, + ) + + first = next(iter(loader)) + assert first["positions"].shape[0] == 2 + # Only a bounded prefix of the sampler should have been consumed, + # never the full epoch, after pulling a single batch. + assert sampler.consumed < 10 + + def test_preload_matches_sequential_order(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader( + dataset, batch_size=3, shuffle=False, collate_metadata=True + ) + indices = [] + for _batch, metadata_list in loader: + indices.extend(m["index"] for m in metadata_list) + assert indices == list(range(10)) + + +# ============================================================================ +# Stage 2 -- iterable datasets +# ============================================================================ + + +class _RangeIterable(IterableDatasetBase): + """Finite per-sample generator yielding (TensorDict, metadata).""" + + def __init__(self, n, dim=4): + self.n = n + self.dim = dim + + def __iter__(self): + for i in range(self.n): + data = TensorDict({"x": torch.full((self.dim,), float(i))}) + yield data, {"index": i} + + +class _BatchIterable(IterableDatasetBase): + """Self-batching generator yielding ready-made batches.""" + + yields_batches = True + + def __init__(self, n_batches, batch=4, dim=3): + self.n_batches = n_batches + self.batch = batch + self.dim = dim + + def __iter__(self): + for b in range(self.n_batches): + yield TensorDict( + {"x": torch.full((self.batch, self.dim), float(b))}, + batch_size=[self.batch], + ) + + +class _SeededIterable(IterableDatasetBase): + """Per-(epoch, position) seeded generator for reproducibility tests.""" + + def __init__(self, n, base_seed=0): + self.n = n + self.base_seed = base_seed + self.epoch = 0 + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + for position in range(self.n): + seed = int( + np.random.SeedSequence( + [self.base_seed, self.epoch, position] + ).generate_state(1)[0] + ) + g = torch.Generator().manual_seed(seed) + yield TensorDict({"x": torch.rand(3, generator=g)}), {"position": position} + + +class _ThreadRecordingIterable(IterableDatasetBase): + """Records which thread the generator runs on.""" + + def __init__(self, n): + self.n = n + self.threads = [] + + def __iter__(self): + for i in range(self.n): + self.threads.append(threading.current_thread()) + yield TensorDict({"x": torch.zeros(2)}), {"index": i} + + +class TestIterableDataLoader: + """Tests for the main-thread-only iterable (generator) path.""" + + def test_per_sample_batching(self): + loader = dp.DataLoader(_RangeIterable(10), batch_size=4) + batches = list(loader) + # 10 samples / 4 -> [4, 4, 2] + assert [b["x"].shape[0] for b in batches] == [4, 4, 2] + + def test_per_sample_drop_last(self): + loader = dp.DataLoader(_RangeIterable(10), batch_size=4, drop_last=True) + batches = list(loader) + # Trailing partial batch dropped -> [4, 4] + assert [b["x"].shape[0] for b in batches] == [4, 4] + + def test_self_batching_passthrough(self): + # The loader batch_size is intentionally different from the generator's + # to prove it is ignored for self-batching datasets. + loader = dp.DataLoader(_BatchIterable(3, batch=5), batch_size=2) + batches = list(loader) + assert len(batches) == 3 + assert all(b["x"].shape[0] == 5 for b in batches) + + def test_len_raises_for_iterable(self): + loader = dp.DataLoader(_RangeIterable(10), batch_size=2) + with pytest.raises(TypeError): + len(loader) + + def test_capped_infinite_consumes_without_length(self): + """A long generator is iterated batch-by-batch; len() is never used.""" + + class _BigIterable(IterableDatasetBase): + def __iter__(self): + i = 0 + while i < 10_000: + yield TensorDict({"x": torch.zeros(2)}), {"index": i} + i += 1 + + loader = dp.DataLoader(_BigIterable(), batch_size=4) + seen = 0 + for _batch in loader: + seen += 1 + if seen == 3: + break + assert seen == 3 + + def test_shuffle_warns_for_iterable(self): + with pytest.warns(UserWarning, match="ignored for iterable"): + dp.DataLoader(_RangeIterable(4), batch_size=2, shuffle=True) + + def test_reproducible_across_runs_distinct_across_epochs(self): + loader = dp.DataLoader(_SeededIterable(6), batch_size=3) + + loader.set_epoch(0) + run_a = torch.cat([b["x"].reshape(-1) for b in loader]) + loader.set_epoch(0) + run_b = torch.cat([b["x"].reshape(-1) for b in loader]) + loader.set_epoch(1) + run_c = torch.cat([b["x"].reshape(-1) for b in loader]) + + assert torch.equal(run_a, run_b) # same epoch -> identical + assert not torch.equal(run_a, run_c) # different epoch -> distinct + + def test_runs_on_main_thread_no_worker_pool(self): + dataset = _ThreadRecordingIterable(4) + + names_before = {t.name for t in threading.enumerate()} + loader = dp.DataLoader(dataset, batch_size=2) + _ = list(loader) + names_after = {t.name for t in threading.enumerate()} + + # Generation happened on the main thread only. + assert dataset.threads, "generator did not run" + assert all(t is threading.main_thread() for t in dataset.threads) + # No prefetch worker pool / pump thread was spawned for this path. + new_threads = names_after - names_before + assert not any( + n.startswith("datapipe_prefetch") or n == "datapipe_pump" + for n in new_threads + ) + + +# ============================================================================ +# CUDA-guarded -- stream-bound consume and Warp-on-non-default-stream +# ============================================================================ + + +class TestStreamBoundConsume: + """Preprocessing on an assigned stream (the default-stream workaround + is gone, so transforms run on the side stream).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_submit_consume_on_side_stream(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset( + reader, + device="cuda:0", + transforms=dp.SubsamplePoints( + input_keys=["positions", "features"], n_points=50 + ), + ) + try: + stream = torch.cuda.Stream() + handle = dataset.submit(0, stream=stream) + data, _metadata = dataset.consume(handle) + torch.cuda.synchronize() + assert data["positions"].device.type == "cuda" + assert data["positions"].shape[0] == 50 + finally: + dataset.close() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_dataloader_streams_match_synchronous(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + + ref = dp.Dataset(reader, device="cuda:0") + ref_loader = dp.DataLoader(ref, batch_size=2, shuffle=False, prefetch_factor=0) + expected = [b["positions"].sum().item() for b in ref_loader] + + reader2 = dp.NumpyReader(numpy_data_dir, pin_memory=True) + streamed = dp.Dataset(reader2, device="cuda:0") + loader = dp.DataLoader( + streamed, + batch_size=2, + shuffle=False, + prefetch_factor=2, + num_streams=4, + use_streams=True, + ) + got = [b["positions"].sum().item() for b in loader] + torch.cuda.synchronize() + assert got == pytest.approx(expected, rel=1e-5) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_prefetch_defers_wait_until_after_previous_batch( + self, numpy_data_dir, monkeypatch + ): + """The compute-stream wait for batch N+1 must be enqueued *after* + batch N is yielded, not during the lookahead consume of N+1. + + This is the overlap-correctness invariant: if the wait for the next + batch's preprocessing landed on the compute stream before the current + batch's model work, the model would block on the next batch's + preprocessing (no overlap). With the deferred wait, the compute-stream + order is ``..., model_{N-1}, wait(prep_N), model_N, ...`` so a batch's + preprocessing overlaps the previous batch's compute. + + We spy on ``Stream.wait_event`` and record the interleaving of waits + and yields: the wait for batch 0 must precede yielding batch 0, while + the wait for batch 1 must come *after* batch 0 is yielded. + """ + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset(reader, device="cuda:0") + loader = dp.DataLoader( + dataset, + batch_size=1, + shuffle=False, + prefetch_factor=2, + num_streams=4, + use_streams=True, + ) + + order: list[str] = [] + stream_cls = type(torch.cuda.current_stream()) + real_wait = stream_cls.wait_event + + def spy_wait(self, event): + order.append("wait") + return real_wait(self, event) + + monkeypatch.setattr(stream_cls, "wait_event", spy_wait) + + try: + for i, _batch in enumerate(loader): + order.append(f"yield{i}") + if i >= 2: + break + finally: + torch.cuda.synchronize() + + wait_indices = [k for k, ev in enumerate(order) if ev == "wait"] + assert len(wait_indices) >= 2, f"expected >=2 waits, got order={order}" + yield0_idx = order.index("yield0") + # Batch 0's preprocessing is gated before batch 0 is handed out. + assert wait_indices[0] < yield0_idx, order + # Batch 1's preprocessing wait is deferred until after batch 0 is + # yielded (the regression guard: the old code waited inline during the + # lookahead consume of batch 1, before batch 0 was ever yielded). + assert wait_indices[1] > yield0_idx, order + + +class TestWarpIterableOnStream: + """Warp launches on a non-default stream from the main thread are safe.""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_darcy_online_simulation_through_iterable_loader(self): + from physicsnemo.datapipes.benchmarks.darcy import Darcy2D + + class _DarcyIterable(IterableDatasetBase): + yields_batches = True + + def __init__(self, num_batches): + self._sim = Darcy2D(resolution=32, batch_size=2, device="cuda") + self._num_batches = num_batches + + def __iter__(self): + sim_iter = iter(self._sim) + for _ in range(self._num_batches): + yield next(sim_iter) + + loader = dp.DataLoader(_DarcyIterable(2), use_streams=True) + batches = list(loader) + torch.cuda.synchronize() # surfaces any illegal-memory-access + assert len(batches) == 2 + for batch in batches: + assert batch["permeability"].device.type == "cuda" + assert batch["darcy"].device.type == "cuda" + + +class TestWarpFunctionalTransformOnStreams: + """A Warp ``FunctionSpec`` transform driven through the multi-stream + preload path. The functional binds the current torch stream as a Warp + stream internally; the loader binds the same stream around the consume. + Both must reuse one cached wrapper -- otherwise the inner wrapper + unregisters the shared stream on teardown and the next launch faults + (illegal memory access).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_functional_warp_transform_multi_stream(self, numpy_data_dir): + from physicsnemo.datapipes.transforms.base import Transform + from physicsnemo.nn.functional import signed_distance_field + + class _SDFTransform(Transform): + """Evaluate an SDF (a Warp functional) against the sample points.""" + + def __call__(self, data): + points = data["positions"].reshape(-1, 3).float() + vertices = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + device=points.device, + ) + faces = torch.tensor([[0, 1, 2]], device=points.device) + sdf, _ = signed_distance_field(vertices, faces, points) + data["sdf"] = sdf.reshape(-1, 1) + return data + + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset(reader, device="cuda:0", transforms=_SDFTransform()) + loader = dp.DataLoader( + dataset, + batch_size=1, + shuffle=False, + prefetch_factor=2, + num_streams=4, + use_streams=True, + ) + # Iterate well past num_streams so every stream is reused at least + # once; a churned registration faults on the second pass. + batches = list(loader) + torch.cuda.synchronize() # surfaces any illegal-memory-access + assert len(batches) == 10 + for batch in batches: + assert batch["sdf"].device.type == "cuda" diff --git a/test/datapipes/readers/test_numpy_consolidated.py b/test/datapipes/readers/test_numpy_consolidated.py index c564ef5534..884712f8ce 100644 --- a/test/datapipes/readers/test_numpy_consolidated.py +++ b/test/datapipes/readers/test_numpy_consolidated.py @@ -239,6 +239,110 @@ def test_supports_coordinated_subsampling(self): assert reader_with_config._coordinated_subsampling_config is not None +class TestNumpyReaderSubsamplingRNG: + """Order- and thread-independent reproducibility of subsampling RNG. + + Reader subsampling derives its generator from ``(base_seed, epoch, + index)``, so a given sample's draw is identical regardless of read + order or worker thread (the threaded producer path), and varies + deterministically per index and per epoch. + """ + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + self.n_samples = 6 + self.n_points = 1000 + self.subsample_points = 100 + for i in range(self.n_samples): + coords = np.random.randn(self.n_points, 3).astype(np.float32) + features = np.random.randn(self.n_points, 4).astype(np.float32) + np.savez( + self.temp_path / f"sample_{i:03d}.npz", + coords=coords, + features=features, + ) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _make_reader(self) -> NumpyReader: + return NumpyReader( + self.temp_path, + file_pattern="sample_*.npz", + fields=["coords", "features"], + coordinated_subsampling={ + "n_points": self.subsample_points, + "target_keys": ["coords", "features"], + }, + ) + + def test_subsample_order_independent(self): + """Reading an index gives the same result regardless of read order.""" + reader = self._make_reader() + reader.set_generator(torch.Generator().manual_seed(123)) + + forward = {i: reader[i][0]["coords"] for i in range(self.n_samples)} + + reader.set_generator(torch.Generator().manual_seed(123)) + reverse = {i: reader[i][0]["coords"] for i in reversed(range(self.n_samples))} + + for i in range(self.n_samples): + assert torch.equal(forward[i], reverse[i]) + + def test_subsample_thread_independent(self): + """Concurrent reads match single-threaded reads for each index.""" + from concurrent.futures import ThreadPoolExecutor + + reader = self._make_reader() + reader.set_generator(torch.Generator().manual_seed(7)) + + serial = {i: reader[i][0]["coords"] for i in range(self.n_samples)} + + indices = list(range(self.n_samples)) * 3 + with ThreadPoolExecutor(max_workers=4) as pool: + results = list(pool.map(lambda i: (i, reader[i][0]["coords"]), indices)) + + for i, coords in results: + assert torch.equal(coords, serial[i]) + + def test_subsample_distinct_across_indices(self): + """Different indices yield different subsamples under one seed.""" + reader = self._make_reader() + reader.set_generator(torch.Generator().manual_seed(0)) + + a = reader[0][0]["coords"] + b = reader[1][0]["coords"] + assert not torch.equal(a, b) + + def test_subsample_epoch_changes_output(self): + """Epoch is folded into the seed, changing the draw deterministically.""" + reader = self._make_reader() + + reader.set_generator(torch.Generator().manual_seed(0)) + reader.set_epoch(0) + e0 = reader[0][0]["coords"] + + reader.set_generator(torch.Generator().manual_seed(0)) + reader.set_epoch(1) + e1 = reader[0][0]["coords"] + assert not torch.equal(e0, e1) + + # Re-deriving epoch 0 reproduces the original draw. + reader.set_generator(torch.Generator().manual_seed(0)) + reader.set_epoch(0) + assert torch.equal(reader[0][0]["coords"], e0) + + def test_unseeded_does_not_raise(self): + """Without a seed, subsampling falls back to the global RNG.""" + reader = self._make_reader() + data, _ = reader[0] + assert data["coords"].shape == (self.subsample_points, 3) + + class TestNumpyReaderMemoryManagement: """Test memory management and cleanup.""" diff --git a/test/datapipes/transforms/test_mesh_augmentations.py b/test/datapipes/transforms/test_mesh_augmentations.py index 90ddedf89a..b04b55fa3f 100644 --- a/test/datapipes/transforms/test_mesh_augmentations.py +++ b/test/datapipes/transforms/test_mesh_augmentations.py @@ -732,14 +732,15 @@ def test_mesh_dataset_set_generator_distributes(self, tmp_path): master = torch.Generator().manual_seed(42) ds.set_generator(master) - # Reader and both transforms should have received generators - assert reader._subsample_generator is not None + # Reader (base seed) and both transforms should have received seeds + assert reader._seed_base is not None assert transforms[0]._generator is not None assert transforms[1]._generator is not None - # Generators should have different seeds (independent forks) + # Seeds should differ (independent forks): the reader's base seed + # plus each transform's generator seed. seeds = { - reader._subsample_generator.initial_seed(), + reader._seed_base, transforms[0]._generator.initial_seed(), transforms[1]._generator.initial_seed(), } From 98df143d61010b582661e152d62fb8e5edda96af Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 22 Jun 2026 21:42:15 +0000 Subject: [PATCH 02/10] Update datasets for streaming data, simulation-like datasets, and aggressive IO prefetching --- .../tutorial_5_iterable_online_simulation.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py b/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py index 09c14d9e90..cbd829d85a 100644 --- a/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py +++ b/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py @@ -124,12 +124,14 @@ def __iter__(self): def main() -> None: - """Run the online Darcy simulation through the PhysicsNeMo ``DataLoader``. - - Builds the iterable :class:`DarcyOnlineDataset`, drives it for a few - epochs with stream overlap enabled, and prints per-batch and per-epoch - host/CUDA timing so the generation/compute overlap is visible. No-ops - on machines without CUDA, since the Warp Darcy solver requires a GPU. + """Run the online Darcy simulation over several epochs and report timings. + + Builds an iterable :class:`DarcyOnlineDataset`, wraps it in a stream-overlapped + ``DataLoader``, and iterates for ``num_epochs``. Each epoch is reseeded via + ``set_epoch`` for a distinct, reproducible batch stream. For every batch we + record host wall-clock time and per-step CUDA event timings, then print a + per-epoch summary. Requires a CUDA device for the Warp Darcy solver; the + function returns early with a message if none is available. """ if not torch.cuda.is_available(): print("This tutorial requires a CUDA device (Warp Darcy solver). Skipping.") From 65cd32b168dd9f0316133239c0ca9b5d8443b8a8 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Fri, 26 Jun 2026 21:58:06 -0500 Subject: [PATCH 03/10] Cleaning up datapipes pr further --- .../unified_external_aero_recipe/src/loss.py | 2 +- .../src/nondim.py | 13 +- physicsnemo/datapipes/dataloader.py | 21 ++-- physicsnemo/datapipes/dataset.py | 24 ++-- physicsnemo/datapipes/io_pump.py | 6 +- physicsnemo/datapipes/mesh_dataset.py | 37 ++---- physicsnemo/datapipes/protocols.py | 112 ++---------------- .../datapipes/transforms/mesh/transforms.py | 54 +++++---- physicsnemo/mesh/transformations/geometric.py | 3 +- test/datapipes/core/test_streaming.py | 49 -------- 10 files changed, 88 insertions(+), 233 deletions(-) diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py index b558d9c357..bc5623fc00 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py @@ -114,7 +114,7 @@ def _vector_loss( target_sq = torch.mean(target**2, dim=tuple(range(pred.ndim - 1))) return torch.sum(diff_sq / (target_sq + eps)) - total = torch.tensor(0.0, device=pred.device, dtype=pred.dtype) + total = torch.zeros((), device=pred.device, dtype=pred.dtype) for i in range(n_components): p, t = pred[..., i], target[..., i] if loss_type == "huber": diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py index 15624e6a89..7f5f8bd93c 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py @@ -238,19 +238,16 @@ def _transform_mesh( T_inf=T_inf, ) - ### `Mesh.copy` is a tensorclass-provided shallow copy: `points`, - ### `cells`, the untouched associations, and the geometric `_cache` - ### are all shared with `mesh`; only the cloned association is swapped. + # Shallow copy shares everything except the swapped association. new_mesh = mesh.copy() # ty: ignore[unresolved-attribute] setattr(new_mesh, self._association, new_td) - ### Scale geometry into nondim space (`x* = x / L_ref`) on the - ### forward pass, and back to physical units (`x = x* * L_ref`) - ### on the inverse. `Mesh.scale` propagates `_cache` through the - ### linear transform. + # Scale geometry to/from nondim space (x* = x / L_ref). + # assume_invertible=True avoids a per-mesh sync from the det check. if L_ref is not None: + torch._assert_async(L_ref != 0) factor = L_ref if inverse else 1.0 / L_ref - new_mesh = new_mesh.scale(factor) + new_mesh = new_mesh.scale(factor, assume_invertible=True) return new_mesh diff --git a/physicsnemo/datapipes/dataloader.py b/physicsnemo/datapipes/dataloader.py index 914e3f34b1..7ab6931b49 100644 --- a/physicsnemo/datapipes/dataloader.py +++ b/physicsnemo/datapipes/dataloader.py @@ -40,7 +40,6 @@ DatasetBase, IterableDatasetBase, preprocessing_stream, - record_stream, ) from physicsnemo.datapipes.registry import register @@ -86,16 +85,13 @@ class DataLoader: submitting sample loads ahead of consumption, bounded by ``prefetch_factor`` batches worth of in-flight samples. The main thread is the sole consumer: it performs all host-to-device transfers - and GPU transforms (including Warp kernels) on the prefetch streams. - Warp's invariant is the single launching thread, not a single stream, - so transforms run on the assigned preprocessing stream and overlap the - compute stream. + and GPU transforms on the prefetch streams, so transforms run on the + assigned preprocessing stream and overlap the compute stream. For the pipeline to stay primed, the main thread must not block: keep reader output (optionally) pinned (so host-to-device copies are asynchronous) and - avoid host readbacks (``.item()``, ``wp.synchronize()``), data- - dependent shapes, and GIL-bound pure-Python transforms on the launch - path. + avoid host readbacks (``.item()``), data-dependent shapes, and + GIL-bound pure-Python transforms on the launch path. Examples -------- @@ -357,9 +353,9 @@ def _iter_prefetch( pool, keeping a bounded number of samples in flight regardless of the consumer's cadence. The main thread is a pure drain loop: it pulls ready handles in order, runs the per-sample consume step - (host-to-device transfer plus GPU transforms, including Warp, on - the assigned stream), and reassembles batches from the boundary - markers the pump forwards. + (host-to-device transfer plus GPU transforms on the assigned + stream), and reassembles batches from the boundary markers the + pump forwards. Stream assignment is optional and decoupled from the threaded producer: when CUDA streams are enabled a stream is round-robined @@ -536,7 +532,7 @@ def _iter_iterable( Main-thread-only iteration for generator (iterable) datasets. There is no worker pool: the dataset's generator runs on the main - thread, so it may freely launch Warp kernels / use streams. Each + thread, so it may freely launch device kernels / use streams. Each item is generated on a preprocessing stream (when streams are enabled) and handed to the compute stream via a CUDA event, so generation of the next item can overlap training on the current @@ -571,7 +567,6 @@ def _iter_iterable( except StopIteration: break if use_stream: - record_stream(item, compute_stream) event = torch.cuda.Event() event.record(prep_stream) compute_stream.wait_event(event) diff --git a/physicsnemo/datapipes/dataset.py b/physicsnemo/datapipes/dataset.py index c56a767053..5c66c46c5b 100644 --- a/physicsnemo/datapipes/dataset.py +++ b/physicsnemo/datapipes/dataset.py @@ -35,7 +35,6 @@ DatasetBase, HostPayload, preprocessing_stream, - record_stream, ) from physicsnemo.datapipes.readers.base import Reader from physicsnemo.datapipes.registry import register @@ -71,9 +70,8 @@ class Dataset(DatasetBase): host-to-device transfer and the GPU transforms on the assigned CUDA stream. - This keeps all device-kernel launches (notably Warp transforms) on - the consuming thread, which must be the same single thread the model - launches from. + This keeps all device-kernel launches on the consuming thread, which + must be the same single thread the model launches from. >>> # Start prefetching >>> dataset.prefetch(0, stream=stream0) # doctest: +SKIP @@ -296,16 +294,14 @@ def _consume( """ Consumer stage: device transfer + transforms on the calling thread. - Runs on whatever thread calls this (the main thread, so any Warp + Runs on whatever thread calls this (the main thread, so any device kernels in the transforms share the model's launching thread). When a CUDA ``stream`` is assigned, the host-to-device copy *and* the - transforms run on that preprocessing stream -- Warp bound to it via + transforms run on that preprocessing stream via :func:`preprocessing_stream` -- so this sample's preprocessing - overlaps the previous batch's training on the compute stream. The - result is tagged via ``record_stream`` so the caching allocator does - not recycle it while training reads it, and a CUDA event orders the - preprocessing before the compute stream (never a host-side - synchronize). + overlaps the previous batch's training on the compute stream. A CUDA + event orders the preprocessing before the compute stream (never a + host-side synchronize). Parameters ---------- @@ -353,10 +349,8 @@ def _consume( data = self.transforms(data) if use_stream: - # Tag the memory so the allocator keeps it alive for the compute - # stream, then record an event marking the preprocessing's - # completion on the prep stream. - record_stream(data, compute_stream) + # Record an event marking the preprocessing's completion on the + # prep stream. event = torch.cuda.Event() event.record(stream) if defer_sync: diff --git a/physicsnemo/datapipes/io_pump.py b/physicsnemo/datapipes/io_pump.py index 516ea3c8c8..06e3cf33cb 100644 --- a/physicsnemo/datapipes/io_pump.py +++ b/physicsnemo/datapipes/io_pump.py @@ -15,10 +15,10 @@ # limitations under the License. """ -IOPump - A self-driving I/O producer that keeps the pipeline primed. +IOPump - A self-driving I/O producer that keeps the IO pipeline primed. -The pump owns a dedicated dispatcher thread that pulls work items from a -source iterator *lazily* and submits each for background loading, keeping +The pump owns a dedicated dispatcher thread that lazily pulls work items from a +source iterator and submits each for background loading, keeping a *bounded* number of samples in flight at all times. Pulling lazily means the source may be arbitrarily large (or effectively unbounded): the dispatcher only advances it when a backpressure slot is free, so memory diff --git a/physicsnemo/datapipes/mesh_dataset.py b/physicsnemo/datapipes/mesh_dataset.py index 2d2e471f7e..4afa936991 100644 --- a/physicsnemo/datapipes/mesh_dataset.py +++ b/physicsnemo/datapipes/mesh_dataset.py @@ -34,7 +34,6 @@ DatasetBase, HostPayload, preprocessing_stream, - record_stream, ) from physicsnemo.datapipes.readers.mesh import DomainMeshReader, MeshReader from physicsnemo.datapipes.registry import register @@ -96,17 +95,10 @@ def __init__( num_workers : int, default=1 Number of worker threads for the prefetch pool. Worker threads run :meth:`_load_host` (disk read + pin_memory) concurrently; - GPU operations (H2D transfer, transforms, Warp kernels) always - run on the main thread in :meth:`_consume`. + GPU operations (H2D transfer, transforms) always run on the + main thread in :meth:`_consume`. """ - # _load_host (disk read + pin_memory) is host-side only and launches no - # device kernels; all GPU work (H2D transfer, transforms, Triton SDF) - # runs on the main thread in _consume and is ordered via CUDA stream - # events. With the Warp-free (torch/Triton) SDF there is no wp.Mesh - # lifetime race to guard against, so load and consume need not be - # serialized: disabling serialization lets all num_workers threads read - # in parallel and makes prefetch_factor actually scale I/O throughput. - super().__init__(num_workers=num_workers, serialize_load_consume=False) + super().__init__(num_workers=num_workers) self.reader = reader self.transforms = list(transforms) if transforms else [] self._device = torch.device(device) if isinstance(device, str) else device @@ -236,15 +228,14 @@ def _consume( ) -> tuple[Union[Mesh, DomainMesh, TensorDict], dict[str, Any]]: """Consumer stage: device transfer + transforms on the calling thread. - Runs on whatever thread calls this (the main thread, so any Warp - mesh-query kernels in the transforms share the model's launching - thread). When a CUDA ``stream`` is assigned, the host-to-device - copy *and* the transforms run on that preprocessing stream -- Warp - bound to it via :func:`preprocessing_stream` -- so this sample's - preprocessing overlaps the previous batch's training on the compute - stream. The result is tagged via ``record_stream`` and a CUDA event - orders the preprocessing before the compute stream (not a host-side - synchronize). + Runs on whatever thread calls this (the main thread, so any device + kernels in the transforms share the model's launching thread). When + a CUDA ``stream`` is assigned, the host-to-device copy *and* the + transforms run on that preprocessing stream via + :func:`preprocessing_stream` -- so this sample's preprocessing + overlaps the previous batch's training on the compute stream. A CUDA + event orders the preprocessing before the compute stream (not a + host-side synchronize). Parameters ---------- @@ -306,10 +297,8 @@ def _apply_transforms(d: Any) -> Any: data = _apply_transforms(data) if use_stream: - # Tag the memory so the allocator keeps it alive for the compute - # stream, then record an event marking the preprocessing's - # completion on the prep stream. - record_stream(data, compute_stream) + # Record an event marking the preprocessing's completion on the + # prep stream. event = torch.cuda.Event() event.record(stream) if defer_sync: diff --git a/physicsnemo/datapipes/protocols.py b/physicsnemo/datapipes/protocols.py index d5dc61b49e..c810a21e46 100644 --- a/physicsnemo/datapipes/protocols.py +++ b/physicsnemo/datapipes/protocols.py @@ -42,27 +42,15 @@ import torch -from physicsnemo.core.function_spec import warp_stream_from_torch - -try: - import warp as wp - - _HAS_WARP = True -except ImportError: # pragma: no cover - warp is normally installed - wp = None - _HAS_WARP = False - @contextlib.contextmanager def preprocessing_stream(stream: Optional["torch.cuda.Stream"]): - """Bind torch (and Warp) to *stream* for the host-to-device + transforms. + """Bind torch to *stream* for the host-to-device copy + transforms. - Within the block both torch's current stream and -- when Warp is - installed -- Warp's current stream are set to *stream*, so Warp kernels - launched by transforms run on the same stream the data was copied on. - The single launching thread is the real Warp invariant; the stream is - free, but torch and Warp must agree on which one. A ``None`` stream is - a no-op (run on the current stream). + Within the block torch's current stream is set to *stream*, so the + host-to-device copy and any device kernels launched by transforms run + on the same stream the data was copied on. A ``None`` stream is a + no-op (run on the current stream). Parameters ---------- @@ -73,60 +61,7 @@ def preprocessing_stream(stream: Optional["torch.cuda.Stream"]): yield return with torch.cuda.stream(stream): - if _HAS_WARP: - # Use the cached wrapper: a fresh stream_from_torch per call would - # register/unregister the same CUDA handle on every consume and - # collide with the inner wrapper a Warp functional launch creates - # for the same stream, corrupting it (illegal memory access). - with wp.ScopedStream(warp_stream_from_torch(stream)): - yield - else: - yield - - -def record_stream(obj: Any, stream: "torch.cuda.Stream") -> None: - """Tag *obj*'s device tensors with *stream* for the caching allocator. - - Recurses into ``TensorDict``/tensor/mesh objects (which expose - ``record_stream``) and plain containers, so memory allocated on a - preprocessing stream is not recycled while the compute stream reads - it. Only CUDA tensors are tagged; CPU tensors (``record_stream`` is - unimplemented on the CPU backend) and objects without device memory - are skipped. - - Parameters - ---------- - obj : Any - Item to tag (tensor, TensorDict, mesh, dict, or sequence). - stream : torch.cuda.Stream - Stream that will consume the memory. - """ - if isinstance(obj, torch.Tensor): - # record_stream is a CUDA-only caching-allocator hint; it is - # unimplemented on the CPU backend, so only tag CUDA tensors. - if obj.is_cuda: - obj.record_stream(stream) - return - record = getattr(obj, "record_stream", None) - if callable(record): - device = getattr(obj, "device", None) - device_type = getattr(device, "type", None) - if device_type == "cuda": - obj.record_stream(stream) - elif device_type is None: - # Device-less container (e.g. a mixed-device TensorDict): recurse - # so CUDA leaves are tagged and CPU leaves are skipped. - values = getattr(obj, "values", None) - if callable(values): - for value in values(): - record_stream(value, stream) - # device_type == "cpu": no-op (nothing for the allocator to track) - elif isinstance(obj, dict): - for value in obj.values(): - record_stream(value, stream) - elif isinstance(obj, (list, tuple)): - for value in obj: - record_stream(value, stream) + yield @dataclass @@ -137,7 +72,7 @@ class HostPayload: the main-thread consumer. It carries a CPU ``TensorDict`` (ideally pinned, so the subsequent host-to-device copy can be asynchronous) plus metadata. It is produced by a worker thread, which must not - launch device kernels (in particular Warp kernels). + launch device kernels. Parameters ---------- @@ -193,8 +128,8 @@ class DatasetBase(ABC): Producer / consumer split -------------------------- Prefetching is split into two stages so that **no device kernels are - launched off the main thread** (a hard requirement for Warp - transforms, which must share the model's single launching thread): + launched off the main thread** (device kernels must share the model's + single launching thread): - :meth:`_load_host` is the **producer**. It runs on a worker thread and performs only thread-safe work: reading, decoding, and staging @@ -202,7 +137,7 @@ class DatasetBase(ABC): - :meth:`_consume` is the **consumer**. It runs on the thread that calls :meth:`consume` / :meth:`__getitem__` (the main thread, in practice) and performs the host-to-device transfer and device - transforms (including Warp kernels) on the assigned CUDA stream. + transforms on the assigned CUDA stream. :class:`Dataset` and :class:`MeshDataset` override both hooks to perform the real split. The default implementations fall back to @@ -214,13 +149,10 @@ def __init__( self, *, num_workers: int = 2, - serialize_load_consume: bool = False, ) -> None: self._executor: Optional[ThreadPoolExecutor] = None self._num_workers = num_workers self._lock = threading.Lock() - self._stage_lock = threading.Lock() - self._serialize_load_consume = serialize_load_consume # Futures still in flight, tracked so close() can drain them. self._inflight: set[Future] = set() # Index-keyed handles backing the prefetch()/__getitem__ compat API. @@ -295,12 +227,6 @@ def _load_host(self, work_item: Any) -> HostPayload: except Exception as e: # noqa: BLE001 return HostPayload(work_item=work_item, error=e) - def _load_host_guarded(self, work_item: Any) -> HostPayload: - if not self._serialize_load_consume: - return self._load_host(work_item) - with self._stage_lock: - return self._load_host(work_item) - def _consume( self, payload: HostPayload, @@ -338,18 +264,6 @@ def _consume( raise payload.error return payload.data, payload.metadata - def _consume_guarded( - self, - payload: HostPayload, - stream: Optional[torch.cuda.Stream] = None, - *, - defer_sync: bool = False, - ) -> tuple[Any, dict[str, Any]]: - if not self._serialize_load_consume: - return self._consume(payload, stream, defer_sync=defer_sync) - with self._stage_lock: - return self._consume(payload, stream, defer_sync=defer_sync) - # ------------------------------------------------------------------ # FIFO prefetch primitive (used by the DataLoader's pump) # ------------------------------------------------------------------ @@ -379,7 +293,7 @@ def submit( Handle bundling the producer future and the assigned stream. """ executor = self._ensure_executor() - future = executor.submit(self._load_host_guarded, work_item) + future = executor.submit(self._load_host, work_item) with self._lock: self._inflight.add(future) future.add_done_callback(self._discard_inflight) @@ -411,7 +325,7 @@ def consume( The sample data and its metadata. """ payload = handle.future.result() # re-raises producer errors via _consume - return self._consume_guarded(payload, handle.stream, defer_sync=defer_sync) + return self._consume(payload, handle.stream, defer_sync=defer_sync) # ------------------------------------------------------------------ # Concrete interface @@ -518,7 +432,7 @@ class IterableDatasetBase(ABC): indexing: it produces data by iteration only. The :class:`~physicsnemo.datapipes.DataLoader` drives it entirely on the main thread (no worker pool), so :meth:`__iter__` may freely launch - Warp kernels and use CUDA streams -- the property that makes online + device kernels and use CUDA streams -- the property that makes online simulation safe here but unsafe on the worker-pool preload path. Emission modes diff --git a/physicsnemo/datapipes/transforms/mesh/transforms.py b/physicsnemo/datapipes/transforms/mesh/transforms.py index b4479d715e..8dc182c64a 100644 --- a/physicsnemo/datapipes/transforms/mesh/transforms.py +++ b/physicsnemo/datapipes/transforms/mesh/transforms.py @@ -99,8 +99,7 @@ def __init__( self.vector = vector def __call__(self, mesh: Mesh) -> Mesh: - with torch.profiler.record_function("TranslateMesh: vector.to(device)"): - return mesh.translate(self.vector.to(mesh.points.device)) + return mesh.translate(self.vector.to(mesh.points.device)) def apply_to_domain(self, domain: DomainMesh) -> DomainMesh: """Apply translation to a :class:`DomainMesh`. @@ -454,10 +453,9 @@ def __init__( def __call__(self, mesh: Mesh) -> Mesh: new_gd = mesh.global_data.clone() - with torch.profiler.record_function("InjectGlobalFields: _fields.to(device)"): - new_gd.update( - self._fields.to(device=mesh.points.device, dtype=mesh.points.dtype) - ) + new_gd.update( + self._fields.to(device=mesh.points.device, dtype=mesh.points.dtype) + ) return Mesh( points=mesh.points, cells=mesh.cells, @@ -535,6 +533,31 @@ def __init__( else: raise ValueError("Provide one of 'stats_file' or 'fields'") + def to(self, device: torch.device | str) -> "NormalizeMeshFields": + """Move internal tensors and the nested field statistics to *device*. + + Extends :meth:`MeshTransform.to` by also moving the per-field + ``mean`` and ``std`` tensors in ``self._stats`` so the per-sample + ``.to()`` in :meth:`__call__` is a no-op. + + Parameters + ---------- + device : torch.device or str + Target device. + + Returns + ------- + NormalizeMeshFields + ``self``, for chaining. + """ + # Base .to() moves only bare tensor attrs; move the nested stats too + # so the per-sample .to() in __call__ is a no-op (no H2D copy/sync). + super().to(device) + for s in self._stats.values(): + s["mean"] = s["mean"].to(self._device) + s["std"] = s["std"].to(self._device) + return self + def __call__(self, mesh: Mesh) -> Mesh: ### Clone and z-score the targeted association's TensorDict in ### place; fields absent from `_stats` (or absent from the mesh) @@ -544,11 +567,8 @@ def __call__(self, mesh: Mesh) -> Mesh: if field_name not in new_td.keys(): continue val = new_td[field_name].float() - with torch.profiler.record_function( - "NormalizeMeshFields: stats.to(device)" - ): - mean = stats["mean"].to(dtype=val.dtype, device=val.device) - std = stats["std"].to(dtype=val.dtype, device=val.device) + mean = stats["mean"].to(dtype=val.dtype, device=val.device) + std = stats["std"].to(dtype=val.dtype, device=val.device) new_td[field_name] = (val - mean) / (std + self._eps) ### `Mesh.copy` is a tensorclass-provided shallow copy: `points`, @@ -594,13 +614,9 @@ def inverse_tensor( dim = 1 if ftype == "scalar" else n_spatial_dims if name in self._stats: stats = self._stats[name] - with torch.profiler.record_function( - "NormalizeMeshFields.inverse_tensor: stats.to(device)" - ): - mean = stats["mean"].to(dtype=tensor.dtype, device=tensor.device) - std = stats["std"].to(dtype=tensor.dtype, device=tensor.device) out[..., idx : idx + dim] = ( - out[..., idx : idx + dim] * (std + self._eps) + mean + out[..., idx : idx + dim] * (stats["std"] + self._eps) + + stats["mean"] ) idx += dim return out @@ -637,9 +653,7 @@ def _inverse_field(name: str, val: torch.Tensor) -> torch.Tensor: stats = self._stats.get(name) if stats is None: return val - mean = stats["mean"].to(dtype=val.dtype, device=val.device) - std = stats["std"].to(dtype=val.dtype, device=val.device) - return val * (std + self._eps) + mean + return val * (stats["std"] + self._eps) + stats["mean"] ### ``named_apply`` is typed ``TensorDict | None`` for its ### in-place mode; the out-of-place path always returns a TD. diff --git a/physicsnemo/mesh/transformations/geometric.py b/physicsnemo/mesh/transformations/geometric.py index 2b7bddb80f..c824329f05 100644 --- a/physicsnemo/mesh/transformations/geometric.py +++ b/physicsnemo/mesh/transformations/geometric.py @@ -502,10 +502,11 @@ def transform( if matrix.shape[0] == matrix.shape[1]: det = matrix.det() + ### The runtime det test syncs (host readback of a cuda tensor). if assume_invertible is not None: is_invertible = assume_invertible else: - is_invertible = det.abs() > 1e-10 + is_invertible = bool(det.abs() > 1e-10) if is_invertible: det_sign = det.sign() diff --git a/test/datapipes/core/test_streaming.py b/test/datapipes/core/test_streaming.py index 0e211671ab..6154f3ebe5 100644 --- a/test/datapipes/core/test_streaming.py +++ b/test/datapipes/core/test_streaming.py @@ -27,7 +27,6 @@ from __future__ import annotations import threading -import time import numpy as np import pytest @@ -140,31 +139,6 @@ def __len__(self): return len(self._store) -class _StageLockedDataset(DatasetBase): - """Dataset that records whether worker load overlaps consume.""" - - def __init__(self): - super().__init__(num_workers=1, serialize_load_consume=True) - self.load_entries: list[int] = [] - self.consume_started = threading.Event() - self.release_consume = threading.Event() - - def _load(self, index): - return TensorDict({"x": torch.tensor([float(index)])}), {"index": index} - - def _load_host(self, work_item): - self.load_entries.append(work_item) - return super()._load_host(work_item) - - def _consume(self, payload, stream=None, *, defer_sync=False): - self.consume_started.set() - self.release_consume.wait(timeout=5.0) - return super()._consume(payload, stream, defer_sync=defer_sync) - - def __len__(self): - return 2 - - class TestSubmitConsume: """Tests for the FIFO submit/consume primitive.""" @@ -199,29 +173,6 @@ def test_producer_error_reraised_on_consume(self): finally: ds.close() - def test_stage_lock_prevents_load_consume_overlap(self): - """Opt-in stage lock keeps worker loads out of active consume.""" - ds = _StageLockedDataset() - try: - first = ds.submit(0) - first.future.result(timeout=5.0) - - consumer = threading.Thread(target=ds.consume, args=(first,)) - consumer.start() - assert ds.consume_started.wait(timeout=5.0) - - second = ds.submit(1) - time.sleep(0.1) - assert ds.load_entries == [0] - - ds.release_consume.set() - consumer.join(timeout=5.0) - second.future.result(timeout=5.0) - assert ds.load_entries == [0, 1] - finally: - ds.release_consume.set() - ds.close() - # ============================================================================ # Stage 1 -- DataLoader laziness over the sampler From 6de0ab523f09dc2882849ba937f9693b596d11d4 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Sat, 27 Jun 2026 09:47:28 -0500 Subject: [PATCH 04/10] Revert function spec core. --- physicsnemo/core/function_spec.py | 44 ++----------------------------- 1 file changed, 2 insertions(+), 42 deletions(-) diff --git a/physicsnemo/core/function_spec.py b/physicsnemo/core/function_spec.py index d076406a4a..945daa3887 100644 --- a/physicsnemo/core/function_spec.py +++ b/physicsnemo/core/function_spec.py @@ -29,44 +29,6 @@ from physicsnemo.core.version_check import check_version_spec -# Cache of Warp stream wrappers keyed by the underlying CUDA stream handle. -# -# ``warp.stream_from_torch`` wraps a torch-owned (external) CUDA stream, and the -# resulting ``warp.Stream`` unregisters that handle from Warp on ``__del__``. -# Creating a fresh wrapper on every launch therefore churns register/unregister -# on a shared stream; unregistering while another wrapper -- or an in-flight -# kernel -- still uses the stream corrupts it (illegal memory access). Keeping -# one long-lived wrapper per handle registers each stream exactly once. -_WARP_STREAM_CACHE: Dict[int, Any] = {} - - -def warp_stream_from_torch(torch_stream: "torch.cuda.Stream") -> Any: - """Return a cached Warp stream wrapping *torch_stream*. - - Wrapping a torch stream registers it with Warp; the wrapper unregisters it - on garbage collection. Caching one wrapper per CUDA stream handle keeps the - registration stable for the process lifetime, which is required when the - same torch stream is bound by nested Warp scopes (e.g. an outer - preprocessing scope and an inner functional launch). - - Parameters - ---------- - torch_stream : torch.cuda.Stream - Torch CUDA stream to wrap. - - Returns - ------- - warp.Stream - Cached Warp stream sharing ``torch_stream``'s underlying CUDA handle. - """ - wp = importlib.import_module("warp") - handle = torch_stream.cuda_stream - cached = _WARP_STREAM_CACHE.get(handle) - if cached is None: - cached = wp.stream_from_torch(torch_stream) - _WARP_STREAM_CACHE[handle] = cached - return cached - @dataclass(frozen=True) class Implementation: @@ -725,13 +687,11 @@ def warp_launch_context(tensor: torch.Tensor): Warp device and stream. """ try: - importlib.import_module("warp") + wp = importlib.import_module("warp") except ImportError as exc: raise ImportError("warp is not available") from exc if tensor.device.type == "cuda": - # Reuse a cached wrapper so binding the current torch stream does - # not churn Warp's stream registration (see warp_stream_from_torch). - stream = warp_stream_from_torch(torch.cuda.current_stream(tensor.device)) + stream = wp.stream_from_torch(torch.cuda.current_stream(tensor.device)) device = None else: stream = None From bfe196f693bb6a3b6f2168df0b404ab61db3be7e Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Sat, 27 Jun 2026 10:07:09 -0500 Subject: [PATCH 05/10] remove timing class --- physicsnemo/datapipes/mesh_dataset.py | 8 +-- physicsnemo/datapipes/readers/mesh.py | 22 +++----- test/datapipes/core/test_streaming.py | 78 +-------------------------- 3 files changed, 12 insertions(+), 96 deletions(-) diff --git a/physicsnemo/datapipes/mesh_dataset.py b/physicsnemo/datapipes/mesh_dataset.py index 4afa936991..03b9fae8b6 100644 --- a/physicsnemo/datapipes/mesh_dataset.py +++ b/physicsnemo/datapipes/mesh_dataset.py @@ -28,7 +28,6 @@ import torch from tensordict import TensorDict -from physicsnemo.datapipes import _timing from physicsnemo.datapipes._rng import fork_generator from physicsnemo.datapipes.protocols import ( DatasetBase, @@ -288,13 +287,11 @@ def _apply_transforms(d: Any) -> Any: with torch.profiler.record_function( "MeshDataset._consume: data.to(device)" ): - with _timing.record("consume/h2d"): - data = data.to(self._device, non_blocking=True) + data = data.to(self._device, non_blocking=True) with torch.profiler.record_function( "MeshDataset._consume: _apply_transforms" ): - with _timing.record("consume/transforms"): - data = _apply_transforms(data) + data = _apply_transforms(data) if use_stream: # Record an event marking the preprocessing's completion on the @@ -310,7 +307,6 @@ def _apply_transforms(d: Any) -> Any: # insert the wait at the right point). compute_stream.wait_event(event) - _timing.tick() return data, metadata def close(self) -> None: diff --git a/physicsnemo/datapipes/readers/mesh.py b/physicsnemo/datapipes/readers/mesh.py index 9746b19031..ff67b09fbc 100644 --- a/physicsnemo/datapipes/readers/mesh.py +++ b/physicsnemo/datapipes/readers/mesh.py @@ -30,7 +30,6 @@ import torch -from physicsnemo.datapipes import _timing from physicsnemo.datapipes._rng import spawn_generator from physicsnemo.datapipes.registry import register from physicsnemo.mesh import DomainMesh, Mesh @@ -424,8 +423,7 @@ def set_epoch(self, epoch: int) -> None: self._epoch = epoch def __getitem__(self, index: int) -> tuple[DomainMesh, dict[str, Any]]: - with _timing.record("producer/domain_load"): - dm = self._load_sample(index) + dm = self._load_sample(index) # Trim unused data before subsample/pin. Both references are lazy (no # memmap materialization here): @@ -463,13 +461,11 @@ def __getitem__(self, index: int) -> tuple[DomainMesh, dict[str, Any]]: n_points=self.subsample_n_points, generator=generator, ) - with _timing.record("producer/interior_subsample"): - interior = _subsample_mesh(dm.interior, **sub_kw) - with _timing.record("producer/boundary_subsample"): - boundaries = { - name: _subsample_mesh(dm.boundaries[name], **sub_kw) - for name in dm.boundary_names - } + interior = _subsample_mesh(dm.interior, **sub_kw) + boundaries = { + name: _subsample_mesh(dm.boundaries[name], **sub_kw) + for name in dm.boundary_names + } dm = DomainMesh( interior=interior, boundaries=boundaries, @@ -478,12 +474,10 @@ def __getitem__(self, index: int) -> tuple[DomainMesh, dict[str, Any]]: # Load extra boundary meshes (full resolution, no subsampling). if self._extra_boundaries: - with _timing.record("producer/stl_load"): - dm = self._load_extra_boundaries(dm, index) + dm = self._load_extra_boundaries(dm, index) if self.pin_memory: - with _timing.record("producer/pin_memory"): - dm = dm.pin_memory() + dm = dm.pin_memory() metadata: dict[str, Any] = { "source_path": str(self._paths[index]), diff --git a/test/datapipes/core/test_streaming.py b/test/datapipes/core/test_streaming.py index 6154f3ebe5..49c1b9426a 100644 --- a/test/datapipes/core/test_streaming.py +++ b/test/datapipes/core/test_streaming.py @@ -21,7 +21,7 @@ primitive). Stage 2 covers iterable datasets driven main-thread-only (finite/capped-infinite generators, ``drop_last``, self-batching pass-through, reproducibility, and no worker pool). CUDA-guarded tests -exercise stream-bound preprocessing and Warp-on-a-non-default-stream. +exercise stream-bound preprocessing. """ from __future__ import annotations @@ -386,7 +386,7 @@ def test_runs_on_main_thread_no_worker_pool(self): # ============================================================================ -# CUDA-guarded -- stream-bound consume and Warp-on-non-default-stream +# CUDA-guarded -- stream-bound consume # ============================================================================ @@ -492,77 +492,3 @@ def spy_wait(self, event): # yielded (the regression guard: the old code waited inline during the # lookahead consume of batch 1, before batch 0 was ever yielded). assert wait_indices[1] > yield0_idx, order - - -class TestWarpIterableOnStream: - """Warp launches on a non-default stream from the main thread are safe.""" - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_darcy_online_simulation_through_iterable_loader(self): - from physicsnemo.datapipes.benchmarks.darcy import Darcy2D - - class _DarcyIterable(IterableDatasetBase): - yields_batches = True - - def __init__(self, num_batches): - self._sim = Darcy2D(resolution=32, batch_size=2, device="cuda") - self._num_batches = num_batches - - def __iter__(self): - sim_iter = iter(self._sim) - for _ in range(self._num_batches): - yield next(sim_iter) - - loader = dp.DataLoader(_DarcyIterable(2), use_streams=True) - batches = list(loader) - torch.cuda.synchronize() # surfaces any illegal-memory-access - assert len(batches) == 2 - for batch in batches: - assert batch["permeability"].device.type == "cuda" - assert batch["darcy"].device.type == "cuda" - - -class TestWarpFunctionalTransformOnStreams: - """A Warp ``FunctionSpec`` transform driven through the multi-stream - preload path. The functional binds the current torch stream as a Warp - stream internally; the loader binds the same stream around the consume. - Both must reuse one cached wrapper -- otherwise the inner wrapper - unregisters the shared stream on teardown and the next launch faults - (illegal memory access).""" - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_functional_warp_transform_multi_stream(self, numpy_data_dir): - from physicsnemo.datapipes.transforms.base import Transform - from physicsnemo.nn.functional import signed_distance_field - - class _SDFTransform(Transform): - """Evaluate an SDF (a Warp functional) against the sample points.""" - - def __call__(self, data): - points = data["positions"].reshape(-1, 3).float() - vertices = torch.tensor( - [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], - device=points.device, - ) - faces = torch.tensor([[0, 1, 2]], device=points.device) - sdf, _ = signed_distance_field(vertices, faces, points) - data["sdf"] = sdf.reshape(-1, 1) - return data - - reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) - dataset = dp.Dataset(reader, device="cuda:0", transforms=_SDFTransform()) - loader = dp.DataLoader( - dataset, - batch_size=1, - shuffle=False, - prefetch_factor=2, - num_streams=4, - use_streams=True, - ) - # Iterate well past num_streams so every stream is reused at least - # once; a churned registration faults on the second pass. - batches = list(loader) - torch.cuda.synchronize() # surfaces any illegal-memory-access - assert len(batches) == 10 - for batch in batches: - assert batch["sdf"].device.type == "cuda" From 2d1ca0814005af121659ba24d8e3d77ad86bcceb Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Sat, 27 Jun 2026 10:10:20 -0500 Subject: [PATCH 06/10] fully revert function spec. Damn it claude. --- physicsnemo/core/function_spec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/physicsnemo/core/function_spec.py b/physicsnemo/core/function_spec.py index 945daa3887..a6bc1a54f8 100644 --- a/physicsnemo/core/function_spec.py +++ b/physicsnemo/core/function_spec.py @@ -123,7 +123,7 @@ class FunctionSpec: from physicsnemo.core.function_spec import FunctionSpec wp.init() - wp.config.quiet = True + wp.config.log_level = wp.LOG_WARNING @wp.kernel def _identity_kernel( From e85843b1642a04c26b421156746d3fee0ab33541 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 29 Jun 2026 09:52:21 -0500 Subject: [PATCH 07/10] Revert most of the radius search changes, except keep the grid ID transfer non blocking --- .../neighbors/radius_search/_warp_impl.py | 62 ++++++++----------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py index 3a61667136..a88ba972ac 100644 --- a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py +++ b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py @@ -215,12 +215,10 @@ def radius_search_impl( input_dtype = points.dtype # Warp supports only fp32, so we have to cast: - with torch.profiler.record_function("radius_search: cast points to fp32"): - if points.dtype != torch.float32: - points = points.to(torch.float32) - with torch.profiler.record_function("radius_search: cast queries to fp32"): - if queries.dtype != torch.float32: - queries = queries.to(torch.float32) + if points.dtype != torch.float32: + points = points.to(torch.float32) + if queries.dtype != torch.float32: + queries = queries.to(torch.float32) # Compute follows data. wp_launch_device, wp_launch_stream = FunctionSpec.warp_launch_context(points) @@ -231,18 +229,13 @@ def radius_search_impl( wp_points_per_b = [] wp_queries_per_b = [] for b in range(B): - with torch.profiler.record_function(f"radius_search: contiguous b={b}"): - pts_b = points[b].contiguous() - qrs_b = queries[b].contiguous() - with torch.profiler.record_function(f"radius_search: from_torch b={b}"): - wp_pts_b = wp.from_torch(pts_b, dtype=wp.vec3) - wp_qrs_b = wp.from_torch(qrs_b, dtype=wp.vec3, return_ctype=True) - with torch.profiler.record_function(f"radius_search: HashGrid build b={b}"): - grid = wp.HashGrid( - dim_x=128, dim_y=128, dim_z=128, device=wp_pts_b.device - ) - grid.reserve(N_queries) - grid.build(points=wp_pts_b, radius=0.5 * radius) + pts_b = points[b].contiguous() + qrs_b = queries[b].contiguous() + wp_pts_b = wp.from_torch(pts_b, dtype=wp.vec3) + wp_qrs_b = wp.from_torch(qrs_b, dtype=wp.vec3, return_ctype=True) + grid = wp.HashGrid(dim_x=128, dim_y=128, dim_z=128, device=wp_pts_b.device) + grid.reserve(N_queries) + grid.build(points=wp_pts_b, radius=0.5 * radius) grids.append(grid) wp_points_per_b.append(wp_pts_b) wp_queries_per_b.append(wp_qrs_b) @@ -333,27 +326,23 @@ def radius_search_impl( # Build warp array of grid IDs. # Construct in pinned host memory first so the H2D copy is # stream-ordered (non_blocking) rather than a blocking cudaMemcpy. - with torch.profiler.record_function("radius_search: grid_ids host->device"): - _grid_ids_cpu = torch.tensor( - [g.id for g in grids], - dtype=torch.int64, - pin_memory=torch.cuda.is_available(), - ) - grid_ids_tensor = _grid_ids_cpu.to(points.device, non_blocking=True) + _grid_ids_cpu = torch.tensor( + [g.id for g in grids], + dtype=torch.int64, + pin_memory=torch.cuda.is_available(), + ) + grid_ids_tensor = _grid_ids_cpu.to(points.device, non_blocking=True) wp_grid_ids = wp.from_torch( grid_ids_tensor, dtype=wp.uint64, return_ctype=True ) # Convert batched points/queries to warp 2D arrays - with torch.profiler.record_function( - "radius_search: contiguous points/queries 2D" - ): - wp_points_2d = wp.from_torch( - points.contiguous(), dtype=wp.vec3, return_ctype=True - ) - wp_queries_2d = wp.from_torch( - queries.contiguous(), dtype=wp.vec3, return_ctype=True - ) + wp_points_2d = wp.from_torch( + points.contiguous(), dtype=wp.vec3, return_ctype=True + ) + wp_queries_2d = wp.from_torch( + queries.contiguous(), dtype=wp.vec3, return_ctype=True + ) # Allocate outputs with batch dimension indices = torch.full( @@ -423,9 +412,8 @@ def radius_search_impl( pts_out = pts_out.squeeze(0) # Handle the matrix of return values: - with torch.profiler.record_function("radius_search: cast outputs to input_dtype"): - pts_out = pts_out.to(input_dtype) - dists_out = dists_out.to(input_dtype) + pts_out = pts_out.to(input_dtype) + dists_out = dists_out.to(input_dtype) return indices, pts_out, dists_out, num_neighbors From 8e8cb3536f2972a315256db56bf56a4d4da25812 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 29 Jun 2026 10:57:35 -0500 Subject: [PATCH 08/10] Remove record function from train.py --- .../unified_external_aero_recipe/src/train.py | 87 ++++++++----------- 1 file changed, 35 insertions(+), 52 deletions(-) diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py index 00888ac4a7..4c84187b7c 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py @@ -271,23 +271,18 @@ def forward_pass( targets: TensorDict = batch["targets"] ### Inputs keep their native dtype; autocast handles model-internal precision. - with torch.profiler.record_function("forward_pass: model forward"): - with get_autocast_context(precision): - output = model(**forward_kwargs) + with get_autocast_context(precision): + output = model(**forward_kwargs) - with torch.profiler.record_function("forward_pass: normalize_output_to_tensordict"): - pred_td = normalize_output_to_tensordict(output, target_config, output_type) + pred_td = normalize_output_to_tensordict(output, target_config, output_type) ### Loss runs in float32 to avoid bf16 precision loss in the reduction. - with torch.profiler.record_function("forward_pass: .float() cast"): - pred_f32 = pred_td.float() - target_f32 = targets.float() + pred_f32 = pred_td.float() + target_f32 = targets.float() - with torch.profiler.record_function("forward_pass: loss_calculator"): - loss, loss_td = loss_calculator(pred_f32, target_f32) + loss, loss_td = loss_calculator(pred_f32, target_f32) with torch.no_grad(): - with torch.profiler.record_function("forward_pass: metric_calculator"): - metric_td = metric_calculator(pred_f32, target_f32) + metric_td = metric_calculator(pred_f32, target_f32) ### Detach (don't sync) the per-field TDs so the caller controls when ### a D2H copy happens; running ``.item()`` here would serialise the ### forward kernels against the host. ``TensorDict.detach()`` walks @@ -383,19 +378,17 @@ def _run_epoch( with grad_ctx: step_t0 = time.perf_counter() for i, batch in enumerate(dataloader): - with torch.profiler.record_function("_run_epoch: recursive_to_device"): - batch = recursive_to_device(batch, dist_manager.device) - - with torch.profiler.record_function("_run_epoch: forward_pass"): - loss, losses, metrics = forward_pass( - batch, - model, - precision, - loss_calculator, - metric_calculator, - output_type=output_type, - target_config=target_config, - ) + batch = recursive_to_device(batch, dist_manager.device) + + loss, losses, metrics = forward_pass( + batch, + model, + precision, + loss_calculator, + metric_calculator, + output_type=output_type, + target_config=target_config, + ) ### Kick off the async D2H copy of the scalar loss value into the ### pinned buffer. Backward + optimizer.step run while the copy is @@ -405,20 +398,14 @@ def _run_epoch( _loss_pinned.copy_(loss.detach(), non_blocking=True) if is_train: - with torch.profiler.record_function("_run_epoch: optimizer.zero_grad"): - optimizer.zero_grad() + optimizer.zero_grad() if precision == "float16" and scaler is not None: - with torch.profiler.record_function("_run_epoch: scaler.backward"): - scaler.scale(loss).backward() - with torch.profiler.record_function("_run_epoch: scaler.step"): - scaler.step(optimizer) - with torch.profiler.record_function("_run_epoch: scaler.update"): - scaler.update() + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() else: - with torch.profiler.record_function("_run_epoch: loss.backward"): - loss.backward() - with torch.profiler.record_function("_run_epoch: optimizer.step"): - optimizer.step() + loss.backward() + optimizer.step() if cfg.training.get("scheduler_update_mode", "epoch") == "step": scheduler.step() @@ -439,11 +426,10 @@ def _run_epoch( ### Read the loss scalar from the pinned buffer; the async copy ### was issued before backward so it has had the full backward + ### optimizer.step to complete without stalling the host. - with torch.profiler.record_function("_run_epoch: loss.item() D2H readback"): - if _loss_pinned is not None: - this_loss = _loss_pinned.item() - else: - this_loss = loss.detach().item() + if _loss_pinned is not None: + this_loss = _loss_pinned.item() + else: + this_loss = loss.detach().item() total_loss += this_loss step_dt = time.perf_counter() - step_t0 @@ -753,16 +739,13 @@ def benchmark_io_epoch( ) for name, t in named_tensors: v_flat = t.float() if t.is_floating_point() else t.to(torch.float32) - with torch.profiler.record_function( - "benchmark_io: tensor stats .item() D2H" - ): - logger.info( - f" {name:30s} " - f"min={v_flat.min().item(): .6e} " - f"mean={v_flat.mean().item(): .6e} " - f"std={v_flat.std().item(): .6e} " - f"max={v_flat.max().item(): .6e}" - ) + logger.info( + f" {name:30s} " + f"min={v_flat.min().item(): .6e} " + f"mean={v_flat.mean().item(): .6e} " + f"std={v_flat.std().item(): .6e} " + f"max={v_flat.max().item(): .6e}" + ) if max_steps is not None and i + 1 >= max_steps: break From a92e79bc9cb37edec90a07d878aae974722292cd Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 29 Jun 2026 11:03:06 -0500 Subject: [PATCH 09/10] Update changelog. --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c670e3c04..a55f41c9ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,6 +111,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 datapipes implementation (`physicsnemo.datapipes.transforms._sdf_torch` / `_sdf_triton`, including its bespoke Triton winding kernel) is superseded and removed; the public datapipes SDF transform delegates here. +- Added an iterable style dataset to physicsnemo datapipes, for on-the-fly gpu simulations. ### Changed @@ -154,6 +155,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 use `torch.no_grad()`, not `torch.inference_mode()`). Also expands CI test coverage and adds an API documentation page for `physicsnemo.diffusion.multi_diffusion`. +- Performance improvements in IO prefetching and GPU preprocessing in physicsnemo datapipes. ### Deprecated From 56933476631c51c7bf756db1c36fbee722298f2c Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 29 Jun 2026 14:28:06 -0500 Subject: [PATCH 10/10] Updating markdown, WIP --- physicsnemo/datapipes/datapipes.md | 171 +++++++++++++++++++---------- 1 file changed, 115 insertions(+), 56 deletions(-) diff --git a/physicsnemo/datapipes/datapipes.md b/physicsnemo/datapipes/datapipes.md index 0aa1dc9929..16533de858 100644 --- a/physicsnemo/datapipes/datapipes.md +++ b/physicsnemo/datapipes/datapipes.md @@ -44,20 +44,28 @@ Three dataset types share this pattern: | `MeshDataset` | `Mesh` / `DomainMesh` tensorclasses | `MeshTransform` | | `MultiDataset` | Union of child `DatasetBase` instances | Delegates to children | -All three inherit from `DatasetBase`, which provides thread-pool -prefetching and a `Future`-based cache (see -[Performance](#performance-threading-and-stream-based-concurrency) below). +`Dataset` and `MeshDataset` inherit from `DatasetBase`, which provides +thread-pool prefetching via a FIFO (First In / First Out) +`submit`/`consume` primitive driven by the `IOPump` (the index-keyed +`prefetch`/`__getitem__` cache is a thin random-access layer on top). +`MultiDataset` implements the same map-style surface by delegating +`submit`, `consume`, `prefetch`, `_pop_events`, and `close` to the child +`DatasetBase` instance that owns each sample; see +[Performance](#performance-threading-and-stream-based-concurrency) below. ## Composability ### Readers -A `Reader` is an ABC with a single contract: +A `Reader` is an ABC with one main loading hook plus a required length: ```python class Reader(ABC): @abstractmethod def _load_sample(self, index: int) -> dict[str, Tensor]: ... + + @abstractmethod + def __len__(self) -> int: ... ``` `__getitem__` wraps the result in a `TensorDict` on CPU (optionally @@ -79,7 +87,9 @@ multi-region consistency. ### Collators -Collators combine per-sample `(TensorDict, metadata)` tuples into batches: +Collators combine per-sample `(data, metadata)` tuples into batches, +where `data` is usually a `TensorDict`, `Mesh`, or `DomainMesh` depending +on the dataset and collator: | Collator | Strategy | |----------|----------| @@ -147,31 +157,36 @@ preprocessing. Threads are a natural fit: ### Producer / consumer split -Prefetching is split into two stages so that **no device kernels are -launched off the main thread** -- a hard requirement for Warp-based -transforms, which must share the model's single launching thread: +For the provided `Dataset` and `MeshDataset` implementations, prefetching +is split into two stages so that **no device kernels are launched off the +main thread** -- device kernels must share the model's single launching +thread: - `_load_host` is the **producer**. It runs on a worker thread and does only thread-safe work: reading, decoding, and staging into pinned host memory. It returns a `HostPayload`. - `_consume` is the **consumer**. It runs on whatever thread calls `__getitem__` (the main thread, in practice) and performs the - host-to-device transfer and device transforms (including Warp kernels). + host-to-device transfer and device transforms. `DatasetBase` owns a `ThreadPoolExecutor` (configurable via `num_workers`) and exposes a FIFO prefetch primitive. `submit(work_item, -stream=...)` runs only the producer on the pool and returns a -`PrefetchHandle` bundling the future with the stream the consumer should -use; `consume(handle)` resolves it on the calling thread: +stream=...)` runs `_load_host` on the pool and returns a `PrefetchHandle` +bundling the future with the stream the consumer should use; +`consume(handle)` resolves it on the calling thread. Subclasses that do +not override `_load_host` and `_consume` fall back to running their full +`_load` pipeline in the worker, so the main-thread launch guarantee +belongs to split implementations such as `Dataset` and `MeshDataset`: ```python def submit(self, work_item, stream=None): future = self._executor.submit(self._load_host, work_item) return PrefetchHandle(future=future, stream=stream) -def consume(self, handle): +def consume(self, handle, *, defer_sync=False): payload = handle.future.result() # re-raises producer errors - return self._consume(payload, handle.stream) # H2D + transforms here + # H2D + transforms here; defer_sync controls who gates the compute stream + return self._consume(payload, handle.stream, defer_sync=defer_sync) ``` Correlation is purely by handle identity (FIFO), so work items need not @@ -193,9 +208,9 @@ them in place without consuming a slot, so the consumer reassembles dynamically-sized batches from the boundaries -- the DataLoader never builds the epoch's batch list in advance. Because dispatch lives off the main thread, the pipeline stays primed even while the main thread is busy -launching kernels or running the model. This path is active whenever -`prefetch_factor > 0`; set `prefetch_factor=0` for fully synchronous -iteration. +launching kernels or running the model. For map-style datasets, this +path is active whenever `prefetch_factor > 0`; set `prefetch_factor=0` +to use synchronous map-style iteration. ### CUDA stream handoff @@ -204,39 +219,69 @@ producer. When `use_streams=True` (and CUDA is available), each sample is round-robined a **preprocessing stream**. The consumer runs *both* the host-to-device copy and the transforms on that stream, then hands the result to the compute stream via a CUDA **event** (never a host -`synchronize()`): +`synchronize()`). Crucially, *who* enqueues the compute-stream wait +depends on the `defer_sync` flag (see +[One-batch lookahead](#one-batch-lookahead-deferred-sync) below): the +DataLoader defers it so the wait lands *after* the previous batch's model +work, while standalone callers wait inline: ```python -def _consume(self, payload, stream=None): +def _consume(self, payload, stream=None, *, defer_sync=False): data = payload.data if device is not None and stream is not None: compute_stream = torch.cuda.current_stream() - # Bind torch AND Warp to the preprocessing stream. - with preprocessing_stream(stream): # torch + wp.ScopedStream + # Bind torch to the preprocessing stream. + with preprocessing_stream(stream): # torch.cuda.stream data = data.to(device, non_blocking=True) # H2D on prep stream data = self.transforms(data) # transforms on SAME stream - data.record_stream(compute_stream) # keep memory alive event = torch.cuda.Event() event.record(stream) - compute_stream.wait_event(event) # order, no host block + if defer_sync: + self._events_pending.append(event) # DataLoader gates later + else: + compute_stream.wait_event(event) # inline order, no host block else: data = self.transforms(data) return data, payload.metadata ``` -**The single launching thread -- not a single stream -- is Warp's real -invariant.** Warp kernels may run on any CUDA stream provided they are -launched from the main thread *and* Warp's current stream matches torch's. -`preprocessing_stream` (in `protocols.py`) binds both via -`wp.ScopedStream(wp.stream_from_torch(stream))`, so transforms (including -Warp mesh-query / BVH kernels) run correctly on the side stream. A -previous `cudaErrorIllegalAddress` here was a torch/Warp stream -*divergence* (data on a side stream, the Warp kernel on Warp's own -stream), not a prohibition on non-default streams; binding both fixes it -and lets GPU preprocessing genuinely overlap training. `record_stream` -keeps the device tensors from being recycled while the compute stream -reads them; the pinned host source is held by the caching host allocator -until the copy completes. +`preprocessing_stream` (in `protocols.py`) binds torch's current stream +to the preprocessing stream via `torch.cuda.stream(stream)`, so the +host-to-device copy and the transforms run on the side stream and GPU +preprocessing genuinely overlaps training. The pinned host source is +held by the caching host allocator until the copy completes. + +### One-batch lookahead (deferred sync) + +The compute-stream wait recorded above is only half the overlap story: +*when* it is enqueued decides whether preprocessing actually overlaps +training. The DataLoader's prefetch loop (`_iter_prefetch`) keeps a +**one-batch lookahead** so the wait lands at the right point: + +- `drain_one_batch()` consumes pump items up to the next `BATCH_BOUNDARY`, + calling `consume(item, defer_sync=True)` on each. With `defer_sync=True` + the consumer enqueues the H2D copy + transforms on the preprocessing + stream and *records* an event, but appends it to `_events_pending` + instead of making the compute stream wait. +- `_pop_events()` (on `DatasetBase`) hands those recorded events back to + the loop. +- Before yielding batch N, the loop eagerly drains batch **N+1**'s items + (launching their preprocessing on the side streams), then calls + `gate_compute_stream(events)` to issue `compute_stream.wait_event` for + batch N's events -- right before the `yield`. + +This ordering is the whole point. Because the wait for batch N is +enqueued *after* the previous iteration's `yield` already enqueued batch +N-1's model kernels, the compute-stream order becomes +`..., model_{N-1}, wait(prep_N), model_N, ...`. Batch N's preprocessing +(already in flight on its own stream) overlaps batch N-1's compute, and +each model only ever blocks on its own batch's preprocessing -- never on +the next batch's. If `_consume` instead waited inline during the +lookahead drain, that wait would be ordered *ahead* of batch N's model +kernels and the model would block on batch N+1's preprocessing -- the +opposite of overlap. Standalone callers (no DataLoader to insert the +gate) leave `defer_sync=False` and get the inline wait so their result is +immediately safe to use. ### Concurrency timeline @@ -244,15 +289,21 @@ With everything launched from the main thread, the worker pool, the preprocessing stream, and the compute stream form a triple buffer: ```text -Worker pool │ load N+2 ─ load N+1 ... (host I/O + thread-safe CPU work) -Preprocess stream │ H2D + Warp transforms for N+1 -Compute stream │ train N +Worker pool │ load N+1 ─ load N ... (host I/O + thread-safe CPU work) +Preprocess stream │ H2D + transforms for N +Compute stream │ train N-1 ─ wait(prep_N) ─ train N ``` -GPU preprocessing of batch N+1 genuinely overlaps training of batch N on +GPU preprocessing of batch N genuinely overlaps training of batch N-1 on a separate stream; the two are ordered by a CUDA event, never a host-side -`synchronize`. A transform (or generator) that forces a host readback -simply serializes itself -- a property of that code, not of the pipeline. +`synchronize`. The ordering is what the one-batch lookahead buys: the +compute stream's `wait(prep_N)` is enqueued by `gate_compute_stream` only +*after* batch N-1's model kernels and only *after* batch N's preprocessing +has been launched on its side stream (see +[One-batch lookahead](#one-batch-lookahead-deferred-sync)), so each model +blocks on its own batch's preprocessing and nothing else. A transform (or +generator) that forces a host readback simply serializes itself -- a +property of that code, not of the pipeline. ### Two data paths: map/descriptor vs iterable @@ -267,10 +318,10 @@ type: - **Generator path (`IterableDatasetBase`)** -- iterable datasets that *produce* data (online simulation, procedural samplers, unbounded streams). Driven **main-thread-only**: no sampler, no pump, no worker - pool. `__iter__` may freely launch Warp kernels and use CUDA streams - (the single-launching-thread invariant holds), and the loader still - drives generation on a preprocessing stream with the same event handoff, - so generation of batch N+1 overlaps training of batch N. + pool. `__iter__` may freely launch device kernels and use CUDA streams, + and the loader still drives generation on a preprocessing stream with a + CUDA event handoff. This path does not use the map-style `IOPump`, + `_events_pending`, or one-batch deferred-sync loop. An iterable dataset yields either per-sample `(data, metadata)` (the loader collates `batch_size` of them, `drop_last` trims the tail) or, when @@ -278,7 +329,7 @@ loader collates `batch_size` of them, `drop_last` trims the tail) or, when unchanged. Iterable datasets have no length: `len(loader)` raises `TypeError`, and `shuffle`/`sampler` are ignored. See `examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py` for -a Warp `Darcy2D` online simulation wired through this path. +a `Darcy2D` online simulation wired through this path. ### Pinned memory @@ -292,13 +343,15 @@ above is most effective when the reader pins its output. Prefetching can be toggled at runtime for debugging: ```python -loader.disable_prefetch() # synchronous, single-stream -- easy to debug -loader.enable_prefetch() # re-enable after debugging +loader.disable_prefetch() # drop CUDA streams (threaded pump still runs) +loader.enable_prefetch() # re-enable streams after debugging ``` `use_streams=False` keeps the threaded producer but drops the CUDA stream handoff (the consumer copies and transforms on the default -stream); `prefetch_factor=0` forces fully synchronous execution. +stream); for map-style datasets, `prefetch_factor=0` forces fully +synchronous execution. Iterable datasets use their separate +main-thread-only generator path regardless of `prefetch_factor`. ## RNG and reproducibility @@ -314,11 +367,17 @@ documented in **[RNG.md](RNG.md)**. Mesh augmentations (`RandomScaleMesh`, `RandomTranslateMesh`, `RandomRotateMesh`) accept any `torch.distributions.Distribution` to -parametrize their random sampling. To preserve reproducibility with -seeded `torch.Generator` objects (which `Distribution.sample()` does not -accept), the augmentations use **inverse CDF sampling**: draw -`U ~ Uniform(0,1)` via `torch.rand(generator=g)`, then compute -`X = distribution.icdf(U)`. This gives exact samples from the target -distribution while keeping all randomness under generator control. +parametrize distribution-backed random sampling. To preserve +reproducibility with seeded `torch.Generator` objects (which +`Distribution.sample()` does not accept), `RandomScaleMesh`, +`RandomTranslateMesh`, and `RandomRotateMesh(mode="axis_aligned")` use +**inverse CDF sampling**: draw `U ~ Uniform(0,1)` via +`torch.rand(generator=g)`, then compute `X = distribution.icdf(U)`. This +gives exact samples from the target distribution while keeping randomness +under generator control for distributions that implement `icdf()`. +Distributions without `icdf()` fall back to `Distribution.sample()` with +a warning and are not generator-reproducible. `RandomRotateMesh` defaults +to `mode="uniform"`, which ignores `axes` and `distribution` and samples +uniform SO(3) rotations via random quaternions. Full usage examples, YAML configuration, and the supported-distribution table are in **[transforms/mesh/DISTRIBUTIONS.md](transforms/mesh/DISTRIBUTIONS.md)**.