diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c670e3c04..a55f41c9ac 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -111,6 +111,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 datapipes implementation (`physicsnemo.datapipes.transforms._sdf_torch` / `_sdf_triton`, including its bespoke Triton winding kernel) is superseded and removed; the public datapipes SDF transform delegates here. +- Added an iterable style dataset to physicsnemo datapipes, for on-the-fly gpu simulations. ### Changed @@ -154,6 +155,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 use `torch.no_grad()`, not `torch.inference_mode()`). Also expands CI test coverage and adds an API documentation page for `physicsnemo.diffusion.multi_diffusion`. +- Performance improvements in IO prefetching and GPU preprocessing in physicsnemo datapipes. ### Deprecated diff --git a/docs/api/datapipes/physicsnemo.datapipes.rst b/docs/api/datapipes/physicsnemo.datapipes.rst index 82b438f25b..047360e65b 100644 --- a/docs/api/datapipes/physicsnemo.datapipes.rst +++ b/docs/api/datapipes/physicsnemo.datapipes.rst @@ -139,6 +139,29 @@ of ``pin_memory`` from the ``DataLoader`` class to the ``Reader`` classes. This is because of the much earlier GPU data transfer in the PhysicsNeMo datapipe compared to PyTorch. +The ``DataLoader`` drives one of two mutually-exclusive paths, selected by +dataset type: + +- **Map-style preload path** (:class:`~physicsnemo.datapipes.DatasetBase`, + e.g. ``Dataset``, ``MeshDataset``): a dedicated dispatcher thread keeps a + *bounded* number of samples in flight by pulling the index stream + **lazily** under backpressure and submitting host-only loads to a worker + pool. The main thread consumes the resulting samples in order + (host-to-device transfer plus transforms on a preprocessing stream) and + reassembles batches from boundary markers, so the full epoch is never + materialized up front and irregular batch sizes are supported. +- **Iterable generator path** (:class:`~physicsnemo.datapipes.IterableDatasetBase`): + a generator dataset driven entirely on the main thread (no sampler, no + worker pool); see `Iterable Datasets`_ below. + +In both paths, **all device-kernel launches happen on the single main +thread**. This is the real constraint for Warp-based transforms: Warp may +launch on any CUDA stream as long as the launch comes from the main thread +and Warp's current stream is bound to the torch stream in use. +Preprocessing therefore runs on a side (preprocessing) stream and is +ordered against the compute stream with a CUDA event, so it overlaps +training without ever blocking the host. + .. autoclass:: physicsnemo.datapipes.dataloader.DataLoader :members: :show-inheritance: @@ -179,6 +202,43 @@ consistent keys. Because the exact collation details differ by dataset, the :show-inheritance: +Iterable Datasets +^^^^^^^^^^^^^^^^^ + +Map-style datasets (``Dataset``, ``MeshDataset``) assume a fixed length and a +sampler that hands out indices. Some workloads have neither: an online +simulation, a procedural generator, or any source that produces samples on +the fly with no meaningful ``__len__``. For those, subclass +:class:`~physicsnemo.datapipes.IterableDatasetBase` and yield samples from +``__iter__``. The ``DataLoader`` detects an iterable dataset automatically +and switches to the main-thread-only generator path; ``shuffle`` and +``sampler`` are ignored (a warning is issued) and ``len(loader)`` raises, +since the length is unknown. + +An iterable dataset chooses one of two emission modes via the +``yields_batches`` attribute: + +- ``yields_batches = False`` (default): ``__iter__`` yields individual + samples and the ``DataLoader`` collates them into batches of + ``batch_size`` (honoring ``drop_last``). +- ``yields_batches = True``: ``__iter__`` yields fully-formed batches and the + ``DataLoader`` passes them through without further collation, which is the + natural fit for a generator that already produces a batch per step. + +Reproducibility follows a per-``(epoch, position)`` scheme rather than the +map-style per-``(epoch, index)`` scheme: implement ``set_epoch`` and/or +``set_generator`` to seed deterministically from the iteration position. +Because the generator runs on the main thread, Warp kernels inside it are +safe on any preprocessing stream the ``DataLoader`` binds. See the online +simulation tutorial in the +`examples directory `_ +for a runnable Warp ``Darcy2D`` generator wired through this path. + +.. autoclass:: physicsnemo.datapipes.IterableDatasetBase + :members: + :show-inheritance: + + Readers ^^^^^^^ diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/drivaer_ml_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/drivaer_ml_volume.yaml index d6a27d5335..244899cea9 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/drivaer_ml_volume.yaml +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/drivaer_ml_volume.yaml @@ -32,6 +32,16 @@ pipeline: pattern: "run_*/domain_*.pdmsh" subsample_n_points: ${sampling_resolution} subsample_n_cells: ${sampling_resolution} + # The volume model consumes the interior as a point cloud (interior.points + + # interior.point_data); dropping interior tet topology makes the point + # subsample a cheap contiguous block read instead of a full slice_points + # remap (the dominant per-sample IO cost otherwise). + drop_interior_cells: true + # The in-file `vehicle` surface boundary is unused by the volume pipeline + # (SDF comes from the injected stl_geometry below; model/collate use only + # the interior). Drop it so we don't subsample (expensive, GIL-held) or pin + # it every sample. + drop_in_file_boundaries: true extra_boundaries: stl_geometry: pattern: "*_single_solid.stl.pmsh" diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/highlift_volume.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/highlift_volume.yaml index b0c6f8f602..c4026ec01f 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/highlift_volume.yaml +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/datasets/highlift_volume.yaml @@ -34,6 +34,12 @@ pipeline: pattern: "geo_LHC*/*.pdmsh" subsample_n_points: ${sampling_resolution} subsample_n_cells: ${sampling_resolution} + # Interior consumed as a point cloud; drop tet topology so the point + # subsample is a cheap contiguous block read (see drivaer_ml_volume.yaml). + drop_interior_cells: true + # In-file boundaries unused by the volume pipeline (SDF uses stl_geometry); + # drop them to skip per-sample subsample + pin. + drop_in_file_boundaries: true extra_boundaries: stl_geometry: pattern: "*.stl.pmsh" diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py index b558d9c357..bc5623fc00 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/loss.py @@ -114,7 +114,7 @@ def _vector_loss( target_sq = torch.mean(target**2, dim=tuple(range(pred.ndim - 1))) return torch.sum(diff_sq / (target_sq + eps)) - total = torch.tensor(0.0, device=pred.device, dtype=pred.dtype) + total = torch.zeros((), device=pred.device, dtype=pred.dtype) for i in range(n_components): p, t = pred[..., i], target[..., i] if loss_type == "huber": diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py index 15624e6a89..7f5f8bd93c 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/nondim.py @@ -238,19 +238,16 @@ def _transform_mesh( T_inf=T_inf, ) - ### `Mesh.copy` is a tensorclass-provided shallow copy: `points`, - ### `cells`, the untouched associations, and the geometric `_cache` - ### are all shared with `mesh`; only the cloned association is swapped. + # Shallow copy shares everything except the swapped association. new_mesh = mesh.copy() # ty: ignore[unresolved-attribute] setattr(new_mesh, self._association, new_td) - ### Scale geometry into nondim space (`x* = x / L_ref`) on the - ### forward pass, and back to physical units (`x = x* * L_ref`) - ### on the inverse. `Mesh.scale` propagates `_cache` through the - ### linear transform. + # Scale geometry to/from nondim space (x* = x / L_ref). + # assume_invertible=True avoids a per-mesh sync from the det check. if L_ref is not None: + torch._assert_async(L_ref != 0) factor = L_ref if inverse else 1.0 / L_ref - new_mesh = new_mesh.scale(factor) + new_mesh = new_mesh.scale(factor, assume_invertible=True) return new_mesh diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py index 7a48ea8957..4c84187b7c 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/train.py @@ -365,6 +365,15 @@ def _run_epoch( n_local = 0 num_steps = len(dataloader) epoch_t0 = time.perf_counter() + ### Single pinned scalar buffer reused every step so the loss D2H + ### transfer is async (non_blocking=True from device to pinned host + ### memory). The copy is issued right after forward_pass and read + ### just before the logger line; by then backward + optimizer.step + ### have run, giving the GPU time to complete the copy without + ### blocking the host. + _loss_pinned = ( + torch.zeros(1, pin_memory=True) if torch.cuda.is_available() else None + ) with grad_ctx: step_t0 = time.perf_counter() @@ -381,6 +390,13 @@ def _run_epoch( target_config=target_config, ) + ### Kick off the async D2H copy of the scalar loss value into the + ### pinned buffer. Backward + optimizer.step run while the copy is + ### in flight, so by the time we call .item() below the transfer + ### is already done and there is no host stall. + if _loss_pinned is not None: + _loss_pinned.copy_(loss.detach(), non_blocking=True) + if is_train: optimizer.zero_grad() if precision == "float16" and scaler is not None: @@ -407,9 +423,13 @@ def _run_epoch( total_metrics_td.add_(metrics) n_local += 1 - ### Per-step sync for the print line; lands after backward + - ### optimizer.step so it overlaps with queued GPU work. - this_loss = loss.detach().item() + ### Read the loss scalar from the pinned buffer; the async copy + ### was issued before backward so it has had the full backward + + ### optimizer.step to complete without stalling the host. + if _loss_pinned is not None: + this_loss = _loss_pinned.item() + else: + this_loss = loss.detach().item() total_loss += this_loss step_dt = time.perf_counter() - step_t0 diff --git a/examples/minimal/datapipes/README.md b/examples/minimal/datapipes/README.md index 6934467232..949d615e44 100644 --- a/examples/minimal/datapipes/README.md +++ b/examples/minimal/datapipes/README.md @@ -215,3 +215,41 @@ python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud \ subsample.n_points=5000 ``` + +### Tutorial 5: Iterable Datasets for Online Simulation + +**File:** `tutorial_5_iterable_online_simulation.py` + +Not every dataset is map-style. When data is *generated* on the fly -- +an online physics simulation, a procedural sampler, an unbounded stream -- +there is no fixed length and no index to address. PhysicsNeMo models these +with `IterableDatasetBase`: + +- Iterable datasets only support iteration: no `__len__`, no `__getitem__`, + no sampler, and no prefetch worker pool. +- They run entirely on the **main thread**, so they may freely launch Warp + kernels and use CUDA streams. Warp's only requirement is a single + launching thread, which the main thread satisfies -- this is what makes + an online GPU simulation safe on this path. +- The `DataLoader` still drives generation on a preprocessing stream (when + `use_streams=True`) and hands each item to the compute stream via a CUDA + event, so generating the next batch can overlap training on the current + one. + +This tutorial wraps the built-in Warp `Darcy2D` flow generator (a +multigrid Jacobi solver) as an iterable dataset and iterates it through the +`DataLoader`. Because `Darcy2D` produces a full batch per step, the wrapper +is *self-batching* (`yields_batches = True`) and the loader passes each +batch through unchanged. + +**When to use the iterable path vs the map/descriptor path:** use the +map-style `Dataset` when you have a fixed corpus on disk addressable by +index (storage-backed, benefits from threaded prefetch). Use an +`IterableDatasetBase` when samples are produced by a generator/simulation, +the length is unbounded or unknown, or the producer itself must launch +device kernels. + +```bash +# Requires a CUDA device (the Darcy solver runs Warp kernels on the GPU) +python tutorial_5_iterable_online_simulation.py +``` diff --git a/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py b/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py new file mode 100644 index 0000000000..cbd829d85a --- /dev/null +++ b/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tutorial 5: Iterable datasets for online simulation. + +Most datasets are *map-style*: a fixed number of samples addressed by +index, read from storage. Some workloads instead *generate* data on the +fly -- an online physics simulation, a procedural sampler, a streaming +source with no fixed length. These are *iterable* datasets. + +PhysicsNeMo models this with :class:`IterableDatasetBase`. Unlike a +map-style dataset, an iterable dataset: + +- has no length and no indexing -- it only supports iteration; +- is driven entirely on the **main thread** (no worker pool), so it may + freely launch Warp kernels and use CUDA streams. This is exactly the + property that makes an online GPU simulation safe here: Warp's + constraint is a single launching thread, which the main thread + satisfies. + +This tutorial wraps the built-in Warp ``Darcy2D`` flow generator -- which +solves the 2D Darcy equation with a multigrid Jacobi solver and yields a +ready-made batch each step -- as an iterable dataset and drives it through +the PhysicsNeMo :class:`DataLoader`. + +Run with:: + + python tutorial_5_iterable_online_simulation.py + +Requires a CUDA device (the Darcy solver runs Warp kernels on the GPU). +""" + +from __future__ import annotations + +import time + +import numpy as np +import torch + +from physicsnemo.datapipes import DataLoader, IterableDatasetBase +from physicsnemo.datapipes.benchmarks.darcy import Darcy2D + + +class DarcyOnlineDataset(IterableDatasetBase): + """Online 2D Darcy-flow simulation as an iterable dataset. + + Wraps :class:`~physicsnemo.datapipes.benchmarks.darcy.Darcy2D`, whose + iterator runs the solver and yields a full ``{"permeability", "darcy"}`` + batch per step. The underlying generator is infinite, so this wrapper + caps it at ``num_batches`` per epoch to give the loader a finite stream. + + Because ``Darcy2D`` already produces a complete batch, this is a + *self-batching* dataset: we set :attr:`yields_batches` so the loader + passes each batch through unchanged instead of re-collating. + + Parameters + ---------- + num_batches : int + Number of batches to emit per epoch. + resolution : int, default=64 + Simulation grid resolution. + batch_size : int, default=8 + Number of simulations per batch. + device : str, default="cuda" + Device the Warp solver runs on. + base_seed : int, default=0 + Base seed for reproducible permeability sampling. + """ + + # Darcy2D emits a full batch per step; do not re-collate. + yields_batches = True + + def __init__( + self, + num_batches: int, + *, + resolution: int = 64, + batch_size: int = 8, + device: str = "cuda", + base_seed: int = 0, + ) -> None: + self._sim = Darcy2D( + resolution=resolution, + batch_size=batch_size, + device=device, + normaliser={"permeability": (1.25, 0.75), "darcy": (4.52e-2, 2.79e-2)}, + ) + self._num_batches = num_batches + self._base_seed = base_seed + self._epoch = 0 + + def set_epoch(self, epoch: int) -> None: + """Select the epoch so each epoch draws a distinct, reproducible stream.""" + self._epoch = epoch + + def __iter__(self): + # One solver iterator drives the simulation; we pull a bounded + # number of steps from the otherwise-infinite generator. + sim_iter = iter(self._sim) + for position in range(self._num_batches): + # Per-(epoch, position) seeding: the stream is reproducible + # across runs and distinct across epochs and positions. There is + # no stable sample index for a generator, so we key on the + # monotonic emission position instead. + seed = np.random.SeedSequence( + [self._base_seed, self._epoch, position] + ).generate_state(1)[0] + np.random.seed(int(seed)) + yield next(sim_iter) + + +def main() -> None: + """Run the online Darcy simulation over several epochs and report timings. + + Builds an iterable :class:`DarcyOnlineDataset`, wraps it in a stream-overlapped + ``DataLoader``, and iterates for ``num_epochs``. Each epoch is reseeded via + ``set_epoch`` for a distinct, reproducible batch stream. For every batch we + record host wall-clock time and per-step CUDA event timings, then print a + per-epoch summary. Requires a CUDA device for the Warp Darcy solver; the + function returns early with a message if none is available. + """ + if not torch.cuda.is_available(): + print("This tutorial requires a CUDA device (Warp Darcy solver). Skipping.") + return + + num_epochs = 5 + num_batches = 16 + dataset = DarcyOnlineDataset(num_batches=num_batches, resolution=64, batch_size=8) + + # use_streams=True runs each simulation step on a preprocessing stream + # and hands the result to the compute stream via a CUDA event, so + # generation of the next batch can overlap training on the current one. + loader = DataLoader(dataset, use_streams=True, seed=0) + + # Iterable datasets have no length: this will take the exception path.oOOh, + try: + len(loader) + except TypeError as exc: + print(f"len(loader) is undefined for iterable datasets: {exc}") + + for epoch in range(num_epochs): + loader.set_epoch(epoch) + print(f"\nEpoch {epoch}") + host_times = [] + cuda_events = [] + epoch_start = time.perf_counter() + prev_host = epoch_start + cuda_start = torch.cuda.Event(enable_timing=True) + cuda_start.record(torch.cuda.current_stream()) + for i, batch in enumerate(loader): + host_now = time.perf_counter() + permeability = batch["permeability"] + darcy = batch["darcy"] + cuda_end = torch.cuda.Event(enable_timing=True) + cuda_end.record(torch.cuda.current_stream()) + cuda_events.append((cuda_start, cuda_end)) + cuda_start = cuda_end + + host_times.append(host_now - prev_host) + prev_host = host_now + print( + f" batch {i}: permeability {tuple(permeability.shape)} " + f"on {permeability.device}, darcy {tuple(darcy.shape)}, " + f"host_dt={host_times[-1]:.4f}s" + ) + + torch.cuda.synchronize() + cuda_times_ms = [start.elapsed_time(end) for start, end in cuda_events] + epoch_wall = time.perf_counter() - epoch_start + mean_host = sum(host_times) / len(host_times) + mean_cuda = sum(cuda_times_ms) / len(cuda_times_ms) + print( + f" epoch summary: batches={len(host_times)}, wall={epoch_wall:.3f}s, " + f"host_mean={mean_host:.4f}s, cuda_mean={mean_cuda:.2f}ms, " + f"cuda_min={min(cuda_times_ms):.2f}ms, cuda_max={max(cuda_times_ms):.2f}ms" + ) + + # Train as usual; the batches are ordinary device tensors. + + +if __name__ == "__main__": + main() diff --git a/physicsnemo/datapipes/RNG.md b/physicsnemo/datapipes/RNG.md index f390d544de..da23c2dbea 100644 --- a/physicsnemo/datapipes/RNG.md +++ b/physicsnemo/datapipes/RNG.md @@ -32,6 +32,21 @@ master seed using `fork_generator(parent, n)`. Each child is seeded with and stable across runs. Children are created on the **same device** as the parent. +For RNG that must be reproducible regardless of *execution order* (e.g. +reader subsampling, which runs on a pool of worker threads), `_rng.py` +also provides coordinate-based seeding: + +- **`derive_seed(base_seed, *coords)`** — mixes a base seed with integer + coordinates (typically `epoch` and sample `index`) into a single + well-mixed 64-bit seed via `numpy.random.SeedSequence`. The result + depends only on the inputs, not on call order or thread. +- **`spawn_generator(base_seed, *coords, device=...)`** — returns a fresh + `torch.Generator` seeded with `derive_seed(base_seed, *coords)`. + +Because each call returns an independent generator seeded purely from its +coordinates, draws are reproducible irrespective of order and safe to +compute concurrently from multiple threads (no shared mutable state). + ### DataLoader When `seed` is set the DataLoader: @@ -70,9 +85,12 @@ sub-dataset. ### Epoch reseeding `DataLoader.set_epoch(epoch)` propagates to the sampler and dataset. -Each component with a generator reseeds it with +The sampler and stochastic transforms reseed their generators with `initial_seed() + epoch`, producing a different but deterministic -random sequence every epoch. +random sequence every epoch. Readers instead store the epoch and fold +it into each sample's derived seed (see [Readers](#readers)), so their +per-sample RNG also varies deterministically per epoch without relying +on a shared, sequentially-drawn generator. ## Generator tree @@ -152,18 +170,63 @@ standalone use. ## Readers -The `Reader` base class defines no-op `set_generator` / `set_epoch`. -Readers that use randomness override them: - -| Reader | Randomness | Generator support | +Reader subsampling runs on the dataset's worker-thread pool (the threaded +`prefetch` producer path; `Dataset` defaults to `num_workers=2`), so the +*order* in which samples are drawn is non-deterministic. A single shared, +sequentially-drawn generator would therefore not be reproducible with +`num_workers > 1`. To avoid this, readers derive RNG **per +`(base_seed, epoch, index)`** instead of from one shared stream: + +- **`set_generator(g)`** stores `g.initial_seed()` as the reader's base + seed (it does *not* keep the generator itself). +- **`set_epoch(e)`** stores the epoch. +- Each `reader[index]` then calls + `spawn_generator(base_seed, epoch, index)` to obtain a fresh generator + for that sample's draws (the `Reader` base class exposes this as + `_index_generator(index)`). + +The draw for a given sample depends only on `(base_seed, epoch, index)`, +so it is **identical regardless of read order or worker thread** — +reproducible for any `num_workers` — while still differing across indices +and across epochs. When no seed has been set, the per-sample generator is +`None` and draws fall back to the global default RNG. + +Transforms remain reproducible because they run on the main thread in +sampler order (via the consume stage), so their sequentially-drawn +generators are unaffected by the threaded producer. + +| Reader | Randomness | Per-`(seed, epoch, index)` RNG | |---|---|---| | `MeshReader` | `torch.randint` (contiguous block selection) | Yes | | `DomainMeshReader` | `torch.randint` | Yes | | `NumpyReader` | `torch.randint` (coordinated subsampling) | Yes | | `ZarrReader` | `torch.randint` | Yes | | `TensorStoreZarrReader` | `torch.randint` | Yes | -| `HDF5Reader` | None | No-op (inherited) | -| `VTKReader` | None | No-op (inherited) | +| `HDF5Reader` | None | n/a (inherited base) | +| `VTKReader` | None | n/a (inherited base) | + +## Iterable & descriptor paths: per-`(epoch, position)` seeding + +Map-style datasets have a stable sample `index`, so readers key their +per-sample RNG on `(base_seed, epoch, index)` (see [Readers](#readers)). +Generator-style (`IterableDatasetBase`) and future descriptor-keyed +sources have **no stable index**: samples are produced in sequence with no +addressable position in a corpus. They therefore key on the **monotonic +emission position** within the epoch instead: + +- **map-style:** `derive_seed(base_seed, epoch, index)` — reproducible for + any read order / `num_workers`, since the index is intrinsic to the + sample. +- **iterable / descriptor:** `derive_seed(base_seed, epoch, position)` — + where `position` is a 0-based counter of emissions in the current epoch. + Reproducible across runs and distinct across epochs and positions. + +Both schemes use the same `derive_seed`/`spawn_generator` primitives; only +the coordinate that stands in for "which sample" differs. The iterable +path runs entirely on the main thread in emission order, so the position +counter is unambiguous (there is no worker-thread reordering to defend +against). See `tutorial_5_iterable_online_simulation.py` for a worked +example seeding an online Darcy-flow simulation per `(epoch, position)`. ## Current limitations diff --git a/physicsnemo/datapipes/__init__.py b/physicsnemo/datapipes/__init__.py index d3902c5459..50e4e2aa4a 100644 --- a/physicsnemo/datapipes/__init__.py +++ b/physicsnemo/datapipes/__init__.py @@ -42,7 +42,7 @@ from physicsnemo.datapipes.dataset import Dataset from physicsnemo.datapipes.mesh_dataset import MeshDataset from physicsnemo.datapipes.multi_dataset import MultiDataset -from physicsnemo.datapipes.protocols import DatasetBase +from physicsnemo.datapipes.protocols import DatasetBase, IterableDatasetBase from physicsnemo.datapipes.readers import ( DomainMeshReader, HDF5Reader, @@ -105,6 +105,7 @@ # "TensorDict", # Re-export from tensordict "DatasetBase", + "IterableDatasetBase", "Dataset", "MeshDataset", "DataLoader", diff --git a/physicsnemo/datapipes/_rng.py b/physicsnemo/datapipes/_rng.py index d76a374df0..5e20c2715a 100644 --- a/physicsnemo/datapipes/_rng.py +++ b/physicsnemo/datapipes/_rng.py @@ -23,9 +23,68 @@ from __future__ import annotations +import numpy as np import torch +def derive_seed(base_seed: int, *coords: int) -> int: + """Deterministically mix a base seed with integer coordinates. + + Combines ``base_seed`` with arbitrary integer ``coords`` (typically + ``epoch`` and sample ``index``) into a single well-mixed 64-bit seed + using :class:`numpy.random.SeedSequence`. The result depends only on + the inputs, not on call order or thread, so per-sample RNG derived + from it is reproducible and safe to compute concurrently. + + Parameters + ---------- + base_seed : int + Base seed (e.g. a generator's ``initial_seed()``). + *coords : int + Additional non-negative integer coordinates to fold in, such as + ``(epoch, index)``. + + Returns + ------- + int + A deterministic 64-bit seed. + """ + seq = np.random.SeedSequence([int(base_seed), *(int(c) for c in coords)]) + return int(seq.generate_state(1, dtype=np.uint64)[0]) + + +def spawn_generator( + base_seed: int, + *coords: int, + device: torch.device | str = "cpu", +) -> torch.Generator: + """Create a fresh :class:`torch.Generator` seeded from mixed coordinates. + + Returns an independent generator whose seed is + :func:`derive_seed(base_seed, *coords) `. Because each + call returns a new generator seeded purely from its inputs, draws are + reproducible regardless of execution order and can be made + concurrently from multiple threads without sharing mutable state. + + Parameters + ---------- + base_seed : int + Base seed (e.g. a generator's ``initial_seed()``). + *coords : int + Additional integer coordinates to fold in, such as ``(epoch, index)``. + device : torch.device or str, default="cpu" + Device the generator is created on. + + Returns + ------- + torch.Generator + A new generator seeded deterministically from the inputs. + """ + generator = torch.Generator(device=device) + generator.manual_seed(derive_seed(base_seed, *coords)) + return generator + + def fork_generator( parent: torch.Generator, n: int, diff --git a/physicsnemo/datapipes/dataloader.py b/physicsnemo/datapipes/dataloader.py index 4e6b7bc61f..7ab6931b49 100644 --- a/physicsnemo/datapipes/dataloader.py +++ b/physicsnemo/datapipes/dataloader.py @@ -25,6 +25,8 @@ from __future__ import annotations +import itertools +import warnings from typing import Any, Callable, Iterator, Optional, Sequence import torch @@ -33,7 +35,12 @@ from physicsnemo.datapipes._rng import fork_generator from physicsnemo.datapipes.collate import Collator, get_collator -from physicsnemo.datapipes.protocols import DatasetBase +from physicsnemo.datapipes.io_pump import BATCH_BOUNDARY, IOPump +from physicsnemo.datapipes.protocols import ( + DatasetBase, + IterableDatasetBase, + preprocessing_stream, +) from physicsnemo.datapipes.registry import register @@ -57,6 +64,35 @@ class DataLoader: - Compatible with PyTorch samplers (DistributedSampler, etc.) - Familiar torch DataLoader interface + Two data paths + -------------- + The path is selected by dataset type: + + - **Map-style** (:class:`~physicsnemo.datapipes.protocols.DatasetBase`): + a dispatcher thread (:class:`~physicsnemo.datapipes.io_pump.IOPump`) + lazily submits sample loads to a worker pool and forwards batch + boundaries, while the main thread consumes handles (host-to-device + transfer + transforms on a preprocessing stream). + - **Iterable** (:class:`~physicsnemo.datapipes.protocols.IterableDatasetBase`): + a generator dataset driven main-thread-only (no sampler, no pump, no + worker pool). ``len()`` is undefined and ``shuffle``/``sampler`` are + ignored; generation runs on a preprocessing stream with the same + event handoff so it overlaps training. + + Concurrency model + ----------------- + A dedicated dispatcher thread keeps the I/O pipeline primed by + submitting sample loads ahead of consumption, bounded by + ``prefetch_factor`` batches worth of in-flight samples. The main + thread is the sole consumer: it performs all host-to-device transfers + and GPU transforms on the prefetch streams, so transforms run on the + assigned preprocessing stream and overlap the compute stream. + + For the pipeline to stay primed, the main thread must not block: keep + reader output (optionally) pinned (so host-to-device copies are asynchronous) and + avoid host readbacks (``.item()``), data-dependent shapes, and + GIL-bound pure-Python transforms on the launch path. + Examples -------- >>> from physicsnemo.datapipes import DataLoader, Dataset, HDF5Reader, Normalize @@ -80,7 +116,7 @@ class DataLoader: def __init__( self, - dataset: DatasetBase, + dataset: DatasetBase | IterableDatasetBase, *, batch_size: int = 1, shuffle: bool = False, @@ -104,9 +140,10 @@ def __init__( Parameters ---------- - dataset : DatasetBase - Dataset to load from. Any subclass of :class:`DatasetBase` - (e.g. :class:`Dataset`, :class:`MeshDataset`). + dataset : DatasetBase or IterableDatasetBase + Dataset to load from. A map-style :class:`DatasetBase` + (e.g. :class:`Dataset`, :class:`MeshDataset`) or an + :class:`IterableDatasetBase` generator dataset. batch_size : int, default=1 Number of samples per batch. shuffle : bool, default=False @@ -151,6 +188,9 @@ def __init__( self.num_streams = num_streams self.use_streams = use_streams and torch.cuda.is_available() self._seed = seed + # Iterable (generator) datasets are driven main-thread-only: no + # sampler, no worker-pool prefetch (see _iter_iterable). + self._iterable = isinstance(dataset, IterableDatasetBase) # Build master generator and fork for sampler + dataset sampler_generator: torch.Generator | None = None @@ -163,14 +203,18 @@ def __init__( if hasattr(dataset, "set_generator"): dataset.set_generator(forks[1]) - # Handle sampler - if sampler is not None: + # Handle sampler. Iterable datasets have no indices, so they carry + # no sampler and ignore shuffle. + if self._iterable: + if sampler is not None or shuffle: + warnings.warn( + "shuffle/sampler are ignored for iterable datasets; " + "the generator controls sample order.", + stacklevel=2, + ) + self.sampler = None + elif sampler is not None: self.sampler = sampler - # For DistributedSampler, propagate seed if available - if seed is not None and hasattr(sampler, "seed"): - # DistributedSampler exposes seed as a constructor arg - # but it's read-only; users should pass seed at construction. - pass elif shuffle: self.sampler = RandomSampler(dataset, generator=sampler_generator) else: @@ -179,7 +223,8 @@ def __init__( # Handle collation self.collate_fn = get_collator(collate_fn, collate_metadata=collate_metadata) - # Create CUDA streams for prefetching + # Create CUDA streams: prefetch uses several round-robin streams; the + # iterable path uses the first as its preprocessing stream. self._streams: list[torch.cuda.Stream] = [] if self.use_streams: for _ in range(num_streams): @@ -193,7 +238,15 @@ def __len__(self) -> int: ------- int Number of batches in the dataloader. + + Raises + ------ + TypeError + If the dataset is iterable (generator-style), which has no + defined length. """ + if self._iterable: + raise TypeError("len() is undefined for an iterable (generator) dataset") n_samples = ( len(self.sampler) if hasattr(self.sampler, "__len__") else len(self.dataset) ) @@ -226,8 +279,15 @@ def __iter__( """ Iterate over batches. - Uses stream-based prefetching when enabled to overlap IO, - GPU transfers, and computation. + Uses the self-priming :class:`IOPump` to overlap host-side I/O + (on the dataset's worker threads) with main-thread consumption + whenever ``prefetch_factor > 0``. This threaded producer path is + independent of CUDA streams: when streams are enabled (and CUDA + is available) each prefetched sample is also assigned a stream so + the host-to-device copy and GPU transforms overlap; otherwise the + same path runs with ``stream=None`` (still overlapping disk I/O + with the main thread). Set ``prefetch_factor=0`` for fully + synchronous iteration. Yields ------ @@ -236,7 +296,9 @@ def __iter__( or tuple of (batched TensorDict, list of metadata dicts) if collate_metadata=True. """ - if self.prefetch_factor > 0 and self.use_streams: + if self._iterable: + yield from self._iter_iterable() + elif self.prefetch_factor > 0: yield from self._iter_prefetch() else: yield from self._iter_simple() @@ -256,59 +318,270 @@ def _iter_simple( samples = [self.dataset[idx] for idx in batch_indices] yield self.collate_fn(samples) + def _work_stream(self) -> Iterator[Any]: + """Lazily yield sampler indices delimited by :data:`BATCH_BOUNDARY`. + + Buffers at most one batch of indices (never the whole epoch), so + arbitrarily long samplers stream without up-front materialization. + A boundary is emitted after each full batch; a trailing partial + batch is emitted only when ``drop_last`` is False. + + Yields + ------ + int or object + Sample indices, with :data:`BATCH_BOUNDARY` after each batch. + """ + batch: list[int] = [] + for index in self.sampler: + batch.append(index) + if len(batch) == self.batch_size: + yield from batch + yield BATCH_BOUNDARY + batch = [] + if batch and not self.drop_last: + yield from batch + yield BATCH_BOUNDARY + def _iter_prefetch( self, ) -> Iterator[TensorDict | tuple[TensorDict, list[dict[str, Any]]]]: """ - Iteration with stream-based prefetching. - - Strategy: - - 1. Prefetch `prefetch_factor` batches worth of samples - 2. As we yield batches, prefetch more to keep the pipeline full - 3. Each sample in a batch uses a different stream for overlap + Iteration driven by a self-priming prefetch pump with one-batch lookahead. + + A dedicated dispatcher thread (the :class:`IOPump`) lazily pulls + the index stream and submits sample loads to the dataset's worker + pool, keeping a bounded number of samples in flight regardless of + the consumer's cadence. The main thread is a pure drain loop: it + pulls ready handles in order, runs the per-sample consume step + (host-to-device transfer plus GPU transforms on the assigned + stream), and reassembles batches from the boundary markers the + pump forwards. + + Stream assignment is optional and decoupled from the threaded + producer: when CUDA streams are enabled a stream is round-robined + per sample (so preprocessing overlaps the previous batch's compute); + otherwise dispatch passes ``stream=None`` and the path still + overlaps host-side I/O with main-thread consumption. + + Because dispatch lives off the main thread, the pipeline stays + primed even while the main thread is blocked launching kernels or + running the model. All device-kernel launches happen here, on the + single main thread. + + One-batch lookahead for preprocessing stream overlap + ---------------------------------------------------- + Before yielding batch N, this method eagerly drains the pump for + batch N+1's items and calls ``consume(..., defer_sync=True)`` on + each. ``consume`` enqueues the host-to-device transfer and GPU + transforms on a preprocessing stream asynchronously and *records* a + CUDA event, but -- with ``defer_sync=True`` -- does **not** make the + compute stream wait on it yet. The wait is inserted here, by + :func:`gate_compute_stream`, immediately before the owning batch is + yielded. + + This ordering is the whole point of the lookahead. If ``_consume`` + made the compute stream wait on batch N+1's event during the + lookahead drain (i.e. before yielding batch N), that wait would be + enqueued on the compute stream *ahead* of batch N's model kernels, + so batch N's model would block on batch N+1's preprocessing -- the + opposite of overlap. By deferring the wait, the compute-stream order + becomes ``..., model_{N-1}, wait(prep_N), model_N, ...``: batch N's + preprocessing (already in flight on its own stream) overlaps batch + N-1's compute, and each model only blocks on its own batch's + preprocessing. + + With ``prefetch_factor >= 2`` the IOPump has already dispatched + batch N+1's disk-I/O ahead of time, so the ``future.result()`` + inside ``consume()`` returns immediately and the lookahead drain + adds negligible host-side latency. Yields ------ TensorDict or tuple[TensorDict, list[dict[str, Any]]] Collated batch. """ - # Collect all batches upfront for prefetch planning - all_batches = list(self._generate_batches()) - if not all_batches: - return - - num_prefetch_batches = min(self.prefetch_factor, len(all_batches)) - stream_idx = 0 - - # Start initial prefetch - prefetched_up_to = 0 - for batch_idx in range(num_prefetch_batches): - for sample_idx in all_batches[batch_idx]: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 - prefetched_up_to = batch_idx + 1 - - # Yield batches and prefetch more - for batch_idx, batch_indices in enumerate(all_batches): - # Collect samples (uses prefetched if available) - samples = [self.dataset[idx] for idx in batch_indices] - batch = self.collate_fn(samples) - - # Prefetch next batch if available - next_prefetch_idx = prefetched_up_to - if next_prefetch_idx < len(all_batches): - for sample_idx in all_batches[next_prefetch_idx]: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 - prefetched_up_to += 1 - - yield batch + # Streams are an optional accelerator on top of the threaded + # producer; only assign them when actually available. + use_streams = self.use_streams and len(self._streams) > 0 + + # The compute stream the training loop runs on. Captured once on the + # main thread; the deferred per-batch preprocessing waits are + # enqueued onto it right before each batch is yielded. + compute_stream = torch.cuda.current_stream() if use_streams else None + + # Round-robin a stream per sample at dispatch time (when enabled); + # submit returns a handle the consumer resolves in order. + stream_counter = itertools.count() + + def dispatch(index: int) -> Any: + stream = ( + self._streams[next(stream_counter) % self.num_streams] + if use_streams + else None + ) + return self.dataset.submit(index, stream=stream) + + # Depth = prefetch_factor batches worth of samples kept in flight + # (at least the stream count when streams drive the overlap). + depth = max(self.prefetch_factor * self.batch_size, 1) + if use_streams: + depth = max(depth, self.num_streams) + + pump = IOPump(self._work_stream(), dispatch, depth=depth) + pump_iter = iter(pump) + + def drain_one_batch() -> tuple[list[Any], bool]: + """Consume pump items up to the next BATCH_BOUNDARY. + + Each non-boundary item is passed through + :meth:`~physicsnemo.datapipes.protocols.DatasetBase.consume` + with ``defer_sync=True``, which enqueues the host-to-device copy + and GPU transforms on a preprocessing stream (asynchronously) + and records a CUDA event *without* making the compute stream wait + on it. The wait is inserted later by :func:`gate_compute_stream`. + + Returns + ------- + tuple[list, bool] + ``(samples, has_batch)`` where *has_batch* is ``True`` when + a ``BATCH_BOUNDARY`` was found and ``False`` when the pump + was exhausted without one (only possible for an empty or + fully-consumed source). + """ + samples: list[Any] = [] + for item in pump_iter: + if item is BATCH_BOUNDARY: + return samples, True + samples.append(self.dataset.consume(item, defer_sync=True)) + return samples, False + + def gate_compute_stream(events: list) -> None: + """Make the compute stream wait on a batch's preprocessing events. + + Called immediately before a batch is yielded, so the wait is + ordered *after* the previous batch's model kernels (already + enqueued by the prior iteration's ``yield``). The batch's + preprocessing -- launched on its side stream during the lookahead + drain -- thus overlaps the previous batch's compute, and the + model only blocks on its own batch's preprocessing. + + A no-op when streams are disabled (``events`` is empty). + """ + if compute_stream is not None: + for event in events: + compute_stream.wait_event(event) + + try: + # Prime: consume the first batch's items, enqueueing their + # preprocessing stream work (event recorded, compute-stream wait + # deferred to gate_compute_stream below). + current_samples, has_first = drain_one_batch() + # Per-batch preprocessing events whose compute-stream wait was + # deferred; gate_compute_stream consumes these right before yield. + current_events = self.dataset._pop_events() + if not has_first: + # Source was empty or had no boundary; yield whatever arrived. + if current_samples: + gate_compute_stream(current_events) + yield self.collate_fn(current_samples) + return + + while True: + # Eagerly drain the NEXT batch before yielding the current + # one. This enqueues batch N+1's H2D transfer and GPU + # transforms on the preprocessing streams so they run + # concurrently with the training loop's compute-stream work + # for batch N (forward / backward / optimizer). + next_samples, has_next = drain_one_batch() + next_events = self.dataset._pop_events() + + # Gate the compute stream on the CURRENT batch's preprocessing + # *now* -- after the previous iteration's yield enqueued the + # previous batch's model work, and after the NEXT batch's + # preprocessing was launched above on its own stream. This is + # what lets the next batch's preprocessing overlap this + # batch's compute instead of blocking it. + gate_compute_stream(current_events) + + # Yield the current batch. While the training loop runs, + # the preprocessing streams are already working on the next + # batch. + yield self.collate_fn(current_samples) + + if not has_next: + # Pump exhausted after the lookahead drain; yield the + # final partial batch (empty when drop_last trimmed it) + # then stop. + if next_samples: + gate_compute_stream(next_events) + yield self.collate_fn(next_samples) + break + + current_samples = next_samples + current_events = next_events + finally: + # Stop the dispatcher (handles early break / exhaustion) and + # drop any prefetched-but-unconsumed handles. + pump.stop() + self.dataset.cancel_prefetch() + + def _iter_iterable( + self, + ) -> Iterator[Any]: + """ + Main-thread-only iteration for generator (iterable) datasets. + + There is no worker pool: the dataset's generator runs on the main + thread, so it may freely launch device kernels / use streams. Each + item is generated on a preprocessing stream (when streams are + enabled) and handed to the compute stream via a CUDA event, so + generation of the next item can overlap training on the current + one. A generator that forces a host readback simply serializes + itself. + + Two emission modes are supported (see :class:`IterableDatasetBase`): + per-sample items are collated into ``batch_size`` batches (with + ``drop_last`` trimming the trailing partial batch); a self-batching + generator (``yields_batches = True``) has each batch passed through + unchanged. - # Clean up any remaining prefetch state - self.dataset.cancel_prefetch() + Yields + ------ + Any + Collated batches, or generator-produced batches when the + dataset is self-batching. + """ + use_stream = self.use_streams and len(self._streams) > 0 + prep_stream = self._streams[0] if use_stream else None + compute_stream = torch.cuda.current_stream() if use_stream else None + self_batching = getattr(self.dataset, "yields_batches", False) + + iterator = iter(self.dataset) + samples: list[Any] = [] + while True: + # Generate the next item on the preprocessing stream, then order + # it before the compute stream without blocking the host. + with preprocessing_stream(prep_stream): + try: + item = next(iterator) + except StopIteration: + break + if use_stream: + event = torch.cuda.Event() + event.record(prep_stream) + compute_stream.wait_event(event) + + if self_batching: + yield item + continue + + samples.append(item) + if len(samples) == self.batch_size: + yield self.collate_fn(samples) + samples = [] + + if not self_batching and samples and not self.drop_last: + yield self.collate_fn(samples) def set_epoch(self, epoch: int) -> None: """ diff --git a/physicsnemo/datapipes/datapipes.md b/physicsnemo/datapipes/datapipes.md index b41ca1d846..16533de858 100644 --- a/physicsnemo/datapipes/datapipes.md +++ b/physicsnemo/datapipes/datapipes.md @@ -44,20 +44,28 @@ Three dataset types share this pattern: | `MeshDataset` | `Mesh` / `DomainMesh` tensorclasses | `MeshTransform` | | `MultiDataset` | Union of child `DatasetBase` instances | Delegates to children | -All three inherit from `DatasetBase`, which provides thread-pool -prefetching and a `Future`-based cache (see -[Performance](#performance-threading-and-stream-based-concurrency) below). +`Dataset` and `MeshDataset` inherit from `DatasetBase`, which provides +thread-pool prefetching via a FIFO (First In / First Out) +`submit`/`consume` primitive driven by the `IOPump` (the index-keyed +`prefetch`/`__getitem__` cache is a thin random-access layer on top). +`MultiDataset` implements the same map-style surface by delegating +`submit`, `consume`, `prefetch`, `_pop_events`, and `close` to the child +`DatasetBase` instance that owns each sample; see +[Performance](#performance-threading-and-stream-based-concurrency) below. ## Composability ### Readers -A `Reader` is an ABC with a single contract: +A `Reader` is an ABC with one main loading hook plus a required length: ```python class Reader(ABC): @abstractmethod def _load_sample(self, index: int) -> dict[str, Tensor]: ... + + @abstractmethod + def __len__(self) -> int: ... ``` `__getitem__` wraps the result in a `TensorDict` on CPU (optionally @@ -79,7 +87,9 @@ multi-region consistency. ### Collators -Collators combine per-sample `(TensorDict, metadata)` tuples into batches: +Collators combine per-sample `(data, metadata)` tuples into batches, +where `data` is usually a `TensorDict`, `Mesh`, or `DomainMesh` depending +on the dataset and collator: | Collator | Strategy | |----------|----------| @@ -141,106 +151,185 @@ preprocessing. Threads are a natural fit: duplication overhead. - **I/O concurrency** -- the GIL is released during disk reads and CUDA kernel launches, so multiple threads usefully overlap I/O with GPU work. -- **Stream parallelism** -- each prefetched sample is assigned its own - CUDA stream, allowing host-to-device transfers and GPU transforms to - run concurrently with the main training computation. +- **Stream parallelism** -- when enabled, each prefetched sample is + assigned a CUDA stream so its host-to-device transfer can overlap with + the main training computation. -### Thread-pool prefetch - -`DatasetBase` owns a `ThreadPoolExecutor` (configurable via -`num_workers`, default 2). Calling `prefetch(index)` submits the -load-and-transform pipeline to the pool and stashes the `Future`: +### Producer / consumer split -```python -def prefetch(self, index, stream=None): - if index in self._prefetch_futures: - return - executor = self._ensure_executor() - self._prefetch_futures[index] = executor.submit(self._load, index) -``` +For the provided `Dataset` and `MeshDataset` implementations, prefetching +is split into two stages so that **no device kernels are launched off the +main thread** -- device kernels must share the model's single launching +thread: -`__getitem__` pops the `Future` if one exists, otherwise loads -synchronously: +- `_load_host` is the **producer**. It runs on a worker thread and does + only thread-safe work: reading, decoding, and staging into pinned host + memory. It returns a `HostPayload`. +- `_consume` is the **consumer**. It runs on whatever thread calls + `__getitem__` (the main thread, in practice) and performs the + host-to-device transfer and device transforms. -```python -def __getitem__(self, index): - future = self._prefetch_futures.pop(index, None) - if future is not None: - return future.result() - return self._load(index) -``` - -This means the DataLoader can keep the next batch loading in background -threads while the current batch is being consumed by the model. - -### CUDA stream overlap - -When GPU execution is available, `Dataset` (and `MeshDataset`) override -`prefetch` to run device transfer and transforms on a caller-supplied -CUDA stream, then record an event for later synchronization: +`DatasetBase` owns a `ThreadPoolExecutor` (configurable via +`num_workers`) and exposes a FIFO prefetch primitive. `submit(work_item, +stream=...)` runs `_load_host` on the pool and returns a `PrefetchHandle` +bundling the future with the stream the consumer should use; +`consume(handle)` resolves it on the calling thread. Subclasses that do +not override `_load_host` and `_consume` fall back to running their full +`_load` pipeline in the worker, so the main-thread launch guarantee +belongs to split implementations such as `Dataset` and `MeshDataset`: ```python -def _load_and_transform(self, index, stream=None): - result = _PrefetchResult(index=index) - data, metadata = self.reader[index] # CPU I/O in worker thread - - if stream is not None: - with torch.cuda.stream(stream): - data = data.to(device, non_blocking=True) # H2D on stream - data = self.transforms(data) # GPU transforms on stream - result.event = torch.cuda.Event() - result.event.record(stream) # mark completion - - result.data, result.metadata = data, metadata - return result +def submit(self, work_item, stream=None): + future = self._executor.submit(self._load_host, work_item) + return PrefetchHandle(future=future, stream=stream) + +def consume(self, handle, *, defer_sync=False): + payload = handle.future.result() # re-raises producer errors + # H2D + transforms here; defer_sync controls who gates the compute stream + return self._consume(payload, handle.stream, defer_sync=defer_sync) ``` -On retrieval, `__getitem__` synchronizes the event before returning: +Correlation is purely by handle identity (FIFO), so work items need not +be hashable, unique, or even integers -- an `int` index is just the +common case. The index-keyed `prefetch(index)` / `__getitem__(index)` +convenience API is a thin layer over `submit`/`consume` for random +access, and is what map-style tests and `MultiDataset` use. + +### Self-priming dispatch (IOPump) + +The threaded producer is driven by `IOPump`, a dedicated dispatcher +thread that keeps a *bounded* number of samples in flight regardless of +the consumer's cadence. It pulls a work-item stream **lazily** (one item +per free backpressure slot, so an arbitrarily long or unbounded source +never materializes up front), calls `submit` for each, and hands the +returned handles back to the main thread in FIFO order. The source +interleaves `BATCH_BOUNDARY` markers between work items; the pump forwards +them in place without consuming a slot, so the consumer reassembles +dynamically-sized batches from the boundaries -- the DataLoader never +builds the epoch's batch list in advance. Because dispatch lives off the +main thread, the pipeline stays primed even while the main thread is busy +launching kernels or running the model. For map-style datasets, this +path is active whenever `prefetch_factor > 0`; set `prefetch_factor=0` +to use synchronous map-style iteration. + +### CUDA stream handoff + +CUDA streams are an *optional* accelerator layered on top of the threaded +producer. When `use_streams=True` (and CUDA is available), each sample is +round-robined a **preprocessing stream**. The consumer runs *both* the +host-to-device copy and the transforms on that stream, then hands the +result to the compute stream via a CUDA **event** (never a host +`synchronize()`). Crucially, *who* enqueues the compute-stream wait +depends on the `defer_sync` flag (see +[One-batch lookahead](#one-batch-lookahead-deferred-sync) below): the +DataLoader defers it so the wait lands *after* the previous batch's model +work, while standalone callers wait inline: ```python -if result.event is not None: - result.event.synchronize() -return result.data, result.metadata +def _consume(self, payload, stream=None, *, defer_sync=False): + data = payload.data + if device is not None and stream is not None: + compute_stream = torch.cuda.current_stream() + # Bind torch to the preprocessing stream. + with preprocessing_stream(stream): # torch.cuda.stream + data = data.to(device, non_blocking=True) # H2D on prep stream + data = self.transforms(data) # transforms on SAME stream + event = torch.cuda.Event() + event.record(stream) + if defer_sync: + self._events_pending.append(event) # DataLoader gates later + else: + compute_stream.wait_event(event) # inline order, no host block + else: + data = self.transforms(data) + return data, payload.metadata ``` -The `DataLoader` owns a pool of `num_streams` CUDA streams (default 4) -and round-robins them across samples. It also maintains a sliding -prefetch window of `prefetch_factor` batches (default 2) ahead of the -current yield position: - -```python -# Prefetch the next batch as we yield the current one -for sample_idx in all_batches[next_prefetch_idx]: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 -``` +`preprocessing_stream` (in `protocols.py`) binds torch's current stream +to the preprocessing stream via `torch.cuda.stream(stream)`, so the +host-to-device copy and the transforms run on the side stream and GPU +preprocessing genuinely overlaps training. The pinned host source is +held by the caching host allocator until the copy completes. + +### One-batch lookahead (deferred sync) + +The compute-stream wait recorded above is only half the overlap story: +*when* it is enqueued decides whether preprocessing actually overlaps +training. The DataLoader's prefetch loop (`_iter_prefetch`) keeps a +**one-batch lookahead** so the wait lands at the right point: + +- `drain_one_batch()` consumes pump items up to the next `BATCH_BOUNDARY`, + calling `consume(item, defer_sync=True)` on each. With `defer_sync=True` + the consumer enqueues the H2D copy + transforms on the preprocessing + stream and *records* an event, but appends it to `_events_pending` + instead of making the compute stream wait. +- `_pop_events()` (on `DatasetBase`) hands those recorded events back to + the loop. +- Before yielding batch N, the loop eagerly drains batch **N+1**'s items + (launching their preprocessing on the side streams), then calls + `gate_compute_stream(events)` to issue `compute_stream.wait_event` for + batch N's events -- right before the `yield`. + +This ordering is the whole point. Because the wait for batch N is +enqueued *after* the previous iteration's `yield` already enqueued batch +N-1's model kernels, the compute-stream order becomes +`..., model_{N-1}, wait(prep_N), model_N, ...`. Batch N's preprocessing +(already in flight on its own stream) overlaps batch N-1's compute, and +each model only ever blocks on its own batch's preprocessing -- never on +the next batch's. If `_consume` instead waited inline during the +lookahead drain, that wait would be ordered *ahead* of batch N's model +kernels and the model would block on batch N+1's preprocessing -- the +opposite of overlap. Standalone callers (no DataLoader to insert the +gate) leave `defer_sync=False` and get the inline wait so their result is +immediately safe to use. ### Concurrency timeline -The diagram below shows how threads and streams overlap for a two-sample -batch with `prefetch_factor=1`: +With everything launched from the main thread, the worker pool, the +preprocessing stream, and the compute stream form a triple buffer: ```text -Main thread Worker 1 Worker 2 Stream 1 Stream 2 - │ │ │ │ │ - ├─prefetch(0,S1)─►│ │ │ │ - ├─prefetch(1,S2)─────────────────────►│ │ │ - │ ├─ Read (I/O) │ │ │ - │ │ ├─ Read (I/O) │ │ - │ ├─ to(device) ─────────────────────────►│ │ - │ ├─ transforms ─────────────────────────►│ │ - │ ├─ event.record() ─────────────────────►│ │ - │ │ ├─ to(device) ─────────────────►│ - │ │ ├─ transforms ─────────────────►│ - │ │ ├─ event.record() ─────────────►│ - ├─ event.synchronize() ×2 │ │ │ - ├─ collate + yield batch │ │ │ - │ │ │ │ │ +Worker pool │ load N+1 ─ load N ... (host I/O + thread-safe CPU work) +Preprocess stream │ H2D + transforms for N +Compute stream │ train N-1 ─ wait(prep_N) ─ train N ``` -While the main thread consumes batch N, worker threads are already -loading batch N+1 on different streams. +GPU preprocessing of batch N genuinely overlaps training of batch N-1 on +a separate stream; the two are ordered by a CUDA event, never a host-side +`synchronize`. The ordering is what the one-batch lookahead buys: the +compute stream's `wait(prep_N)` is enqueued by `gate_compute_stream` only +*after* batch N-1's model kernels and only *after* batch N's preprocessing +has been launched on its side stream (see +[One-batch lookahead](#one-batch-lookahead-deferred-sync)), so each model +blocks on its own batch's preprocessing and nothing else. A transform (or +generator) that forces a host readback simply serializes itself -- a +property of that code, not of the pipeline. + +### Two data paths: map/descriptor vs iterable + +The DataLoader selects one of two mutually-exclusive paths by dataset +type: + +- **Preload path (`DatasetBase`)** -- map-style and descriptor-keyed + datasets. Uses the worker pool + `IOPump` described above: workers do + thread-safe host I/O, the main thread consumes handles (H2D + transforms + on the preprocessing stream). This is the path for storage-backed data + addressable by index. +- **Generator path (`IterableDatasetBase`)** -- iterable datasets that + *produce* data (online simulation, procedural samplers, unbounded + streams). Driven **main-thread-only**: no sampler, no pump, no worker + pool. `__iter__` may freely launch device kernels and use CUDA streams, + and the loader still drives generation on a preprocessing stream with a + CUDA event handoff. This path does not use the map-style `IOPump`, + `_events_pending`, or one-batch deferred-sync loop. + +An iterable dataset yields either per-sample `(data, metadata)` (the +loader collates `batch_size` of them, `drop_last` trims the tail) or, when +`yields_batches = True`, ready-made batches that the loader passes through +unchanged. Iterable datasets have no length: `len(loader)` raises +`TypeError`, and `shuffle`/`sampler` are ignored. See +`examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py` for +a `Darcy2D` online simulation wired through this path. ### Pinned memory @@ -254,12 +343,15 @@ above is most effective when the reader pins its output. Prefetching can be toggled at runtime for debugging: ```python -loader.disable_prefetch() # synchronous, single-stream -- easy to debug -loader.enable_prefetch() # re-enable after debugging +loader.disable_prefetch() # drop CUDA streams (threaded pump still runs) +loader.enable_prefetch() # re-enable streams after debugging ``` -Setting `use_streams=False` or `prefetch_factor=0` at construction time -also forces synchronous execution. +`use_streams=False` keeps the threaded producer but drops the CUDA +stream handoff (the consumer copies and transforms on the default +stream); for map-style datasets, `prefetch_factor=0` forces fully +synchronous execution. Iterable datasets use their separate +main-thread-only generator path regardless of `prefetch_factor`. ## RNG and reproducibility @@ -275,11 +367,17 @@ documented in **[RNG.md](RNG.md)**. Mesh augmentations (`RandomScaleMesh`, `RandomTranslateMesh`, `RandomRotateMesh`) accept any `torch.distributions.Distribution` to -parametrize their random sampling. To preserve reproducibility with -seeded `torch.Generator` objects (which `Distribution.sample()` does not -accept), the augmentations use **inverse CDF sampling**: draw -`U ~ Uniform(0,1)` via `torch.rand(generator=g)`, then compute -`X = distribution.icdf(U)`. This gives exact samples from the target -distribution while keeping all randomness under generator control. +parametrize distribution-backed random sampling. To preserve +reproducibility with seeded `torch.Generator` objects (which +`Distribution.sample()` does not accept), `RandomScaleMesh`, +`RandomTranslateMesh`, and `RandomRotateMesh(mode="axis_aligned")` use +**inverse CDF sampling**: draw `U ~ Uniform(0,1)` via +`torch.rand(generator=g)`, then compute `X = distribution.icdf(U)`. This +gives exact samples from the target distribution while keeping randomness +under generator control for distributions that implement `icdf()`. +Distributions without `icdf()` fall back to `Distribution.sample()` with +a warning and are not generator-reproducible. `RandomRotateMesh` defaults +to `mode="uniform"`, which ignores `axes` and `distribution` and samples +uniform SO(3) rotations via random quaternions. Full usage examples, YAML configuration, and the supported-distribution table are in **[transforms/mesh/DISTRIBUTIONS.md](transforms/mesh/DISTRIBUTIONS.md)**. diff --git a/physicsnemo/datapipes/dataset.py b/physicsnemo/datapipes/dataset.py index 727503e800..5c66c46c5b 100644 --- a/physicsnemo/datapipes/dataset.py +++ b/physicsnemo/datapipes/dataset.py @@ -31,7 +31,11 @@ from tensordict import TensorDict from physicsnemo.datapipes._rng import fork_generator -from physicsnemo.datapipes.protocols import DatasetBase, _PrefetchResult +from physicsnemo.datapipes.protocols import ( + DatasetBase, + HostPayload, + preprocessing_stream, +) from physicsnemo.datapipes.readers.base import Reader from physicsnemo.datapipes.registry import register from physicsnemo.datapipes.transforms.base import Transform @@ -55,10 +59,19 @@ class Dataset(DatasetBase): Prefetching Model ----------------- - The dataset supports prefetching samples using a thread pool. - When a CUDA stream is provided, GPU operations (device transfer, - GPU transforms) happen on that stream, allowing overlap with - other computation. + The dataset supports prefetching samples using a thread pool. The + work is split into a thread-safe *producer* stage and a main-thread + *consumer* stage: + + - :meth:`_load_host` (producer, worker thread) reads the sample into + host memory (pinned when the reader is configured with + ``pin_memory=True``). It launches no device kernels. + - :meth:`_consume` (consumer, calling thread) performs the + host-to-device transfer and the GPU transforms on the assigned + CUDA stream. + + This keeps all device-kernel launches on the consuming thread, which + must be the same single thread the model launches from. >>> # Start prefetching >>> dataset.prefetch(0, stream=stream0) # doctest: +SKIP @@ -242,94 +255,69 @@ def set_epoch(self, epoch: int) -> None: t.set_epoch(epoch) # ------------------------------------------------------------------ - # Stream-aware prefetch (overrides DatasetBase defaults) + # Producer / consumer split (overrides DatasetBase defaults) # ------------------------------------------------------------------ - def _load_and_transform( - self, - index: int, - stream: Optional[torch.cuda.Stream] = None, - ) -> _PrefetchResult: + def _load_host(self, work_item: int) -> HostPayload: """ - Load a sample and apply transforms with optional CUDA stream. + Producer stage: read a sample into host memory. + + Runs on a worker thread and launches no device kernels. Pinning + is the reader's responsibility: construct the reader with + ``pin_memory=True`` to stage tensors in pinned memory so the + consumer's host-to-device copy can be asynchronous. Parameters ---------- - index : int - Sample index. - stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. + work_item : int + Sample index to read from the reader. Returns ------- - _PrefetchResult - PrefetchResult with data, metadata, or error. + HostPayload + Payload carrying the host data and metadata, or a captured + error. """ - result = _PrefetchResult(index=index) - try: - data, metadata = self.reader[index] - - if self.target_device is not None: - if stream is not None: - with torch.cuda.stream(stream): - data = data.to(self.target_device, non_blocking=True) - else: - data = data.to(self.target_device, non_blocking=True) - - if self.transforms is not None: - if stream is not None: - with torch.cuda.stream(stream): - data = self.transforms(data) - result.event = torch.cuda.Event() - result.event.record(stream) - else: - data = self.transforms(data) - - result.data = data - result.metadata = metadata - - except Exception as e: - result.error = e + data, metadata = self.reader[work_item] + return HostPayload(work_item=work_item, data=data, metadata=metadata) + except Exception as e: # noqa: BLE001 + return HostPayload(work_item=work_item, error=e) - return result - - def prefetch( + def _consume( self, - index: int, + payload: HostPayload, stream: Optional[torch.cuda.Stream] = None, - ) -> None: + *, + defer_sync: bool = False, + ) -> tuple[TensorDict, dict[str, Any]]: """ - Start prefetching a sample asynchronously. + Consumer stage: device transfer + transforms on the calling thread. - When a CUDA stream is provided, GPU operations (device transfer - and transforms) run on that stream for overlap with computation. + Runs on whatever thread calls this (the main thread, so any device + kernels in the transforms share the model's launching thread). When + a CUDA ``stream`` is assigned, the host-to-device copy *and* the + transforms run on that preprocessing stream via + :func:`preprocessing_stream` -- so this sample's preprocessing + overlaps the previous batch's training on the compute stream. A CUDA + event orders the preprocessing before the compute stream (never a + host-side synchronize). Parameters ---------- - index : int - Sample index to prefetch. + payload : HostPayload + Producer payload from :meth:`_load_host`. stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. - """ - if index in self._prefetch_futures: - return - - executor = self._ensure_executor() - future = executor.submit(self._load_and_transform, index, stream) - self._prefetch_futures[index] = future - - def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: - """ - Get a transformed sample by index. - - If the index was prefetched, returns the prefetched result - (waiting for completion if necessary). Otherwise loads synchronously. - - Parameters - ---------- - index : int - Sample index. + Preprocessing stream for the host-to-device transfer and + transforms. ``None`` runs on the current stream. + defer_sync : bool, default=False + When False, ``compute_stream.wait_event`` is enqueued here so the + result is immediately safe to use on the current stream. When + True, the recorded event is appended to :attr:`_events_pending` + and the DataLoader enqueues the wait just before the batch is + yielded -- after the previous batch's model work -- so this + batch's preprocessing overlaps the previous batch's compute + instead of blocking it. Returns ------- @@ -338,26 +326,43 @@ def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: Raises ------ - IndexError - If index is out of range. Exception - If prefetch failed, re-raises the error. + If the producer captured an error, re-raises it. """ - future = self._prefetch_futures.pop(index, None) + if payload.error is not None: + raise payload.error - if future is not None: - result = future.result() + data = payload.data + metadata = payload.metadata - if isinstance(result, _PrefetchResult): - if result.error is not None: - raise result.error - if result.event is not None: - torch.cuda.current_stream().wait_event(result.event) - return result.data, result.metadata + device_is_cuda = ( + self.target_device is not None + and torch.device(self.target_device).type == "cuda" + ) + use_stream = stream is not None and device_is_cuda + compute_stream = torch.cuda.current_stream() if use_stream else None - return result + with preprocessing_stream(stream if use_stream else None): + if self.target_device is not None: + data = data.to(self.target_device, non_blocking=True) + if self.transforms is not None: + data = self.transforms(data) + + if use_stream: + # Record an event marking the preprocessing's completion on the + # prep stream. + event = torch.cuda.Event() + event.record(stream) + if defer_sync: + # Defer the compute-stream wait to the DataLoader so it lands + # after the previous batch's model work (real overlap). + self._events_pending.append(event) + else: + # Inline ordering for standalone callers (no DataLoader to + # insert the wait at the right point). + compute_stream.wait_event(event) - return self._load(index) + return data, metadata @property def field_names(self) -> list[str]: diff --git a/physicsnemo/datapipes/io_pump.py b/physicsnemo/datapipes/io_pump.py new file mode 100644 index 0000000000..06e3cf33cb --- /dev/null +++ b/physicsnemo/datapipes/io_pump.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +IOPump - A self-driving I/O producer that keeps the IO pipeline primed. + +The pump owns a dedicated dispatcher thread that lazily pulls work items from a +source iterator and submits each for background loading, keeping +a *bounded* number of samples in flight at all times. Pulling lazily means +the source may be arbitrarily large (or effectively unbounded): the +dispatcher only advances it when a backpressure slot is free, so memory +stays bounded regardless of source length. + +Dispatch is decoupled from the consumer's cadence: the dispatcher keeps +topping the pipeline off while the main thread is busy launching kernels or +running the model, so a ready sample is (almost) always waiting when the +consumer asks for the next one. + +The pump is agnostic to *how* a sample is produced and consumed: it drives +opaque work items through a user-provided ``dispatch_fn`` (which starts the +background load and returns a handle) and hands those handles back to the +consumer in the exact order they were pulled. Correlation is purely +positional (FIFO), so work items need not be hashable or unique. + +A source may interleave :data:`BATCH_BOUNDARY` markers between work items; +the pump forwards them to the consumer in order without consuming a +backpressure slot, letting the consumer reassemble dynamically-sized +batches without knowing the batch layout up front. +""" + +from __future__ import annotations + +import queue +import threading +from typing import Any, Callable, Iterable, Iterator + +# Public marker a source yields to delimit the end of one batch. A distinct +# sentinel object so it can never collide with a real work item. +BATCH_BOUNDARY = object() + +# Internal marker the dispatcher pushes once the source is exhausted (or the +# pump is stopped), telling the consumer to finish iterating. +_DONE = object() + + +class _PumpError: + """Wraps an exception raised on the dispatcher thread. + + Forwarded through the ready queue so the consumer re-raises it on the + main thread instead of blocking forever waiting for items that will + never arrive. + """ + + __slots__ = ("exc",) + + def __init__(self, exc: BaseException) -> None: + self.exc = exc + + +class IOPump: + """Bounded, self-driving prefetch dispatcher. + + A dedicated dispatcher thread pulls work items from ``source``, + acquires a backpressure slot, calls ``dispatch_fn(work_item)`` to start + the background load, and makes the returned handle available to the + consumer in FIFO order via iteration. Slots are released as the + consumer advances, keeping at most ``depth`` samples in flight. + + Parameters + ---------- + source : Iterable + Work items to load, optionally interleaved with + :data:`BATCH_BOUNDARY` markers. Consumed lazily, one item at a + time, only as backpressure slots free up. + dispatch_fn : Callable[[Any], Any] + Called on the dispatcher thread to start loading a work item (for + example ``dataset.submit(work_item, stream=...)``). It must be + non-blocking and thread-safe and must not launch device kernels; + it returns an opaque handle that the consumer later turns into a + sample. + depth : int + Maximum number of samples dispatched but not yet consumed. Acts as + both the backpressure valve and the jitter buffer that hides + consumer stalls. Clamped to at least 1. + + Notes + ----- + A pump instance is single-consumer. Iterate it with a single thread + (the main/launcher thread). Call :meth:`stop` (or use it as a context + manager) to tear down the dispatcher thread; already-submitted loads + are left to complete and are reaped by the owning dataset. + """ + + def __init__( + self, + source: Iterable[Any], + dispatch_fn: Callable[[Any], Any], + depth: int, + ) -> None: + self._source = source + self._dispatch_fn = dispatch_fn + self._depth = max(1, int(depth)) + self._slots = threading.Semaphore(self._depth) + self._ready_queue: queue.Queue = queue.Queue() + self._stop = threading.Event() + self._thread = threading.Thread( + target=self._run, + name="datapipe_pump", + daemon=True, + ) + self._thread.start() + + # ------------------------------------------------------------------ + # Dispatcher thread + # ------------------------------------------------------------------ + + def _run(self) -> None: + """Dispatcher loop: keep ``depth`` samples in flight, in order.""" + source = iter(self._source) + while not self._stop.is_set(): + try: + item = next(source) + except StopIteration: + break + except BaseException as exc: # noqa: BLE001 + # A failing source must surface to the consumer, not hang it. + self._ready_queue.put(_PumpError(exc)) + return + + if item is BATCH_BOUNDARY: + # Boundaries are bookkeeping, not work: forward without + # consuming a slot. + self._ready_queue.put(BATCH_BOUNDARY) + continue + + # Backpressure: block until the consumer frees a slot. This is + # also where lazy pulling is enforced -- the source is not + # advanced again until there is room in flight. + self._slots.acquire() + if self._stop.is_set(): + break + try: + handle = self._dispatch_fn(item) + except BaseException as exc: # noqa: BLE001 + # A failing dispatch must surface to the consumer, not hang it. + self._ready_queue.put(_PumpError(exc)) + return + self._ready_queue.put(handle) + + self._ready_queue.put(_DONE) + + # ------------------------------------------------------------------ + # Consumer side (single consumer / the main thread) + # ------------------------------------------------------------------ + + def __iter__(self) -> Iterator[Any]: + """Yield ready handles (and batch boundaries) in FIFO order. + + Yields each loaded sample's handle in the order its work item was + pulled, and forwards :data:`BATCH_BOUNDARY` markers in place. + Releases a backpressure slot after each handle is consumed (i.e. + on the next iteration), so the dispatcher can refill the pipeline + as the consumer advances. Returns once the source is exhausted. + + Yields + ------ + object + Either a handle returned by ``dispatch_fn`` or + :data:`BATCH_BOUNDARY`. + """ + while True: + item = self._ready_queue.get() + if item is _DONE: + return + if isinstance(item, _PumpError): + raise item.exc + if item is BATCH_BOUNDARY: + yield BATCH_BOUNDARY + continue + yield item + # Consumer has finished with this sample; free a slot. + self._slots.release() + + # ------------------------------------------------------------------ + # Teardown + # ------------------------------------------------------------------ + + def stop(self) -> None: + """Stop the dispatcher thread and release its resources. + + Idempotent. Unblocks the dispatcher if it is waiting on a slot, + then joins it briefly. In-flight background loads already submitted + via ``dispatch_fn`` are not cancelled; the owning dataset reaps + them. + """ + if self._stop.is_set(): + return + self._stop.set() + # Unblock the dispatcher if it is parked acquiring a slot. + self._slots.release() + self._thread.join(timeout=5.0) + + def __enter__(self) -> "IOPump": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.stop() diff --git a/physicsnemo/datapipes/mesh_dataset.py b/physicsnemo/datapipes/mesh_dataset.py index 6ca740bb4b..03b9fae8b6 100644 --- a/physicsnemo/datapipes/mesh_dataset.py +++ b/physicsnemo/datapipes/mesh_dataset.py @@ -29,7 +29,11 @@ from tensordict import TensorDict from physicsnemo.datapipes._rng import fork_generator -from physicsnemo.datapipes.protocols import DatasetBase, _PrefetchResult +from physicsnemo.datapipes.protocols import ( + DatasetBase, + HostPayload, + preprocessing_stream, +) from physicsnemo.datapipes.readers.mesh import DomainMeshReader, MeshReader from physicsnemo.datapipes.registry import register from physicsnemo.datapipes.transforms.mesh.base import MeshTransform @@ -88,10 +92,10 @@ def __init__( device : str or torch.device, optional If set, move mesh data to this device after loading (before transforms). num_workers : int, default=1 - Number of worker threads for prefetching. Defaults to 1 - because mesh transforms construct new Mesh objects internally - and tensordict's ``_device_recorder`` is not safe for - concurrent TensorDict construction across threads. + Number of worker threads for the prefetch pool. Worker threads + run :meth:`_load_host` (disk read + pin_memory) concurrently; + GPU operations (H2D transfer, transforms) always run on the + main thread in :meth:`_consume`. """ super().__init__(num_workers=num_workers) self.reader = reader @@ -165,16 +169,21 @@ def _load( self, index: int ) -> tuple[Union[Mesh, DomainMesh, TensorDict], dict[str, Any]]: """Synchronous load: reader -> device transfer -> transforms.""" - data, metadata = self.reader[index] + with torch.profiler.record_function("MeshDataset._load: reader[index]"): + data, metadata = self.reader[index] if self._device is not None: - data = data.to(self._device) + with torch.profiler.record_function("MeshDataset._load: data.to(device)"): + data = data.to(self._device) for t in self.transforms: - if isinstance(data, DomainMesh): - data = t.apply_to_domain(data) - else: - data = t(data) + with torch.profiler.record_function( + f"MeshDataset._load: transform {type(t).__name__}" + ): + if isinstance(data, DomainMesh): + data = t.apply_to_domain(data) + else: + data = t(data) return data, metadata @@ -182,101 +191,66 @@ def __len__(self) -> int: return len(self.reader) # ------------------------------------------------------------------ - # Stream-aware prefetch (overrides DatasetBase defaults) + # Producer / consumer split (overrides DatasetBase defaults) # ------------------------------------------------------------------ - def _load_and_transform( - self, - index: int, - stream: Optional[torch.cuda.Stream] = None, - ) -> _PrefetchResult: - """Load a sample and apply transforms with optional CUDA stream. + def _load_host(self, work_item: int) -> HostPayload: + """Producer stage: read a mesh sample on a worker thread. + + Launches no device kernels: it only reads the raw sample. Device + transfer and mesh transforms happen later in :meth:`_consume` on + the consuming thread. Parameters ---------- - index : int - Sample index. - stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. + work_item : int + Sample index to read from the reader. Returns ------- - _PrefetchResult - Result with data, metadata, or error. + HostPayload + Payload carrying the host data and metadata, or a captured + error. """ - result = _PrefetchResult(index=index) - try: - data, metadata = self.reader[index] - - if self._device is not None: - if stream is not None: - with torch.cuda.stream(stream): - data = data.to(self._device, non_blocking=True) - else: - data = data.to(self._device, non_blocking=True) - - for t in self.transforms: - if stream is not None: - with torch.cuda.stream(stream): - if isinstance(data, DomainMesh): - data = t.apply_to_domain(data) - else: - data = t(data) - else: - if isinstance(data, DomainMesh): - data = t.apply_to_domain(data) - else: - data = t(data) - - if stream is not None: - result.event = torch.cuda.Event() - result.event.record(stream) - - result.data = data - result.metadata = metadata - - except Exception as e: - result.error = e - - return result + data, metadata = self.reader[work_item] + return HostPayload(work_item=work_item, data=data, metadata=metadata) + except Exception as e: # noqa: BLE001 + return HostPayload(work_item=work_item, error=e) - def prefetch( + def _consume( self, - index: int, + payload: HostPayload, stream: Optional[torch.cuda.Stream] = None, - ) -> None: - """Start prefetching a sample asynchronously. - - When a CUDA stream is provided, GPU operations (device transfer - and transforms) run on that stream for overlap with computation. - - Parameters - ---------- - index : int - Sample index to prefetch. - stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. - """ - if index in self._prefetch_futures: - return - - executor = self._ensure_executor() - future = executor.submit(self._load_and_transform, index, stream) - self._prefetch_futures[index] = future - - def __getitem__( - self, index: int + *, + defer_sync: bool = False, ) -> tuple[Union[Mesh, DomainMesh, TensorDict], dict[str, Any]]: - """Get a transformed sample by index. + """Consumer stage: device transfer + transforms on the calling thread. - If the index was prefetched, returns the prefetched result - (waiting for completion if necessary). Otherwise loads synchronously. + Runs on whatever thread calls this (the main thread, so any device + kernels in the transforms share the model's launching thread). When + a CUDA ``stream`` is assigned, the host-to-device copy *and* the + transforms run on that preprocessing stream via + :func:`preprocessing_stream` -- so this sample's preprocessing + overlaps the previous batch's training on the compute stream. A CUDA + event orders the preprocessing before the compute stream (not a + host-side synchronize). Parameters ---------- - index : int - Sample index. + payload : HostPayload + Producer payload from :meth:`_load_host`. + stream : torch.cuda.Stream, optional + Preprocessing stream for the host-to-device transfer and + transforms. ``None`` runs on the current stream. + defer_sync : bool, default=False + When False, ``compute_stream.wait_event`` is enqueued here so the + result is immediately safe to use on the current stream. When + True, the recorded event is appended to :attr:`_events_pending` + and the DataLoader enqueues the wait just before the batch is + yielded -- after the previous batch's model work -- so this + batch's preprocessing overlaps the previous batch's compute + instead of blocking it. Returns ------- @@ -286,23 +260,54 @@ def __getitem__( Raises ------ Exception - If prefetch failed, re-raises the error. + If the producer captured an error, re-raises it. """ - future = self._prefetch_futures.pop(index, None) + if payload.error is not None: + raise payload.error - if future is not None: - result = future.result() + data = payload.data + metadata = payload.metadata - if isinstance(result, _PrefetchResult): - if result.error is not None: - raise result.error - if result.event is not None: - torch.cuda.current_stream().wait_event(result.event) - return result.data, result.metadata + def _apply_transforms(d: Any) -> Any: + for t in self.transforms: + if isinstance(d, DomainMesh): + d = t.apply_to_domain(d) + else: + d = t(d) + return d - return result + device_is_cuda = ( + self._device is not None and torch.device(self._device).type == "cuda" + ) + use_stream = stream is not None and device_is_cuda + compute_stream = torch.cuda.current_stream() if use_stream else None - return self._load(index) + with preprocessing_stream(stream if use_stream else None): + if self._device is not None: + with torch.profiler.record_function( + "MeshDataset._consume: data.to(device)" + ): + data = data.to(self._device, non_blocking=True) + with torch.profiler.record_function( + "MeshDataset._consume: _apply_transforms" + ): + data = _apply_transforms(data) + + if use_stream: + # Record an event marking the preprocessing's completion on the + # prep stream. + event = torch.cuda.Event() + event.record(stream) + if defer_sync: + # Defer the compute-stream wait to the DataLoader so it lands + # after the previous batch's model work (real overlap). + self._events_pending.append(event) + else: + # Inline ordering for standalone callers (no DataLoader to + # insert the wait at the right point). + compute_stream.wait_event(event) + + return data, metadata def close(self) -> None: """Close the dataset and stop prefetching. diff --git a/physicsnemo/datapipes/multi_dataset.py b/physicsnemo/datapipes/multi_dataset.py index 2547bf5f95..6fc6a7470a 100644 --- a/physicsnemo/datapipes/multi_dataset.py +++ b/physicsnemo/datapipes/multi_dataset.py @@ -304,6 +304,76 @@ def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: metadata[DATASET_INDEX_METADATA_KEY] = ds_id return data, metadata + def submit(self, index: int, stream: Optional[Any] = None) -> tuple[int, Any]: + """ + Submit a global index for background loading (FIFO prefetch primitive). + + Maps the global index to its owning sub-dataset and delegates to + that dataset's :meth:`~DatasetBase.submit`. The returned handle is + wrapped with the owning dataset id so :meth:`consume` can restore + the ``dataset_index`` metadata. + + Parameters + ---------- + index : int + Global sample index to load. + stream : object, optional + CUDA stream for the consume step. + + Returns + ------- + tuple[int, PrefetchHandle] + ``(dataset_index, handle)`` to pass to :meth:`consume`. + """ + ds_id, local_i = self._index_to_dataset_and_local(index) + handle = self._datasets[ds_id].submit(local_i, stream=stream) + return ds_id, handle + + def consume( + self, handle: tuple[int, Any], *, defer_sync: bool = False + ) -> tuple[TensorDict, dict[str, Any]]: + """ + Resolve a :meth:`submit` handle into ``(data, metadata)``. + + Parameters + ---------- + handle : tuple[int, PrefetchHandle] + The ``(dataset_index, handle)`` returned by :meth:`submit`. + defer_sync : bool, default=False + Forwarded to the owning sub-dataset's + :meth:`~DatasetBase.consume`. When True, the sub-dataset records + its preprocessing event into its own ``_events_pending`` instead + of making the compute stream wait on it; :meth:`_pop_events` + collects those events for the DataLoader. + + Returns + ------- + tuple[TensorDict, dict[str, Any]] + Sample and metadata, enriched with ``dataset_index``. + """ + ds_id, inner = handle + data, metadata = self._datasets[ds_id].consume(inner, defer_sync=defer_sync) + metadata = dict(metadata) + metadata[DATASET_INDEX_METADATA_KEY] = ds_id + return data, metadata + + def _pop_events(self) -> list: + """Aggregate and clear pending preprocessing events across sub-datasets. + + Each sub-dataset records its deferred preprocessing CUDA events on + itself, so the DataLoader retrieves them through the + :class:`MultiDataset` by gathering from every constituent. + + Returns + ------- + list + CUDA events recorded by the sub-datasets since the last pop. + """ + collected: list = [] + for ds in self._datasets: + collected.extend(ds._pop_events()) + return collected + def prefetch( self, index: int, diff --git a/physicsnemo/datapipes/protocols.py b/physicsnemo/datapipes/protocols.py index 69d4363f3b..c810a21e46 100644 --- a/physicsnemo/datapipes/protocols.py +++ b/physicsnemo/datapipes/protocols.py @@ -15,16 +15,26 @@ # limitations under the License. """ -Base class for dataset components. +Base classes for dataset components. -Provides :class:`DatasetBase`, an ABC that owns the thread-based prefetch -infrastructure shared by :class:`Dataset`, :class:`MeshDataset`, and any -future dataset implementations. The user-facing extension points are -**Readers** and **Transforms**, not dataset subclasses. +Provides two abstractions consumed by :class:`~physicsnemo.datapipes.DataLoader`: + +- :class:`DatasetBase` -- map-style datasets. Owns the thread-based + prefetch infrastructure (a producer/consumer split plus a FIFO + ``submit``/``consume`` primitive) shared by :class:`Dataset`, + :class:`MeshDataset`, and any future implementation. +- :class:`IterableDatasetBase` -- generator-style datasets that produce + data directly on the main thread (online simulation and other + stream-sensitive workloads). No prefetch, no length, no indexing. + +The user-facing extension points are **Readers** and **Transforms**, not +dataset subclasses. """ from __future__ import annotations +import contextlib +import threading from abc import ABC, abstractmethod from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass, field @@ -33,61 +43,306 @@ import torch +@contextlib.contextmanager +def preprocessing_stream(stream: Optional["torch.cuda.Stream"]): + """Bind torch to *stream* for the host-to-device copy + transforms. + + Within the block torch's current stream is set to *stream*, so the + host-to-device copy and any device kernels launched by transforms run + on the same stream the data was copied on. A ``None`` stream is a + no-op (run on the current stream). + + Parameters + ---------- + stream : torch.cuda.Stream, optional + Stream to bind, or ``None`` to run on the current stream. + """ + if stream is None: + yield + return + with torch.cuda.stream(stream): + yield + + @dataclass -class _PrefetchResult: - """Result of a stream-aware prefetch operation. +class HostPayload: + """A sample produced by the (thread-safe) I/O stage, staged on the host. + + A ``HostPayload`` is the boundary object between the I/O producer and + the main-thread consumer. It carries a CPU ``TensorDict`` (ideally + pinned, so the subsequent host-to-device copy can be asynchronous) + plus metadata. It is produced by a worker thread, which must not + launch device kernels. - Used by :class:`Dataset` and :class:`MeshDataset` to carry data, - metadata, and an optional CUDA event through the prefetch pipeline. + Parameters + ---------- + work_item : Any + The work item this payload was produced from -- an ``int`` index + for map-style datasets, or any opaque descriptor for + descriptor-driven sources. + data : Any, optional + Host ``TensorDict`` (or mesh) payload. ``None`` on error. + metadata : dict, optional + Per-sample metadata produced by the reader. + error : Exception, optional + Exception captured during production, re-raised on consumption. """ - index: int + work_item: Any data: Any = None metadata: Optional[dict[str, Any]] = field(default=None) error: Optional[Exception] = field(default=None) - event: Optional[torch.cuda.Event] = field(default=None) + + +@dataclass +class PrefetchHandle: + """Handle to one in-flight prefetch returned by :meth:`DatasetBase.submit`. + + Bundles the producer ``Future`` with the CUDA stream the consumer + should use for the host-to-device copy and transforms. The handle is + correlated to its sample purely by identity/order, so opaque work + items need not be hashable or unique. + + Parameters + ---------- + future : concurrent.futures.Future + Future resolving to the producer's :class:`HostPayload`. + stream : torch.cuda.Stream, optional + Stream assigned for the consume step, or ``None`` for the current + stream. + """ + + future: Future + stream: Optional[torch.cuda.Stream] = None class DatasetBase(ABC): - """Abstract base for datasets compatible with :class:`DataLoader`. + """Abstract base for map-style datasets compatible with :class:`DataLoader`. + + Subclasses implement :meth:`_load` (the synchronous data-loading + pipeline) and :meth:`__len__`. Everything else -- indexing, the + ``submit``/``consume`` prefetch primitive, the index-keyed + ``prefetch``/``__getitem__`` convenience API, cancellation, and + cleanup -- is provided here. - Subclasses implement :meth:`_load` (the actual data-loading pipeline) - and :meth:`__len__`. Everything else — ``__getitem__`` with prefetch - cache lookup, thread-pool prefetching, cancellation, cleanup — is - provided here. + Producer / consumer split + -------------------------- + Prefetching is split into two stages so that **no device kernels are + launched off the main thread** (device kernels must share the model's + single launching thread): - Both :class:`Dataset` and :class:`MeshDataset` override - :meth:`prefetch` and :meth:`__getitem__` to add CUDA-stream - support via :class:`_PrefetchResult`. + - :meth:`_load_host` is the **producer**. It runs on a worker thread + and performs only thread-safe work: reading, decoding, and staging + into host memory. It returns a :class:`HostPayload`. + - :meth:`_consume` is the **consumer**. It runs on the thread that + calls :meth:`consume` / :meth:`__getitem__` (the main thread, in + practice) and performs the host-to-device transfer and device + transforms on the assigned CUDA stream. + + :class:`Dataset` and :class:`MeshDataset` override both hooks to + perform the real split. The default implementations fall back to + running the full :meth:`_load` on the worker for any subclass that + does not override them. """ - def __init__(self, *, num_workers: int = 2) -> None: - self._prefetch_futures: dict[int, Future] = {} + def __init__( + self, + *, + num_workers: int = 2, + ) -> None: self._executor: Optional[ThreadPoolExecutor] = None self._num_workers = num_workers + self._lock = threading.Lock() + # Futures still in flight, tracked so close() can drain them. + self._inflight: set[Future] = set() + # Index-keyed handles backing the prefetch()/__getitem__ compat API. + self._prefetch_handles: dict[int, PrefetchHandle] = {} + # Per-sample preprocessing CUDA events recorded by _consume when + # invoked with defer_sync=True. The compute-stream wait on these is + # deferred to the DataLoader, which inserts it right before the batch + # is yielded (i.e. after the previous batch's model work is enqueued) + # so preprocessing genuinely overlaps the prior batch's compute. + # Always accessed on the main thread so no locking is needed. + self._events_pending: list = [] + + def _pop_events(self) -> list: + """Return and clear the pending preprocessing-event list. + + Called by the DataLoader after each ``consume()`` (or batch of + consumes) made with ``defer_sync=True`` to retrieve the CUDA events + that ``_consume`` recorded but did not yet wait on. The DataLoader + inserts ``compute_stream.wait_event`` for each returned event right + before the corresponding batch is yielded. + + Returns + ------- + list + CUDA events recorded during the most recent deferred + ``consume()`` call(s). Empty when none were recorded (e.g. no + stream, CPU target, or ``defer_sync=False``). + """ + lst = self._events_pending + self._events_pending = [] + return lst @abstractmethod def _load(self, index: int) -> tuple[Any, dict[str, Any]]: """Load and return a single sample ``(data, metadata)``. - This is the hook that subclasses must implement. It is called - both synchronously (from ``__getitem__``) and asynchronously - (from the prefetch thread pool). + This is the synchronous, full-pipeline hook that subclasses must + implement. It is called directly from :meth:`__getitem__` when the + index was not prefetched. """ ... @abstractmethod def __len__(self) -> int: ... + # ------------------------------------------------------------------ + # Producer / consumer hooks (overridden by Dataset / MeshDataset) + # ------------------------------------------------------------------ + + def _load_host(self, work_item: Any) -> HostPayload: + """Producer stage: load *work_item* into a :class:`HostPayload`. + + Runs on a worker thread and must not launch device kernels. The + default implementation runs the full :meth:`_load` pipeline for + backward compatibility; subclasses that use device transforms + override this to stop at host staging. + + Parameters + ---------- + work_item : Any + Work item to load (an ``int`` index by default). + + Returns + ------- + HostPayload + Payload carrying the host data and metadata, or a captured + error. + """ + try: + data, metadata = self._load(work_item) + return HostPayload(work_item=work_item, data=data, metadata=metadata) + except Exception as e: # noqa: BLE001 + return HostPayload(work_item=work_item, error=e) + + def _consume( + self, + payload: HostPayload, + stream: Optional[torch.cuda.Stream] = None, + *, + defer_sync: bool = False, + ) -> tuple[Any, dict[str, Any]]: + """Consumer stage: turn a :class:`HostPayload` into ``(data, metadata)``. + + Runs on the calling (main) thread. The default implementation + unwraps the payload and re-raises any captured error; the + ``stream`` and ``defer_sync`` arguments are ignored. Subclasses + override this to perform the host-to-device transfer and device + transforms. + + Parameters + ---------- + payload : HostPayload + Producer payload from :meth:`_load_host`. + stream : torch.cuda.Stream, optional + Stream for the consume step. + defer_sync : bool, default=False + When True (and a stream is used), record the preprocessing CUDA + event into :attr:`_events_pending` instead of making the compute + stream wait on it here. The DataLoader then inserts the wait + just before the batch is yielded so preprocessing overlaps the + previous batch's compute. Ignored by this default implementation. + + Returns + ------- + tuple[Any, dict[str, Any]] + The sample data and its metadata. + """ + if payload.error is not None: + raise payload.error + return payload.data, payload.metadata + + # ------------------------------------------------------------------ + # FIFO prefetch primitive (used by the DataLoader's pump) + # ------------------------------------------------------------------ + + def submit( + self, + work_item: Any, + stream: Optional[torch.cuda.Stream] = None, + ) -> PrefetchHandle: + """Submit *work_item* for background loading and return its handle. + + Only the (thread-safe) producer stage runs on the worker pool. The + returned :class:`PrefetchHandle` is later passed to :meth:`consume` + on the main thread. Safe to call from a dispatcher thread distinct + from the consumer. + + Parameters + ---------- + work_item : Any + Work item to load. + stream : torch.cuda.Stream, optional + Stream the consumer should use for this sample. + + Returns + ------- + PrefetchHandle + Handle bundling the producer future and the assigned stream. + """ + executor = self._ensure_executor() + future = executor.submit(self._load_host, work_item) + with self._lock: + self._inflight.add(future) + future.add_done_callback(self._discard_inflight) + return PrefetchHandle(future=future, stream=stream) + + def consume( + self, handle: PrefetchHandle, *, defer_sync: bool = False + ) -> tuple[Any, dict[str, Any]]: + """Resolve a :meth:`submit` handle into ``(data, metadata)``. + + Blocks until the producer future is ready, then runs the consumer + stage on the calling thread (host-to-device transfer and device + transforms on the handle's stream). + + Parameters + ---------- + handle : PrefetchHandle + Handle returned by :meth:`submit`. + defer_sync : bool, default=False + Forwarded to :meth:`_consume`. When True, the compute-stream + wait on this sample's preprocessing event is deferred to the + caller (the DataLoader) via :attr:`_events_pending`. Standalone + callers leave this False so the wait is enqueued inline and the + result is immediately safe to use on the current stream. + + Returns + ------- + tuple[Any, dict[str, Any]] + The sample data and its metadata. + """ + payload = handle.future.result() # re-raises producer errors via _consume + return self._consume(payload, handle.stream, defer_sync=defer_sync) + # ------------------------------------------------------------------ # Concrete interface # ------------------------------------------------------------------ def __getitem__(self, index: int) -> tuple[Any, dict[str, Any]]: - """Return sample *index*, using a prefetched result when available.""" - future = self._prefetch_futures.pop(index, None) - if future is not None: - return future.result() # re-raises on error + """Return sample *index*, using a prefetched result when available. + + When the index was prefetched via :meth:`prefetch`, the pending + handle is consumed on the calling thread (so device transforms run + here, not on the worker). Otherwise the sample is loaded + synchronously. + """ + with self._lock: + handle = self._prefetch_handles.pop(index, None) + if handle is not None: + return self.consume(handle) return self._load(index) def prefetch( @@ -95,32 +350,47 @@ def prefetch( index: int, stream: Optional[torch.cuda.Stream] = None, ) -> None: - """Submit *index* for background loading in a worker thread. + """Start prefetching *index*, retrievable later via :meth:`__getitem__`. - The ``stream`` parameter is accepted for interface compatibility - but ignored by the default implementation. :class:`Dataset` - overrides this to run GPU transfers on the given stream. + Index-keyed convenience wrapper around :meth:`submit`. A repeated + prefetch of an in-flight index is a no-op. + + Parameters + ---------- + index : int + Sample index to prefetch. + stream : torch.cuda.Stream, optional + Stream the consumer should use for this sample. """ - if index in self._prefetch_futures: - return - executor = self._ensure_executor() - self._prefetch_futures[index] = executor.submit(self._load, index) + with self._lock: + if index in self._prefetch_handles: + return + handle = self.submit(index, stream) + with self._lock: + if index in self._prefetch_handles: + return # raced; the extra handle's future is reaped on completion + self._prefetch_handles[index] = handle def cancel_prefetch(self, index: Optional[int] = None) -> None: - """Discard prefetch results (already-running tasks still complete).""" - if index is None: - self._prefetch_futures.clear() - else: - self._prefetch_futures.pop(index, None) + """Discard prefetch handles (already-running tasks still complete).""" + with self._lock: + if index is None: + self._prefetch_handles.clear() + else: + self._prefetch_handles.pop(index, None) def close(self) -> None: """Drain in-flight prefetches and shut down the thread pool.""" - for future in self._prefetch_futures.values(): + with self._lock: + futures = list(self._inflight) + self._prefetch_handles.clear() + for future in futures: try: future.result(timeout=30.0) except Exception: # noqa: BLE001, S110 pass - self._prefetch_futures.clear() + with self._lock: + self._inflight.clear() if self._executor is not None: self._executor.shutdown(wait=True) @@ -130,13 +400,19 @@ def close(self) -> None: # Helpers # ------------------------------------------------------------------ + def _discard_inflight(self, future: Future) -> None: + """Done-callback: drop a finished future from the in-flight set.""" + with self._lock: + self._inflight.discard(future) + def _ensure_executor(self) -> ThreadPoolExecutor: - if self._executor is None: - self._executor = ThreadPoolExecutor( - max_workers=self._num_workers, - thread_name_prefix="datapipe_prefetch", - ) - return self._executor + with self._lock: + if self._executor is None: + self._executor = ThreadPoolExecutor( + max_workers=self._num_workers, + thread_name_prefix="datapipe_prefetch", + ) + return self._executor def __iter__(self) -> Iterator[tuple[Any, dict[str, Any]]]: for i in range(len(self)): @@ -147,3 +423,70 @@ def __enter__(self) -> "DatasetBase": def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() + + +class IterableDatasetBase(ABC): + """Abstract base for generator-style datasets driven on the main thread. + + Unlike :class:`DatasetBase`, an iterable dataset has no length and no + indexing: it produces data by iteration only. The + :class:`~physicsnemo.datapipes.DataLoader` drives it entirely on the + main thread (no worker pool), so :meth:`__iter__` may freely launch + device kernels and use CUDA streams -- the property that makes online + simulation safe here but unsafe on the worker-pool preload path. + + Emission modes + -------------- + - **Per-sample** (default, ``yields_batches = False``): :meth:`__iter__` + yields ``(data, metadata)`` for one sample at a time and the loader + collates ``batch_size`` of them. + - **Self-batching** (``yields_batches = True``): :meth:`__iter__` yields + a fully-formed batch already and the loader passes it through + unchanged (``batch_size``/``drop_last`` do not apply). + + Subclasses may optionally implement :meth:`set_epoch` and + :meth:`set_generator` for reproducible seeding. + """ + + # When True, __iter__ yields ready-made batches and the loader does not + # re-collate (e.g. an online simulator that produces a batch per step). + yields_batches: bool = False + + @abstractmethod + def __iter__(self) -> Iterator[Any]: + """Yield samples ``(data, metadata)`` or ready-made batches. + + Returns + ------- + Iterator + Per-sample ``(data, metadata)`` tuples, or full batches when + :attr:`yields_batches` is True. + """ + ... + + def set_epoch(self, epoch: int) -> None: + """Reseed for *epoch* (no-op by default). + + Parameters + ---------- + epoch : int + Current epoch number. + """ + + def set_generator(self, generator: torch.Generator) -> None: + """Seed the dataset's randomness from *generator* (no-op by default). + + Parameters + ---------- + generator : torch.Generator + Parent generator supplied by the DataLoader. + """ + + def __enter__(self) -> "IterableDatasetBase": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def close(self) -> None: + """Release any resources held by the dataset (no-op by default).""" diff --git a/physicsnemo/datapipes/readers/base.py b/physicsnemo/datapipes/readers/base.py index 17ce9b0b86..9ec8f2abe3 100644 --- a/physicsnemo/datapipes/readers/base.py +++ b/physicsnemo/datapipes/readers/base.py @@ -31,6 +31,8 @@ import torch from tensordict import TensorDict +from physicsnemo.datapipes._rng import spawn_generator + logger = logging.getLogger(__name__) @@ -110,6 +112,11 @@ def __init__( self.pin_memory = pin_memory self.include_index_in_metadata = include_index_in_metadata self._coordinated_subsampling_config = coordinated_subsampling + # Base seed + epoch for deterministic per-index RNG. See + # :meth:`_index_generator`. ``None`` means no seed was provided + # (random draws fall back to the global default RNG). + self._seed_base: int | None = None + self._epoch: int = 0 @abstractmethod def _load_sample(self, index: int) -> dict[str, torch.Tensor]: @@ -279,28 +286,61 @@ def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: raise RuntimeError(error_msg) from e def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible random sampling. + """Assign a base seed for reproducible, order-independent sampling. + + Stores ``generator.initial_seed()`` as the base seed. Per-sample + generators are then derived deterministically from + ``(base_seed, epoch, index)`` via :meth:`_index_generator`, so + random draws are reproducible regardless of the order in which + samples are read or which worker thread reads them. - Override in subclasses that use randomness (e.g. subsampling). - The default implementation is a no-op. + Subclasses that use randomness should draw from + :meth:`_index_generator` rather than a shared generator. Readers + with no randomness inherit this harmlessly. Parameters ---------- generator : torch.Generator - Generator to use for random draws. + Generator whose ``initial_seed()`` seeds all per-sample RNG. """ + self._seed_base = generator.initial_seed() def set_epoch(self, epoch: int) -> None: - """Reseed the reader's RNG for a new epoch. + """Set the epoch used to vary per-sample RNG deterministically. - Override in subclasses that use randomness. - The default implementation is a no-op. + The epoch is folded into each sample's derived seed (see + :meth:`_index_generator`), so each epoch produces a different but + reproducible sequence. Parameters ---------- epoch : int Current epoch number. """ + self._epoch = epoch + + def _index_generator(self, index: int) -> torch.Generator | None: + """Return a fresh generator seeded for sample *index* this epoch. + + Derives an independent :class:`torch.Generator` from + ``(base_seed, epoch, index)``. Because the seed depends only on + those values, the draw for a given sample is identical regardless + of read order or worker thread. Returns ``None`` when no base + seed has been set (the unseeded fallback). + + Parameters + ---------- + index : int + Sample index. + + Returns + ------- + torch.Generator or None + A per-sample generator, or ``None`` if no seed was provided. + """ + if self._seed_base is None: + return None + return spawn_generator(self._seed_base, self._epoch, index) def close(self) -> None: """ diff --git a/physicsnemo/datapipes/readers/mesh.py b/physicsnemo/datapipes/readers/mesh.py index 603dcf13bd..ff67b09fbc 100644 --- a/physicsnemo/datapipes/readers/mesh.py +++ b/physicsnemo/datapipes/readers/mesh.py @@ -30,6 +30,7 @@ import torch +from physicsnemo.datapipes._rng import spawn_generator from physicsnemo.datapipes.registry import register from physicsnemo.mesh import DomainMesh, Mesh @@ -55,7 +56,8 @@ def _contiguous_block_slice( """ if total <= k: return slice(0, total) - start = torch.randint(0, total - k + 1, (1,), generator=generator).item() + with torch.profiler.record_function("mesh_reader: randint.item() scalar readback"): + start = torch.randint(0, total - k + 1, (1,), generator=generator).item() return slice(start, start + k) @@ -188,7 +190,10 @@ def __init__( self.include_index_in_metadata = include_index_in_metadata self.subsample_n_points = subsample_n_points self.subsample_n_cells = subsample_n_cells - self._subsample_generator: torch.Generator | None = None + # Base seed + epoch for deterministic per-index RNG (see + # :meth:`set_generator`). ``None`` means unseeded. + self._seed_base: int | None = None + self._epoch: int = 0 if not self._root.exists(): raise FileNotFoundError(f"Path not found: {self._root}") @@ -212,38 +217,43 @@ def __len__(self) -> int: return len(self._paths) def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling. + """Assign a base seed for reproducible, order-independent subsampling. Called by :class:`MeshDataset` when the DataLoader provides a - seed. Replaces any previously assigned generator. + seed. Stores ``generator.initial_seed()`` as the base seed; each + sample then derives its own generator from + ``(base_seed, epoch, index)``, so subsampling is reproducible + regardless of read order or worker thread. Parameters ---------- generator : torch.Generator - Generator to use for contiguous block selection. + Generator whose ``initial_seed()`` seeds all per-sample RNG. """ - self._subsample_generator = generator + self._seed_base = generator.initial_seed() def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch. + """Set the epoch used to vary per-sample RNG deterministically. - Produces a different (but deterministic) sequence of contiguous - blocks each epoch when a generator has been assigned via - :meth:`set_generator`. + The epoch is folded into each sample's derived seed, producing a + different (but deterministic) sequence of contiguous blocks each + epoch when a base seed has been assigned via :meth:`set_generator`. """ - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) + self._epoch = epoch def __getitem__(self, index: int) -> tuple[Mesh, dict[str, Any]]: mesh = self._load_sample(index) + generator = ( + None + if self._seed_base is None + else spawn_generator(self._seed_base, self._epoch, index) + ) mesh = _subsample_mesh( mesh, self.subsample_n_cells, self.subsample_n_points, - generator=self._subsample_generator, + generator=generator, ) if self.pin_memory: @@ -286,6 +296,8 @@ def __init__( subsample_n_points: int | None = None, subsample_n_cells: int | None = None, extra_boundaries: dict[str, dict] | None = None, + drop_interior_cells: bool = False, + drop_in_file_boundaries: bool = False, ) -> None: """ Initialize the domain mesh reader. @@ -332,14 +344,41 @@ def __init__( extra_boundaries: stl_geometry: pattern: "*_single_solid.stl.pmsh" + drop_interior_cells : bool, default=False + If True, discard the interior mesh's cell connectivity (and + cell_data) immediately after load, turning it into a point + cloud. This makes ``subsample_n_points`` take the cheap + contiguous-block path instead of the expensive + ``slice_points`` remap (which allocates an ``n_points`` map + and scatter-reads the full cell array from the memmap). Use + for point-based models that consume only ``interior.points`` + and ``interior.point_data`` (e.g. GeoTransolver volume) and + never the interior tet/cell topology. Boundaries are + unaffected, so surface normals etc. still work. + drop_in_file_boundaries : bool, default=False + If True, discard the boundaries stored *in* the DomainMesh + file immediately after load (before subsampling and pinning). + ``extra_boundaries`` are added afterwards and are therefore + unaffected. Use when the model consumes only the interior + (plus any ``extra_boundaries``) and never the in-file + boundaries -- e.g. a volume pipeline whose SDF comes from an + injected STL, where the in-file car-surface boundary would + otherwise be subsampled (an expensive ``slice_points`` remap, + GIL-held, that blocks worker-thread overlap) and pinned every + sample for nothing. """ self._root = Path(path) self._pattern = pattern self.pin_memory = pin_memory self.include_index_in_metadata = include_index_in_metadata + self.drop_interior_cells = drop_interior_cells + self.drop_in_file_boundaries = drop_in_file_boundaries self.subsample_n_points = subsample_n_points self.subsample_n_cells = subsample_n_cells - self._subsample_generator: torch.Generator | None = None + # Base seed + epoch for deterministic per-index RNG (see + # :meth:`set_generator`). ``None`` means unseeded. + self._seed_base: int | None = None + self._epoch: int = 0 self._extra_boundaries = extra_boundaries or {} if not self._root.exists(): @@ -359,45 +398,77 @@ def __len__(self) -> int: return len(self._paths) def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling. + """Assign a base seed for reproducible, order-independent subsampling. Called by :class:`MeshDataset` when the DataLoader provides a - seed. Replaces any previously assigned generator. + seed. Stores ``generator.initial_seed()`` as the base seed; each + sample then derives its own generator from + ``(base_seed, epoch, index)``, so subsampling is reproducible + regardless of read order or worker thread. Parameters ---------- generator : torch.Generator - Generator to use for contiguous block selection. + Generator whose ``initial_seed()`` seeds all per-sample RNG. """ - self._subsample_generator = generator + self._seed_base = generator.initial_seed() def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch. + """Set the epoch used to vary per-sample RNG deterministically. - Produces a different (but deterministic) sequence of contiguous - blocks each epoch when a generator has been assigned via - :meth:`set_generator`. + The epoch is folded into each sample's derived seed, producing a + different (but deterministic) sequence of contiguous blocks each + epoch when a base seed has been assigned via :meth:`set_generator`. """ - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) + self._epoch = epoch def __getitem__(self, index: int) -> tuple[DomainMesh, dict[str, Any]]: dm = self._load_sample(index) + # Trim unused data before subsample/pin. Both references are lazy (no + # memmap materialization here): + # - drop_interior_cells: turn the interior into a point cloud so its + # point subsample takes the cheap contiguous-block path instead of a + # full slice_points remap + scattered reads. + # - drop_in_file_boundaries: skip the in-file boundaries entirely so we + # don't subsample (an expensive, GIL-held slice_points remap that + # starves worker-thread overlap) or pin a surface the model ignores. + if (self.drop_interior_cells and dm.interior.n_cells > 0) or ( + self.drop_in_file_boundaries and len(dm.boundary_names) > 0 + ): + interior = dm.interior + if self.drop_interior_cells and interior.n_cells > 0: + interior = Mesh( + points=interior.points, + point_data=interior.point_data, + global_data=interior.global_data, + ) + boundaries = {} if self.drop_in_file_boundaries else dm.boundaries + dm = DomainMesh( + interior=interior, + boundaries=boundaries, + global_data=dm.global_data, + ) + if self.subsample_n_cells is not None or self.subsample_n_points is not None: + generator = ( + None + if self._seed_base is None + else spawn_generator(self._seed_base, self._epoch, index) + ) sub_kw = dict( n_cells=self.subsample_n_cells, n_points=self.subsample_n_points, - generator=self._subsample_generator, + generator=generator, ) + interior = _subsample_mesh(dm.interior, **sub_kw) + boundaries = { + name: _subsample_mesh(dm.boundaries[name], **sub_kw) + for name in dm.boundary_names + } dm = DomainMesh( - interior=_subsample_mesh(dm.interior, **sub_kw), - boundaries={ - name: _subsample_mesh(dm.boundaries[name], **sub_kw) - for name in dm.boundary_names - }, + interior=interior, + boundaries=boundaries, global_data=dm.global_data, ) diff --git a/physicsnemo/datapipes/readers/numpy.py b/physicsnemo/datapipes/readers/numpy.py index 2f546d5916..a82e347a7e 100644 --- a/physicsnemo/datapipes/readers/numpy.py +++ b/physicsnemo/datapipes/readers/numpy.py @@ -112,7 +112,6 @@ def __init__( self.default_values = default_values or {} self.file_pattern = file_pattern self.index_key = index_key - self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -168,22 +167,12 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields - def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling.""" - self._subsample_generator = generator - - def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch.""" - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) - def _select_random_sections_from_slice( self, slice_start: int, slice_stop: int, n_points: int, + generator: Optional[torch.Generator] = None, ) -> slice: """ Select a random contiguous slice from a range. @@ -196,6 +185,9 @@ def _select_random_sections_from_slice( Stop index of the available range (exclusive). n_points : int Number of points to sample. + generator : torch.Generator, optional + Per-sample generator for reproducible, order-independent + selection. ``None`` uses the global default RNG. Returns ------- @@ -219,7 +211,7 @@ def _select_random_sections_from_slice( slice_start, slice_stop - n_points + 1, (1,), - generator=self._subsample_generator, + generator=generator, ).item() return slice(start, start + n_points) @@ -228,6 +220,7 @@ def _load_from_npz( npz: np.lib.npyio.NpzFile, index: Optional[int] = None, file_path: Optional[Path] = None, + generator: Optional[torch.Generator] = None, ) -> dict[str, torch.Tensor]: """ Load data from an npz file. @@ -241,6 +234,8 @@ def _load_from_npz( None for directory mode (load entire arrays). file_path : Path, optional Path to the file (for error messages). + generator : torch.Generator, optional + Per-sample generator for reproducible coordinated subsampling. Returns ------- @@ -272,7 +267,7 @@ def _load_from_npz( if field in npz.files: array_shape = npz[field].shape[0] subsample_slice = self._select_random_sections_from_slice( - 0, array_shape, n_points + 0, array_shape, n_points, generator=generator ) break @@ -301,12 +296,17 @@ def _load_from_npz( def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample.""" + # Derive a per-sample generator so coordinated subsampling is + # reproducible regardless of read order or worker thread. + generator = self._index_generator(index) if self._mode == "directory": file_path = self._files[index] with np.load(file_path) as npz: - return self._load_from_npz(npz, index=None, file_path=file_path) + return self._load_from_npz( + npz, index=None, file_path=file_path, generator=generator + ) else: # single - return self._load_from_npz(self._data, index=index) + return self._load_from_npz(self._data, index=index, generator=generator) def __len__(self) -> int: """Return number of samples.""" diff --git a/physicsnemo/datapipes/readers/tensorstore_zarr.py b/physicsnemo/datapipes/readers/tensorstore_zarr.py index 9bc407aea5..5ce3e6021a 100644 --- a/physicsnemo/datapipes/readers/tensorstore_zarr.py +++ b/physicsnemo/datapipes/readers/tensorstore_zarr.py @@ -155,7 +155,6 @@ def __init__( self._user_fields = fields self.default_values = default_values or {} self.group_pattern = group_pattern - self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -235,17 +234,6 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields - def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling.""" - self._subsample_generator = generator - - def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch.""" - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) - def _read_attributes(self, group_path: Path) -> dict[str, Any]: """Read attributes from a Zarr group (v2 or v3).""" store_spec = {"driver": "file", "path": str(group_path)} @@ -274,6 +262,7 @@ def _select_random_sections_from_slice( slice_start: int, slice_stop: int, n_points: int, + generator: Optional[torch.Generator] = None, ) -> slice: """ Select a random contiguous slice from a range. @@ -286,6 +275,9 @@ def _select_random_sections_from_slice( Stop index of the available range (exclusive). n_points : int Number of points to sample. + generator : torch.Generator, optional + Per-sample generator for reproducible, order-independent + selection. ``None`` uses the global default RNG. Returns ------- @@ -309,13 +301,15 @@ def _select_random_sections_from_slice( slice_start, slice_stop - n_points + 1, (1,), - generator=self._subsample_generator, + generator=generator, ).item() return slice(start, start + n_points) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample from a Zarr group using TensorStore.""" group_path = self._groups[index] + # Per-sample generator: reproducible regardless of read order/thread. + generator = self._index_generator(index) # Read attributes (stored as tensors in sample) attributes = self._read_attributes(group_path) @@ -368,7 +362,7 @@ def _load_sample(self, index: int) -> dict[str, torch.Tensor]: if key in stores: array_shape = stores[key].shape[0] subsample_slice = self._select_random_sections_from_slice( - 0, array_shape, n_points + 0, array_shape, n_points, generator=generator ) break diff --git a/physicsnemo/datapipes/readers/zarr.py b/physicsnemo/datapipes/readers/zarr.py index a27cee5447..2d2a70cbd2 100644 --- a/physicsnemo/datapipes/readers/zarr.py +++ b/physicsnemo/datapipes/readers/zarr.py @@ -144,7 +144,6 @@ def __init__( self.group_pattern = group_pattern self._cache_stores = cache_stores self._cached_stores: dict[Path, Any] = {} # Cache for opened zarr stores - self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -206,17 +205,6 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields - def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling.""" - self._subsample_generator = generator - - def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch.""" - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) - def _open_zarr_store(self, path: Path) -> Any: """ Open a zarr store, using cache if enabled. @@ -255,6 +243,7 @@ def _select_random_sections_from_slice( slice_start: int, slice_stop: int, n_points: int, + generator: Optional[torch.Generator] = None, ) -> slice: """ Select a random contiguous slice from a range. @@ -267,6 +256,9 @@ def _select_random_sections_from_slice( Stop index of the available range (exclusive). n_points : int Number of points to sample. + generator : torch.Generator, optional + Per-sample generator for reproducible, order-independent + selection. ``None`` uses the global default RNG. Returns ------- @@ -290,12 +282,14 @@ def _select_random_sections_from_slice( slice_start, slice_stop - n_points + 1, (1,), - generator=self._subsample_generator, + generator=generator, ).item() return slice(start, start + n_points) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample from a Zarr group.""" + # Per-sample generator: reproducible regardless of read order/thread. + generator = self._index_generator(index) if self._single_group_mode: # Single group: index into first dimension of each array group_path = self._groups[0] @@ -339,7 +333,7 @@ def _load_sample(self, index: int) -> dict[str, torch.Tensor]: else: array_shape = root[field].shape[0] subsample_slice = self._select_random_sections_from_slice( - 0, array_shape, n_points + 0, array_shape, n_points, generator=generator ) break diff --git a/physicsnemo/datapipes/transforms/mesh/transforms.py b/physicsnemo/datapipes/transforms/mesh/transforms.py index e64ff609e2..8dc182c64a 100644 --- a/physicsnemo/datapipes/transforms/mesh/transforms.py +++ b/physicsnemo/datapipes/transforms/mesh/transforms.py @@ -533,6 +533,31 @@ def __init__( else: raise ValueError("Provide one of 'stats_file' or 'fields'") + def to(self, device: torch.device | str) -> "NormalizeMeshFields": + """Move internal tensors and the nested field statistics to *device*. + + Extends :meth:`MeshTransform.to` by also moving the per-field + ``mean`` and ``std`` tensors in ``self._stats`` so the per-sample + ``.to()`` in :meth:`__call__` is a no-op. + + Parameters + ---------- + device : torch.device or str + Target device. + + Returns + ------- + NormalizeMeshFields + ``self``, for chaining. + """ + # Base .to() moves only bare tensor attrs; move the nested stats too + # so the per-sample .to() in __call__ is a no-op (no H2D copy/sync). + super().to(device) + for s in self._stats.values(): + s["mean"] = s["mean"].to(self._device) + s["std"] = s["std"].to(self._device) + return self + def __call__(self, mesh: Mesh) -> Mesh: ### Clone and z-score the targeted association's TensorDict in ### place; fields absent from `_stats` (or absent from the mesh) @@ -589,10 +614,9 @@ def inverse_tensor( dim = 1 if ftype == "scalar" else n_spatial_dims if name in self._stats: stats = self._stats[name] - mean = stats["mean"].to(dtype=tensor.dtype, device=tensor.device) - std = stats["std"].to(dtype=tensor.dtype, device=tensor.device) out[..., idx : idx + dim] = ( - out[..., idx : idx + dim] * (std + self._eps) + mean + out[..., idx : idx + dim] * (stats["std"] + self._eps) + + stats["mean"] ) idx += dim return out @@ -629,9 +653,7 @@ def _inverse_field(name: str, val: torch.Tensor) -> torch.Tensor: stats = self._stats.get(name) if stats is None: return val - mean = stats["mean"].to(dtype=val.dtype, device=val.device) - std = stats["std"].to(dtype=val.dtype, device=val.device) - return val * (std + self._eps) + mean + return val * (stats["std"] + self._eps) + stats["mean"] ### ``named_apply`` is typed ``TensorDict | None`` for its ### in-place mode; the out-of-place path always returns a TD. diff --git a/physicsnemo/mesh/transformations/geometric.py b/physicsnemo/mesh/transformations/geometric.py index 2b7bddb80f..c824329f05 100644 --- a/physicsnemo/mesh/transformations/geometric.py +++ b/physicsnemo/mesh/transformations/geometric.py @@ -502,10 +502,11 @@ def transform( if matrix.shape[0] == matrix.shape[1]: det = matrix.det() + ### The runtime det test syncs (host readback of a cuda tensor). if assume_invertible is not None: is_invertible = assume_invertible else: - is_invertible = det.abs() > 1e-10 + is_invertible = bool(det.abs() > 1e-10) if is_invertible: det_sign = det.sign() diff --git a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py index e3d3c8461c..a88ba972ac 100644 --- a/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py +++ b/physicsnemo/nn/functional/neighbors/radius_search/_warp_impl.py @@ -323,10 +323,15 @@ def radius_search_impl( # Deterministic output path: always use batched 2D kernel launch # --------------------------------------------------------------- - # Build warp array of grid IDs - grid_ids_tensor = torch.tensor( - [g.id for g in grids], dtype=torch.int64, device=points.device + # Build warp array of grid IDs. + # Construct in pinned host memory first so the H2D copy is + # stream-ordered (non_blocking) rather than a blocking cudaMemcpy. + _grid_ids_cpu = torch.tensor( + [g.id for g in grids], + dtype=torch.int64, + pin_memory=torch.cuda.is_available(), ) + grid_ids_tensor = _grid_ids_cpu.to(points.device, non_blocking=True) wp_grid_ids = wp.from_torch( grid_ids_tensor, dtype=wp.uint64, return_ctype=True ) diff --git a/test/datapipes/core/test_dataset.py b/test/datapipes/core/test_dataset.py index c6f382725c..f46424559d 100644 --- a/test/datapipes/core/test_dataset.py +++ b/test/datapipes/core/test_dataset.py @@ -742,3 +742,60 @@ def test_gpu_iteration_with_transforms(self, numpy_data_dir): assert data["positions"].device.type == "cuda" torch.cuda.synchronize() + + +class TestDatasetSubsamplingReproducibility: + """Reader subsampling RNG is reproducible across the threaded path. + + Reader subsampling derives its generator per ``(base_seed, epoch, + index)``, so the multi-worker ``prefetch`` path matches synchronous + loading and two seeded runs agree -- even with ``num_workers > 1``. + """ + + def _make_dataset(self, data_dir, seed: int) -> dp.Dataset: + reader = dp.NumpyReader( + data_dir, + file_pattern="sample_*.npz", + fields=["positions", "features"], + coordinated_subsampling={ + "n_points": 50, + "target_keys": ["positions", "features"], + }, + ) + dataset = dp.Dataset(reader, num_workers=2) + dataset.set_generator(torch.Generator().manual_seed(seed)) + return dataset + + def test_prefetch_matches_synchronous(self, numpy_data_dir): + """Threaded prefetch (num_workers=2) matches synchronous loading.""" + indices = list(range(10)) + + sync_ds = self._make_dataset(numpy_data_dir, seed=2024) + expected = {i: sync_ds._load(i)[0]["positions"] for i in indices} + + async_ds = self._make_dataset(numpy_data_dir, seed=2024) + for i in indices: + async_ds.prefetch(i) + actual = {i: async_ds[i][0]["positions"] for i in indices} + async_ds.close() + + for i in indices: + assert torch.equal(actual[i], expected[i]) + + def test_two_seeded_runs_identical(self, numpy_data_dir): + """Two datasets seeded identically yield identical prefetch results.""" + indices = list(range(10)) + + ds_a = self._make_dataset(numpy_data_dir, seed=99) + ds_b = self._make_dataset(numpy_data_dir, seed=99) + for i in indices: + ds_a.prefetch(i) + ds_b.prefetch(i) + + for i in indices: + a = ds_a[i][0]["positions"] + b = ds_b[i][0]["positions"] + assert torch.equal(a, b) + + ds_a.close() + ds_b.close() diff --git a/test/datapipes/core/test_streaming.py b/test/datapipes/core/test_streaming.py new file mode 100644 index 0000000000..49c1b9426a --- /dev/null +++ b/test/datapipes/core/test_streaming.py @@ -0,0 +1,494 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the lazy preload path and the iterable (generator) dataset path. + +Stage 1 covers the lazy, FIFO-handle preload path (IOPump laziness, +``BATCH_BOUNDARY`` reassembly, opaque work items, the ``submit``/``consume`` +primitive). Stage 2 covers iterable datasets driven main-thread-only +(finite/capped-infinite generators, ``drop_last``, self-batching +pass-through, reproducibility, and no worker pool). CUDA-guarded tests +exercise stream-bound preprocessing. +""" + +from __future__ import annotations + +import threading + +import numpy as np +import pytest +import torch +from tensordict import TensorDict + +import physicsnemo.datapipes as dp +from physicsnemo.datapipes.io_pump import BATCH_BOUNDARY, IOPump +from physicsnemo.datapipes.protocols import DatasetBase, IterableDatasetBase + +# ============================================================================ +# Stage 1 -- IOPump (lazy, FIFO, batch boundaries) +# ============================================================================ + + +class TestIOPump: + """Tests for the lazy, self-driving prefetch pump.""" + + def test_lazy_bounded_pull_on_infinite_source(self): + """Pump pulls an unbounded source lazily, bounded by depth.""" + pulled: list[int] = [] + + def source(): + i = 0 + while True: + pulled.append(i) + yield i + i += 1 + + depth = 3 + pump = IOPump(source(), lambda x: x, depth=depth) + out = [] + for item in pump: + out.append(item) + if len(out) == 5: + break + pump.stop() + + assert out == [0, 1, 2, 3, 4] + # The dispatcher must not have run far ahead of what was consumed: + # at most consumed + depth + a small slack for the in-flight pull. + assert len(pulled) <= 5 + depth + 2 + + def test_batch_boundary_reassembly_irregular(self): + """Boundaries delimit dynamically-sized batches without slot use.""" + source = [0, 1, BATCH_BOUNDARY, 2, BATCH_BOUNDARY, 3, 4, 5, BATCH_BOUNDARY] + pump = IOPump(iter(source), lambda x: x, depth=2) + + batches: list[list[int]] = [] + current: list[int] = [] + for item in pump: + if item is BATCH_BOUNDARY: + batches.append(current) + current = [] + else: + current.append(item) + pump.stop() + + assert batches == [[0, 1], [2], [3, 4, 5]] + + def test_fifo_order_preserved(self): + """Handles are yielded in the order work items were pulled.""" + pump = IOPump(iter(range(20)), lambda x: x * 10, depth=4) + out = list(pump) + pump.stop() + assert out == [x * 10 for x in range(20)] + + def test_dispatch_error_surfaces_not_hangs(self): + """A dispatch exception is raised on the consumer, never a hang.""" + + def boom(x): + raise RuntimeError("dispatch failed") + + pump = IOPump(iter(range(5)), boom, depth=2) + with pytest.raises(RuntimeError, match="dispatch failed"): + list(pump) + pump.stop() + + def test_source_error_surfaces_not_hangs(self): + """A failing source is raised on the consumer, never a hang.""" + + def source(): + yield 0 + raise ValueError("source failed") + + pump = IOPump(source(), lambda x: x, depth=2) + with pytest.raises(ValueError, match="source failed"): + list(pump) + pump.stop() + + +# ============================================================================ +# Stage 1 -- submit / consume FIFO primitive with opaque work items +# ============================================================================ + + +class _DescriptorDataset(DatasetBase): + """Map-style dataset keyed by an opaque (non-int) descriptor.""" + + def __init__(self): + super().__init__(num_workers=2) + self._store = {"alpha": 1.0, "beta": 2.0, "gamma": 3.0} + + def _load(self, key): + if key == "explode": + raise KeyError("no such key") + return TensorDict({"x": torch.tensor([self._store[key]])}), {"key": key} + + def __len__(self): + return len(self._store) + + +class TestSubmitConsume: + """Tests for the FIFO submit/consume primitive.""" + + def test_opaque_descriptor_roundtrip(self): + """submit/consume works with non-int, string work items.""" + ds = _DescriptorDataset() + try: + handle = ds.submit("beta") + data, metadata = ds.consume(handle) + assert metadata["key"] == "beta" + assert data["x"].item() == 2.0 + finally: + ds.close() + + def test_submit_consume_fifo_independent_of_value(self): + """Multiple in-flight handles consume to their own results.""" + ds = _DescriptorDataset() + try: + handles = [ds.submit(k) for k in ("alpha", "beta", "gamma")] + keys = [ds.consume(h)[1]["key"] for h in handles] + assert keys == ["alpha", "beta", "gamma"] + finally: + ds.close() + + def test_producer_error_reraised_on_consume(self): + """An error raised in the producer surfaces on consume.""" + ds = _DescriptorDataset() + try: + handle = ds.submit("explode") + with pytest.raises(KeyError): + ds.consume(handle) + finally: + ds.close() + + +# ============================================================================ +# Stage 1 -- DataLoader laziness over the sampler +# ============================================================================ + + +class _CountingSampler: + """Sequential sampler that records how many indices it has yielded.""" + + def __init__(self, n): + self.n = n + self.consumed = 0 + + def __iter__(self): + self.consumed = 0 + for i in range(self.n): + self.consumed += 1 + yield i + + def __len__(self): + return self.n + + +class TestDataLoaderLazyPreload: + """The preload path must not materialize the whole epoch up front.""" + + def test_sampler_not_fully_drained_on_early_break(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + sampler = _CountingSampler(10) + # Keep the prefetch window small (single stream, one batch ahead) so + # the bound is well under the epoch; with the default num_streams the + # in-flight depth would cover this tiny 10-sample dataset entirely. + loader = dp.DataLoader( + dataset, + batch_size=2, + sampler=sampler, + prefetch_factor=1, + num_streams=1, + ) + + first = next(iter(loader)) + assert first["positions"].shape[0] == 2 + # Only a bounded prefix of the sampler should have been consumed, + # never the full epoch, after pulling a single batch. + assert sampler.consumed < 10 + + def test_preload_matches_sequential_order(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader( + dataset, batch_size=3, shuffle=False, collate_metadata=True + ) + indices = [] + for _batch, metadata_list in loader: + indices.extend(m["index"] for m in metadata_list) + assert indices == list(range(10)) + + +# ============================================================================ +# Stage 2 -- iterable datasets +# ============================================================================ + + +class _RangeIterable(IterableDatasetBase): + """Finite per-sample generator yielding (TensorDict, metadata).""" + + def __init__(self, n, dim=4): + self.n = n + self.dim = dim + + def __iter__(self): + for i in range(self.n): + data = TensorDict({"x": torch.full((self.dim,), float(i))}) + yield data, {"index": i} + + +class _BatchIterable(IterableDatasetBase): + """Self-batching generator yielding ready-made batches.""" + + yields_batches = True + + def __init__(self, n_batches, batch=4, dim=3): + self.n_batches = n_batches + self.batch = batch + self.dim = dim + + def __iter__(self): + for b in range(self.n_batches): + yield TensorDict( + {"x": torch.full((self.batch, self.dim), float(b))}, + batch_size=[self.batch], + ) + + +class _SeededIterable(IterableDatasetBase): + """Per-(epoch, position) seeded generator for reproducibility tests.""" + + def __init__(self, n, base_seed=0): + self.n = n + self.base_seed = base_seed + self.epoch = 0 + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + for position in range(self.n): + seed = int( + np.random.SeedSequence( + [self.base_seed, self.epoch, position] + ).generate_state(1)[0] + ) + g = torch.Generator().manual_seed(seed) + yield TensorDict({"x": torch.rand(3, generator=g)}), {"position": position} + + +class _ThreadRecordingIterable(IterableDatasetBase): + """Records which thread the generator runs on.""" + + def __init__(self, n): + self.n = n + self.threads = [] + + def __iter__(self): + for i in range(self.n): + self.threads.append(threading.current_thread()) + yield TensorDict({"x": torch.zeros(2)}), {"index": i} + + +class TestIterableDataLoader: + """Tests for the main-thread-only iterable (generator) path.""" + + def test_per_sample_batching(self): + loader = dp.DataLoader(_RangeIterable(10), batch_size=4) + batches = list(loader) + # 10 samples / 4 -> [4, 4, 2] + assert [b["x"].shape[0] for b in batches] == [4, 4, 2] + + def test_per_sample_drop_last(self): + loader = dp.DataLoader(_RangeIterable(10), batch_size=4, drop_last=True) + batches = list(loader) + # Trailing partial batch dropped -> [4, 4] + assert [b["x"].shape[0] for b in batches] == [4, 4] + + def test_self_batching_passthrough(self): + # The loader batch_size is intentionally different from the generator's + # to prove it is ignored for self-batching datasets. + loader = dp.DataLoader(_BatchIterable(3, batch=5), batch_size=2) + batches = list(loader) + assert len(batches) == 3 + assert all(b["x"].shape[0] == 5 for b in batches) + + def test_len_raises_for_iterable(self): + loader = dp.DataLoader(_RangeIterable(10), batch_size=2) + with pytest.raises(TypeError): + len(loader) + + def test_capped_infinite_consumes_without_length(self): + """A long generator is iterated batch-by-batch; len() is never used.""" + + class _BigIterable(IterableDatasetBase): + def __iter__(self): + i = 0 + while i < 10_000: + yield TensorDict({"x": torch.zeros(2)}), {"index": i} + i += 1 + + loader = dp.DataLoader(_BigIterable(), batch_size=4) + seen = 0 + for _batch in loader: + seen += 1 + if seen == 3: + break + assert seen == 3 + + def test_shuffle_warns_for_iterable(self): + with pytest.warns(UserWarning, match="ignored for iterable"): + dp.DataLoader(_RangeIterable(4), batch_size=2, shuffle=True) + + def test_reproducible_across_runs_distinct_across_epochs(self): + loader = dp.DataLoader(_SeededIterable(6), batch_size=3) + + loader.set_epoch(0) + run_a = torch.cat([b["x"].reshape(-1) for b in loader]) + loader.set_epoch(0) + run_b = torch.cat([b["x"].reshape(-1) for b in loader]) + loader.set_epoch(1) + run_c = torch.cat([b["x"].reshape(-1) for b in loader]) + + assert torch.equal(run_a, run_b) # same epoch -> identical + assert not torch.equal(run_a, run_c) # different epoch -> distinct + + def test_runs_on_main_thread_no_worker_pool(self): + dataset = _ThreadRecordingIterable(4) + + names_before = {t.name for t in threading.enumerate()} + loader = dp.DataLoader(dataset, batch_size=2) + _ = list(loader) + names_after = {t.name for t in threading.enumerate()} + + # Generation happened on the main thread only. + assert dataset.threads, "generator did not run" + assert all(t is threading.main_thread() for t in dataset.threads) + # No prefetch worker pool / pump thread was spawned for this path. + new_threads = names_after - names_before + assert not any( + n.startswith("datapipe_prefetch") or n == "datapipe_pump" + for n in new_threads + ) + + +# ============================================================================ +# CUDA-guarded -- stream-bound consume +# ============================================================================ + + +class TestStreamBoundConsume: + """Preprocessing on an assigned stream (the default-stream workaround + is gone, so transforms run on the side stream).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_submit_consume_on_side_stream(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset( + reader, + device="cuda:0", + transforms=dp.SubsamplePoints( + input_keys=["positions", "features"], n_points=50 + ), + ) + try: + stream = torch.cuda.Stream() + handle = dataset.submit(0, stream=stream) + data, _metadata = dataset.consume(handle) + torch.cuda.synchronize() + assert data["positions"].device.type == "cuda" + assert data["positions"].shape[0] == 50 + finally: + dataset.close() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_dataloader_streams_match_synchronous(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + + ref = dp.Dataset(reader, device="cuda:0") + ref_loader = dp.DataLoader(ref, batch_size=2, shuffle=False, prefetch_factor=0) + expected = [b["positions"].sum().item() for b in ref_loader] + + reader2 = dp.NumpyReader(numpy_data_dir, pin_memory=True) + streamed = dp.Dataset(reader2, device="cuda:0") + loader = dp.DataLoader( + streamed, + batch_size=2, + shuffle=False, + prefetch_factor=2, + num_streams=4, + use_streams=True, + ) + got = [b["positions"].sum().item() for b in loader] + torch.cuda.synchronize() + assert got == pytest.approx(expected, rel=1e-5) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_prefetch_defers_wait_until_after_previous_batch( + self, numpy_data_dir, monkeypatch + ): + """The compute-stream wait for batch N+1 must be enqueued *after* + batch N is yielded, not during the lookahead consume of N+1. + + This is the overlap-correctness invariant: if the wait for the next + batch's preprocessing landed on the compute stream before the current + batch's model work, the model would block on the next batch's + preprocessing (no overlap). With the deferred wait, the compute-stream + order is ``..., model_{N-1}, wait(prep_N), model_N, ...`` so a batch's + preprocessing overlaps the previous batch's compute. + + We spy on ``Stream.wait_event`` and record the interleaving of waits + and yields: the wait for batch 0 must precede yielding batch 0, while + the wait for batch 1 must come *after* batch 0 is yielded. + """ + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset(reader, device="cuda:0") + loader = dp.DataLoader( + dataset, + batch_size=1, + shuffle=False, + prefetch_factor=2, + num_streams=4, + use_streams=True, + ) + + order: list[str] = [] + stream_cls = type(torch.cuda.current_stream()) + real_wait = stream_cls.wait_event + + def spy_wait(self, event): + order.append("wait") + return real_wait(self, event) + + monkeypatch.setattr(stream_cls, "wait_event", spy_wait) + + try: + for i, _batch in enumerate(loader): + order.append(f"yield{i}") + if i >= 2: + break + finally: + torch.cuda.synchronize() + + wait_indices = [k for k, ev in enumerate(order) if ev == "wait"] + assert len(wait_indices) >= 2, f"expected >=2 waits, got order={order}" + yield0_idx = order.index("yield0") + # Batch 0's preprocessing is gated before batch 0 is handed out. + assert wait_indices[0] < yield0_idx, order + # Batch 1's preprocessing wait is deferred until after batch 0 is + # yielded (the regression guard: the old code waited inline during the + # lookahead consume of batch 1, before batch 0 was ever yielded). + assert wait_indices[1] > yield0_idx, order diff --git a/test/datapipes/readers/test_numpy_consolidated.py b/test/datapipes/readers/test_numpy_consolidated.py index c564ef5534..884712f8ce 100644 --- a/test/datapipes/readers/test_numpy_consolidated.py +++ b/test/datapipes/readers/test_numpy_consolidated.py @@ -239,6 +239,110 @@ def test_supports_coordinated_subsampling(self): assert reader_with_config._coordinated_subsampling_config is not None +class TestNumpyReaderSubsamplingRNG: + """Order- and thread-independent reproducibility of subsampling RNG. + + Reader subsampling derives its generator from ``(base_seed, epoch, + index)``, so a given sample's draw is identical regardless of read + order or worker thread (the threaded producer path), and varies + deterministically per index and per epoch. + """ + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + self.n_samples = 6 + self.n_points = 1000 + self.subsample_points = 100 + for i in range(self.n_samples): + coords = np.random.randn(self.n_points, 3).astype(np.float32) + features = np.random.randn(self.n_points, 4).astype(np.float32) + np.savez( + self.temp_path / f"sample_{i:03d}.npz", + coords=coords, + features=features, + ) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _make_reader(self) -> NumpyReader: + return NumpyReader( + self.temp_path, + file_pattern="sample_*.npz", + fields=["coords", "features"], + coordinated_subsampling={ + "n_points": self.subsample_points, + "target_keys": ["coords", "features"], + }, + ) + + def test_subsample_order_independent(self): + """Reading an index gives the same result regardless of read order.""" + reader = self._make_reader() + reader.set_generator(torch.Generator().manual_seed(123)) + + forward = {i: reader[i][0]["coords"] for i in range(self.n_samples)} + + reader.set_generator(torch.Generator().manual_seed(123)) + reverse = {i: reader[i][0]["coords"] for i in reversed(range(self.n_samples))} + + for i in range(self.n_samples): + assert torch.equal(forward[i], reverse[i]) + + def test_subsample_thread_independent(self): + """Concurrent reads match single-threaded reads for each index.""" + from concurrent.futures import ThreadPoolExecutor + + reader = self._make_reader() + reader.set_generator(torch.Generator().manual_seed(7)) + + serial = {i: reader[i][0]["coords"] for i in range(self.n_samples)} + + indices = list(range(self.n_samples)) * 3 + with ThreadPoolExecutor(max_workers=4) as pool: + results = list(pool.map(lambda i: (i, reader[i][0]["coords"]), indices)) + + for i, coords in results: + assert torch.equal(coords, serial[i]) + + def test_subsample_distinct_across_indices(self): + """Different indices yield different subsamples under one seed.""" + reader = self._make_reader() + reader.set_generator(torch.Generator().manual_seed(0)) + + a = reader[0][0]["coords"] + b = reader[1][0]["coords"] + assert not torch.equal(a, b) + + def test_subsample_epoch_changes_output(self): + """Epoch is folded into the seed, changing the draw deterministically.""" + reader = self._make_reader() + + reader.set_generator(torch.Generator().manual_seed(0)) + reader.set_epoch(0) + e0 = reader[0][0]["coords"] + + reader.set_generator(torch.Generator().manual_seed(0)) + reader.set_epoch(1) + e1 = reader[0][0]["coords"] + assert not torch.equal(e0, e1) + + # Re-deriving epoch 0 reproduces the original draw. + reader.set_generator(torch.Generator().manual_seed(0)) + reader.set_epoch(0) + assert torch.equal(reader[0][0]["coords"], e0) + + def test_unseeded_does_not_raise(self): + """Without a seed, subsampling falls back to the global RNG.""" + reader = self._make_reader() + data, _ = reader[0] + assert data["coords"].shape == (self.subsample_points, 3) + + class TestNumpyReaderMemoryManagement: """Test memory management and cleanup.""" diff --git a/test/datapipes/transforms/test_mesh_augmentations.py b/test/datapipes/transforms/test_mesh_augmentations.py index 90ddedf89a..b04b55fa3f 100644 --- a/test/datapipes/transforms/test_mesh_augmentations.py +++ b/test/datapipes/transforms/test_mesh_augmentations.py @@ -732,14 +732,15 @@ def test_mesh_dataset_set_generator_distributes(self, tmp_path): master = torch.Generator().manual_seed(42) ds.set_generator(master) - # Reader and both transforms should have received generators - assert reader._subsample_generator is not None + # Reader (base seed) and both transforms should have received seeds + assert reader._seed_base is not None assert transforms[0]._generator is not None assert transforms[1]._generator is not None - # Generators should have different seeds (independent forks) + # Seeds should differ (independent forks): the reader's base seed + # plus each transform's generator seed. seeds = { - reader._subsample_generator.initial_seed(), + reader._seed_base, transforms[0]._generator.initial_seed(), transforms[1]._generator.initial_seed(), }