feat: add CosmoDownscaling diagnostic model (ERA5 -> COSMO-REA6/REA2)#939
feat: add CosmoDownscaling diagnostic model (ERA5 -> COSMO-REA6/REA2)#939gertln wants to merge 3 commits into
Conversation
Add CosmoDownscaling, a diagnostic model that downscales ERA5 to high-resolution COSMO-REA regional reanalysis -- COSMO-REA6 (~6 km) and COSMO-REA2 (~2.2 km) -- each with a deterministic regression (mean) and a generative EDM/Karras diffusion mode, selected via `mode` on `load_model`. Supports movable sub-domains via `set_domain` and optional hub-height wind components for wind-energy use. - earth2studio/models/dx/cosmo_downscaling.py: the wrapper (physicsnemo DiT with axial RoPE + NATTEN attention, EDM diffusion sampler, channel transforms + physical-space constraints incl. a solar-zenith day/night gate, sub-domain cropping, derived hub-height wind components) - test/models/dx/test_cosmo_downscaling.py: construction + behavior tests with mocked nets, plus a GPU `--package` test against a real package - examples/03_downscaling/04_cosmo_rea_downscaling.py: end-to-end example (SFNO -> downscale full domain, sub-domain, rollout, diffusion ensemble vs regression mean, hub-height wind, COSMO-REA2) - register the model; add the `cosmo` extra; install/CHANGELOG/API-doc entries Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: gertl <gertl@nvidia.com>
Greptile SummaryAdds
|
| Filename | Overview |
|---|---|
| earth2studio/models/dx/cosmo_downscaling.py | New 1995-line CosmoDownscaling diagnostic model with regression and diffusion modes, rotated-pole output grid, set_domain sub-region slicing, hub-height wind derivation, and full package-loading logic; well-engineered but shares mutable DiT state across sub-domain instances and hardcodes 0.25° ERA5 spacing in set_domain |
| test/models/dx/test_cosmo_downscaling.py | Comprehensive 1324-line CPU test suite covering coordinate contracts, constraint parsing, domain slicing, halo/patch-snap edge cases, hub-wind derivation, diffusion schedule, and a constructor-signature guard that forces new args to be threaded through set_domain |
| examples/03_downscaling/04_cosmo_rea_downscaling.py | End-to-end example covering SFNO to CosmoDownscaling rollout, sub-domain, diffusion ensemble vs regression mean, hub-height wind, and COSMO-REA2; correctly listed in expected_failing_examples since weights are not yet hosted |
| pyproject.toml | Adds cosmo extra (einops, natten, nvidia-physicsnemo>=2.0, nvtx) and includes it in the all extra; all permissive licenses |
| earth2studio/models/dx/init.py | Adds CosmoDownscaling import and all entry |
| docs/conf.py | Adds the COSMO-REA example to expected_failing_examples with a clear comment; appropriate short-term measure until weights are hosted |
| test/conftest.py | Adds the new test file to _TEST_DEPENDENCIES under the cosmo extra, consistent with other model test registrations |
Reviews (1): Last reviewed commit: "feat: add CosmoDownscaling diagnostic mo..." | Re-trigger Greptile
| sig_n = t_next.to(torch.float32).repeat(x.shape[0]) | ||
| d_prime = ( | ||
| x_next - net(x_next.float(), sig_n, condition=cond).double() | ||
| ) / t_next | ||
| x_next = x + (t_next - t_cur) * 0.5 * (d_cur + d_prime) | ||
| x = x_next | ||
| return x.float() | ||
|
|
||
| @torch.inference_mode() | ||
| def _forward( | ||
| self, | ||
| era5: torch.Tensor, | ||
| valid_time: datetime, | ||
| lat2d: torch.Tensor, | ||
| lon2d: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| """Forward for one ERA5 frame -> [number_of_samples, C_out, H, W] physical | ||
| (C_out = the trained outputs, plus the derived hub-height wind components | ||
| when ``hub_heights`` is set). |
There was a problem hiding this comment.
Shared mutable state mutated per-forward — unsafe for concurrent use
_rebind_latent writes directly into the attn_kwargs_forward dict and the detokenizer's h_patches/w_patches on the live DiT object. Because set_domain propagates the same network reference (cr.regression_model is dx.regression_model), two instances sharing one network will overwrite each other's latent-grid bindings if called concurrently (e.g., two threads, two asyncio tasks, or a DataLoader with num_workers > 0).
Sequential single-threaded calls are safe (the rebind happens before each forward), but nothing in the API surface or docstrings warns users about this constraint. Adding a note to set_domain and load_model — or copying the latent-state on sub-domain creation — would prevent silent wrong-grid outputs under concurrent use.
There was a problem hiding this comment.
The network is shared across set_domain sub-domains and each forward sets its grid on it in place. So sub-domains are fine to run sequentially but two running at the same time would lead to conflicts. I documented this in set_domain and _rebind_latent with the advice to run concurrent domains in separate processes. (d498ccf)
| "qk_norm": rp.get("qk_norm", False), | ||
| } | ||
| backend = rp.get("natten_backend", "cutlass-fna") | ||
| if backend is not None: | ||
| attn["na2d_kwargs"] = {"backend": backend} | ||
| conv_layers = rp.get("detok_conv_layers", 0) | ||
| return DiT( | ||
| input_size=list(rp.get("img_resolution", [256, 256])), | ||
| in_channels=cond_channels, | ||
| out_channels=out_channels, |
There was a problem hiding this comment.
Hardcoded 0.25° ERA5 resolution in
set_domain
_reg always constructs the sub-domain's ERA5 input grid at 0.25° steps regardless of the parent model's actual input resolution. If the parent package ships a coarser or finer ERA5 grid (e.g., 0.5°), the sub-domain's lat_input_grid / lon_input_grid will not match the user's data, and output_coords will reject the input via _is_native_input with the relatively opaque message "requires the native input grid". The step size should be derived from the parent's lat_input_numpy spacing rather than hardcoded.
There was a problem hiding this comment.
Fixed in d498ccf. set_domain now reads the input spacing from the package's own grid instead of assuming 0.25.
| attn["na2d_kwargs"] = {"backend": backend} | ||
| dit = DiT( | ||
| input_size=list(dp.get("img_resolution", [256, 256])), | ||
| in_channels=out_channels + cond_channels, | ||
| out_channels=out_channels, | ||
| patch_size=dp["patch_size"], |
There was a problem hiding this comment.
Extended-grid state not propagated to sub-domains — chained
set_domain silently loses OOD-margin access
After set_domain returns a sub-domain instance, its _ext_lat_numpy, _ext_lon_numpy, and _ext_static_numpy remain None (the constructor default). If a user calls set_domain again on the returned sub-domain, the call is silently restricted to the sub-domain's lat_output_numpy only — the extended margin invariants are no longer reachable. There is no warning or documentation note about this; users who want two differently-shaped sub-regions that both access the OOD margin must call set_domain from the original loaded model, not chain it.
There was a problem hiding this comment.
By design, now documented (d498ccf). A sub-domain only keeps its own cropped grid, so calling set_domain again on it can't go back to the full footprint.
| regression_model: torch.nn.Module | None, | ||
| diffusion_model: torch.nn.Module | None, | ||
| resolution: str, | ||
| mode: str, |
There was a problem hiding this comment.
interp module shadowed by a local variable inside interp_levels_to_height
The module-level from earth2studio.utils import ... interp is shadowed by the local interp tensor assignment inside the loop. This is not a bug today (the module is not called inside this function), but any future edit adding a call to interp.latlon_interpolation_regular within the loop would silently call the tensor instead of the module. Renaming the local to val avoids the ambiguity.
| regression_model: torch.nn.Module | None, | |
| diffusion_model: torch.nn.Module | None, | |
| resolution: str, | |
| mode: str, | |
| val = ( | |
| values[..., j, :, :] + (values[..., j + 1, :, :] - values[..., j, :, :]) * w | |
| ) | |
| out = torch.where((t >= lo) & (t < hi), val, out) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
renamed the local to interp_val so it no longer shadows the interp import (d498ccf)
…currency - set_domain: derive the ERA5 input-axis spacing from the package's own validated grid instead of a hardcoded 0.25 deg (behavior-preserving for the shipped 0.25 deg package; supports a non-0.25 deg package). - interp_levels_to_height: rename local `interp` -> `interp_val` so it no longer shadows the module-level `interp` import. - set_domain / _rebind_latent: document that the network is shared by reference and rebinds its latent-grid state in place, so sub-domains are safe sequentially but not concurrently; and that a second set_domain on a returned sub-domain cannot reach back to the full/extended footprint. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
| "log_level": {"backreference_missing": "warning", "gallery_examples": "debug"}, | ||
| # COSMO-REA weights are not yet publicly hosted, so this example raises without | ||
| # a local $COSMO_REA_PACKAGE. Whitelist it so the full gallery build does not | ||
| # fail; remove this entry once the package is hosted (load_default_package). |
There was a problem hiding this comment.
Add a TODO so we don't forget this
| CorrDiffTaiwan, | ||
| ) | ||
| from earth2studio.models.dx.corrdiff_cmip6 import CorrDiffCMIP6 | ||
| from earth2studio.models.dx.cosmo_downscaling import CosmoDownscaling # noqa |
There was a problem hiding this comment.
Remove it - not needed.
| cosmo = [ | ||
| "einops>=0.8.1", | ||
| "natten ; python_version < '3.14'", | ||
| "nvidia-physicsnemo>=2.0", |
There was a problem hiding this comment.
I believe we need a stricter constraint on this if we need RoPE stuff from what I recently merged
There was a problem hiding this comment.
Pinned nvidia-physicsnemo>=2.2.0. natten2d_rope is only on main (2.2.0a0) and the latest release (2.1.1) doesn't have it. So it is intentionally uninstallable until 2.2.0 ships. (Also removed cosmo from the all extra because >=2.2.0 conflicts with da-healda <=2.1.1 )
| modes are DiT-RoPE (a diffusion transformer with rotary position embedding) and | ||
| resolution-agnostic (single forward at any grid size). | ||
|
|
||
| Domain handling (see the integration handoff + design memory): |
There was a problem hiding this comment.
Don't reference implementation notes or agent memory context
| native trained footprint, or reach into the extended invariant margin -- the | ||
| latter proceeds with a one-time out-of-distribution warning; beyond the extended | ||
| extent raises. Both the mean and diffusion models are DiT-RoPE | ||
| (resolution-agnostic), so a sub-domain runs at any size in a single forward. |
There was a problem hiding this comment.
It's actually fixed resolution, and more domain/crop size agnostic, right?
There was a problem hiding this comment.
yes, it is fixed grid resolution and flexible domain size. Reworded "resolution-agnostic" to "crop-size agnostic" here and everywhere it appeared.
| cos_zenith_angle = None | ||
|
|
||
|
|
||
| # Fixed background channel layout (see cosmo_rea2_era5.py::_load_full_frame). |
There was a problem hiding this comment.
Don't reference missing/external files
There was a problem hiding this comment.
Probably will pop up in several other places -- have your agents make a pass to remove the parts of comments like this that reference stuff which won't end up getting merged
| "CLCT": "tcc", # total cloud cover; % -> 0-1 fraction (COSMO_OUTPUT_UNIT_SCALE) | ||
| "PS": "sp", | ||
| "PMSL": "msl", | ||
| } |
There was a problem hiding this comment.
Can we open and merge a COSMO-REA data source before this PR so that the model wrapper can simply reuse the COSMO lexicon? Adding a lexicon to the model wrapper file is not standard
There was a problem hiding this comment.
And, I don't see a particular for the model wrapper to be doing any lexicon interaction in the first place. Wrappers operate based on input_coords and output_coords, which should specify variables that exist in the E2S lexicon. If there are things in the model package using these other names, ideally we rename them to avoid having to deal with lexicon conversions in model code. But, if absolutely necessary, then we should pull from a COSMO lexicon as I suggested above
There was a problem hiding this comment.
Added a CosmoLexicon in earth2studio/lexicon/ and the wrapper uses it
| # z0-free (roughness cancels between the bracketing levels). | ||
|
|
||
|
|
||
| def interp_levels_to_height( |
There was a problem hiding this comment.
I appreciate the design note to clarify intent/choices. I agree given the dependence on elevation it makes sense to attach this to the model. Now that we have it in a commit I think we can drop the lengthy comment explaining rationale.
As for the method itself, is there a reason to make it a standalone, separate public helper rather than a method of the class (possible just an internal helper method too, i.e. _interp_levels_to_height)?
There was a problem hiding this comment.
Trimmed the comment and made it a private module helper function _interp_levels_to_height
|
|
||
| @check_optional_dependencies() | ||
| class CosmoDownscaling(torch.nn.Module, AutoModelMixin): | ||
| """COSMO-REA downscaling model: ERA5 -> high-resolution COSMO-REA. |
There was a problem hiding this comment.
This should eventually get filled out with details of the model, use-cases, intended usage, etc
There was a problem hiding this comment.
added details about the two modes, use cases and intended usage
| background channel order. | ||
| channel_transforms : dict | None | ||
| Per-output-channel nonlinear transform spec (from the training | ||
| ``channel_transforms`` zarr attr), used to invert after de-normalizing. |
There was a problem hiding this comment.
training zarr attrs shouldn't be referenced
There was a problem hiding this comment.
removed this reference
| solar gate), applied in physical space in postprocess for both modes. | ||
| number_of_samples : int | ||
| Number of samples (diffusion); the ``sample`` dim is kept (size 1) for | ||
| ``mode="mean"`` so both modes share an output contract. |
There was a problem hiding this comment.
Why is this a constructor arg? Shouldn't it be dynamic?
There was a problem hiding this comment.
The constructor arg just provides the default but it is dynamic and can be set between calls. This matches CorrDiff behavior and the diagnostic __call__(x, coords) signature has no per-call kwargs, so sample count is carried as instance state. I added a note to the docstring making this behavior more explicit.
| Diffusion noise-schedule bounds. Default to 0.002 / 800.0. | ||
| rho : float | ||
| Karras noise-schedule exponent. Defaults to 7.0. | ||
| solver : str |
There was a problem hiding this comment.
Do the rho and solver need to be configurable via constructor args? Can they just be fixed or do they change?
There was a problem hiding this comment.
They're kept configurable on purpose. We expose them because these settings (sigma_max, rho, solver, steps) can shape the sampled output distribution e.g. ensemble spread and how extremes are represented. So leaving them tunable is useful for inference experimentation and optimization. I've added a docstring note on this.
|
|
||
| Badges | ||
| ------ | ||
| region:eu class:ds product:wind product:precip product:temp product:atmos year:2024 gpu:80gb |
| regression_model: torch.nn.Module | None, | ||
| diffusion_model: torch.nn.Module | None, | ||
| resolution: str, | ||
| mode: str, |
There was a problem hiding this comment.
Use Literal[...] for string args with a fixed set of options
| ) / t_next | ||
| x_next = x + (t_next - t_cur) * 0.5 * (d_cur + d_prime) | ||
| x = x_next | ||
| return x.float() |
There was a problem hiding this comment.
Any reason to use a custom sampler over the standard physicsnemo.diffusion offerings?
There was a problem hiding this comment.
Switched _denoise to the standard EDMNoiseScheduler + get_denoiser + sample()
| ``halo`` (px, default 0 = off): run on a block expanded by ``halo`` real | ||
| cells per side and trim it off the output, keeping the returned bbox | ||
| interior clear of the DiT's boundary artifact (~32 px); clamps + warns at | ||
| the grid edge. Both models are DiT-RoPE (resolution-agnostic), so any size |
There was a problem hiding this comment.
Same resolution-agnostic comment applies here
| # Build the DiT with upstream physicsnemo's axial-2D-RoPE NATTEN backend | ||
| # (attention_backend="natten2d_rope" in _build_*_dit). RoPE adds no | ||
| # parameters, so a non-RoPE DiT would load the same weights 0/0 yet run a | ||
| # wrong (un-RoPE'd) forward -- building with the backend avoids that. |
There was a problem hiding this comment.
This approach of loading checkpoints with manual helpers extracting out of a zip file seems brittle and error-prone. If anything in upstream PNM changes about the internal zip format, it could break here, for example. I strongly prefer the standard DiT.from_checkpoint(...) approach.
There was a problem hiding this comment.
Switched to the standard from_checkpoint
| checkpoint keys are ``model.model.*`` plus ``sigma_data``.""" | ||
| from physicsnemo.diffusion.preconditioners import EDMPreconditioner | ||
| from physicsnemo.diffusion.utils import ConcatConditionWrapper | ||
| from physicsnemo.models.dit import DiT |
There was a problem hiding this comment.
Move imports to the top in a standard protected import like other models (see e.g. StormScope)
| "rope_theta": dp.get("rope_theta", 10000.0), | ||
| "qk_norm": dp.get("qk_norm", False), | ||
| } | ||
| backend = dp.get("natten_backend", "cutlass-fna") |
There was a problem hiding this comment.
Why hardcode the backend to cutlass-fna (or manually impose it)? This could block e.g. the Hopper bf16 kernels getting used on H100. Natten should be able to select the appropriate one.
There was a problem hiding this comment.
I am no longer forcing a specifc backend
| resolution: str = "rea6", | ||
| hub_heights: Sequence[float] | None = None, | ||
| hub_interp: str = "linear", | ||
| ) -> DiagnosticModel: |
There was a problem hiding this comment.
Use Literal[...] for all string args with a fixed set of options
|
|
||
| @classmethod | ||
| @check_optional_dependencies() | ||
| def load_model( |
There was a problem hiding this comment.
General comment. Add a minimal config.json in the top level of the model package directory (can just be super simple like this) if there aren't any files already named that (or config.yaml) . Then, add a line in the load_model to resolve/load that file from the package. That will enable automatic download tracking in HuggingFace (I know it's a weird system but it comes from HF rules, not ours)
There was a problem hiding this comment.
Added a minimal config.json at the package root
…, standard sampler - Add CosmoLexicon (earth2studio/lexicon/cosmo.py); the wrapper maps COSMO output names to Earth2Studio names via it, replacing the in-module mapping dicts. - Load checkpoints with the standard DiT.from_checkpoint / EDMPreconditioner.from_checkpoint; drop the manual .mdlus zip extraction, the hand-built DiT, and the forced NATTEN backend (native modules let NATTEN auto-select the kernel per GPU). - Replace the inline EDM/Karras sampler loop with physicsnemo's EDMNoiseScheduler + sample(). - Literal[...] for fixed-option string args; move physicsnemo/json/xarray imports to the top (physicsnemo in the protected optional-dependency block); resolve a package-root config.json in load_model for HuggingFace download tracking. - Pin nvidia-physicsnemo>=2.2.0 (RoPE); drop cosmo from the `all` extra (it conflicts with da-healda's <=2.1.1 cap until 2.2.0 ships). - Docstrings/comments: "crop-size agnostic at the fixed resolution" wording, remove internal/dangling references, expand the class docstring, TODO(cosmo) markers, badge year 2026. - Tests: add CosmoLexicon tests + an euler-vs-heun solver test; route the output-name test through CosmoLexicon; drop the obsolete DiT-builder test. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Signed-off-by: gertl <gertl@nvidia.com>

Description
Add
CosmoDownscaling, a diagnostic model that downscales global ERA5 to high-resolution COSMO-REA regional reanalysis over Europe — COSMO-REA6 (~6 km) and COSMO-REA2 (~2.2 km). Its input is an ERA5 state, so it can downscale an ERA5 analysis directly or run behind a global forecast model (e.g. SFNO → CosmoDownscaling). A runnable example is included.A single class loads one of four checkpoints via two
load_modelselectors:mode="mean"— a deterministic regression that predicts the conditional mean (the expected field,E[y | x]): a single, smooth field in one forward pass. Fast and cheap — a good first high-resolution look.mode="diffusion"— a generative diffusion model that samples the conditional distributionp(y | x): an ensemble of realizations that also captures the spread the mean cannot represent (~N forward passes; seeded for reproducibility).resolution="rea6"|"rea2"— selects the per-resolution weights and grid from one package.Both modes use a PhysicsNeMo DiT (diffusion transformer) at a fixed output resolution; the network is crop-size agnostic, so
set_domain(...)returns a new instance for lat/lon sub-region. From a small box up to the full grid. Optional hub-height wind adds derivedu{H}m/v{H}mcomponents (wind speed via the stockDerivedWS)Blocked by:
Closes: #940
Diagnostic details
modeinput_coords())u{H}m/v{H}msampledim always present: 1 formean,number_of_samplesfordiffusion(per-sample seedseed + i)Dependencies added
New
cosmomodel extra (pip install earth2studio[cosmo]/uv add earth2studio --extra cosmo), also added to theallextra.nvidia-physicsnemo>=2.0natten; python_version < '3.14'einops>=0.8.1nvtx>=0.2.11All permissive (Apache-2.0 / MIT); no non-permissive licenses introduced.
Validation
test/models/dx/test_cosmo_downscaling.py): 51 passed, 5 skipped. The 5 skips are intentionally resource-gated tests — the 4 real-weight GPU--packagecases and 1 real-DiT-builder smoke test — which require a GPU, a built package, and/or physicsnemo's not-yet-released RoPE op, so they cannot run in the default CI lane (the standard@pytest.mark.packagepattern used by every model).{rea6, rea2} × {mean, diffusion}combos load 0 missing / 0 unexpected keys and run a fulldx(x, coords)forward (output shape, finiteness, and every metadata min/max bound asserted in both modes), plus the DiT-builder smoke test.black+ruffclean; SPDX headers present.examples/03_downscaling/04_cosmo_rea_downscaling.pyverified end-to-end on real weights (full domain, sub-domain, rollout, diffusion ensemble vs regression mean, hub-height wind, COSMO-REA2).The CPU tests are construction-driven and GPU-free: they pin the wrapper's contract — coords/handshake + lexicon output names, channel-transform round-trips, physical-space min/max bounds and a mode-identical solar-zenith day/night gate,
set_domainsub-domain slicing (incl. NATTEN-minimum and patch-snap edges), the diffusion sampler's schedule termination + ensemble seeding, and the hub-wind derivation +DerivedWSnaming handshake.Notes for reviewers
Two known, bounded items, each guarded so CI stays green today:
natten2d_rope, merged on physicsnemomainbut not in a tagged release. Thecosmoextra therefore pinsnvidia-physicsnemo>=2.0, and the real-DiT-builder +--packagetests are skip-guarded when the op is unavailable (they skip, not fail). A follow-up bumps the pin and drops the guards once a release includes it.load_default_packageraisesNotImplementedError; the example is inexpected_failing_examplesindocs/conf.pyso the gallery build stays green. FlippingDEFAULT_PACKAGE_URIon and removing the whitelist entry is the only change needed once hosting lands. (The model was validated against real weights locally.)Checklist