Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4f5a707
deduplication
peterdsharpe Jun 25, 2026
66b01b1
cuts the warning
peterdsharpe Jun 25, 2026
ad3d58b
Metrics: removes total_channels and field_dim
peterdsharpe Jun 25, 2026
3291a11
train.py - omegaconf tidying
peterdsharpe Jun 25, 2026
06cb44c
In nondim, removes `inverse_tensor`, which is already superseded by e…
peterdsharpe Jun 25, 2026
193a485
grammar
peterdsharpe Jun 25, 2026
c103fd8
Merge branch 'main' into psharpe/unified-recipe-minor-cleanups
peterdsharpe Jun 25, 2026
c9a1942
Adds caveat on DoMINO
peterdsharpe Jun 25, 2026
a4e0d46
Merge branch 'psharpe/unified-recipe-minor-cleanups' of https://githu…
peterdsharpe Jun 25, 2026
ee24138
Refactor metrics handling in training process
peterdsharpe Jun 25, 2026
6bb1c49
Implement dedicated un-augmented validation dataset for manifest mode
peterdsharpe Jun 25, 2026
7d2f433
Merge branch 'main' into psharpe/unified-recipe-reduce-fix
peterdsharpe Jun 25, 2026
14b8406
Merge branch 'main' into psharpe/unified-recipe-manifest-validation-fix
peterdsharpe Jun 25, 2026
b34736b
better docstring
peterdsharpe Jun 25, 2026
9224727
Synchronize nomenclature between train.py and infer.py
peterdsharpe Jun 25, 2026
4a213b0
Standardizes nomenclature for JSONL logging around clean Phase defini…
peterdsharpe Jun 25, 2026
d09b0f4
Merge branch 'main' into psharpe/unified-recipe-reduce-fix
peterdsharpe Jun 25, 2026
d028742
last name synchronizations between infer and train
peterdsharpe Jun 25, 2026
0e1d8d2
Merge branch 'psharpe/unified-recipe-reduce-fix' into psharpe/unified…
peterdsharpe Jun 25, 2026
7510f88
Merge branch 'main' into psharpe/unified-recipe-reduce-fix
peterdsharpe Jun 26, 2026
3133581
Merge branch 'main' into psharpe/unified-recipe-reduce-fix
peterdsharpe Jun 26, 2026
4fe1bbe
Refactors reduce_and_average_epoch into reduce_and_average, so that a…
peterdsharpe Jun 26, 2026
1c1e864
trim comment a hair
peterdsharpe Jun 26, 2026
bf71c2f
Merge branch 'psharpe/unified-recipe-reduce-fix' into psharpe/unified…
peterdsharpe Jun 26, 2026
3a5c57c
Merge branch 'main' into psharpe/unified-recipe-manifest-validation-fix
peterdsharpe Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,41 @@ def _resolve_manifest_indices_from_spec(
return train_indices, val_indices


def _build_manifest_val_dataset(
ds_yaml: DictConfig,
*,
augment: bool,
device: str | torch.device | None,
num_workers: int,
pin_memory: bool,
) -> MeshDataset | None:
"""Build a dedicated un-augmented validation dataset for manifest mode.

Manifest mode shares a single reader across the train / val splits
(the :class:`ManifestSampler` pair carves out per-split indices). That
means validation would otherwise run through the *augmented* transform
chain whenever ``augment`` is enabled -- unlike directory mode, which
always builds its val dataset with ``augment=False``.

To restore parity, when *augment* is ``True`` this returns a separate
dataset built with ``augment=False`` over the same ``train_datadir``.
Its reader globs the same sorted paths, so manifest indices resolved
against the train reader address the same samples here. When *augment*
is ``False`` the train and val transform chains are identical, so this
returns ``None`` and the caller lets validation share the train dataset
(avoiding a redundant second reader).
"""
if not augment:
return None
return build_dataset(
ds_yaml,
augment=False,
device=device,
num_workers=num_workers,
pin_memory=pin_memory,
)


def _build_collate(
cfg: DictConfig, target_config: dict[str, FieldType]
) -> Callable[[list[tuple[Any, Any]]], dict[str, Any]]:
Expand Down Expand Up @@ -710,7 +745,10 @@ def build_dataloaders(
``cfg.train_split`` / ``cfg.val_split`` keys select which subsets to
use; one reader covers the full directory and
:class:`ManifestSampler` restricts each loader to the matching
indices.
indices. Augmentations are training-only: when ``cfg.augment`` is set,
validation uses a separate un-augmented dataset over the same
directory (mirroring directory mode); otherwise it shares the train
dataset.

NOTE (limitation): only ONE chosen dataset may carry a manifest
today. If both ``cfg.dataset`` and an entry in ``cfg.extra_datasets``
Expand Down Expand Up @@ -764,6 +802,7 @@ def build_dataloaders(
val_datasets: list = []
manifest_train_indices: list[int] | None = None
manifest_val_indices: list[int] | None = None
manifest_val_dataset: MeshDataset | None = None
using_manifests = False
first_targets: dict[str, str] | None = None
first_metrics: list[str] | None = None
Expand Down Expand Up @@ -848,11 +887,24 @@ def build_dataloaders(
pin_memory=pin_memory,
)
train_datasets.append(dataset)
### NOTE: this overwrites any prior dataset's indices; see the
### docstring's multi-dataset limitation note.
### NOTE: this overwrites any prior manifest dataset's indices
### (and the val dataset below); see the docstring's
### multi-dataset limitation note.
manifest_train_indices, manifest_val_indices = (
_resolve_manifest_indices_from_spec(dataset.reader, manifest_spec)
)
### Augmentations are training-only: when enabled, give
### validation its own un-augmented dataset over the same
### directory so eval is never augmented (matching directory
### mode). Stays None when augment is off, so val shares the
### train dataset.
manifest_val_dataset = _build_manifest_val_dataset(
ds_yaml,
augment=augment,
device=device,
num_workers=num_workers,
pin_memory=pin_memory,
)
continue

### Directory mode: separate readers / datasets per split.
Expand Down Expand Up @@ -900,9 +952,15 @@ def build_dataloaders(
train_dataset = _combine_datasets(train_datasets)

if using_manifests:
### Manifest mode: train and val share one underlying dataset;
### the samplers carve out the per-split index sets.
val_dataset = train_dataset
### Manifest mode: train and val share one underlying reader; the
### samplers carve out the per-split index sets. When augmentations
### are enabled, validation uses a dedicated un-augmented dataset
### (built in the loop above) so eval is never augmented -- matching
### directory mode; otherwise the chains are identical and val
### shares the train dataset.
val_dataset = (
manifest_val_dataset if manifest_val_dataset is not None else train_dataset
)
train_sampler, val_sampler = _build_manifest_samplers(
manifest_train_indices,
manifest_val_indices,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from typing import Literal, TypeAlias, cast

import torch
import torch.distributed as dist
from jaxtyping import Float
from omegaconf import DictConfig, OmegaConf
from tensordict import TensorDict
Expand Down Expand Up @@ -218,22 +217,6 @@ def _metrics_for_tensor(
for m in self.metric_names
}

def _all_reduce(self, metrics: TensorDict) -> TensorDict:
if self.process_group is None:
return metrics
world_size = dist.get_world_size(self.process_group)
if world_size == 1:
return metrics
### Single all_reduce over a stacked 1-D tensor (vs. one comm
### per leaf) -- one collective beats N regardless of the
### container type. Rebuild a TensorDict from the reduced
### stack so callers see the same per-key access pattern.
keys = list(metrics.keys())
stacked = torch.stack([metrics[k] for k in keys])
dist.all_reduce(stacked, group=self.process_group)
stacked = stacked / world_size
return TensorDict({k: stacked[i] for i, k in enumerate(keys)}, batch_size=[])

def __call__(
self,
pred: TensorDict,
Expand Down Expand Up @@ -284,7 +267,7 @@ def __call__(
t_mag = torch.linalg.vector_norm(t, dim=-1)
out.update(self._metrics_for_tensor(p_mag, t_mag, (name,)))

return self._all_reduce(TensorDict(out, batch_size=[]))
return TensorDict(out, batch_size=[])

def __repr__(self) -> str:
fields_str = ", ".join(f"{n}:{t}" for n, t in self.target_config.items())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import hydra
import torch
import torch.distributed as dist
from datasets import build_dataloaders
from loss import LossCalculator
from metrics import MetricCalculator, resolve_metrics
Expand Down Expand Up @@ -120,6 +121,53 @@ def _to_float_dicts(
)


def _reduce_and_average_epoch(
total_loss: float,
losses_td: TensorDict | None,
metrics_td: TensorDict | None,
n_local: int,
*,
device: torch.device | str,
) -> tuple[float, dict[str, float], dict[str, float]]:
"""Average epoch loss/metric *sums* over the GLOBAL sample count.

The per-rank loop accumulates *sums* of per-step (per-sample, since
``batch_size == 1``) losses and metrics. This packs ``total_loss``,
the local sample count, and every loss/metric leaf into one float32
tensor, all-reduces it once (SUM) when running distributed, then
divides by the reduced count. One collective + one D2H. Correct for
uneven per-rank shards (``global_sum / global_count``) and
deadlock-free (invoked once after the per-rank loops finish, not per
step). The single-process path is identical to the previous
``sum / n_local`` averaging.
"""
if losses_td is None or metrics_td is None:
return total_loss / max(n_local, 1), {}, {}
Comment thread
peterdsharpe marked this conversation as resolved.
Outdated
loss_keys = cast(list[str], list(losses_td.keys()))
metric_keys = cast(list[str], list(metrics_td.keys()))
leaves = cast(
list[torch.Tensor], list(losses_td.values()) + list(metrics_td.values())
)
### [total_loss, n_local, *loss_sums, *metric_sums] -> one collective.
packed = torch.cat(
[
torch.tensor([total_loss, float(n_local)], device=device),
torch.stack(leaves).float().to(device),
]
)
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
dist.all_reduce(packed)
reduced_loss, reduced_n, *leaf_sums = packed.tolist()
n = max(reduced_n, 1.0)
n_loss = len(loss_keys)
averaged = [v / n for v in leaf_sums]
return (
reduced_loss / n,
dict(zip(loss_keys, averaged[:n_loss])),
dict(zip(metric_keys, averaged[n_loss:])),
)


def _log_to_tensorboard(
writer: SummaryWriter | None,
values: Mapping[str, float | torch.Tensor],
Expand Down Expand Up @@ -401,8 +449,16 @@ def _run_epoch(

epoch_dt = time.perf_counter() - epoch_t0
n = max(n_batches, 1)
avg_loss = total_loss / n
avg_losses, avg_metrics = _to_float_dicts(total_losses_td, total_metrics_td, n=n)
### Reduce the epoch sums + sample count across ranks once, so logged
### loss/metrics are the GLOBAL averages (not rank-0's shard) under
### DDP. `n` above is kept local for the per-rank step-rate line below.
avg_loss, avg_losses, avg_metrics = _reduce_and_average_epoch(
total_loss,
total_losses_td,
total_metrics_td,
n_batches,
device=dist_manager.device,
)

logger.info(
f"Epoch {epoch} {mode} done in {epoch_dt:.1f}s "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
from types import SimpleNamespace

import pytest
from omegaconf import OmegaConf
from omegaconf import DictConfig, OmegaConf

from datasets import (
ManifestSampler,
_build_manifest_val_dataset,
build_dataset,
load_manifest,
resolve_manifest_indices,
resolve_manifest_spec,
Expand Down Expand Up @@ -388,3 +390,94 @@ def test_directory_mode_unaffected_by_loud_failure(self, tmp_path: Path):
ds_yaml = OmegaConf.create({"train_datadir": str(tmp_path)})
ds_block = OmegaConf.create({})
assert resolve_manifest_spec(ds_yaml, ds_block) is None


### ---------------------------------------------------------------------------
### _build_manifest_val_dataset
### ---------------------------------------------------------------------------


class TestManifestValDataset:
"""Tests for :func:`datasets._build_manifest_val_dataset`.

Manifest mode shares one reader across the train / val splits, so
validation must not inherit the train augmentations. This mirrors
directory mode, which always builds its val dataset with
``augment=False`` -- the asymmetry these tests lock down.
"""

@staticmethod
def _augmented_ds_yaml(datadir: Path) -> DictConfig:
"""Minimal manifest-style volume dataset YAML carrying augmentations.

Trimmed to what the dataset builder inspects: the reader globs
paths lazily (no file is opened at construction), so the directory
only needs placeholder files, and the transform chain just needs a
``CenterMesh`` anchor plus the augmentations that get inserted
after it.
"""
return OmegaConf.create(
{
"pipeline": {
"reader": {
"_target_": "${dp:DomainMeshReader}",
"path": str(datadir),
"pattern": "run_*/domain_*.pdmsh",
},
"augmentations": [
{"_target_": "${dp:RandomRotateMesh}", "axes": ["z"]},
{"_target_": "${dp:RandomTranslateMesh}"},
],
"transforms": [
{"_target_": "${dp:CenterMesh}"},
],
},
"targets": {"pressure": "scalar"},
}
)

@staticmethod
def _make_datadir(tmp_path: Path) -> Path:
"""Create placeholder runs the reader can glob (it never opens them)."""
for i in range(2):
run = tmp_path / f"run_{i}"
run.mkdir()
(run / f"domain_{i}.pdmsh").write_bytes(b"")
return tmp_path

def test_augment_off_returns_none(self, tmp_path: Path):
"""``augment=False`` -> val shares the train dataset (None sentinel)."""
ds_yaml = self._augmented_ds_yaml(self._make_datadir(tmp_path))
assert (
_build_manifest_val_dataset(
ds_yaml,
augment=False,
device=None,
num_workers=1,
pin_memory=False,
)
is None
)

def test_augment_on_returns_unaugmented_dataset(self, tmp_path: Path):
"""``augment=True`` -> a separate dataset whose chain has no augmentations."""
ds_yaml = self._augmented_ds_yaml(self._make_datadir(tmp_path))

### Guard against a vacuous assertion: the train dataset must
### actually carry a stochastic augmentation for the val check to
### mean anything.
train_ds = build_dataset(
ds_yaml, augment=True, device=None, num_workers=1, pin_memory=False
)
assert any(getattr(t, "stochastic", False) for t in train_ds.transforms)

val_ds = _build_manifest_val_dataset(
ds_yaml, augment=True, device=None, num_workers=1, pin_memory=False
)
assert val_ds is not None
### A distinct object (own reader), not the train dataset.
assert val_ds is not train_ds
### No stochastic (augmentation) transforms survive on the val chain.
assert not any(getattr(t, "stochastic", False) for t in val_ds.transforms)
### ...but the deterministic CenterMesh transform is still present.
assert any(type(t).__name__ == "CenterMesh" for t in val_ds.transforms)
Loading
Loading