diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/datasets.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/datasets.py index 4530167e50..26be83a589 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/datasets.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/datasets.py @@ -581,6 +581,41 @@ def _resolve_manifest_indices_from_spec( return train_indices, val_indices +def _build_manifest_val_dataset( + ds_yaml: DictConfig, + *, + augment: bool, + device: str | torch.device | None, + num_workers: int, + pin_memory: bool, +) -> MeshDataset | None: + """Build a dedicated un-augmented validation dataset for manifest mode. + + Manifest mode shares a single reader across the train / val splits + (the :class:`ManifestSampler` pair carves out per-split indices). That + means validation would otherwise run through the *augmented* transform + chain whenever ``augment`` is enabled -- unlike directory mode, which + always builds its val dataset with ``augment=False``. + + To restore parity, when *augment* is ``True`` this returns a separate + dataset built with ``augment=False`` over the same ``train_datadir``. + Its reader globs the same sorted paths, so manifest indices resolved + against the train reader address the same samples here. When *augment* + is ``False`` the train and val transform chains are identical, so this + returns ``None`` and the caller lets validation share the train dataset + (avoiding a redundant second reader). + """ + if not augment: + return None + return build_dataset( + ds_yaml, + augment=False, + device=device, + num_workers=num_workers, + pin_memory=pin_memory, + ) + + def _build_collate( cfg: DictConfig, target_config: dict[str, FieldType] ) -> Callable[[list[tuple[Any, Any]]], dict[str, Any]]: @@ -710,7 +745,10 @@ def build_dataloaders( ``cfg.train_split`` / ``cfg.val_split`` keys select which subsets to use; one reader covers the full directory and :class:`ManifestSampler` restricts each loader to the matching - indices. + indices. Augmentations are training-only: when ``cfg.augment`` is set, + validation uses a separate un-augmented dataset over the same + directory (mirroring directory mode); otherwise it shares the train + dataset. NOTE (limitation): only ONE chosen dataset may carry a manifest today. If both ``cfg.dataset`` and an entry in ``cfg.extra_datasets`` @@ -764,6 +802,7 @@ def build_dataloaders( val_datasets: list = [] manifest_train_indices: list[int] | None = None manifest_val_indices: list[int] | None = None + manifest_val_dataset: MeshDataset | None = None using_manifests = False first_targets: dict[str, str] | None = None first_metrics: list[str] | None = None @@ -848,11 +887,24 @@ def build_dataloaders( pin_memory=pin_memory, ) train_datasets.append(dataset) - ### NOTE: this overwrites any prior dataset's indices; see the - ### docstring's multi-dataset limitation note. + ### NOTE: this overwrites any prior manifest dataset's indices + ### (and the val dataset below); see the docstring's + ### multi-dataset limitation note. manifest_train_indices, manifest_val_indices = ( _resolve_manifest_indices_from_spec(dataset.reader, manifest_spec) ) + ### Augmentations are training-only: when enabled, give + ### validation its own un-augmented dataset over the same + ### directory so eval is never augmented (matching directory + ### mode). Stays None when augment is off, so val shares the + ### train dataset. + manifest_val_dataset = _build_manifest_val_dataset( + ds_yaml, + augment=augment, + device=device, + num_workers=num_workers, + pin_memory=pin_memory, + ) continue ### Directory mode: separate readers / datasets per split. @@ -900,9 +952,15 @@ def build_dataloaders( train_dataset = _combine_datasets(train_datasets) if using_manifests: - ### Manifest mode: train and val share one underlying dataset; - ### the samplers carve out the per-split index sets. - val_dataset = train_dataset + ### Manifest mode: train and val share one underlying reader; the + ### samplers carve out the per-split index sets. When augmentations + ### are enabled, validation uses a dedicated un-augmented dataset + ### (built in the loop above) so eval is never augmented -- matching + ### directory mode; otherwise the chains are identical and val + ### shares the train dataset. + val_dataset = ( + manifest_val_dataset if manifest_val_dataset is not None else train_dataset + ) train_sampler, val_sampler = _build_manifest_samplers( manifest_train_indices, manifest_val_indices, diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/tests/test_manifest.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/tests/test_manifest.py index b2b0b5280d..b4ee4b2001 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/tests/test_manifest.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/tests/test_manifest.py @@ -31,10 +31,12 @@ from types import SimpleNamespace import pytest -from omegaconf import OmegaConf +from omegaconf import DictConfig, OmegaConf from datasets import ( ManifestSampler, + _build_manifest_val_dataset, + build_dataset, load_manifest, resolve_manifest_indices, resolve_manifest_spec, @@ -388,3 +390,94 @@ def test_directory_mode_unaffected_by_loud_failure(self, tmp_path: Path): ds_yaml = OmegaConf.create({"train_datadir": str(tmp_path)}) ds_block = OmegaConf.create({}) assert resolve_manifest_spec(ds_yaml, ds_block) is None + + +### --------------------------------------------------------------------------- +### _build_manifest_val_dataset +### --------------------------------------------------------------------------- + + +class TestManifestValDataset: + """Tests for :func:`datasets._build_manifest_val_dataset`. + + Manifest mode shares one reader across the train / val splits, so + validation must not inherit the train augmentations. This mirrors + directory mode, which always builds its val dataset with + ``augment=False`` -- the asymmetry these tests lock down. + """ + + @staticmethod + def _augmented_ds_yaml(datadir: Path) -> DictConfig: + """Minimal manifest-style volume dataset YAML carrying augmentations. + + Trimmed to what the dataset builder inspects: the reader globs + paths lazily (no file is opened at construction), so the directory + only needs placeholder files, and the transform chain just needs a + ``CenterMesh`` anchor plus the augmentations that get inserted + after it. + """ + return OmegaConf.create( + { + "pipeline": { + "reader": { + "_target_": "${dp:DomainMeshReader}", + "path": str(datadir), + "pattern": "run_*/domain_*.pdmsh", + }, + "augmentations": [ + {"_target_": "${dp:RandomRotateMesh}", "axes": ["z"]}, + {"_target_": "${dp:RandomTranslateMesh}"}, + ], + "transforms": [ + {"_target_": "${dp:CenterMesh}"}, + ], + }, + "targets": {"pressure": "scalar"}, + } + ) + + @staticmethod + def _make_datadir(tmp_path: Path) -> Path: + """Create placeholder runs the reader can glob (it never opens them).""" + for i in range(2): + run = tmp_path / f"run_{i}" + run.mkdir() + (run / f"domain_{i}.pdmsh").write_bytes(b"") + return tmp_path + + def test_augment_off_returns_none(self, tmp_path: Path): + """``augment=False`` -> val shares the train dataset (None sentinel).""" + ds_yaml = self._augmented_ds_yaml(self._make_datadir(tmp_path)) + assert ( + _build_manifest_val_dataset( + ds_yaml, + augment=False, + device=None, + num_workers=1, + pin_memory=False, + ) + is None + ) + + def test_augment_on_returns_unaugmented_dataset(self, tmp_path: Path): + """``augment=True`` -> a separate dataset whose chain has no augmentations.""" + ds_yaml = self._augmented_ds_yaml(self._make_datadir(tmp_path)) + + ### Guard against a vacuous assertion: the train dataset must + ### actually carry a stochastic augmentation for the val check to + ### mean anything. + train_ds = build_dataset( + ds_yaml, augment=True, device=None, num_workers=1, pin_memory=False + ) + assert any(getattr(t, "stochastic", False) for t in train_ds.transforms) + + val_ds = _build_manifest_val_dataset( + ds_yaml, augment=True, device=None, num_workers=1, pin_memory=False + ) + assert val_ds is not None + ### A distinct object (own reader), not the train dataset. + assert val_ds is not train_ds + ### No stochastic (augmentation) transforms survive on the val chain. + assert not any(getattr(t, "stochastic", False) for t in val_ds.transforms) + ### ...but the deterministic CenterMesh transform is still present. + assert any(type(t).__name__ == "CenterMesh" for t in val_ds.transforms)