diff --git a/docs/api/datapipes/physicsnemo.datapipes.rst b/docs/api/datapipes/physicsnemo.datapipes.rst index 82b438f25b..047360e65b 100644 --- a/docs/api/datapipes/physicsnemo.datapipes.rst +++ b/docs/api/datapipes/physicsnemo.datapipes.rst @@ -139,6 +139,29 @@ of ``pin_memory`` from the ``DataLoader`` class to the ``Reader`` classes. This is because of the much earlier GPU data transfer in the PhysicsNeMo datapipe compared to PyTorch. +The ``DataLoader`` drives one of two mutually-exclusive paths, selected by +dataset type: + +- **Map-style preload path** (:class:`~physicsnemo.datapipes.DatasetBase`, + e.g. ``Dataset``, ``MeshDataset``): a dedicated dispatcher thread keeps a + *bounded* number of samples in flight by pulling the index stream + **lazily** under backpressure and submitting host-only loads to a worker + pool. The main thread consumes the resulting samples in order + (host-to-device transfer plus transforms on a preprocessing stream) and + reassembles batches from boundary markers, so the full epoch is never + materialized up front and irregular batch sizes are supported. +- **Iterable generator path** (:class:`~physicsnemo.datapipes.IterableDatasetBase`): + a generator dataset driven entirely on the main thread (no sampler, no + worker pool); see `Iterable Datasets`_ below. + +In both paths, **all device-kernel launches happen on the single main +thread**. This is the real constraint for Warp-based transforms: Warp may +launch on any CUDA stream as long as the launch comes from the main thread +and Warp's current stream is bound to the torch stream in use. +Preprocessing therefore runs on a side (preprocessing) stream and is +ordered against the compute stream with a CUDA event, so it overlaps +training without ever blocking the host. + .. autoclass:: physicsnemo.datapipes.dataloader.DataLoader :members: :show-inheritance: @@ -179,6 +202,43 @@ consistent keys. Because the exact collation details differ by dataset, the :show-inheritance: +Iterable Datasets +^^^^^^^^^^^^^^^^^ + +Map-style datasets (``Dataset``, ``MeshDataset``) assume a fixed length and a +sampler that hands out indices. Some workloads have neither: an online +simulation, a procedural generator, or any source that produces samples on +the fly with no meaningful ``__len__``. For those, subclass +:class:`~physicsnemo.datapipes.IterableDatasetBase` and yield samples from +``__iter__``. The ``DataLoader`` detects an iterable dataset automatically +and switches to the main-thread-only generator path; ``shuffle`` and +``sampler`` are ignored (a warning is issued) and ``len(loader)`` raises, +since the length is unknown. + +An iterable dataset chooses one of two emission modes via the +``yields_batches`` attribute: + +- ``yields_batches = False`` (default): ``__iter__`` yields individual + samples and the ``DataLoader`` collates them into batches of + ``batch_size`` (honoring ``drop_last``). +- ``yields_batches = True``: ``__iter__`` yields fully-formed batches and the + ``DataLoader`` passes them through without further collation, which is the + natural fit for a generator that already produces a batch per step. + +Reproducibility follows a per-``(epoch, position)`` scheme rather than the +map-style per-``(epoch, index)`` scheme: implement ``set_epoch`` and/or +``set_generator`` to seed deterministically from the iteration position. +Because the generator runs on the main thread, Warp kernels inside it are +safe on any preprocessing stream the ``DataLoader`` binds. See the online +simulation tutorial in the +`examples directory `_ +for a runnable Warp ``Darcy2D`` generator wired through this path. + +.. autoclass:: physicsnemo.datapipes.IterableDatasetBase + :members: + :show-inheritance: + + Readers ^^^^^^^ diff --git a/examples/minimal/datapipes/README.md b/examples/minimal/datapipes/README.md index 6934467232..949d615e44 100644 --- a/examples/minimal/datapipes/README.md +++ b/examples/minimal/datapipes/README.md @@ -215,3 +215,41 @@ python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud python tutorial_04_hydra_config.py --config-name tutorial_04_pointcloud \ subsample.n_points=5000 ``` + +### Tutorial 5: Iterable Datasets for Online Simulation + +**File:** `tutorial_5_iterable_online_simulation.py` + +Not every dataset is map-style. When data is *generated* on the fly -- +an online physics simulation, a procedural sampler, an unbounded stream -- +there is no fixed length and no index to address. PhysicsNeMo models these +with `IterableDatasetBase`: + +- Iterable datasets only support iteration: no `__len__`, no `__getitem__`, + no sampler, and no prefetch worker pool. +- They run entirely on the **main thread**, so they may freely launch Warp + kernels and use CUDA streams. Warp's only requirement is a single + launching thread, which the main thread satisfies -- this is what makes + an online GPU simulation safe on this path. +- The `DataLoader` still drives generation on a preprocessing stream (when + `use_streams=True`) and hands each item to the compute stream via a CUDA + event, so generating the next batch can overlap training on the current + one. + +This tutorial wraps the built-in Warp `Darcy2D` flow generator (a +multigrid Jacobi solver) as an iterable dataset and iterates it through the +`DataLoader`. Because `Darcy2D` produces a full batch per step, the wrapper +is *self-batching* (`yields_batches = True`) and the loader passes each +batch through unchanged. + +**When to use the iterable path vs the map/descriptor path:** use the +map-style `Dataset` when you have a fixed corpus on disk addressable by +index (storage-backed, benefits from threaded prefetch). Use an +`IterableDatasetBase` when samples are produced by a generator/simulation, +the length is unbounded or unknown, or the producer itself must launch +device kernels. + +```bash +# Requires a CUDA device (the Darcy solver runs Warp kernels on the GPU) +python tutorial_5_iterable_online_simulation.py +``` diff --git a/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py b/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py new file mode 100644 index 0000000000..27dd567024 --- /dev/null +++ b/examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tutorial 5: Iterable datasets for online simulation. + +Most datasets are *map-style*: a fixed number of samples addressed by +index, read from storage. Some workloads instead *generate* data on the +fly -- an online physics simulation, a procedural sampler, a streaming +source with no fixed length. These are *iterable* datasets. + +PhysicsNeMo models this with :class:`IterableDatasetBase`. Unlike a +map-style dataset, an iterable dataset: + +- has no length and no indexing -- it only supports iteration; +- is driven entirely on the **main thread** (no worker pool), so it may + freely launch Warp kernels and use CUDA streams. This is exactly the + property that makes an online GPU simulation safe here: Warp's + constraint is a single launching thread, which the main thread + satisfies. + +This tutorial wraps the built-in Warp ``Darcy2D`` flow generator -- which +solves the 2D Darcy equation with a multigrid Jacobi solver and yields a +ready-made batch each step -- as an iterable dataset and drives it through +the PhysicsNeMo :class:`DataLoader`. + +Run with:: + + python tutorial_5_iterable_online_simulation.py + +Requires a CUDA device (the Darcy solver runs Warp kernels on the GPU). +""" + +from __future__ import annotations + +import time + +import numpy as np +import torch + +from physicsnemo.datapipes import DataLoader, IterableDatasetBase +from physicsnemo.datapipes.benchmarks.darcy import Darcy2D + + +class DarcyOnlineDataset(IterableDatasetBase): + """Online 2D Darcy-flow simulation as an iterable dataset. + + Wraps :class:`~physicsnemo.datapipes.benchmarks.darcy.Darcy2D`, whose + iterator runs the solver and yields a full ``{"permeability", "darcy"}`` + batch per step. The underlying generator is infinite, so this wrapper + caps it at ``num_batches`` per epoch to give the loader a finite stream. + + Because ``Darcy2D`` already produces a complete batch, this is a + *self-batching* dataset: we set :attr:`yields_batches` so the loader + passes each batch through unchanged instead of re-collating. + + Parameters + ---------- + num_batches : int + Number of batches to emit per epoch. + resolution : int, default=64 + Simulation grid resolution. + batch_size : int, default=8 + Number of simulations per batch. + device : str, default="cuda" + Device the Warp solver runs on. + base_seed : int, default=0 + Base seed for reproducible permeability sampling. + """ + + # Darcy2D emits a full batch per step; do not re-collate. + yields_batches = True + + def __init__( + self, + num_batches: int, + *, + resolution: int = 64, + batch_size: int = 8, + device: str = "cuda", + base_seed: int = 0, + ) -> None: + self._sim = Darcy2D( + resolution=resolution, + batch_size=batch_size, + device=device, + normaliser={"permeability": (1.25, 0.75), "darcy": (4.52e-2, 2.79e-2)}, + ) + self._num_batches = num_batches + self._base_seed = base_seed + self._epoch = 0 + + def set_epoch(self, epoch: int) -> None: + """Select the epoch so each epoch draws a distinct, reproducible stream.""" + self._epoch = epoch + + def __iter__(self): + # One solver iterator drives the simulation; we pull a bounded + # number of steps from the otherwise-infinite generator. + sim_iter = iter(self._sim) + for position in range(self._num_batches): + # Per-(epoch, position) seeding: the stream is reproducible + # across runs and distinct across epochs and positions. There is + # no stable sample index for a generator, so we key on the + # monotonic emission position instead. + seed = np.random.SeedSequence( + [self._base_seed, self._epoch, position] + ).generate_state(1)[0] + np.random.seed(int(seed)) + yield next(sim_iter) + + +def main() -> None: + if not torch.cuda.is_available(): + print("This tutorial requires a CUDA device (Warp Darcy solver). Skipping.") + return + + num_epochs = 5 + num_batches = 16 + dataset = DarcyOnlineDataset(num_batches=num_batches, resolution=64, batch_size=8) + + # use_streams=True runs each simulation step on a preprocessing stream + # and hands the result to the compute stream via a CUDA event, so + # generation of the next batch can overlap training on the current one. + loader = DataLoader(dataset, use_streams=True, seed=0) + + # Iterable datasets have no length: this will take the exception path.oOOh, + try: + len(loader) + except TypeError as exc: + print(f"len(loader) is undefined for iterable datasets: {exc}") + + for epoch in range(num_epochs): + loader.set_epoch(epoch) + print(f"\nEpoch {epoch}") + host_times = [] + cuda_events = [] + epoch_start = time.perf_counter() + prev_host = epoch_start + cuda_start = torch.cuda.Event(enable_timing=True) + cuda_start.record(torch.cuda.current_stream()) + for i, batch in enumerate(loader): + host_now = time.perf_counter() + permeability = batch["permeability"] + darcy = batch["darcy"] + cuda_end = torch.cuda.Event(enable_timing=True) + cuda_end.record(torch.cuda.current_stream()) + cuda_events.append((cuda_start, cuda_end)) + cuda_start = cuda_end + + host_times.append(host_now - prev_host) + prev_host = host_now + print( + f" batch {i}: permeability {tuple(permeability.shape)} " + f"on {permeability.device}, darcy {tuple(darcy.shape)}, " + f"host_dt={host_times[-1]:.4f}s" + ) + + torch.cuda.synchronize() + cuda_times_ms = [start.elapsed_time(end) for start, end in cuda_events] + epoch_wall = time.perf_counter() - epoch_start + mean_host = sum(host_times) / len(host_times) + mean_cuda = sum(cuda_times_ms) / len(cuda_times_ms) + print( + f" epoch summary: batches={len(host_times)}, wall={epoch_wall:.3f}s, " + f"host_mean={mean_host:.4f}s, cuda_mean={mean_cuda:.2f}ms, " + f"cuda_min={min(cuda_times_ms):.2f}ms, cuda_max={max(cuda_times_ms):.2f}ms" + ) + + # Train as usual; the batches are ordinary device tensors. + + +if __name__ == "__main__": + main() diff --git a/physicsnemo/core/function_spec.py b/physicsnemo/core/function_spec.py index 945daa3887..d076406a4a 100644 --- a/physicsnemo/core/function_spec.py +++ b/physicsnemo/core/function_spec.py @@ -29,6 +29,44 @@ from physicsnemo.core.version_check import check_version_spec +# Cache of Warp stream wrappers keyed by the underlying CUDA stream handle. +# +# ``warp.stream_from_torch`` wraps a torch-owned (external) CUDA stream, and the +# resulting ``warp.Stream`` unregisters that handle from Warp on ``__del__``. +# Creating a fresh wrapper on every launch therefore churns register/unregister +# on a shared stream; unregistering while another wrapper -- or an in-flight +# kernel -- still uses the stream corrupts it (illegal memory access). Keeping +# one long-lived wrapper per handle registers each stream exactly once. +_WARP_STREAM_CACHE: Dict[int, Any] = {} + + +def warp_stream_from_torch(torch_stream: "torch.cuda.Stream") -> Any: + """Return a cached Warp stream wrapping *torch_stream*. + + Wrapping a torch stream registers it with Warp; the wrapper unregisters it + on garbage collection. Caching one wrapper per CUDA stream handle keeps the + registration stable for the process lifetime, which is required when the + same torch stream is bound by nested Warp scopes (e.g. an outer + preprocessing scope and an inner functional launch). + + Parameters + ---------- + torch_stream : torch.cuda.Stream + Torch CUDA stream to wrap. + + Returns + ------- + warp.Stream + Cached Warp stream sharing ``torch_stream``'s underlying CUDA handle. + """ + wp = importlib.import_module("warp") + handle = torch_stream.cuda_stream + cached = _WARP_STREAM_CACHE.get(handle) + if cached is None: + cached = wp.stream_from_torch(torch_stream) + _WARP_STREAM_CACHE[handle] = cached + return cached + @dataclass(frozen=True) class Implementation: @@ -687,11 +725,13 @@ def warp_launch_context(tensor: torch.Tensor): Warp device and stream. """ try: - wp = importlib.import_module("warp") + importlib.import_module("warp") except ImportError as exc: raise ImportError("warp is not available") from exc if tensor.device.type == "cuda": - stream = wp.stream_from_torch(torch.cuda.current_stream(tensor.device)) + # Reuse a cached wrapper so binding the current torch stream does + # not churn Warp's stream registration (see warp_stream_from_torch). + stream = warp_stream_from_torch(torch.cuda.current_stream(tensor.device)) device = None else: stream = None diff --git a/physicsnemo/datapipes/RNG.md b/physicsnemo/datapipes/RNG.md index f390d544de..da23c2dbea 100644 --- a/physicsnemo/datapipes/RNG.md +++ b/physicsnemo/datapipes/RNG.md @@ -32,6 +32,21 @@ master seed using `fork_generator(parent, n)`. Each child is seeded with and stable across runs. Children are created on the **same device** as the parent. +For RNG that must be reproducible regardless of *execution order* (e.g. +reader subsampling, which runs on a pool of worker threads), `_rng.py` +also provides coordinate-based seeding: + +- **`derive_seed(base_seed, *coords)`** — mixes a base seed with integer + coordinates (typically `epoch` and sample `index`) into a single + well-mixed 64-bit seed via `numpy.random.SeedSequence`. The result + depends only on the inputs, not on call order or thread. +- **`spawn_generator(base_seed, *coords, device=...)`** — returns a fresh + `torch.Generator` seeded with `derive_seed(base_seed, *coords)`. + +Because each call returns an independent generator seeded purely from its +coordinates, draws are reproducible irrespective of order and safe to +compute concurrently from multiple threads (no shared mutable state). + ### DataLoader When `seed` is set the DataLoader: @@ -70,9 +85,12 @@ sub-dataset. ### Epoch reseeding `DataLoader.set_epoch(epoch)` propagates to the sampler and dataset. -Each component with a generator reseeds it with +The sampler and stochastic transforms reseed their generators with `initial_seed() + epoch`, producing a different but deterministic -random sequence every epoch. +random sequence every epoch. Readers instead store the epoch and fold +it into each sample's derived seed (see [Readers](#readers)), so their +per-sample RNG also varies deterministically per epoch without relying +on a shared, sequentially-drawn generator. ## Generator tree @@ -152,18 +170,63 @@ standalone use. ## Readers -The `Reader` base class defines no-op `set_generator` / `set_epoch`. -Readers that use randomness override them: - -| Reader | Randomness | Generator support | +Reader subsampling runs on the dataset's worker-thread pool (the threaded +`prefetch` producer path; `Dataset` defaults to `num_workers=2`), so the +*order* in which samples are drawn is non-deterministic. A single shared, +sequentially-drawn generator would therefore not be reproducible with +`num_workers > 1`. To avoid this, readers derive RNG **per +`(base_seed, epoch, index)`** instead of from one shared stream: + +- **`set_generator(g)`** stores `g.initial_seed()` as the reader's base + seed (it does *not* keep the generator itself). +- **`set_epoch(e)`** stores the epoch. +- Each `reader[index]` then calls + `spawn_generator(base_seed, epoch, index)` to obtain a fresh generator + for that sample's draws (the `Reader` base class exposes this as + `_index_generator(index)`). + +The draw for a given sample depends only on `(base_seed, epoch, index)`, +so it is **identical regardless of read order or worker thread** — +reproducible for any `num_workers` — while still differing across indices +and across epochs. When no seed has been set, the per-sample generator is +`None` and draws fall back to the global default RNG. + +Transforms remain reproducible because they run on the main thread in +sampler order (via the consume stage), so their sequentially-drawn +generators are unaffected by the threaded producer. + +| Reader | Randomness | Per-`(seed, epoch, index)` RNG | |---|---|---| | `MeshReader` | `torch.randint` (contiguous block selection) | Yes | | `DomainMeshReader` | `torch.randint` | Yes | | `NumpyReader` | `torch.randint` (coordinated subsampling) | Yes | | `ZarrReader` | `torch.randint` | Yes | | `TensorStoreZarrReader` | `torch.randint` | Yes | -| `HDF5Reader` | None | No-op (inherited) | -| `VTKReader` | None | No-op (inherited) | +| `HDF5Reader` | None | n/a (inherited base) | +| `VTKReader` | None | n/a (inherited base) | + +## Iterable & descriptor paths: per-`(epoch, position)` seeding + +Map-style datasets have a stable sample `index`, so readers key their +per-sample RNG on `(base_seed, epoch, index)` (see [Readers](#readers)). +Generator-style (`IterableDatasetBase`) and future descriptor-keyed +sources have **no stable index**: samples are produced in sequence with no +addressable position in a corpus. They therefore key on the **monotonic +emission position** within the epoch instead: + +- **map-style:** `derive_seed(base_seed, epoch, index)` — reproducible for + any read order / `num_workers`, since the index is intrinsic to the + sample. +- **iterable / descriptor:** `derive_seed(base_seed, epoch, position)` — + where `position` is a 0-based counter of emissions in the current epoch. + Reproducible across runs and distinct across epochs and positions. + +Both schemes use the same `derive_seed`/`spawn_generator` primitives; only +the coordinate that stands in for "which sample" differs. The iterable +path runs entirely on the main thread in emission order, so the position +counter is unambiguous (there is no worker-thread reordering to defend +against). See `tutorial_5_iterable_online_simulation.py` for a worked +example seeding an online Darcy-flow simulation per `(epoch, position)`. ## Current limitations diff --git a/physicsnemo/datapipes/__init__.py b/physicsnemo/datapipes/__init__.py index d3902c5459..50e4e2aa4a 100644 --- a/physicsnemo/datapipes/__init__.py +++ b/physicsnemo/datapipes/__init__.py @@ -42,7 +42,7 @@ from physicsnemo.datapipes.dataset import Dataset from physicsnemo.datapipes.mesh_dataset import MeshDataset from physicsnemo.datapipes.multi_dataset import MultiDataset -from physicsnemo.datapipes.protocols import DatasetBase +from physicsnemo.datapipes.protocols import DatasetBase, IterableDatasetBase from physicsnemo.datapipes.readers import ( DomainMeshReader, HDF5Reader, @@ -105,6 +105,7 @@ # "TensorDict", # Re-export from tensordict "DatasetBase", + "IterableDatasetBase", "Dataset", "MeshDataset", "DataLoader", diff --git a/physicsnemo/datapipes/_rng.py b/physicsnemo/datapipes/_rng.py index d76a374df0..5e20c2715a 100644 --- a/physicsnemo/datapipes/_rng.py +++ b/physicsnemo/datapipes/_rng.py @@ -23,9 +23,68 @@ from __future__ import annotations +import numpy as np import torch +def derive_seed(base_seed: int, *coords: int) -> int: + """Deterministically mix a base seed with integer coordinates. + + Combines ``base_seed`` with arbitrary integer ``coords`` (typically + ``epoch`` and sample ``index``) into a single well-mixed 64-bit seed + using :class:`numpy.random.SeedSequence`. The result depends only on + the inputs, not on call order or thread, so per-sample RNG derived + from it is reproducible and safe to compute concurrently. + + Parameters + ---------- + base_seed : int + Base seed (e.g. a generator's ``initial_seed()``). + *coords : int + Additional non-negative integer coordinates to fold in, such as + ``(epoch, index)``. + + Returns + ------- + int + A deterministic 64-bit seed. + """ + seq = np.random.SeedSequence([int(base_seed), *(int(c) for c in coords)]) + return int(seq.generate_state(1, dtype=np.uint64)[0]) + + +def spawn_generator( + base_seed: int, + *coords: int, + device: torch.device | str = "cpu", +) -> torch.Generator: + """Create a fresh :class:`torch.Generator` seeded from mixed coordinates. + + Returns an independent generator whose seed is + :func:`derive_seed(base_seed, *coords) `. Because each + call returns a new generator seeded purely from its inputs, draws are + reproducible regardless of execution order and can be made + concurrently from multiple threads without sharing mutable state. + + Parameters + ---------- + base_seed : int + Base seed (e.g. a generator's ``initial_seed()``). + *coords : int + Additional integer coordinates to fold in, such as ``(epoch, index)``. + device : torch.device or str, default="cpu" + Device the generator is created on. + + Returns + ------- + torch.Generator + A new generator seeded deterministically from the inputs. + """ + generator = torch.Generator(device=device) + generator.manual_seed(derive_seed(base_seed, *coords)) + return generator + + def fork_generator( parent: torch.Generator, n: int, diff --git a/physicsnemo/datapipes/dataloader.py b/physicsnemo/datapipes/dataloader.py index 4e6b7bc61f..3c1a4ead39 100644 --- a/physicsnemo/datapipes/dataloader.py +++ b/physicsnemo/datapipes/dataloader.py @@ -25,6 +25,8 @@ from __future__ import annotations +import itertools +import warnings from typing import Any, Callable, Iterator, Optional, Sequence import torch @@ -33,7 +35,13 @@ from physicsnemo.datapipes._rng import fork_generator from physicsnemo.datapipes.collate import Collator, get_collator -from physicsnemo.datapipes.protocols import DatasetBase +from physicsnemo.datapipes.io_pump import BATCH_BOUNDARY, IOPump +from physicsnemo.datapipes.protocols import ( + DatasetBase, + IterableDatasetBase, + preprocessing_stream, + record_stream, +) from physicsnemo.datapipes.registry import register @@ -57,6 +65,38 @@ class DataLoader: - Compatible with PyTorch samplers (DistributedSampler, etc.) - Familiar torch DataLoader interface + Two data paths + -------------- + The path is selected by dataset type: + + - **Map-style** (:class:`~physicsnemo.datapipes.protocols.DatasetBase`): + a dispatcher thread (:class:`~physicsnemo.datapipes.io_pump.IOPump`) + lazily submits sample loads to a worker pool and forwards batch + boundaries, while the main thread consumes handles (host-to-device + transfer + transforms on a preprocessing stream). + - **Iterable** (:class:`~physicsnemo.datapipes.protocols.IterableDatasetBase`): + a generator dataset driven main-thread-only (no sampler, no pump, no + worker pool). ``len()`` is undefined and ``shuffle``/``sampler`` are + ignored; generation runs on a preprocessing stream with the same + event handoff so it overlaps training. + + Concurrency model + ----------------- + A dedicated dispatcher thread keeps the I/O pipeline primed by + submitting sample loads ahead of consumption, bounded by + ``prefetch_factor`` batches worth of in-flight samples. The main + thread is the sole consumer: it performs all host-to-device transfers + and GPU transforms (including Warp kernels) on the prefetch streams. + Warp's invariant is the single launching thread, not a single stream, + so transforms run on the assigned preprocessing stream and overlap the + compute stream. + + For the pipeline to stay primed, the main thread must not block: keep + reader output (optionally) pinned (so host-to-device copies are asynchronous) and + avoid host readbacks (``.item()``, ``wp.synchronize()``), data- + dependent shapes, and GIL-bound pure-Python transforms on the launch + path. + Examples -------- >>> from physicsnemo.datapipes import DataLoader, Dataset, HDF5Reader, Normalize @@ -80,7 +120,7 @@ class DataLoader: def __init__( self, - dataset: DatasetBase, + dataset: DatasetBase | IterableDatasetBase, *, batch_size: int = 1, shuffle: bool = False, @@ -104,9 +144,10 @@ def __init__( Parameters ---------- - dataset : DatasetBase - Dataset to load from. Any subclass of :class:`DatasetBase` - (e.g. :class:`Dataset`, :class:`MeshDataset`). + dataset : DatasetBase or IterableDatasetBase + Dataset to load from. A map-style :class:`DatasetBase` + (e.g. :class:`Dataset`, :class:`MeshDataset`) or an + :class:`IterableDatasetBase` generator dataset. batch_size : int, default=1 Number of samples per batch. shuffle : bool, default=False @@ -151,6 +192,9 @@ def __init__( self.num_streams = num_streams self.use_streams = use_streams and torch.cuda.is_available() self._seed = seed + # Iterable (generator) datasets are driven main-thread-only: no + # sampler, no worker-pool prefetch (see _iter_iterable). + self._iterable = isinstance(dataset, IterableDatasetBase) # Build master generator and fork for sampler + dataset sampler_generator: torch.Generator | None = None @@ -163,14 +207,18 @@ def __init__( if hasattr(dataset, "set_generator"): dataset.set_generator(forks[1]) - # Handle sampler - if sampler is not None: + # Handle sampler. Iterable datasets have no indices, so they carry + # no sampler and ignore shuffle. + if self._iterable: + if sampler is not None or shuffle: + warnings.warn( + "shuffle/sampler are ignored for iterable datasets; " + "the generator controls sample order.", + stacklevel=2, + ) + self.sampler = None + elif sampler is not None: self.sampler = sampler - # For DistributedSampler, propagate seed if available - if seed is not None and hasattr(sampler, "seed"): - # DistributedSampler exposes seed as a constructor arg - # but it's read-only; users should pass seed at construction. - pass elif shuffle: self.sampler = RandomSampler(dataset, generator=sampler_generator) else: @@ -179,7 +227,8 @@ def __init__( # Handle collation self.collate_fn = get_collator(collate_fn, collate_metadata=collate_metadata) - # Create CUDA streams for prefetching + # Create CUDA streams: prefetch uses several round-robin streams; the + # iterable path uses the first as its preprocessing stream. self._streams: list[torch.cuda.Stream] = [] if self.use_streams: for _ in range(num_streams): @@ -193,7 +242,15 @@ def __len__(self) -> int: ------- int Number of batches in the dataloader. + + Raises + ------ + TypeError + If the dataset is iterable (generator-style), which has no + defined length. """ + if self._iterable: + raise TypeError("len() is undefined for an iterable (generator) dataset") n_samples = ( len(self.sampler) if hasattr(self.sampler, "__len__") else len(self.dataset) ) @@ -226,8 +283,15 @@ def __iter__( """ Iterate over batches. - Uses stream-based prefetching when enabled to overlap IO, - GPU transfers, and computation. + Uses the self-priming :class:`IOPump` to overlap host-side I/O + (on the dataset's worker threads) with main-thread consumption + whenever ``prefetch_factor > 0``. This threaded producer path is + independent of CUDA streams: when streams are enabled (and CUDA + is available) each prefetched sample is also assigned a stream so + the host-to-device copy and GPU transforms overlap; otherwise the + same path runs with ``stream=None`` (still overlapping disk I/O + with the main thread). Set ``prefetch_factor=0`` for fully + synchronous iteration. Yields ------ @@ -236,7 +300,9 @@ def __iter__( or tuple of (batched TensorDict, list of metadata dicts) if collate_metadata=True. """ - if self.prefetch_factor > 0 and self.use_streams: + if self._iterable: + yield from self._iter_iterable() + elif self.prefetch_factor > 0: yield from self._iter_prefetch() else: yield from self._iter_simple() @@ -256,59 +322,156 @@ def _iter_simple( samples = [self.dataset[idx] for idx in batch_indices] yield self.collate_fn(samples) + def _work_stream(self) -> Iterator[Any]: + """Lazily yield sampler indices delimited by :data:`BATCH_BOUNDARY`. + + Buffers at most one batch of indices (never the whole epoch), so + arbitrarily long samplers stream without up-front materialization. + A boundary is emitted after each full batch; a trailing partial + batch is emitted only when ``drop_last`` is False. + + Yields + ------ + int or object + Sample indices, with :data:`BATCH_BOUNDARY` after each batch. + """ + batch: list[int] = [] + for index in self.sampler: + batch.append(index) + if len(batch) == self.batch_size: + yield from batch + yield BATCH_BOUNDARY + batch = [] + if batch and not self.drop_last: + yield from batch + yield BATCH_BOUNDARY + def _iter_prefetch( self, ) -> Iterator[TensorDict | tuple[TensorDict, list[dict[str, Any]]]]: """ - Iteration with stream-based prefetching. - - Strategy: - - 1. Prefetch `prefetch_factor` batches worth of samples - 2. As we yield batches, prefetch more to keep the pipeline full - 3. Each sample in a batch uses a different stream for overlap + Iteration driven by a self-priming prefetch pump. + + A dedicated dispatcher thread (the :class:`IOPump`) lazily pulls + the index stream and submits sample loads to the dataset's worker + pool, keeping a bounded number of samples in flight regardless of + the consumer's cadence. The main thread is a pure drain loop: it + pulls ready handles in order, runs the per-sample consume step + (host-to-device transfer plus GPU transforms, including Warp, on + the assigned stream), and reassembles batches from the boundary + markers the pump forwards. + + Stream assignment is optional and decoupled from the threaded + producer: when CUDA streams are enabled a stream is round-robined + per sample (so preprocessing overlaps the previous batch's compute); + otherwise dispatch passes ``stream=None`` and the path still + overlaps host-side I/O with main-thread consumption. + + Because dispatch lives off the main thread, the pipeline stays + primed even while the main thread is blocked launching kernels or + running the model. All device-kernel launches happen here, on the + single main thread. Yields ------ TensorDict or tuple[TensorDict, list[dict[str, Any]]] Collated batch. """ - # Collect all batches upfront for prefetch planning - all_batches = list(self._generate_batches()) - if not all_batches: - return - - num_prefetch_batches = min(self.prefetch_factor, len(all_batches)) - stream_idx = 0 - - # Start initial prefetch - prefetched_up_to = 0 - for batch_idx in range(num_prefetch_batches): - for sample_idx in all_batches[batch_idx]: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 - prefetched_up_to = batch_idx + 1 - - # Yield batches and prefetch more - for batch_idx, batch_indices in enumerate(all_batches): - # Collect samples (uses prefetched if available) - samples = [self.dataset[idx] for idx in batch_indices] - batch = self.collate_fn(samples) - - # Prefetch next batch if available - next_prefetch_idx = prefetched_up_to - if next_prefetch_idx < len(all_batches): - for sample_idx in all_batches[next_prefetch_idx]: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 - prefetched_up_to += 1 - - yield batch + # Streams are an optional accelerator on top of the threaded + # producer; only assign them when actually available. + use_streams = self.use_streams and len(self._streams) > 0 + + # Round-robin a stream per sample at dispatch time (when enabled); + # submit returns a handle the consumer resolves in order. + stream_counter = itertools.count() + + def dispatch(index: int) -> Any: + stream = ( + self._streams[next(stream_counter) % self.num_streams] + if use_streams + else None + ) + return self.dataset.submit(index, stream=stream) + + # Depth = prefetch_factor batches worth of samples kept in flight + # (at least the stream count when streams drive the overlap). + depth = max(self.prefetch_factor * self.batch_size, 1) + if use_streams: + depth = max(depth, self.num_streams) + + pump = IOPump(self._work_stream(), dispatch, depth=depth) + try: + samples: list[Any] = [] + for item in pump: + if item is BATCH_BOUNDARY: + yield self.collate_fn(samples) + samples = [] + else: + samples.append(self.dataset.consume(item)) + finally: + # Stop the dispatcher (handles early break / exhaustion) and + # drop any prefetched-but-unconsumed handles. + pump.stop() + self.dataset.cancel_prefetch() + + def _iter_iterable( + self, + ) -> Iterator[Any]: + """ + Main-thread-only iteration for generator (iterable) datasets. + + There is no worker pool: the dataset's generator runs on the main + thread, so it may freely launch Warp kernels / use streams. Each + item is generated on a preprocessing stream (when streams are + enabled) and handed to the compute stream via a CUDA event, so + generation of the next item can overlap training on the current + one. A generator that forces a host readback simply serializes + itself. + + Two emission modes are supported (see :class:`IterableDatasetBase`): + per-sample items are collated into ``batch_size`` batches (with + ``drop_last`` trimming the trailing partial batch); a self-batching + generator (``yields_batches = True``) has each batch passed through + unchanged. - # Clean up any remaining prefetch state - self.dataset.cancel_prefetch() + Yields + ------ + Any + Collated batches, or generator-produced batches when the + dataset is self-batching. + """ + use_stream = self.use_streams and len(self._streams) > 0 + prep_stream = self._streams[0] if use_stream else None + compute_stream = torch.cuda.current_stream() if use_stream else None + self_batching = getattr(self.dataset, "yields_batches", False) + + iterator = iter(self.dataset) + samples: list[Any] = [] + while True: + # Generate the next item on the preprocessing stream, then order + # it before the compute stream without blocking the host. + with preprocessing_stream(prep_stream): + try: + item = next(iterator) + except StopIteration: + break + if use_stream: + record_stream(item, compute_stream) + event = torch.cuda.Event() + event.record(prep_stream) + compute_stream.wait_event(event) + + if self_batching: + yield item + continue + + samples.append(item) + if len(samples) == self.batch_size: + yield self.collate_fn(samples) + samples = [] + + if not self_batching and samples and not self.drop_last: + yield self.collate_fn(samples) def set_epoch(self, epoch: int) -> None: """ diff --git a/physicsnemo/datapipes/datapipes.md b/physicsnemo/datapipes/datapipes.md index b41ca1d846..0aa1dc9929 100644 --- a/physicsnemo/datapipes/datapipes.md +++ b/physicsnemo/datapipes/datapipes.md @@ -141,106 +141,144 @@ preprocessing. Threads are a natural fit: duplication overhead. - **I/O concurrency** -- the GIL is released during disk reads and CUDA kernel launches, so multiple threads usefully overlap I/O with GPU work. -- **Stream parallelism** -- each prefetched sample is assigned its own - CUDA stream, allowing host-to-device transfers and GPU transforms to - run concurrently with the main training computation. +- **Stream parallelism** -- when enabled, each prefetched sample is + assigned a CUDA stream so its host-to-device transfer can overlap with + the main training computation. -### Thread-pool prefetch +### Producer / consumer split -`DatasetBase` owns a `ThreadPoolExecutor` (configurable via -`num_workers`, default 2). Calling `prefetch(index)` submits the -load-and-transform pipeline to the pool and stashes the `Future`: +Prefetching is split into two stages so that **no device kernels are +launched off the main thread** -- a hard requirement for Warp-based +transforms, which must share the model's single launching thread: -```python -def prefetch(self, index, stream=None): - if index in self._prefetch_futures: - return - executor = self._ensure_executor() - self._prefetch_futures[index] = executor.submit(self._load, index) -``` +- `_load_host` is the **producer**. It runs on a worker thread and does + only thread-safe work: reading, decoding, and staging into pinned host + memory. It returns a `HostPayload`. +- `_consume` is the **consumer**. It runs on whatever thread calls + `__getitem__` (the main thread, in practice) and performs the + host-to-device transfer and device transforms (including Warp kernels). -`__getitem__` pops the `Future` if one exists, otherwise loads -synchronously: +`DatasetBase` owns a `ThreadPoolExecutor` (configurable via +`num_workers`) and exposes a FIFO prefetch primitive. `submit(work_item, +stream=...)` runs only the producer on the pool and returns a +`PrefetchHandle` bundling the future with the stream the consumer should +use; `consume(handle)` resolves it on the calling thread: ```python -def __getitem__(self, index): - future = self._prefetch_futures.pop(index, None) - if future is not None: - return future.result() - return self._load(index) -``` - -This means the DataLoader can keep the next batch loading in background -threads while the current batch is being consumed by the model. +def submit(self, work_item, stream=None): + future = self._executor.submit(self._load_host, work_item) + return PrefetchHandle(future=future, stream=stream) -### CUDA stream overlap - -When GPU execution is available, `Dataset` (and `MeshDataset`) override -`prefetch` to run device transfer and transforms on a caller-supplied -CUDA stream, then record an event for later synchronization: - -```python -def _load_and_transform(self, index, stream=None): - result = _PrefetchResult(index=index) - data, metadata = self.reader[index] # CPU I/O in worker thread - - if stream is not None: - with torch.cuda.stream(stream): - data = data.to(device, non_blocking=True) # H2D on stream - data = self.transforms(data) # GPU transforms on stream - result.event = torch.cuda.Event() - result.event.record(stream) # mark completion - - result.data, result.metadata = data, metadata - return result +def consume(self, handle): + payload = handle.future.result() # re-raises producer errors + return self._consume(payload, handle.stream) # H2D + transforms here ``` -On retrieval, `__getitem__` synchronizes the event before returning: +Correlation is purely by handle identity (FIFO), so work items need not +be hashable, unique, or even integers -- an `int` index is just the +common case. The index-keyed `prefetch(index)` / `__getitem__(index)` +convenience API is a thin layer over `submit`/`consume` for random +access, and is what map-style tests and `MultiDataset` use. + +### Self-priming dispatch (IOPump) + +The threaded producer is driven by `IOPump`, a dedicated dispatcher +thread that keeps a *bounded* number of samples in flight regardless of +the consumer's cadence. It pulls a work-item stream **lazily** (one item +per free backpressure slot, so an arbitrarily long or unbounded source +never materializes up front), calls `submit` for each, and hands the +returned handles back to the main thread in FIFO order. The source +interleaves `BATCH_BOUNDARY` markers between work items; the pump forwards +them in place without consuming a slot, so the consumer reassembles +dynamically-sized batches from the boundaries -- the DataLoader never +builds the epoch's batch list in advance. Because dispatch lives off the +main thread, the pipeline stays primed even while the main thread is busy +launching kernels or running the model. This path is active whenever +`prefetch_factor > 0`; set `prefetch_factor=0` for fully synchronous +iteration. + +### CUDA stream handoff + +CUDA streams are an *optional* accelerator layered on top of the threaded +producer. When `use_streams=True` (and CUDA is available), each sample is +round-robined a **preprocessing stream**. The consumer runs *both* the +host-to-device copy and the transforms on that stream, then hands the +result to the compute stream via a CUDA **event** (never a host +`synchronize()`): ```python -if result.event is not None: - result.event.synchronize() -return result.data, result.metadata +def _consume(self, payload, stream=None): + data = payload.data + if device is not None and stream is not None: + compute_stream = torch.cuda.current_stream() + # Bind torch AND Warp to the preprocessing stream. + with preprocessing_stream(stream): # torch + wp.ScopedStream + data = data.to(device, non_blocking=True) # H2D on prep stream + data = self.transforms(data) # transforms on SAME stream + data.record_stream(compute_stream) # keep memory alive + event = torch.cuda.Event() + event.record(stream) + compute_stream.wait_event(event) # order, no host block + else: + data = self.transforms(data) + return data, payload.metadata ``` -The `DataLoader` owns a pool of `num_streams` CUDA streams (default 4) -and round-robins them across samples. It also maintains a sliding -prefetch window of `prefetch_factor` batches (default 2) ahead of the -current yield position: - -```python -# Prefetch the next batch as we yield the current one -for sample_idx in all_batches[next_prefetch_idx]: - stream = self._streams[stream_idx % self.num_streams] - self.dataset.prefetch(sample_idx, stream=stream) - stream_idx += 1 -``` +**The single launching thread -- not a single stream -- is Warp's real +invariant.** Warp kernels may run on any CUDA stream provided they are +launched from the main thread *and* Warp's current stream matches torch's. +`preprocessing_stream` (in `protocols.py`) binds both via +`wp.ScopedStream(wp.stream_from_torch(stream))`, so transforms (including +Warp mesh-query / BVH kernels) run correctly on the side stream. A +previous `cudaErrorIllegalAddress` here was a torch/Warp stream +*divergence* (data on a side stream, the Warp kernel on Warp's own +stream), not a prohibition on non-default streams; binding both fixes it +and lets GPU preprocessing genuinely overlap training. `record_stream` +keeps the device tensors from being recycled while the compute stream +reads them; the pinned host source is held by the caching host allocator +until the copy completes. ### Concurrency timeline -The diagram below shows how threads and streams overlap for a two-sample -batch with `prefetch_factor=1`: +With everything launched from the main thread, the worker pool, the +preprocessing stream, and the compute stream form a triple buffer: ```text -Main thread Worker 1 Worker 2 Stream 1 Stream 2 - │ │ │ │ │ - ├─prefetch(0,S1)─►│ │ │ │ - ├─prefetch(1,S2)─────────────────────►│ │ │ - │ ├─ Read (I/O) │ │ │ - │ │ ├─ Read (I/O) │ │ - │ ├─ to(device) ─────────────────────────►│ │ - │ ├─ transforms ─────────────────────────►│ │ - │ ├─ event.record() ─────────────────────►│ │ - │ │ ├─ to(device) ─────────────────►│ - │ │ ├─ transforms ─────────────────►│ - │ │ ├─ event.record() ─────────────►│ - ├─ event.synchronize() ×2 │ │ │ - ├─ collate + yield batch │ │ │ - │ │ │ │ │ +Worker pool │ load N+2 ─ load N+1 ... (host I/O + thread-safe CPU work) +Preprocess stream │ H2D + Warp transforms for N+1 +Compute stream │ train N ``` -While the main thread consumes batch N, worker threads are already -loading batch N+1 on different streams. +GPU preprocessing of batch N+1 genuinely overlaps training of batch N on +a separate stream; the two are ordered by a CUDA event, never a host-side +`synchronize`. A transform (or generator) that forces a host readback +simply serializes itself -- a property of that code, not of the pipeline. + +### Two data paths: map/descriptor vs iterable + +The DataLoader selects one of two mutually-exclusive paths by dataset +type: + +- **Preload path (`DatasetBase`)** -- map-style and descriptor-keyed + datasets. Uses the worker pool + `IOPump` described above: workers do + thread-safe host I/O, the main thread consumes handles (H2D + transforms + on the preprocessing stream). This is the path for storage-backed data + addressable by index. +- **Generator path (`IterableDatasetBase`)** -- iterable datasets that + *produce* data (online simulation, procedural samplers, unbounded + streams). Driven **main-thread-only**: no sampler, no pump, no worker + pool. `__iter__` may freely launch Warp kernels and use CUDA streams + (the single-launching-thread invariant holds), and the loader still + drives generation on a preprocessing stream with the same event handoff, + so generation of batch N+1 overlaps training of batch N. + +An iterable dataset yields either per-sample `(data, metadata)` (the +loader collates `batch_size` of them, `drop_last` trims the tail) or, when +`yields_batches = True`, ready-made batches that the loader passes through +unchanged. Iterable datasets have no length: `len(loader)` raises +`TypeError`, and `shuffle`/`sampler` are ignored. See +`examples/minimal/datapipes/tutorial_5_iterable_online_simulation.py` for +a Warp `Darcy2D` online simulation wired through this path. ### Pinned memory @@ -258,8 +296,9 @@ loader.disable_prefetch() # synchronous, single-stream -- easy to debug loader.enable_prefetch() # re-enable after debugging ``` -Setting `use_streams=False` or `prefetch_factor=0` at construction time -also forces synchronous execution. +`use_streams=False` keeps the threaded producer but drops the CUDA +stream handoff (the consumer copies and transforms on the default +stream); `prefetch_factor=0` forces fully synchronous execution. ## RNG and reproducibility diff --git a/physicsnemo/datapipes/dataset.py b/physicsnemo/datapipes/dataset.py index 727503e800..7a65be8731 100644 --- a/physicsnemo/datapipes/dataset.py +++ b/physicsnemo/datapipes/dataset.py @@ -31,7 +31,12 @@ from tensordict import TensorDict from physicsnemo.datapipes._rng import fork_generator -from physicsnemo.datapipes.protocols import DatasetBase, _PrefetchResult +from physicsnemo.datapipes.protocols import ( + DatasetBase, + HostPayload, + preprocessing_stream, + record_stream, +) from physicsnemo.datapipes.readers.base import Reader from physicsnemo.datapipes.registry import register from physicsnemo.datapipes.transforms.base import Transform @@ -55,10 +60,20 @@ class Dataset(DatasetBase): Prefetching Model ----------------- - The dataset supports prefetching samples using a thread pool. - When a CUDA stream is provided, GPU operations (device transfer, - GPU transforms) happen on that stream, allowing overlap with - other computation. + The dataset supports prefetching samples using a thread pool. The + work is split into a thread-safe *producer* stage and a main-thread + *consumer* stage: + + - :meth:`_load_host` (producer, worker thread) reads the sample into + host memory (pinned when the reader is configured with + ``pin_memory=True``). It launches no device kernels. + - :meth:`_consume` (consumer, calling thread) performs the + host-to-device transfer and the GPU transforms on the assigned + CUDA stream. + + This keeps all device-kernel launches (notably Warp transforms) on + the consuming thread, which must be the same single thread the model + launches from. >>> # Start prefetching >>> dataset.prefetch(0, stream=stream0) # doctest: +SKIP @@ -242,94 +257,60 @@ def set_epoch(self, epoch: int) -> None: t.set_epoch(epoch) # ------------------------------------------------------------------ - # Stream-aware prefetch (overrides DatasetBase defaults) + # Producer / consumer split (overrides DatasetBase defaults) # ------------------------------------------------------------------ - def _load_and_transform( - self, - index: int, - stream: Optional[torch.cuda.Stream] = None, - ) -> _PrefetchResult: + def _load_host(self, work_item: int) -> HostPayload: """ - Load a sample and apply transforms with optional CUDA stream. + Producer stage: read a sample into host memory. + + Runs on a worker thread and launches no device kernels. Pinning + is the reader's responsibility: construct the reader with + ``pin_memory=True`` to stage tensors in pinned memory so the + consumer's host-to-device copy can be asynchronous. Parameters ---------- - index : int - Sample index. - stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. + work_item : int + Sample index to read from the reader. Returns ------- - _PrefetchResult - PrefetchResult with data, metadata, or error. + HostPayload + Payload carrying the host data and metadata, or a captured + error. """ - result = _PrefetchResult(index=index) - try: - data, metadata = self.reader[index] - - if self.target_device is not None: - if stream is not None: - with torch.cuda.stream(stream): - data = data.to(self.target_device, non_blocking=True) - else: - data = data.to(self.target_device, non_blocking=True) - - if self.transforms is not None: - if stream is not None: - with torch.cuda.stream(stream): - data = self.transforms(data) - result.event = torch.cuda.Event() - result.event.record(stream) - else: - data = self.transforms(data) - - result.data = data - result.metadata = metadata - - except Exception as e: - result.error = e + data, metadata = self.reader[work_item] + return HostPayload(work_item=work_item, data=data, metadata=metadata) + except Exception as e: # noqa: BLE001 + return HostPayload(work_item=work_item, error=e) - return result - - def prefetch( + def _consume( self, - index: int, + payload: HostPayload, stream: Optional[torch.cuda.Stream] = None, - ) -> None: + ) -> tuple[TensorDict, dict[str, Any]]: """ - Start prefetching a sample asynchronously. - - When a CUDA stream is provided, GPU operations (device transfer - and transforms) run on that stream for overlap with computation. + Consumer stage: device transfer + transforms on the calling thread. + + Runs on whatever thread calls this (the main thread, so any Warp + kernels in the transforms share the model's launching thread). When + a CUDA ``stream`` is assigned, the host-to-device copy *and* the + transforms run on that preprocessing stream -- Warp bound to it via + :func:`preprocessing_stream` -- so this sample's preprocessing + overlaps the previous batch's training on the compute stream. The + result is handed back to the compute stream with a CUDA event (not + a host-side synchronize), and tagged via ``record_stream`` so the + caching allocator does not recycle it while training reads it. Parameters ---------- - index : int - Sample index to prefetch. + payload : HostPayload + Producer payload from :meth:`_load_host`. stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. - """ - if index in self._prefetch_futures: - return - - executor = self._ensure_executor() - future = executor.submit(self._load_and_transform, index, stream) - self._prefetch_futures[index] = future - - def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: - """ - Get a transformed sample by index. - - If the index was prefetched, returns the prefetched result - (waiting for completion if necessary). Otherwise loads synchronously. - - Parameters - ---------- - index : int - Sample index. + Preprocessing stream for the host-to-device transfer and + transforms. ``None`` runs on the current stream. Returns ------- @@ -338,26 +319,38 @@ def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: Raises ------ - IndexError - If index is out of range. Exception - If prefetch failed, re-raises the error. + If the producer captured an error, re-raises it. """ - future = self._prefetch_futures.pop(index, None) + if payload.error is not None: + raise payload.error - if future is not None: - result = future.result() + data = payload.data + metadata = payload.metadata - if isinstance(result, _PrefetchResult): - if result.error is not None: - raise result.error - if result.event is not None: - torch.cuda.current_stream().wait_event(result.event) - return result.data, result.metadata + device_is_cuda = ( + self.target_device is not None + and torch.device(self.target_device).type == "cuda" + ) + use_stream = stream is not None and device_is_cuda + compute_stream = torch.cuda.current_stream() if use_stream else None - return result + with preprocessing_stream(stream if use_stream else None): + if self.target_device is not None: + data = data.to(self.target_device, non_blocking=True) + if self.transforms is not None: + data = self.transforms(data) + + if use_stream: + # Order the preprocessing stream's result before the compute + # stream consumes it, without blocking the host: tag the memory + # so the allocator keeps it alive, then gate on a CUDA event. + record_stream(data, compute_stream) + event = torch.cuda.Event() + event.record(stream) + compute_stream.wait_event(event) - return self._load(index) + return data, metadata @property def field_names(self) -> list[str]: diff --git a/physicsnemo/datapipes/io_pump.py b/physicsnemo/datapipes/io_pump.py new file mode 100644 index 0000000000..516ea3c8c8 --- /dev/null +++ b/physicsnemo/datapipes/io_pump.py @@ -0,0 +1,220 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +IOPump - A self-driving I/O producer that keeps the pipeline primed. + +The pump owns a dedicated dispatcher thread that pulls work items from a +source iterator *lazily* and submits each for background loading, keeping +a *bounded* number of samples in flight at all times. Pulling lazily means +the source may be arbitrarily large (or effectively unbounded): the +dispatcher only advances it when a backpressure slot is free, so memory +stays bounded regardless of source length. + +Dispatch is decoupled from the consumer's cadence: the dispatcher keeps +topping the pipeline off while the main thread is busy launching kernels or +running the model, so a ready sample is (almost) always waiting when the +consumer asks for the next one. + +The pump is agnostic to *how* a sample is produced and consumed: it drives +opaque work items through a user-provided ``dispatch_fn`` (which starts the +background load and returns a handle) and hands those handles back to the +consumer in the exact order they were pulled. Correlation is purely +positional (FIFO), so work items need not be hashable or unique. + +A source may interleave :data:`BATCH_BOUNDARY` markers between work items; +the pump forwards them to the consumer in order without consuming a +backpressure slot, letting the consumer reassemble dynamically-sized +batches without knowing the batch layout up front. +""" + +from __future__ import annotations + +import queue +import threading +from typing import Any, Callable, Iterable, Iterator + +# Public marker a source yields to delimit the end of one batch. A distinct +# sentinel object so it can never collide with a real work item. +BATCH_BOUNDARY = object() + +# Internal marker the dispatcher pushes once the source is exhausted (or the +# pump is stopped), telling the consumer to finish iterating. +_DONE = object() + + +class _PumpError: + """Wraps an exception raised on the dispatcher thread. + + Forwarded through the ready queue so the consumer re-raises it on the + main thread instead of blocking forever waiting for items that will + never arrive. + """ + + __slots__ = ("exc",) + + def __init__(self, exc: BaseException) -> None: + self.exc = exc + + +class IOPump: + """Bounded, self-driving prefetch dispatcher. + + A dedicated dispatcher thread pulls work items from ``source``, + acquires a backpressure slot, calls ``dispatch_fn(work_item)`` to start + the background load, and makes the returned handle available to the + consumer in FIFO order via iteration. Slots are released as the + consumer advances, keeping at most ``depth`` samples in flight. + + Parameters + ---------- + source : Iterable + Work items to load, optionally interleaved with + :data:`BATCH_BOUNDARY` markers. Consumed lazily, one item at a + time, only as backpressure slots free up. + dispatch_fn : Callable[[Any], Any] + Called on the dispatcher thread to start loading a work item (for + example ``dataset.submit(work_item, stream=...)``). It must be + non-blocking and thread-safe and must not launch device kernels; + it returns an opaque handle that the consumer later turns into a + sample. + depth : int + Maximum number of samples dispatched but not yet consumed. Acts as + both the backpressure valve and the jitter buffer that hides + consumer stalls. Clamped to at least 1. + + Notes + ----- + A pump instance is single-consumer. Iterate it with a single thread + (the main/launcher thread). Call :meth:`stop` (or use it as a context + manager) to tear down the dispatcher thread; already-submitted loads + are left to complete and are reaped by the owning dataset. + """ + + def __init__( + self, + source: Iterable[Any], + dispatch_fn: Callable[[Any], Any], + depth: int, + ) -> None: + self._source = source + self._dispatch_fn = dispatch_fn + self._depth = max(1, int(depth)) + self._slots = threading.Semaphore(self._depth) + self._ready_queue: queue.Queue = queue.Queue() + self._stop = threading.Event() + self._thread = threading.Thread( + target=self._run, + name="datapipe_pump", + daemon=True, + ) + self._thread.start() + + # ------------------------------------------------------------------ + # Dispatcher thread + # ------------------------------------------------------------------ + + def _run(self) -> None: + """Dispatcher loop: keep ``depth`` samples in flight, in order.""" + source = iter(self._source) + while not self._stop.is_set(): + try: + item = next(source) + except StopIteration: + break + except BaseException as exc: # noqa: BLE001 + # A failing source must surface to the consumer, not hang it. + self._ready_queue.put(_PumpError(exc)) + return + + if item is BATCH_BOUNDARY: + # Boundaries are bookkeeping, not work: forward without + # consuming a slot. + self._ready_queue.put(BATCH_BOUNDARY) + continue + + # Backpressure: block until the consumer frees a slot. This is + # also where lazy pulling is enforced -- the source is not + # advanced again until there is room in flight. + self._slots.acquire() + if self._stop.is_set(): + break + try: + handle = self._dispatch_fn(item) + except BaseException as exc: # noqa: BLE001 + # A failing dispatch must surface to the consumer, not hang it. + self._ready_queue.put(_PumpError(exc)) + return + self._ready_queue.put(handle) + + self._ready_queue.put(_DONE) + + # ------------------------------------------------------------------ + # Consumer side (single consumer / the main thread) + # ------------------------------------------------------------------ + + def __iter__(self) -> Iterator[Any]: + """Yield ready handles (and batch boundaries) in FIFO order. + + Yields each loaded sample's handle in the order its work item was + pulled, and forwards :data:`BATCH_BOUNDARY` markers in place. + Releases a backpressure slot after each handle is consumed (i.e. + on the next iteration), so the dispatcher can refill the pipeline + as the consumer advances. Returns once the source is exhausted. + + Yields + ------ + object + Either a handle returned by ``dispatch_fn`` or + :data:`BATCH_BOUNDARY`. + """ + while True: + item = self._ready_queue.get() + if item is _DONE: + return + if isinstance(item, _PumpError): + raise item.exc + if item is BATCH_BOUNDARY: + yield BATCH_BOUNDARY + continue + yield item + # Consumer has finished with this sample; free a slot. + self._slots.release() + + # ------------------------------------------------------------------ + # Teardown + # ------------------------------------------------------------------ + + def stop(self) -> None: + """Stop the dispatcher thread and release its resources. + + Idempotent. Unblocks the dispatcher if it is waiting on a slot, + then joins it briefly. In-flight background loads already submitted + via ``dispatch_fn`` are not cancelled; the owning dataset reaps + them. + """ + if self._stop.is_set(): + return + self._stop.set() + # Unblock the dispatcher if it is parked acquiring a slot. + self._slots.release() + self._thread.join(timeout=5.0) + + def __enter__(self) -> "IOPump": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.stop() diff --git a/physicsnemo/datapipes/mesh_dataset.py b/physicsnemo/datapipes/mesh_dataset.py index 6ca740bb4b..112c231f53 100644 --- a/physicsnemo/datapipes/mesh_dataset.py +++ b/physicsnemo/datapipes/mesh_dataset.py @@ -29,7 +29,12 @@ from tensordict import TensorDict from physicsnemo.datapipes._rng import fork_generator -from physicsnemo.datapipes.protocols import DatasetBase, _PrefetchResult +from physicsnemo.datapipes.protocols import ( + DatasetBase, + HostPayload, + preprocessing_stream, + record_stream, +) from physicsnemo.datapipes.readers.mesh import DomainMeshReader, MeshReader from physicsnemo.datapipes.registry import register from physicsnemo.datapipes.transforms.mesh.base import MeshTransform @@ -93,7 +98,10 @@ def __init__( and tensordict's ``_device_recorder`` is not safe for concurrent TensorDict construction across threads. """ - super().__init__(num_workers=num_workers) + # Real mesh readers/tensorclasses can expose a CUDA illegal-access race + # when active host-side refill overlaps Warp SDF transforms. Keep the + # producer/consumer split, but serialize those two stages for meshes. + super().__init__(num_workers=num_workers, serialize_load_consume=True) self.reader = reader self.transforms = list(transforms) if transforms else [] self._device = torch.device(device) if isinstance(device, str) else device @@ -182,101 +190,56 @@ def __len__(self) -> int: return len(self.reader) # ------------------------------------------------------------------ - # Stream-aware prefetch (overrides DatasetBase defaults) + # Producer / consumer split (overrides DatasetBase defaults) # ------------------------------------------------------------------ - def _load_and_transform( - self, - index: int, - stream: Optional[torch.cuda.Stream] = None, - ) -> _PrefetchResult: - """Load a sample and apply transforms with optional CUDA stream. + def _load_host(self, work_item: int) -> HostPayload: + """Producer stage: read a mesh sample on a worker thread. + + Launches no device kernels: it only reads the raw sample. Device + transfer and mesh transforms happen later in :meth:`_consume` on + the consuming thread. Parameters ---------- - index : int - Sample index. - stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. + work_item : int + Sample index to read from the reader. Returns ------- - _PrefetchResult - Result with data, metadata, or error. + HostPayload + Payload carrying the host data and metadata, or a captured + error. """ - result = _PrefetchResult(index=index) - try: - data, metadata = self.reader[index] - - if self._device is not None: - if stream is not None: - with torch.cuda.stream(stream): - data = data.to(self._device, non_blocking=True) - else: - data = data.to(self._device, non_blocking=True) + data, metadata = self.reader[work_item] + return HostPayload(work_item=work_item, data=data, metadata=metadata) + except Exception as e: # noqa: BLE001 + return HostPayload(work_item=work_item, error=e) - for t in self.transforms: - if stream is not None: - with torch.cuda.stream(stream): - if isinstance(data, DomainMesh): - data = t.apply_to_domain(data) - else: - data = t(data) - else: - if isinstance(data, DomainMesh): - data = t.apply_to_domain(data) - else: - data = t(data) - - if stream is not None: - result.event = torch.cuda.Event() - result.event.record(stream) - - result.data = data - result.metadata = metadata - - except Exception as e: - result.error = e - - return result - - def prefetch( + def _consume( self, - index: int, + payload: HostPayload, stream: Optional[torch.cuda.Stream] = None, - ) -> None: - """Start prefetching a sample asynchronously. - - When a CUDA stream is provided, GPU operations (device transfer - and transforms) run on that stream for overlap with computation. - - Parameters - ---------- - index : int - Sample index to prefetch. - stream : torch.cuda.Stream, optional - Optional CUDA stream for GPU operations. - """ - if index in self._prefetch_futures: - return - - executor = self._ensure_executor() - future = executor.submit(self._load_and_transform, index, stream) - self._prefetch_futures[index] = future - - def __getitem__( - self, index: int ) -> tuple[Union[Mesh, DomainMesh, TensorDict], dict[str, Any]]: - """Get a transformed sample by index. + """Consumer stage: device transfer + transforms on the calling thread. - If the index was prefetched, returns the prefetched result - (waiting for completion if necessary). Otherwise loads synchronously. + Runs on whatever thread calls this (the main thread, so any Warp + mesh-query kernels in the transforms share the model's launching + thread). When a CUDA ``stream`` is assigned, the host-to-device + copy *and* the transforms run on that preprocessing stream -- Warp + bound to it via :func:`preprocessing_stream` -- so this sample's + preprocessing overlaps the previous batch's training on the compute + stream. The result is handed back with a CUDA event (not a + host-side synchronize) and tagged via ``record_stream``. Parameters ---------- - index : int - Sample index. + payload : HostPayload + Producer payload from :meth:`_load_host`. + stream : torch.cuda.Stream, optional + Preprocessing stream for the host-to-device transfer and + transforms. ``None`` runs on the current stream. Returns ------- @@ -286,23 +249,43 @@ def __getitem__( Raises ------ Exception - If prefetch failed, re-raises the error. + If the producer captured an error, re-raises it. """ - future = self._prefetch_futures.pop(index, None) + if payload.error is not None: + raise payload.error - if future is not None: - result = future.result() + data = payload.data + metadata = payload.metadata - if isinstance(result, _PrefetchResult): - if result.error is not None: - raise result.error - if result.event is not None: - torch.cuda.current_stream().wait_event(result.event) - return result.data, result.metadata + def _apply_transforms(d: Any) -> Any: + for t in self.transforms: + if isinstance(d, DomainMesh): + d = t.apply_to_domain(d) + else: + d = t(d) + return d - return result + device_is_cuda = ( + self._device is not None and torch.device(self._device).type == "cuda" + ) + use_stream = stream is not None and device_is_cuda + compute_stream = torch.cuda.current_stream() if use_stream else None - return self._load(index) + with preprocessing_stream(stream if use_stream else None): + if self._device is not None: + data = data.to(self._device, non_blocking=True) + data = _apply_transforms(data) + + if use_stream: + # Order the preprocessing stream's result before the compute + # stream consumes it, without blocking the host: tag the memory + # so the allocator keeps it alive, then gate on a CUDA event. + record_stream(data, compute_stream) + event = torch.cuda.Event() + event.record(stream) + compute_stream.wait_event(event) + + return data, metadata def close(self) -> None: """Close the dataset and stop prefetching. diff --git a/physicsnemo/datapipes/multi_dataset.py b/physicsnemo/datapipes/multi_dataset.py index 2547bf5f95..8f8d7b9cb6 100644 --- a/physicsnemo/datapipes/multi_dataset.py +++ b/physicsnemo/datapipes/multi_dataset.py @@ -304,6 +304,51 @@ def __getitem__(self, index: int) -> tuple[TensorDict, dict[str, Any]]: metadata[DATASET_INDEX_METADATA_KEY] = ds_id return data, metadata + def submit(self, index: int, stream: Optional[Any] = None) -> tuple[int, Any]: + """ + Submit a global index for background loading (FIFO prefetch primitive). + + Maps the global index to its owning sub-dataset and delegates to + that dataset's :meth:`~DatasetBase.submit`. The returned handle is + wrapped with the owning dataset id so :meth:`consume` can restore + the ``dataset_index`` metadata. + + Parameters + ---------- + index : int + Global sample index to load. + stream : object, optional + CUDA stream for the consume step. + + Returns + ------- + tuple[int, PrefetchHandle] + ``(dataset_index, handle)`` to pass to :meth:`consume`. + """ + ds_id, local_i = self._index_to_dataset_and_local(index) + handle = self._datasets[ds_id].submit(local_i, stream=stream) + return ds_id, handle + + def consume(self, handle: tuple[int, Any]) -> tuple[TensorDict, dict[str, Any]]: + """ + Resolve a :meth:`submit` handle into ``(data, metadata)``. + + Parameters + ---------- + handle : tuple[int, PrefetchHandle] + The ``(dataset_index, handle)`` returned by :meth:`submit`. + + Returns + ------- + tuple[TensorDict, dict[str, Any]] + Sample and metadata, enriched with ``dataset_index``. + """ + ds_id, inner = handle + data, metadata = self._datasets[ds_id].consume(inner) + metadata = dict(metadata) + metadata[DATASET_INDEX_METADATA_KEY] = ds_id + return data, metadata + def prefetch( self, index: int, diff --git a/physicsnemo/datapipes/protocols.py b/physicsnemo/datapipes/protocols.py index 69d4363f3b..5ce61cf23f 100644 --- a/physicsnemo/datapipes/protocols.py +++ b/physicsnemo/datapipes/protocols.py @@ -15,16 +15,26 @@ # limitations under the License. """ -Base class for dataset components. +Base classes for dataset components. -Provides :class:`DatasetBase`, an ABC that owns the thread-based prefetch -infrastructure shared by :class:`Dataset`, :class:`MeshDataset`, and any -future dataset implementations. The user-facing extension points are -**Readers** and **Transforms**, not dataset subclasses. +Provides two abstractions consumed by :class:`~physicsnemo.datapipes.DataLoader`: + +- :class:`DatasetBase` -- map-style datasets. Owns the thread-based + prefetch infrastructure (a producer/consumer split plus a FIFO + ``submit``/``consume`` primitive) shared by :class:`Dataset`, + :class:`MeshDataset`, and any future implementation. +- :class:`IterableDatasetBase` -- generator-style datasets that produce + data directly on the main thread (online simulation and other + stream-sensitive workloads). No prefetch, no length, no indexing. + +The user-facing extension points are **Readers** and **Transforms**, not +dataset subclasses. """ from __future__ import annotations +import contextlib +import threading from abc import ABC, abstractmethod from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass, field @@ -32,62 +42,328 @@ import torch +from physicsnemo.core.function_spec import warp_stream_from_torch + +try: + import warp as wp + + _HAS_WARP = True +except ImportError: # pragma: no cover - warp is normally installed + wp = None + _HAS_WARP = False + + +@contextlib.contextmanager +def preprocessing_stream(stream: Optional["torch.cuda.Stream"]): + """Bind torch (and Warp) to *stream* for the host-to-device + transforms. + + Within the block both torch's current stream and -- when Warp is + installed -- Warp's current stream are set to *stream*, so Warp kernels + launched by transforms run on the same stream the data was copied on. + The single launching thread is the real Warp invariant; the stream is + free, but torch and Warp must agree on which one. A ``None`` stream is + a no-op (run on the current stream). + + Parameters + ---------- + stream : torch.cuda.Stream, optional + Stream to bind, or ``None`` to run on the current stream. + """ + if stream is None: + yield + return + with torch.cuda.stream(stream): + if _HAS_WARP: + # Use the cached wrapper: a fresh stream_from_torch per call would + # register/unregister the same CUDA handle on every consume and + # collide with the inner wrapper a Warp functional launch creates + # for the same stream, corrupting it (illegal memory access). + with wp.ScopedStream(warp_stream_from_torch(stream)): + yield + else: + yield + + +def record_stream(obj: Any, stream: "torch.cuda.Stream") -> None: + """Tag *obj*'s device tensors with *stream* for the caching allocator. + + Recurses into ``TensorDict``/tensor/mesh objects (which expose + ``record_stream``) and plain containers, so memory allocated on a + preprocessing stream is not recycled while the compute stream reads + it. Objects without device memory are ignored. + + Parameters + ---------- + obj : Any + Item to tag (tensor, TensorDict, mesh, dict, or sequence). + stream : torch.cuda.Stream + Stream that will consume the memory. + """ + record = getattr(obj, "record_stream", None) + if callable(record): + obj.record_stream(stream) + elif isinstance(obj, dict): + for value in obj.values(): + record_stream(value, stream) + elif isinstance(obj, (list, tuple)): + for value in obj: + record_stream(value, stream) + @dataclass -class _PrefetchResult: - """Result of a stream-aware prefetch operation. +class HostPayload: + """A sample produced by the (thread-safe) I/O stage, staged on the host. + + A ``HostPayload`` is the boundary object between the I/O producer and + the main-thread consumer. It carries a CPU ``TensorDict`` (ideally + pinned, so the subsequent host-to-device copy can be asynchronous) + plus metadata. It is produced by a worker thread, which must not + launch device kernels (in particular Warp kernels). - Used by :class:`Dataset` and :class:`MeshDataset` to carry data, - metadata, and an optional CUDA event through the prefetch pipeline. + Parameters + ---------- + work_item : Any + The work item this payload was produced from -- an ``int`` index + for map-style datasets, or any opaque descriptor for + descriptor-driven sources. + data : Any, optional + Host ``TensorDict`` (or mesh) payload. ``None`` on error. + metadata : dict, optional + Per-sample metadata produced by the reader. + error : Exception, optional + Exception captured during production, re-raised on consumption. """ - index: int + work_item: Any data: Any = None metadata: Optional[dict[str, Any]] = field(default=None) error: Optional[Exception] = field(default=None) - event: Optional[torch.cuda.Event] = field(default=None) + + +@dataclass +class PrefetchHandle: + """Handle to one in-flight prefetch returned by :meth:`DatasetBase.submit`. + + Bundles the producer ``Future`` with the CUDA stream the consumer + should use for the host-to-device copy and transforms. The handle is + correlated to its sample purely by identity/order, so opaque work + items need not be hashable or unique. + + Parameters + ---------- + future : concurrent.futures.Future + Future resolving to the producer's :class:`HostPayload`. + stream : torch.cuda.Stream, optional + Stream assigned for the consume step, or ``None`` for the current + stream. + """ + + future: Future + stream: Optional[torch.cuda.Stream] = None class DatasetBase(ABC): - """Abstract base for datasets compatible with :class:`DataLoader`. + """Abstract base for map-style datasets compatible with :class:`DataLoader`. - Subclasses implement :meth:`_load` (the actual data-loading pipeline) - and :meth:`__len__`. Everything else — ``__getitem__`` with prefetch - cache lookup, thread-pool prefetching, cancellation, cleanup — is - provided here. + Subclasses implement :meth:`_load` (the synchronous data-loading + pipeline) and :meth:`__len__`. Everything else -- indexing, the + ``submit``/``consume`` prefetch primitive, the index-keyed + ``prefetch``/``__getitem__`` convenience API, cancellation, and + cleanup -- is provided here. - Both :class:`Dataset` and :class:`MeshDataset` override - :meth:`prefetch` and :meth:`__getitem__` to add CUDA-stream - support via :class:`_PrefetchResult`. + Producer / consumer split + -------------------------- + Prefetching is split into two stages so that **no device kernels are + launched off the main thread** (a hard requirement for Warp + transforms, which must share the model's single launching thread): + + - :meth:`_load_host` is the **producer**. It runs on a worker thread + and performs only thread-safe work: reading, decoding, and staging + into host memory. It returns a :class:`HostPayload`. + - :meth:`_consume` is the **consumer**. It runs on the thread that + calls :meth:`consume` / :meth:`__getitem__` (the main thread, in + practice) and performs the host-to-device transfer and device + transforms (including Warp kernels) on the assigned CUDA stream. + + :class:`Dataset` and :class:`MeshDataset` override both hooks to + perform the real split. The default implementations fall back to + running the full :meth:`_load` on the worker for any subclass that + does not override them. """ - def __init__(self, *, num_workers: int = 2) -> None: - self._prefetch_futures: dict[int, Future] = {} + def __init__( + self, + *, + num_workers: int = 2, + serialize_load_consume: bool = False, + ) -> None: self._executor: Optional[ThreadPoolExecutor] = None self._num_workers = num_workers + self._lock = threading.Lock() + self._stage_lock = threading.Lock() + self._serialize_load_consume = serialize_load_consume + # Futures still in flight, tracked so close() can drain them. + self._inflight: set[Future] = set() + # Index-keyed handles backing the prefetch()/__getitem__ compat API. + self._prefetch_handles: dict[int, PrefetchHandle] = {} @abstractmethod def _load(self, index: int) -> tuple[Any, dict[str, Any]]: """Load and return a single sample ``(data, metadata)``. - This is the hook that subclasses must implement. It is called - both synchronously (from ``__getitem__``) and asynchronously - (from the prefetch thread pool). + This is the synchronous, full-pipeline hook that subclasses must + implement. It is called directly from :meth:`__getitem__` when the + index was not prefetched. """ ... @abstractmethod def __len__(self) -> int: ... + # ------------------------------------------------------------------ + # Producer / consumer hooks (overridden by Dataset / MeshDataset) + # ------------------------------------------------------------------ + + def _load_host(self, work_item: Any) -> HostPayload: + """Producer stage: load *work_item* into a :class:`HostPayload`. + + Runs on a worker thread and must not launch device kernels. The + default implementation runs the full :meth:`_load` pipeline for + backward compatibility; subclasses that use device transforms + override this to stop at host staging. + + Parameters + ---------- + work_item : Any + Work item to load (an ``int`` index by default). + + Returns + ------- + HostPayload + Payload carrying the host data and metadata, or a captured + error. + """ + try: + data, metadata = self._load(work_item) + return HostPayload(work_item=work_item, data=data, metadata=metadata) + except Exception as e: # noqa: BLE001 + return HostPayload(work_item=work_item, error=e) + + def _load_host_guarded(self, work_item: Any) -> HostPayload: + if not self._serialize_load_consume: + return self._load_host(work_item) + with self._stage_lock: + return self._load_host(work_item) + + def _consume( + self, + payload: HostPayload, + stream: Optional[torch.cuda.Stream] = None, + ) -> tuple[Any, dict[str, Any]]: + """Consumer stage: turn a :class:`HostPayload` into ``(data, metadata)``. + + Runs on the calling (main) thread. The default implementation + unwraps the payload and re-raises any captured error; the + ``stream`` argument is ignored. Subclasses override this to + perform the host-to-device transfer and device transforms. + + Parameters + ---------- + payload : HostPayload + Producer payload from :meth:`_load_host`. + stream : torch.cuda.Stream, optional + Stream for the consume step. + + Returns + ------- + tuple[Any, dict[str, Any]] + The sample data and its metadata. + """ + if payload.error is not None: + raise payload.error + return payload.data, payload.metadata + + def _consume_guarded( + self, + payload: HostPayload, + stream: Optional[torch.cuda.Stream] = None, + ) -> tuple[Any, dict[str, Any]]: + if not self._serialize_load_consume: + return self._consume(payload, stream) + with self._stage_lock: + return self._consume(payload, stream) + + # ------------------------------------------------------------------ + # FIFO prefetch primitive (used by the DataLoader's pump) + # ------------------------------------------------------------------ + + def submit( + self, + work_item: Any, + stream: Optional[torch.cuda.Stream] = None, + ) -> PrefetchHandle: + """Submit *work_item* for background loading and return its handle. + + Only the (thread-safe) producer stage runs on the worker pool. The + returned :class:`PrefetchHandle` is later passed to :meth:`consume` + on the main thread. Safe to call from a dispatcher thread distinct + from the consumer. + + Parameters + ---------- + work_item : Any + Work item to load. + stream : torch.cuda.Stream, optional + Stream the consumer should use for this sample. + + Returns + ------- + PrefetchHandle + Handle bundling the producer future and the assigned stream. + """ + executor = self._ensure_executor() + future = executor.submit(self._load_host_guarded, work_item) + with self._lock: + self._inflight.add(future) + future.add_done_callback(self._discard_inflight) + return PrefetchHandle(future=future, stream=stream) + + def consume(self, handle: PrefetchHandle) -> tuple[Any, dict[str, Any]]: + """Resolve a :meth:`submit` handle into ``(data, metadata)``. + + Blocks until the producer future is ready, then runs the consumer + stage on the calling thread (host-to-device transfer and device + transforms on the handle's stream). + + Parameters + ---------- + handle : PrefetchHandle + Handle returned by :meth:`submit`. + + Returns + ------- + tuple[Any, dict[str, Any]] + The sample data and its metadata. + """ + payload = handle.future.result() # re-raises producer errors via _consume + return self._consume_guarded(payload, handle.stream) + # ------------------------------------------------------------------ # Concrete interface # ------------------------------------------------------------------ def __getitem__(self, index: int) -> tuple[Any, dict[str, Any]]: - """Return sample *index*, using a prefetched result when available.""" - future = self._prefetch_futures.pop(index, None) - if future is not None: - return future.result() # re-raises on error + """Return sample *index*, using a prefetched result when available. + + When the index was prefetched via :meth:`prefetch`, the pending + handle is consumed on the calling thread (so device transforms run + here, not on the worker). Otherwise the sample is loaded + synchronously. + """ + with self._lock: + handle = self._prefetch_handles.pop(index, None) + if handle is not None: + return self.consume(handle) return self._load(index) def prefetch( @@ -95,32 +371,47 @@ def prefetch( index: int, stream: Optional[torch.cuda.Stream] = None, ) -> None: - """Submit *index* for background loading in a worker thread. + """Start prefetching *index*, retrievable later via :meth:`__getitem__`. - The ``stream`` parameter is accepted for interface compatibility - but ignored by the default implementation. :class:`Dataset` - overrides this to run GPU transfers on the given stream. + Index-keyed convenience wrapper around :meth:`submit`. A repeated + prefetch of an in-flight index is a no-op. + + Parameters + ---------- + index : int + Sample index to prefetch. + stream : torch.cuda.Stream, optional + Stream the consumer should use for this sample. """ - if index in self._prefetch_futures: - return - executor = self._ensure_executor() - self._prefetch_futures[index] = executor.submit(self._load, index) + with self._lock: + if index in self._prefetch_handles: + return + handle = self.submit(index, stream) + with self._lock: + if index in self._prefetch_handles: + return # raced; the extra handle's future is reaped on completion + self._prefetch_handles[index] = handle def cancel_prefetch(self, index: Optional[int] = None) -> None: - """Discard prefetch results (already-running tasks still complete).""" - if index is None: - self._prefetch_futures.clear() - else: - self._prefetch_futures.pop(index, None) + """Discard prefetch handles (already-running tasks still complete).""" + with self._lock: + if index is None: + self._prefetch_handles.clear() + else: + self._prefetch_handles.pop(index, None) def close(self) -> None: """Drain in-flight prefetches and shut down the thread pool.""" - for future in self._prefetch_futures.values(): + with self._lock: + futures = list(self._inflight) + self._prefetch_handles.clear() + for future in futures: try: future.result(timeout=30.0) except Exception: # noqa: BLE001, S110 pass - self._prefetch_futures.clear() + with self._lock: + self._inflight.clear() if self._executor is not None: self._executor.shutdown(wait=True) @@ -130,13 +421,19 @@ def close(self) -> None: # Helpers # ------------------------------------------------------------------ + def _discard_inflight(self, future: Future) -> None: + """Done-callback: drop a finished future from the in-flight set.""" + with self._lock: + self._inflight.discard(future) + def _ensure_executor(self) -> ThreadPoolExecutor: - if self._executor is None: - self._executor = ThreadPoolExecutor( - max_workers=self._num_workers, - thread_name_prefix="datapipe_prefetch", - ) - return self._executor + with self._lock: + if self._executor is None: + self._executor = ThreadPoolExecutor( + max_workers=self._num_workers, + thread_name_prefix="datapipe_prefetch", + ) + return self._executor def __iter__(self) -> Iterator[tuple[Any, dict[str, Any]]]: for i in range(len(self)): @@ -147,3 +444,70 @@ def __enter__(self) -> "DatasetBase": def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() + + +class IterableDatasetBase(ABC): + """Abstract base for generator-style datasets driven on the main thread. + + Unlike :class:`DatasetBase`, an iterable dataset has no length and no + indexing: it produces data by iteration only. The + :class:`~physicsnemo.datapipes.DataLoader` drives it entirely on the + main thread (no worker pool), so :meth:`__iter__` may freely launch + Warp kernels and use CUDA streams -- the property that makes online + simulation safe here but unsafe on the worker-pool preload path. + + Emission modes + -------------- + - **Per-sample** (default, ``yields_batches = False``): :meth:`__iter__` + yields ``(data, metadata)`` for one sample at a time and the loader + collates ``batch_size`` of them. + - **Self-batching** (``yields_batches = True``): :meth:`__iter__` yields + a fully-formed batch already and the loader passes it through + unchanged (``batch_size``/``drop_last`` do not apply). + + Subclasses may optionally implement :meth:`set_epoch` and + :meth:`set_generator` for reproducible seeding. + """ + + # When True, __iter__ yields ready-made batches and the loader does not + # re-collate (e.g. an online simulator that produces a batch per step). + yields_batches: bool = False + + @abstractmethod + def __iter__(self) -> Iterator[Any]: + """Yield samples ``(data, metadata)`` or ready-made batches. + + Returns + ------- + Iterator + Per-sample ``(data, metadata)`` tuples, or full batches when + :attr:`yields_batches` is True. + """ + ... + + def set_epoch(self, epoch: int) -> None: + """Reseed for *epoch* (no-op by default). + + Parameters + ---------- + epoch : int + Current epoch number. + """ + + def set_generator(self, generator: torch.Generator) -> None: + """Seed the dataset's randomness from *generator* (no-op by default). + + Parameters + ---------- + generator : torch.Generator + Parent generator supplied by the DataLoader. + """ + + def __enter__(self) -> "IterableDatasetBase": + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.close() + + def close(self) -> None: + """Release any resources held by the dataset (no-op by default).""" diff --git a/physicsnemo/datapipes/readers/base.py b/physicsnemo/datapipes/readers/base.py index 17ce9b0b86..9ec8f2abe3 100644 --- a/physicsnemo/datapipes/readers/base.py +++ b/physicsnemo/datapipes/readers/base.py @@ -31,6 +31,8 @@ import torch from tensordict import TensorDict +from physicsnemo.datapipes._rng import spawn_generator + logger = logging.getLogger(__name__) @@ -110,6 +112,11 @@ def __init__( self.pin_memory = pin_memory self.include_index_in_metadata = include_index_in_metadata self._coordinated_subsampling_config = coordinated_subsampling + # Base seed + epoch for deterministic per-index RNG. See + # :meth:`_index_generator`. ``None`` means no seed was provided + # (random draws fall back to the global default RNG). + self._seed_base: int | None = None + self._epoch: int = 0 @abstractmethod def _load_sample(self, index: int) -> dict[str, torch.Tensor]: @@ -279,28 +286,61 @@ def __iter__(self) -> Iterator[tuple[TensorDict, dict[str, Any]]]: raise RuntimeError(error_msg) from e def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible random sampling. + """Assign a base seed for reproducible, order-independent sampling. + + Stores ``generator.initial_seed()`` as the base seed. Per-sample + generators are then derived deterministically from + ``(base_seed, epoch, index)`` via :meth:`_index_generator`, so + random draws are reproducible regardless of the order in which + samples are read or which worker thread reads them. - Override in subclasses that use randomness (e.g. subsampling). - The default implementation is a no-op. + Subclasses that use randomness should draw from + :meth:`_index_generator` rather than a shared generator. Readers + with no randomness inherit this harmlessly. Parameters ---------- generator : torch.Generator - Generator to use for random draws. + Generator whose ``initial_seed()`` seeds all per-sample RNG. """ + self._seed_base = generator.initial_seed() def set_epoch(self, epoch: int) -> None: - """Reseed the reader's RNG for a new epoch. + """Set the epoch used to vary per-sample RNG deterministically. - Override in subclasses that use randomness. - The default implementation is a no-op. + The epoch is folded into each sample's derived seed (see + :meth:`_index_generator`), so each epoch produces a different but + reproducible sequence. Parameters ---------- epoch : int Current epoch number. """ + self._epoch = epoch + + def _index_generator(self, index: int) -> torch.Generator | None: + """Return a fresh generator seeded for sample *index* this epoch. + + Derives an independent :class:`torch.Generator` from + ``(base_seed, epoch, index)``. Because the seed depends only on + those values, the draw for a given sample is identical regardless + of read order or worker thread. Returns ``None`` when no base + seed has been set (the unseeded fallback). + + Parameters + ---------- + index : int + Sample index. + + Returns + ------- + torch.Generator or None + A per-sample generator, or ``None`` if no seed was provided. + """ + if self._seed_base is None: + return None + return spawn_generator(self._seed_base, self._epoch, index) def close(self) -> None: """ diff --git a/physicsnemo/datapipes/readers/mesh.py b/physicsnemo/datapipes/readers/mesh.py index 603dcf13bd..479f5da6a4 100644 --- a/physicsnemo/datapipes/readers/mesh.py +++ b/physicsnemo/datapipes/readers/mesh.py @@ -30,6 +30,7 @@ import torch +from physicsnemo.datapipes._rng import spawn_generator from physicsnemo.datapipes.registry import register from physicsnemo.mesh import DomainMesh, Mesh @@ -188,7 +189,10 @@ def __init__( self.include_index_in_metadata = include_index_in_metadata self.subsample_n_points = subsample_n_points self.subsample_n_cells = subsample_n_cells - self._subsample_generator: torch.Generator | None = None + # Base seed + epoch for deterministic per-index RNG (see + # :meth:`set_generator`). ``None`` means unseeded. + self._seed_base: int | None = None + self._epoch: int = 0 if not self._root.exists(): raise FileNotFoundError(f"Path not found: {self._root}") @@ -212,38 +216,43 @@ def __len__(self) -> int: return len(self._paths) def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling. + """Assign a base seed for reproducible, order-independent subsampling. Called by :class:`MeshDataset` when the DataLoader provides a - seed. Replaces any previously assigned generator. + seed. Stores ``generator.initial_seed()`` as the base seed; each + sample then derives its own generator from + ``(base_seed, epoch, index)``, so subsampling is reproducible + regardless of read order or worker thread. Parameters ---------- generator : torch.Generator - Generator to use for contiguous block selection. + Generator whose ``initial_seed()`` seeds all per-sample RNG. """ - self._subsample_generator = generator + self._seed_base = generator.initial_seed() def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch. + """Set the epoch used to vary per-sample RNG deterministically. - Produces a different (but deterministic) sequence of contiguous - blocks each epoch when a generator has been assigned via - :meth:`set_generator`. + The epoch is folded into each sample's derived seed, producing a + different (but deterministic) sequence of contiguous blocks each + epoch when a base seed has been assigned via :meth:`set_generator`. """ - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) + self._epoch = epoch def __getitem__(self, index: int) -> tuple[Mesh, dict[str, Any]]: mesh = self._load_sample(index) + generator = ( + None + if self._seed_base is None + else spawn_generator(self._seed_base, self._epoch, index) + ) mesh = _subsample_mesh( mesh, self.subsample_n_cells, self.subsample_n_points, - generator=self._subsample_generator, + generator=generator, ) if self.pin_memory: @@ -339,7 +348,10 @@ def __init__( self.include_index_in_metadata = include_index_in_metadata self.subsample_n_points = subsample_n_points self.subsample_n_cells = subsample_n_cells - self._subsample_generator: torch.Generator | None = None + # Base seed + epoch for deterministic per-index RNG (see + # :meth:`set_generator`). ``None`` means unseeded. + self._seed_base: int | None = None + self._epoch: int = 0 self._extra_boundaries = extra_boundaries or {} if not self._root.exists(): @@ -359,38 +371,43 @@ def __len__(self) -> int: return len(self._paths) def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling. + """Assign a base seed for reproducible, order-independent subsampling. Called by :class:`MeshDataset` when the DataLoader provides a - seed. Replaces any previously assigned generator. + seed. Stores ``generator.initial_seed()`` as the base seed; each + sample then derives its own generator from + ``(base_seed, epoch, index)``, so subsampling is reproducible + regardless of read order or worker thread. Parameters ---------- generator : torch.Generator - Generator to use for contiguous block selection. + Generator whose ``initial_seed()`` seeds all per-sample RNG. """ - self._subsample_generator = generator + self._seed_base = generator.initial_seed() def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch. + """Set the epoch used to vary per-sample RNG deterministically. - Produces a different (but deterministic) sequence of contiguous - blocks each epoch when a generator has been assigned via - :meth:`set_generator`. + The epoch is folded into each sample's derived seed, producing a + different (but deterministic) sequence of contiguous blocks each + epoch when a base seed has been assigned via :meth:`set_generator`. """ - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) + self._epoch = epoch def __getitem__(self, index: int) -> tuple[DomainMesh, dict[str, Any]]: dm = self._load_sample(index) if self.subsample_n_cells is not None or self.subsample_n_points is not None: + generator = ( + None + if self._seed_base is None + else spawn_generator(self._seed_base, self._epoch, index) + ) sub_kw = dict( n_cells=self.subsample_n_cells, n_points=self.subsample_n_points, - generator=self._subsample_generator, + generator=generator, ) dm = DomainMesh( interior=_subsample_mesh(dm.interior, **sub_kw), diff --git a/physicsnemo/datapipes/readers/numpy.py b/physicsnemo/datapipes/readers/numpy.py index 2f546d5916..a82e347a7e 100644 --- a/physicsnemo/datapipes/readers/numpy.py +++ b/physicsnemo/datapipes/readers/numpy.py @@ -112,7 +112,6 @@ def __init__( self.default_values = default_values or {} self.file_pattern = file_pattern self.index_key = index_key - self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -168,22 +167,12 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields - def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling.""" - self._subsample_generator = generator - - def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch.""" - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) - def _select_random_sections_from_slice( self, slice_start: int, slice_stop: int, n_points: int, + generator: Optional[torch.Generator] = None, ) -> slice: """ Select a random contiguous slice from a range. @@ -196,6 +185,9 @@ def _select_random_sections_from_slice( Stop index of the available range (exclusive). n_points : int Number of points to sample. + generator : torch.Generator, optional + Per-sample generator for reproducible, order-independent + selection. ``None`` uses the global default RNG. Returns ------- @@ -219,7 +211,7 @@ def _select_random_sections_from_slice( slice_start, slice_stop - n_points + 1, (1,), - generator=self._subsample_generator, + generator=generator, ).item() return slice(start, start + n_points) @@ -228,6 +220,7 @@ def _load_from_npz( npz: np.lib.npyio.NpzFile, index: Optional[int] = None, file_path: Optional[Path] = None, + generator: Optional[torch.Generator] = None, ) -> dict[str, torch.Tensor]: """ Load data from an npz file. @@ -241,6 +234,8 @@ def _load_from_npz( None for directory mode (load entire arrays). file_path : Path, optional Path to the file (for error messages). + generator : torch.Generator, optional + Per-sample generator for reproducible coordinated subsampling. Returns ------- @@ -272,7 +267,7 @@ def _load_from_npz( if field in npz.files: array_shape = npz[field].shape[0] subsample_slice = self._select_random_sections_from_slice( - 0, array_shape, n_points + 0, array_shape, n_points, generator=generator ) break @@ -301,12 +296,17 @@ def _load_from_npz( def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample.""" + # Derive a per-sample generator so coordinated subsampling is + # reproducible regardless of read order or worker thread. + generator = self._index_generator(index) if self._mode == "directory": file_path = self._files[index] with np.load(file_path) as npz: - return self._load_from_npz(npz, index=None, file_path=file_path) + return self._load_from_npz( + npz, index=None, file_path=file_path, generator=generator + ) else: # single - return self._load_from_npz(self._data, index=index) + return self._load_from_npz(self._data, index=index, generator=generator) def __len__(self) -> int: """Return number of samples.""" diff --git a/physicsnemo/datapipes/readers/tensorstore_zarr.py b/physicsnemo/datapipes/readers/tensorstore_zarr.py index 9bc407aea5..5ce3e6021a 100644 --- a/physicsnemo/datapipes/readers/tensorstore_zarr.py +++ b/physicsnemo/datapipes/readers/tensorstore_zarr.py @@ -155,7 +155,6 @@ def __init__( self._user_fields = fields self.default_values = default_values or {} self.group_pattern = group_pattern - self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -235,17 +234,6 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields - def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling.""" - self._subsample_generator = generator - - def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch.""" - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) - def _read_attributes(self, group_path: Path) -> dict[str, Any]: """Read attributes from a Zarr group (v2 or v3).""" store_spec = {"driver": "file", "path": str(group_path)} @@ -274,6 +262,7 @@ def _select_random_sections_from_slice( slice_start: int, slice_stop: int, n_points: int, + generator: Optional[torch.Generator] = None, ) -> slice: """ Select a random contiguous slice from a range. @@ -286,6 +275,9 @@ def _select_random_sections_from_slice( Stop index of the available range (exclusive). n_points : int Number of points to sample. + generator : torch.Generator, optional + Per-sample generator for reproducible, order-independent + selection. ``None`` uses the global default RNG. Returns ------- @@ -309,13 +301,15 @@ def _select_random_sections_from_slice( slice_start, slice_stop - n_points + 1, (1,), - generator=self._subsample_generator, + generator=generator, ).item() return slice(start, start + n_points) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample from a Zarr group using TensorStore.""" group_path = self._groups[index] + # Per-sample generator: reproducible regardless of read order/thread. + generator = self._index_generator(index) # Read attributes (stored as tensors in sample) attributes = self._read_attributes(group_path) @@ -368,7 +362,7 @@ def _load_sample(self, index: int) -> dict[str, torch.Tensor]: if key in stores: array_shape = stores[key].shape[0] subsample_slice = self._select_random_sections_from_slice( - 0, array_shape, n_points + 0, array_shape, n_points, generator=generator ) break diff --git a/physicsnemo/datapipes/readers/zarr.py b/physicsnemo/datapipes/readers/zarr.py index a27cee5447..2d2a70cbd2 100644 --- a/physicsnemo/datapipes/readers/zarr.py +++ b/physicsnemo/datapipes/readers/zarr.py @@ -144,7 +144,6 @@ def __init__( self.group_pattern = group_pattern self._cache_stores = cache_stores self._cached_stores: dict[Path, Any] = {} # Cache for opened zarr stores - self._subsample_generator: torch.Generator | None = None if not self.path.exists(): raise FileNotFoundError(f"Path not found: {self.path}") @@ -206,17 +205,6 @@ def fields(self) -> list[str]: return self._user_fields return self._available_fields - def set_generator(self, generator: torch.Generator) -> None: - """Assign a ``torch.Generator`` for reproducible subsampling.""" - self._subsample_generator = generator - - def set_epoch(self, epoch: int) -> None: - """Reseed the subsample RNG for a new epoch.""" - if self._subsample_generator is not None: - self._subsample_generator.manual_seed( - self._subsample_generator.initial_seed() + epoch - ) - def _open_zarr_store(self, path: Path) -> Any: """ Open a zarr store, using cache if enabled. @@ -255,6 +243,7 @@ def _select_random_sections_from_slice( slice_start: int, slice_stop: int, n_points: int, + generator: Optional[torch.Generator] = None, ) -> slice: """ Select a random contiguous slice from a range. @@ -267,6 +256,9 @@ def _select_random_sections_from_slice( Stop index of the available range (exclusive). n_points : int Number of points to sample. + generator : torch.Generator, optional + Per-sample generator for reproducible, order-independent + selection. ``None`` uses the global default RNG. Returns ------- @@ -290,12 +282,14 @@ def _select_random_sections_from_slice( slice_start, slice_stop - n_points + 1, (1,), - generator=self._subsample_generator, + generator=generator, ).item() return slice(start, start + n_points) def _load_sample(self, index: int) -> dict[str, torch.Tensor]: """Load a single sample from a Zarr group.""" + # Per-sample generator: reproducible regardless of read order/thread. + generator = self._index_generator(index) if self._single_group_mode: # Single group: index into first dimension of each array group_path = self._groups[0] @@ -339,7 +333,7 @@ def _load_sample(self, index: int) -> dict[str, torch.Tensor]: else: array_shape = root[field].shape[0] subsample_slice = self._select_random_sections_from_slice( - 0, array_shape, n_points + 0, array_shape, n_points, generator=generator ) break diff --git a/physicsnemo/nn/functional/geometry/sdf.py b/physicsnemo/nn/functional/geometry/sdf.py index 4a9e8ea06b..79c519b575 100644 --- a/physicsnemo/nn/functional/geometry/sdf.py +++ b/physicsnemo/nn/functional/geometry/sdf.py @@ -122,7 +122,6 @@ def signed_distance_field_impl( >>> signed_distance_field(mesh_vertices, mesh_indices, input_points) (tensor([0.5]), tensor([0.5, 0.5, 0.5])) """ - if input_points.shape[-1] != 3: raise ValueError("input_points must have last dimension of size 3") @@ -154,12 +153,23 @@ def signed_distance_field_impl( with wp.ScopedStream(wp_launch_stream): wp.init() - # zero copy the vertices, indices, and input points to warp: - wp_vertices = wp.from_torch(mesh_vertices.to(torch.float32), dtype=wp.vec3) - wp_indices = wp.from_torch( - mesh_indices.to(torch.int32).contiguous(), dtype=wp.int32 + mesh_vertices_f32 = mesh_vertices.to(torch.float32) + mesh_indices_i32 = mesh_indices.to(torch.int32).contiguous() + input_points_f32 = input_points.to(torch.float32) + torch_launch_stream = ( + torch.cuda.current_stream(input_points.device) + if input_points.device.type == "cuda" + else None ) - wp_input_points = wp.from_torch(input_points.to(torch.float32), dtype=wp.vec3) + if torch_launch_stream is not None: + mesh_vertices_f32.record_stream(torch_launch_stream) + mesh_indices_i32.record_stream(torch_launch_stream) + input_points_f32.record_stream(torch_launch_stream) + + # zero copy the vertices, indices, and input points to warp: + wp_vertices = wp.from_torch(mesh_vertices_f32, dtype=wp.vec3) + wp_indices = wp.from_torch(mesh_indices_i32, dtype=wp.int32) + wp_input_points = wp.from_torch(input_points_f32, dtype=wp.vec3) # Convert output points: wp_sdf = wp.from_torch(sdf, dtype=wp.float32) @@ -190,7 +200,9 @@ def signed_distance_field_impl( sdf = sdf.reshape(input_shape[:-1]) sdf_hit_point = sdf_hit_point.reshape(input_shape) - return sdf.to(input_points.dtype), sdf_hit_point.to(input_points.dtype) + sdf_out = sdf.to(input_points.dtype) + sdf_hit_point_out = sdf_hit_point.to(input_points.dtype) + return sdf_out, sdf_hit_point_out @signed_distance_field_impl.register_fake diff --git a/test/datapipes/core/test_dataset.py b/test/datapipes/core/test_dataset.py index c6f382725c..f46424559d 100644 --- a/test/datapipes/core/test_dataset.py +++ b/test/datapipes/core/test_dataset.py @@ -742,3 +742,60 @@ def test_gpu_iteration_with_transforms(self, numpy_data_dir): assert data["positions"].device.type == "cuda" torch.cuda.synchronize() + + +class TestDatasetSubsamplingReproducibility: + """Reader subsampling RNG is reproducible across the threaded path. + + Reader subsampling derives its generator per ``(base_seed, epoch, + index)``, so the multi-worker ``prefetch`` path matches synchronous + loading and two seeded runs agree -- even with ``num_workers > 1``. + """ + + def _make_dataset(self, data_dir, seed: int) -> dp.Dataset: + reader = dp.NumpyReader( + data_dir, + file_pattern="sample_*.npz", + fields=["positions", "features"], + coordinated_subsampling={ + "n_points": 50, + "target_keys": ["positions", "features"], + }, + ) + dataset = dp.Dataset(reader, num_workers=2) + dataset.set_generator(torch.Generator().manual_seed(seed)) + return dataset + + def test_prefetch_matches_synchronous(self, numpy_data_dir): + """Threaded prefetch (num_workers=2) matches synchronous loading.""" + indices = list(range(10)) + + sync_ds = self._make_dataset(numpy_data_dir, seed=2024) + expected = {i: sync_ds._load(i)[0]["positions"] for i in indices} + + async_ds = self._make_dataset(numpy_data_dir, seed=2024) + for i in indices: + async_ds.prefetch(i) + actual = {i: async_ds[i][0]["positions"] for i in indices} + async_ds.close() + + for i in indices: + assert torch.equal(actual[i], expected[i]) + + def test_two_seeded_runs_identical(self, numpy_data_dir): + """Two datasets seeded identically yield identical prefetch results.""" + indices = list(range(10)) + + ds_a = self._make_dataset(numpy_data_dir, seed=99) + ds_b = self._make_dataset(numpy_data_dir, seed=99) + for i in indices: + ds_a.prefetch(i) + ds_b.prefetch(i) + + for i in indices: + a = ds_a[i][0]["positions"] + b = ds_b[i][0]["positions"] + assert torch.equal(a, b) + + ds_a.close() + ds_b.close() diff --git a/test/datapipes/core/test_streaming.py b/test/datapipes/core/test_streaming.py new file mode 100644 index 0000000000..1fbfa3c841 --- /dev/null +++ b/test/datapipes/core/test_streaming.py @@ -0,0 +1,553 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the lazy preload path and the iterable (generator) dataset path. + +Stage 1 covers the lazy, FIFO-handle preload path (IOPump laziness, +``BATCH_BOUNDARY`` reassembly, opaque work items, the ``submit``/``consume`` +primitive). Stage 2 covers iterable datasets driven main-thread-only +(finite/capped-infinite generators, ``drop_last``, self-batching +pass-through, reproducibility, and no worker pool). CUDA-guarded tests +exercise stream-bound preprocessing and Warp-on-a-non-default-stream. +""" + +from __future__ import annotations + +import threading +import time + +import numpy as np +import pytest +import torch +from tensordict import TensorDict + +import physicsnemo.datapipes as dp +from physicsnemo.datapipes.io_pump import BATCH_BOUNDARY, IOPump +from physicsnemo.datapipes.protocols import DatasetBase, IterableDatasetBase + +# ============================================================================ +# Stage 1 -- IOPump (lazy, FIFO, batch boundaries) +# ============================================================================ + + +class TestIOPump: + """Tests for the lazy, self-driving prefetch pump.""" + + def test_lazy_bounded_pull_on_infinite_source(self): + """Pump pulls an unbounded source lazily, bounded by depth.""" + pulled: list[int] = [] + + def source(): + i = 0 + while True: + pulled.append(i) + yield i + i += 1 + + depth = 3 + pump = IOPump(source(), lambda x: x, depth=depth) + out = [] + for item in pump: + out.append(item) + if len(out) == 5: + break + pump.stop() + + assert out == [0, 1, 2, 3, 4] + # The dispatcher must not have run far ahead of what was consumed: + # at most consumed + depth + a small slack for the in-flight pull. + assert len(pulled) <= 5 + depth + 2 + + def test_batch_boundary_reassembly_irregular(self): + """Boundaries delimit dynamically-sized batches without slot use.""" + source = [0, 1, BATCH_BOUNDARY, 2, BATCH_BOUNDARY, 3, 4, 5, BATCH_BOUNDARY] + pump = IOPump(iter(source), lambda x: x, depth=2) + + batches: list[list[int]] = [] + current: list[int] = [] + for item in pump: + if item is BATCH_BOUNDARY: + batches.append(current) + current = [] + else: + current.append(item) + pump.stop() + + assert batches == [[0, 1], [2], [3, 4, 5]] + + def test_fifo_order_preserved(self): + """Handles are yielded in the order work items were pulled.""" + pump = IOPump(iter(range(20)), lambda x: x * 10, depth=4) + out = list(pump) + pump.stop() + assert out == [x * 10 for x in range(20)] + + def test_dispatch_error_surfaces_not_hangs(self): + """A dispatch exception is raised on the consumer, never a hang.""" + + def boom(x): + raise RuntimeError("dispatch failed") + + pump = IOPump(iter(range(5)), boom, depth=2) + with pytest.raises(RuntimeError, match="dispatch failed"): + list(pump) + pump.stop() + + def test_source_error_surfaces_not_hangs(self): + """A failing source is raised on the consumer, never a hang.""" + + def source(): + yield 0 + raise ValueError("source failed") + + pump = IOPump(source(), lambda x: x, depth=2) + with pytest.raises(ValueError, match="source failed"): + list(pump) + pump.stop() + + +# ============================================================================ +# Stage 1 -- submit / consume FIFO primitive with opaque work items +# ============================================================================ + + +class _DescriptorDataset(DatasetBase): + """Map-style dataset keyed by an opaque (non-int) descriptor.""" + + def __init__(self): + super().__init__(num_workers=2) + self._store = {"alpha": 1.0, "beta": 2.0, "gamma": 3.0} + + def _load(self, key): + if key == "explode": + raise KeyError("no such key") + return TensorDict({"x": torch.tensor([self._store[key]])}), {"key": key} + + def __len__(self): + return len(self._store) + + +class _StageLockedDataset(DatasetBase): + """Dataset that records whether worker load overlaps consume.""" + + def __init__(self): + super().__init__(num_workers=1, serialize_load_consume=True) + self.load_entries: list[int] = [] + self.consume_started = threading.Event() + self.release_consume = threading.Event() + + def _load(self, index): + return TensorDict({"x": torch.tensor([float(index)])}), {"index": index} + + def _load_host(self, work_item): + self.load_entries.append(work_item) + return super()._load_host(work_item) + + def _consume(self, payload, stream=None): + self.consume_started.set() + self.release_consume.wait(timeout=5.0) + return super()._consume(payload, stream) + + def __len__(self): + return 2 + + +class TestSubmitConsume: + """Tests for the FIFO submit/consume primitive.""" + + def test_opaque_descriptor_roundtrip(self): + """submit/consume works with non-int, string work items.""" + ds = _DescriptorDataset() + try: + handle = ds.submit("beta") + data, metadata = ds.consume(handle) + assert metadata["key"] == "beta" + assert data["x"].item() == 2.0 + finally: + ds.close() + + def test_submit_consume_fifo_independent_of_value(self): + """Multiple in-flight handles consume to their own results.""" + ds = _DescriptorDataset() + try: + handles = [ds.submit(k) for k in ("alpha", "beta", "gamma")] + keys = [ds.consume(h)[1]["key"] for h in handles] + assert keys == ["alpha", "beta", "gamma"] + finally: + ds.close() + + def test_producer_error_reraised_on_consume(self): + """An error raised in the producer surfaces on consume.""" + ds = _DescriptorDataset() + try: + handle = ds.submit("explode") + with pytest.raises(KeyError): + ds.consume(handle) + finally: + ds.close() + + def test_stage_lock_prevents_load_consume_overlap(self): + """Opt-in stage lock keeps worker loads out of active consume.""" + ds = _StageLockedDataset() + try: + first = ds.submit(0) + first.future.result(timeout=5.0) + + consumer = threading.Thread(target=ds.consume, args=(first,)) + consumer.start() + assert ds.consume_started.wait(timeout=5.0) + + second = ds.submit(1) + time.sleep(0.1) + assert ds.load_entries == [0] + + ds.release_consume.set() + consumer.join(timeout=5.0) + second.future.result(timeout=5.0) + assert ds.load_entries == [0, 1] + finally: + ds.release_consume.set() + ds.close() + + +# ============================================================================ +# Stage 1 -- DataLoader laziness over the sampler +# ============================================================================ + + +class _CountingSampler: + """Sequential sampler that records how many indices it has yielded.""" + + def __init__(self, n): + self.n = n + self.consumed = 0 + + def __iter__(self): + self.consumed = 0 + for i in range(self.n): + self.consumed += 1 + yield i + + def __len__(self): + return self.n + + +class TestDataLoaderLazyPreload: + """The preload path must not materialize the whole epoch up front.""" + + def test_sampler_not_fully_drained_on_early_break(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + sampler = _CountingSampler(10) + loader = dp.DataLoader( + dataset, batch_size=2, sampler=sampler, prefetch_factor=1 + ) + + first = next(iter(loader)) + assert first["positions"].shape[0] == 2 + # Only a bounded prefix of the sampler should have been consumed, + # never the full epoch, after pulling a single batch. + assert sampler.consumed < 10 + + def test_preload_matches_sequential_order(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir) + dataset = dp.Dataset(reader) + loader = dp.DataLoader( + dataset, batch_size=3, shuffle=False, collate_metadata=True + ) + indices = [] + for _batch, metadata_list in loader: + indices.extend(m["index"] for m in metadata_list) + assert indices == list(range(10)) + + +# ============================================================================ +# Stage 2 -- iterable datasets +# ============================================================================ + + +class _RangeIterable(IterableDatasetBase): + """Finite per-sample generator yielding (TensorDict, metadata).""" + + def __init__(self, n, dim=4): + self.n = n + self.dim = dim + + def __iter__(self): + for i in range(self.n): + data = TensorDict({"x": torch.full((self.dim,), float(i))}) + yield data, {"index": i} + + +class _BatchIterable(IterableDatasetBase): + """Self-batching generator yielding ready-made batches.""" + + yields_batches = True + + def __init__(self, n_batches, batch=4, dim=3): + self.n_batches = n_batches + self.batch = batch + self.dim = dim + + def __iter__(self): + for b in range(self.n_batches): + yield TensorDict( + {"x": torch.full((self.batch, self.dim), float(b))}, + batch_size=[self.batch], + ) + + +class _SeededIterable(IterableDatasetBase): + """Per-(epoch, position) seeded generator for reproducibility tests.""" + + def __init__(self, n, base_seed=0): + self.n = n + self.base_seed = base_seed + self.epoch = 0 + + def set_epoch(self, epoch): + self.epoch = epoch + + def __iter__(self): + for position in range(self.n): + seed = int( + np.random.SeedSequence( + [self.base_seed, self.epoch, position] + ).generate_state(1)[0] + ) + g = torch.Generator().manual_seed(seed) + yield TensorDict({"x": torch.rand(3, generator=g)}), {"position": position} + + +class _ThreadRecordingIterable(IterableDatasetBase): + """Records which thread the generator runs on.""" + + def __init__(self, n): + self.n = n + self.threads = [] + + def __iter__(self): + for i in range(self.n): + self.threads.append(threading.current_thread()) + yield TensorDict({"x": torch.zeros(2)}), {"index": i} + + +class TestIterableDataLoader: + """Tests for the main-thread-only iterable (generator) path.""" + + def test_per_sample_batching(self): + loader = dp.DataLoader(_RangeIterable(10), batch_size=4) + batches = list(loader) + # 10 samples / 4 -> [4, 4, 2] + assert [b["x"].shape[0] for b in batches] == [4, 4, 2] + + def test_per_sample_drop_last(self): + loader = dp.DataLoader(_RangeIterable(10), batch_size=4, drop_last=True) + batches = list(loader) + # Trailing partial batch dropped -> [4, 4] + assert [b["x"].shape[0] for b in batches] == [4, 4] + + def test_self_batching_passthrough(self): + # The loader batch_size is intentionally different from the generator's + # to prove it is ignored for self-batching datasets. + loader = dp.DataLoader(_BatchIterable(3, batch=5), batch_size=2) + batches = list(loader) + assert len(batches) == 3 + assert all(b["x"].shape[0] == 5 for b in batches) + + def test_len_raises_for_iterable(self): + loader = dp.DataLoader(_RangeIterable(10), batch_size=2) + with pytest.raises(TypeError): + len(loader) + + def test_capped_infinite_consumes_without_length(self): + """A long generator is iterated batch-by-batch; len() is never used.""" + + class _BigIterable(IterableDatasetBase): + def __iter__(self): + i = 0 + while i < 10_000: + yield TensorDict({"x": torch.zeros(2)}), {"index": i} + i += 1 + + loader = dp.DataLoader(_BigIterable(), batch_size=4) + seen = 0 + for _batch in loader: + seen += 1 + if seen == 3: + break + assert seen == 3 + + def test_shuffle_warns_for_iterable(self): + with pytest.warns(UserWarning, match="ignored for iterable"): + dp.DataLoader(_RangeIterable(4), batch_size=2, shuffle=True) + + def test_reproducible_across_runs_distinct_across_epochs(self): + loader = dp.DataLoader(_SeededIterable(6), batch_size=3) + + loader.set_epoch(0) + run_a = torch.cat([b["x"].reshape(-1) for b in loader]) + loader.set_epoch(0) + run_b = torch.cat([b["x"].reshape(-1) for b in loader]) + loader.set_epoch(1) + run_c = torch.cat([b["x"].reshape(-1) for b in loader]) + + assert torch.equal(run_a, run_b) # same epoch -> identical + assert not torch.equal(run_a, run_c) # different epoch -> distinct + + def test_runs_on_main_thread_no_worker_pool(self): + dataset = _ThreadRecordingIterable(4) + + names_before = {t.name for t in threading.enumerate()} + loader = dp.DataLoader(dataset, batch_size=2) + _ = list(loader) + names_after = {t.name for t in threading.enumerate()} + + # Generation happened on the main thread only. + assert dataset.threads, "generator did not run" + assert all(t is threading.main_thread() for t in dataset.threads) + # No prefetch worker pool / pump thread was spawned for this path. + new_threads = names_after - names_before + assert not any( + n.startswith("datapipe_prefetch") or n == "datapipe_pump" + for n in new_threads + ) + + +# ============================================================================ +# CUDA-guarded -- stream-bound consume and Warp-on-non-default-stream +# ============================================================================ + + +class TestStreamBoundConsume: + """Preprocessing on an assigned stream (the default-stream workaround + is gone, so transforms run on the side stream).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_submit_consume_on_side_stream(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset( + reader, + device="cuda:0", + transforms=dp.SubsamplePoints( + input_keys=["positions", "features"], n_points=50 + ), + ) + try: + stream = torch.cuda.Stream() + handle = dataset.submit(0, stream=stream) + data, _metadata = dataset.consume(handle) + torch.cuda.synchronize() + assert data["positions"].device.type == "cuda" + assert data["positions"].shape[0] == 50 + finally: + dataset.close() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_dataloader_streams_match_synchronous(self, numpy_data_dir): + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + + ref = dp.Dataset(reader, device="cuda:0") + ref_loader = dp.DataLoader(ref, batch_size=2, shuffle=False, prefetch_factor=0) + expected = [b["positions"].sum().item() for b in ref_loader] + + reader2 = dp.NumpyReader(numpy_data_dir, pin_memory=True) + streamed = dp.Dataset(reader2, device="cuda:0") + loader = dp.DataLoader( + streamed, + batch_size=2, + shuffle=False, + prefetch_factor=2, + num_streams=4, + use_streams=True, + ) + got = [b["positions"].sum().item() for b in loader] + torch.cuda.synchronize() + assert got == pytest.approx(expected, rel=1e-5) + + +class TestWarpIterableOnStream: + """Warp launches on a non-default stream from the main thread are safe.""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_darcy_online_simulation_through_iterable_loader(self): + from physicsnemo.datapipes.benchmarks.darcy import Darcy2D + + class _DarcyIterable(IterableDatasetBase): + yields_batches = True + + def __init__(self, num_batches): + self._sim = Darcy2D(resolution=32, batch_size=2, device="cuda") + self._num_batches = num_batches + + def __iter__(self): + sim_iter = iter(self._sim) + for _ in range(self._num_batches): + yield next(sim_iter) + + loader = dp.DataLoader(_DarcyIterable(2), use_streams=True) + batches = list(loader) + torch.cuda.synchronize() # surfaces any illegal-memory-access + assert len(batches) == 2 + for batch in batches: + assert batch["permeability"].device.type == "cuda" + assert batch["darcy"].device.type == "cuda" + + +class TestWarpFunctionalTransformOnStreams: + """A Warp ``FunctionSpec`` transform driven through the multi-stream + preload path. The functional binds the current torch stream as a Warp + stream internally; the loader binds the same stream around the consume. + Both must reuse one cached wrapper -- otherwise the inner wrapper + unregisters the shared stream on teardown and the next launch faults + (illegal memory access).""" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + def test_functional_warp_transform_multi_stream(self, numpy_data_dir): + from physicsnemo.datapipes.transforms.base import Transform + from physicsnemo.nn.functional import signed_distance_field + + class _SDFTransform(Transform): + """Evaluate an SDF (a Warp functional) against the sample points.""" + + def __call__(self, data): + points = data["positions"].reshape(-1, 3).float() + vertices = torch.tensor( + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + device=points.device, + ) + faces = torch.tensor([[0, 1, 2]], device=points.device) + sdf, _ = signed_distance_field(vertices, faces, points) + data["sdf"] = sdf.reshape(-1, 1) + return data + + reader = dp.NumpyReader(numpy_data_dir, pin_memory=True) + dataset = dp.Dataset(reader, device="cuda:0", transforms=_SDFTransform()) + loader = dp.DataLoader( + dataset, + batch_size=1, + shuffle=False, + prefetch_factor=2, + num_streams=4, + use_streams=True, + ) + # Iterate well past num_streams so every stream is reused at least + # once; a churned registration faults on the second pass. + batches = list(loader) + torch.cuda.synchronize() # surfaces any illegal-memory-access + assert len(batches) == 10 + for batch in batches: + assert batch["sdf"].device.type == "cuda" diff --git a/test/datapipes/readers/test_numpy_consolidated.py b/test/datapipes/readers/test_numpy_consolidated.py index c564ef5534..63b4779b0c 100644 --- a/test/datapipes/readers/test_numpy_consolidated.py +++ b/test/datapipes/readers/test_numpy_consolidated.py @@ -239,6 +239,112 @@ def test_supports_coordinated_subsampling(self): assert reader_with_config._coordinated_subsampling_config is not None +class TestNumpyReaderSubsamplingRNG: + """Order- and thread-independent reproducibility of subsampling RNG. + + Reader subsampling derives its generator from ``(base_seed, epoch, + index)``, so a given sample's draw is identical regardless of read + order or worker thread (the threaded producer path), and varies + deterministically per index and per epoch. + """ + + def setup_method(self): + """Set up test fixtures.""" + self.temp_dir = tempfile.mkdtemp() + self.temp_path = Path(self.temp_dir) + + self.n_samples = 6 + self.n_points = 1000 + self.subsample_points = 100 + for i in range(self.n_samples): + coords = np.random.randn(self.n_points, 3).astype(np.float32) + features = np.random.randn(self.n_points, 4).astype(np.float32) + np.savez( + self.temp_path / f"sample_{i:03d}.npz", + coords=coords, + features=features, + ) + + def teardown_method(self): + """Clean up temporary files.""" + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def _make_reader(self) -> NumpyReader: + return NumpyReader( + self.temp_path, + file_pattern="sample_*.npz", + fields=["coords", "features"], + coordinated_subsampling={ + "n_points": self.subsample_points, + "target_keys": ["coords", "features"], + }, + ) + + def test_subsample_order_independent(self): + """Reading an index gives the same result regardless of read order.""" + reader = self._make_reader() + reader.set_generator(torch.Generator().manual_seed(123)) + + forward = {i: reader[i][0]["coords"] for i in range(self.n_samples)} + + reader.set_generator(torch.Generator().manual_seed(123)) + reverse = { + i: reader[i][0]["coords"] for i in reversed(range(self.n_samples)) + } + + for i in range(self.n_samples): + assert torch.equal(forward[i], reverse[i]) + + def test_subsample_thread_independent(self): + """Concurrent reads match single-threaded reads for each index.""" + from concurrent.futures import ThreadPoolExecutor + + reader = self._make_reader() + reader.set_generator(torch.Generator().manual_seed(7)) + + serial = {i: reader[i][0]["coords"] for i in range(self.n_samples)} + + indices = list(range(self.n_samples)) * 3 + with ThreadPoolExecutor(max_workers=4) as pool: + results = list(pool.map(lambda i: (i, reader[i][0]["coords"]), indices)) + + for i, coords in results: + assert torch.equal(coords, serial[i]) + + def test_subsample_distinct_across_indices(self): + """Different indices yield different subsamples under one seed.""" + reader = self._make_reader() + reader.set_generator(torch.Generator().manual_seed(0)) + + a = reader[0][0]["coords"] + b = reader[1][0]["coords"] + assert not torch.equal(a, b) + + def test_subsample_epoch_changes_output(self): + """Epoch is folded into the seed, changing the draw deterministically.""" + reader = self._make_reader() + + reader.set_generator(torch.Generator().manual_seed(0)) + reader.set_epoch(0) + e0 = reader[0][0]["coords"] + + reader.set_generator(torch.Generator().manual_seed(0)) + reader.set_epoch(1) + e1 = reader[0][0]["coords"] + assert not torch.equal(e0, e1) + + # Re-deriving epoch 0 reproduces the original draw. + reader.set_generator(torch.Generator().manual_seed(0)) + reader.set_epoch(0) + assert torch.equal(reader[0][0]["coords"], e0) + + def test_unseeded_does_not_raise(self): + """Without a seed, subsampling falls back to the global RNG.""" + reader = self._make_reader() + data, _ = reader[0] + assert data["coords"].shape == (self.subsample_points, 3) + + class TestNumpyReaderMemoryManagement: """Test memory management and cleanup.""" diff --git a/test/datapipes/transforms/test_mesh_augmentations.py b/test/datapipes/transforms/test_mesh_augmentations.py index 90ddedf89a..b04b55fa3f 100644 --- a/test/datapipes/transforms/test_mesh_augmentations.py +++ b/test/datapipes/transforms/test_mesh_augmentations.py @@ -732,14 +732,15 @@ def test_mesh_dataset_set_generator_distributes(self, tmp_path): master = torch.Generator().manual_seed(42) ds.set_generator(master) - # Reader and both transforms should have received generators - assert reader._subsample_generator is not None + # Reader (base seed) and both transforms should have received seeds + assert reader._seed_base is not None assert transforms[0]._generator is not None assert transforms[1]._generator is not None - # Generators should have different seeds (independent forks) + # Seeds should differ (independent forks): the reader's base seed + # plus each transform's generator seed. seeds = { - reader._subsample_generator.initial_seed(), + reader._seed_base, transforms[0]._generator.initial_seed(), transforms[1]._generator.initial_seed(), } diff --git a/test/nn/functional/geometry/test_sdf.py b/test/nn/functional/geometry/test_sdf.py index a2f3261f56..c7be9c402f 100644 --- a/test/nn/functional/geometry/test_sdf.py +++ b/test/nn/functional/geometry/test_sdf.py @@ -108,6 +108,47 @@ def test_signed_distance_field_index_layout_compatibility(device: str): torch.testing.assert_close(hit_flat, hit_faces) +@requires_module("warp") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_signed_distance_field_repeated_non_default_streams(): + """Repeated Warp SDF launches can retire resources without host sync.""" + device = torch.device("cuda") + mesh_vertices = _tetrahedron_vertices().to(device=device, dtype=torch.float32) + mesh_indices_flat = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + device=device, + dtype=torch.int32, + ) + query_points = torch.tensor( + [[1.0, 1.0, 1.0], [0.05, 0.1, 0.1]], + device=device, + dtype=torch.float32, + ) + streams = [torch.cuda.Stream(device=device) for _ in range(2)] + default_stream = torch.cuda.current_stream(device) + + outputs = [] + for i in range(12): + stream = streams[i % len(streams)] + with torch.cuda.stream(stream): + sdf_out, _hit_points = signed_distance_field( + mesh_vertices, + mesh_indices_flat, + query_points, + use_sign_winding_number=False, + ) + outputs.append((stream, sdf_out)) + + for stream, sdf_out in outputs: + default_stream.wait_stream(stream) + torch.testing.assert_close( + sdf_out, + torch.tensor([1.1547, -0.05], device=device), + atol=1e-6, + rtol=1e-6, + ) + + # Validate benchmark input generation contract for SDF. @requires_module("warp") def test_signed_distance_field_make_inputs_forward(device: str):