Skip to content

feat: add CosmoDownscaling diagnostic model (ERA5 -> COSMO-REA6/REA2)#939

Open
gertln wants to merge 3 commits into
NVIDIA:mainfrom
gertln:cosmo-downscaling
Open

feat: add CosmoDownscaling diagnostic model (ERA5 -> COSMO-REA6/REA2)#939
gertln wants to merge 3 commits into
NVIDIA:mainfrom
gertln:cosmo-downscaling

Conversation

@gertln

@gertln gertln commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Description

Add CosmoDownscaling, a diagnostic model that downscales global ERA5 to high-resolution COSMO-REA regional reanalysis over Europe — COSMO-REA6 (~6 km) and COSMO-REA2 (~2.2 km). Its input is an ERA5 state, so it can downscale an ERA5 analysis directly or run behind a global forecast model (e.g. SFNO → CosmoDownscaling). A runnable example is included.

A single class loads one of four checkpoints via two load_model selectors:

  • mode="mean" — a deterministic regression that predicts the conditional mean (the expected field, E[y | x]): a single, smooth field in one forward pass. Fast and cheap — a good first high-resolution look.
  • mode="diffusion" — a generative diffusion model that samples the conditional distribution p(y | x): an ensemble of realizations that also captures the spread the mean cannot represent (~N forward passes; seeded for reproducibility).
  • resolution="rea6"|"rea2" — selects the per-resolution weights and grid from one package.

Both modes use a PhysicsNeMo DiT (diffusion transformer) at a fixed output resolution; the network is crop-size agnostic, so set_domain(...) returns a new instance for lat/lon sub-region. From a small box up to the full grid. Optional hub-height wind adds derived u{H}m/v{H}m components (wind speed via the stock DerivedWS)

Blocked by:
Closes: #940

Diagnostic details

Property Value
Diagnostic type Generative (EDM/Karras diffusion) + deterministic regression, one class via mode
Framework PyTorch (PhysicsNeMo DiT)
Input variables ERA5 surface + pressure-level fields (package metadata; see input_coords())
Output variables COSMO-REA surface and model-level (3D) fields — surface temperature/wind/precip/cloud/fluxes plus model-level winds, temperature, humidity, and TKE; overlap vars mapped to the Earth2Studio lexicon, COSMO-only kept raw; plus optional derived hub-height u{H}m/v{H}m
Input spatial resolution ERA5 regular lat/lon crop on the native footprint
Output spatial resolution 2D rotated-pole curvilinear grid: COSMO-REA6 ~6 km / COSMO-REA2 ~2.2 km
Samples sample dim always present: 1 for mean, number_of_samples for diffusion (per-sample seed seed + i)
Reference COSMO-REA6 / COSMO-REA2 regional reanalysis (DWD and the Hans-Ertel-Centre for Weather Research, University of Bonn)
GitHub https://github.com/NVIDIA/earth2studio

Dependencies added

New cosmo model extra (pip install earth2studio[cosmo] / uv add earth2studio --extra cosmo), also added to the all extra.

Package Version License License URL Reason
nvidia-physicsnemo >=2.0 Apache-2.0 https://github.com/NVIDIA/physicsnemo/blob/main/LICENSE.txt DiT (RoPE + NATTEN), EDM preconditioner/sampler, zenith-angle util
natten ; python_version < '3.14' MIT https://github.com/SHI-Labs/NATTEN/blob/main/LICENSE Neighborhood attention in the DiT
einops >=0.8.1 MIT https://github.com/arogozhnikov/einops/blob/master/LICENSE Tensor rearranges in the network
nvtx >=0.2.11 Apache-2.0 https://github.com/NVIDIA/NVTX/blob/release-v3/LICENSE.txt Profiling ranges (consistent with other model extras)

All permissive (Apache-2.0 / MIT); no non-permissive licenses introduced.

Validation

  • CPU suite (test/models/dx/test_cosmo_downscaling.py): 51 passed, 5 skipped. The 5 skips are intentionally resource-gated tests — the 4 real-weight GPU --package cases and 1 real-DiT-builder smoke test — which require a GPU, a built package, and/or physicsnemo's not-yet-released RoPE op, so they cannot run in the default CI lane (the standard @pytest.mark.package pattern used by every model).
  • GPU lane (those gated tests, run locally on real weights): all pass — the 4 {rea6, rea2} × {mean, diffusion} combos load 0 missing / 0 unexpected keys and run a full dx(x, coords) forward (output shape, finiteness, and every metadata min/max bound asserted in both modes), plus the DiT-builder smoke test.
  • black + ruff clean; SPDX headers present.
  • Example examples/03_downscaling/04_cosmo_rea_downscaling.py verified end-to-end on real weights (full domain, sub-domain, rollout, diffusion ensemble vs regression mean, hub-height wind, COSMO-REA2).

The CPU tests are construction-driven and GPU-free: they pin the wrapper's contract — coords/handshake + lexicon output names, channel-transform round-trips, physical-space min/max bounds and a mode-identical solar-zenith day/night gate, set_domain sub-domain slicing (incl. NATTEN-minimum and patch-snap edges), the diffusion sampler's schedule termination + ensemble seeding, and the hub-wind derivation + DerivedWS naming handshake.

Notes for reviewers

Two known, bounded items, each guarded so CI stays green today:

  1. PhysicsNeMo RoPE not yet released. The DiT uses natten2d_rope, merged on physicsnemo main but not in a tagged release. The cosmo extra therefore pins nvidia-physicsnemo>=2.0, and the real-DiT-builder + --package tests are skip-guarded when the op is unavailable (they skip, not fail). A follow-up bumps the pin and drops the guards once a release includes it.
  2. Weights not hosted yet. load_default_package raises NotImplementedError; the example is in expected_failing_examples in docs/conf.py so the gallery build stays green. Flipping DEFAULT_PACKAGE_URI on and removing the whitelist entry is the only change needed once hosting lands. (The model was validated against real weights locally.)

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.
  • Assess and address Greptile feedback (AI code review bot).

Add CosmoDownscaling, a diagnostic model that downscales ERA5 to high-resolution
COSMO-REA regional reanalysis -- COSMO-REA6 (~6 km) and COSMO-REA2 (~2.2 km) --
each with a deterministic regression (mean) and a generative EDM/Karras diffusion
mode, selected via `mode` on `load_model`. Supports movable sub-domains via
`set_domain` and optional hub-height wind components for wind-energy use.

- earth2studio/models/dx/cosmo_downscaling.py: the wrapper (physicsnemo DiT with
  axial RoPE + NATTEN attention, EDM diffusion sampler, channel transforms +
  physical-space constraints incl. a solar-zenith day/night gate, sub-domain
  cropping, derived hub-height wind components)
- test/models/dx/test_cosmo_downscaling.py: construction + behavior tests with
  mocked nets, plus a GPU `--package` test against a real package
- examples/03_downscaling/04_cosmo_rea_downscaling.py: end-to-end example
  (SFNO -> downscale full domain, sub-domain, rollout, diffusion ensemble vs
  regression mean, hub-height wind, COSMO-REA2)
- register the model; add the `cosmo` extra; install/CHANGELOG/API-doc entries

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: gertl <gertl@nvidia.com>
@gertln gertln requested a review from pzharrington June 26, 2026 08:32
@gertln

gertln commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator Author
cosmo_domains

Output domains: native and extended grid footprints for both resolutions. REA6 spans broad Europe at ~6 km; REA2 is central-European natively at ~2.2 km, with an extended grid reaching a comparable broad domain.

Domain Resolution Grid size
REA6 native ~6 km 824 × 848
REA6 extended ~6 km 1006 × 1030
REA2 native ~2.2 km 780 × 724
REA2 extended ~2.2 km 2780 × 2724

@gertln gertln marked this pull request as ready for review June 26, 2026 10:10
@greptile-apps

greptile-apps Bot commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

Adds CosmoDownscaling, a new DiagnosticModel that downscales ERA5 to high-resolution COSMO-REA6 (~6 km) or COSMO-REA2 (~2.2 km) over Europe via either a deterministic regression (mean) or generative EDM diffusion DiT-RoPE network. A new cosmo package extra is introduced, and the model is gated behind load_default_package → NotImplementedError until weights are hosted.

  • Core model (cosmo_downscaling.py, 1995 lines): handles ERA5 z-scoring → bilinear regrid → rotated-pole DiT forward → inverse transforms → physical constraints → optional hub-height wind; set_domain slices the native/extended grid to a lat/lon bbox without retraining.
  • Tests (test_cosmo_downscaling.py, 1324 lines): CPU-only suite covering coord contracts, constraint parsing, halo/patch-snap edge cases, diffusion schedule shape and seeding, and a constructor-signature guard that forces new args through set_domain.
  • Docs/example/packaging: docs/conf.py whitelists the example in expected_failing_examples pending weight hosting; install guide and RST index are updated; pyproject.toml adds the cosmo extra with all permissive licenses.

Confidence Score: 4/5

Safe to merge for the primary sequential use case; the shared-state mutation and hardcoded ERA5 spacing are edge cases that do not affect the documented workflow.

The implementation is thorough and well-tested. All findings are non-blocking: the _rebind_latent shared-state mutation only becomes a problem under concurrent access, the hardcoded 0.25 degree ERA5 step in set_domain surfaces as a ValueError rather than silent wrong output, and the extended-grid propagation gap is a usability limitation rather than a correctness defect.

earth2studio/models/dx/cosmo_downscaling.py — specifically the _rebind_latent / set_domain interaction and the _reg ERA5-resolution assumption.

Important Files Changed

Filename Overview
earth2studio/models/dx/cosmo_downscaling.py New 1995-line CosmoDownscaling diagnostic model with regression and diffusion modes, rotated-pole output grid, set_domain sub-region slicing, hub-height wind derivation, and full package-loading logic; well-engineered but shares mutable DiT state across sub-domain instances and hardcodes 0.25° ERA5 spacing in set_domain
test/models/dx/test_cosmo_downscaling.py Comprehensive 1324-line CPU test suite covering coordinate contracts, constraint parsing, domain slicing, halo/patch-snap edge cases, hub-wind derivation, diffusion schedule, and a constructor-signature guard that forces new args to be threaded through set_domain
examples/03_downscaling/04_cosmo_rea_downscaling.py End-to-end example covering SFNO to CosmoDownscaling rollout, sub-domain, diffusion ensemble vs regression mean, hub-height wind, and COSMO-REA2; correctly listed in expected_failing_examples since weights are not yet hosted
pyproject.toml Adds cosmo extra (einops, natten, nvidia-physicsnemo>=2.0, nvtx) and includes it in the all extra; all permissive licenses
earth2studio/models/dx/init.py Adds CosmoDownscaling import and all entry
docs/conf.py Adds the COSMO-REA example to expected_failing_examples with a clear comment; appropriate short-term measure until weights are hosted
test/conftest.py Adds the new test file to _TEST_DEPENDENCIES under the cosmo extra, consistent with other model test registrations

Reviews (1): Last reviewed commit: "feat: add CosmoDownscaling diagnostic mo..." | Re-trigger Greptile

Comment on lines +1242 to +1260
sig_n = t_next.to(torch.float32).repeat(x.shape[0])
d_prime = (
x_next - net(x_next.float(), sig_n, condition=cond).double()
) / t_next
x_next = x + (t_next - t_cur) * 0.5 * (d_cur + d_prime)
x = x_next
return x.float()

@torch.inference_mode()
def _forward(
self,
era5: torch.Tensor,
valid_time: datetime,
lat2d: torch.Tensor,
lon2d: torch.Tensor,
) -> torch.Tensor:
"""Forward for one ERA5 frame -> [number_of_samples, C_out, H, W] physical
(C_out = the trained outputs, plus the derived hub-height wind components
when ``hub_heights`` is set).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Shared mutable state mutated per-forward — unsafe for concurrent use

_rebind_latent writes directly into the attn_kwargs_forward dict and the detokenizer's h_patches/w_patches on the live DiT object. Because set_domain propagates the same network reference (cr.regression_model is dx.regression_model), two instances sharing one network will overwrite each other's latent-grid bindings if called concurrently (e.g., two threads, two asyncio tasks, or a DataLoader with num_workers > 0).

Sequential single-threaded calls are safe (the rebind happens before each forward), but nothing in the API surface or docstrings warns users about this constraint. Adding a note to set_domain and load_model — or copying the latent-state on sub-domain creation — would prevent silent wrong-grid outputs under concurrent use.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The network is shared across set_domain sub-domains and each forward sets its grid on it in place. So sub-domains are fine to run sequentially but two running at the same time would lead to conflicts. I documented this in set_domain and _rebind_latent with the advice to run concurrent domains in separate processes. (d498ccf)

Comment on lines +1602 to +1611
"qk_norm": rp.get("qk_norm", False),
}
backend = rp.get("natten_backend", "cutlass-fna")
if backend is not None:
attn["na2d_kwargs"] = {"backend": backend}
conv_layers = rp.get("detok_conv_layers", 0)
return DiT(
input_size=list(rp.get("img_resolution", [256, 256])),
in_channels=cond_channels,
out_channels=out_channels,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded 0.25° ERA5 resolution in set_domain

_reg always constructs the sub-domain's ERA5 input grid at 0.25° steps regardless of the parent model's actual input resolution. If the parent package ships a coarser or finer ERA5 grid (e.g., 0.5°), the sub-domain's lat_input_grid / lon_input_grid will not match the user's data, and output_coords will reject the input via _is_native_input with the relatively opaque message "requires the native input grid". The step size should be derived from the parent's lat_input_numpy spacing rather than hardcoded.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in d498ccf. set_domain now reads the input spacing from the package's own grid instead of assuming 0.25.

Comment on lines +1644 to +1649
attn["na2d_kwargs"] = {"backend": backend}
dit = DiT(
input_size=list(dp.get("img_resolution", [256, 256])),
in_channels=out_channels + cond_channels,
out_channels=out_channels,
patch_size=dp["patch_size"],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Extended-grid state not propagated to sub-domains — chained set_domain silently loses OOD-margin access

After set_domain returns a sub-domain instance, its _ext_lat_numpy, _ext_lon_numpy, and _ext_static_numpy remain None (the constructor default). If a user calls set_domain again on the returned sub-domain, the call is silently restricted to the sub-domain's lat_output_numpy only — the extended margin invariants are no longer reachable. There is no warning or documentation note about this; users who want two differently-shaped sub-regions that both access the OOD margin must call set_domain from the original loaded model, not chain it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By design, now documented (d498ccf). A sub-domain only keeps its own cropped grid, so calling set_domain again on it can't go back to the full footprint.

Comment on lines +353 to +356
regression_model: torch.nn.Module | None,
diffusion_model: torch.nn.Module | None,
resolution: str,
mode: str,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 interp module shadowed by a local variable inside interp_levels_to_height

The module-level from earth2studio.utils import ... interp is shadowed by the local interp tensor assignment inside the loop. This is not a bug today (the module is not called inside this function), but any future edit adding a call to interp.latlon_interpolation_regular within the loop would silently call the tensor instead of the module. Renaming the local to val avoids the ambiguity.

Suggested change
regression_model: torch.nn.Module | None,
diffusion_model: torch.nn.Module | None,
resolution: str,
mode: str,
val = (
values[..., j, :, :] + (values[..., j + 1, :, :] - values[..., j, :, :]) * w
)
out = torch.where((t >= lo) & (t < hi), val, out)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed the local to interp_val so it no longer shadows the interp import (d498ccf)

…currency

- set_domain: derive the ERA5 input-axis spacing from the package's own
  validated grid instead of a hardcoded 0.25 deg (behavior-preserving for
  the shipped 0.25 deg package; supports a non-0.25 deg package).
- interp_levels_to_height: rename local `interp` -> `interp_val` so it no
  longer shadows the module-level `interp` import.
- set_domain / _rebind_latent: document that the network is shared by
  reference and rebinds its latent-grid state in place, so sub-domains are
  safe sequentially but not concurrently; and that a second set_domain on a
  returned sub-domain cannot reach back to the full/extended footprint.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Comment thread docs/conf.py Outdated
"log_level": {"backreference_missing": "warning", "gallery_examples": "debug"},
# COSMO-REA weights are not yet publicly hosted, so this example raises without
# a local $COSMO_REA_PACKAGE. Whitelist it so the full gallery build does not
# fail; remove this entry once the package is hosted (load_default_package).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a TODO so we don't forget this

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done (bab2e90)

Comment thread earth2studio/models/dx/__init__.py Outdated
CorrDiffTaiwan,
)
from earth2studio.models.dx.corrdiff_cmip6 import CorrDiffCMIP6
from earth2studio.models.dx.cosmo_downscaling import CosmoDownscaling # noqa

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why # noqa?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove it - not needed.

Comment thread pyproject.toml Outdated
cosmo = [
"einops>=0.8.1",
"natten ; python_version < '3.14'",
"nvidia-physicsnemo>=2.0",

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we need a stricter constraint on this if we need RoPE stuff from what I recently merged

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pinned nvidia-physicsnemo>=2.2.0. natten2d_rope is only on main (2.2.0a0) and the latest release (2.1.1) doesn't have it. So it is intentionally uninstallable until 2.2.0 ships. (Also removed cosmo from the all extra because >=2.2.0 conflicts with da-healda <=2.1.1 )

modes are DiT-RoPE (a diffusion transformer with rotary position embedding) and
resolution-agnostic (single forward at any grid size).

Domain handling (see the integration handoff + design memory):

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't reference implementation notes or agent memory context

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

native trained footprint, or reach into the extended invariant margin -- the
latter proceeds with a one-time out-of-distribution warning; beyond the extended
extent raises. Both the mean and diffusion models are DiT-RoPE
(resolution-agnostic), so a sub-domain runs at any size in a single forward.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's actually fixed resolution, and more domain/crop size agnostic, right?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it is fixed grid resolution and flexible domain size. Reworded "resolution-agnostic" to "crop-size agnostic" here and everywhere it appeared.

cos_zenith_angle = None


# Fixed background channel layout (see cosmo_rea2_era5.py::_load_full_frame).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't reference missing/external files

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably will pop up in several other places -- have your agents make a pass to remove the parts of comments like this that reference stuff which won't end up getting merged

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"CLCT": "tcc", # total cloud cover; % -> 0-1 fraction (COSMO_OUTPUT_UNIT_SCALE)
"PS": "sp",
"PMSL": "msl",
}

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we open and merge a COSMO-REA data source before this PR so that the model wrapper can simply reuse the COSMO lexicon? Adding a lexicon to the model wrapper file is not standard

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And, I don't see a particular for the model wrapper to be doing any lexicon interaction in the first place. Wrappers operate based on input_coords and output_coords, which should specify variables that exist in the E2S lexicon. If there are things in the model package using these other names, ideally we rename them to avoid having to deal with lexicon conversions in model code. But, if absolutely necessary, then we should pull from a COSMO lexicon as I suggested above

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a CosmoLexicon in earth2studio/lexicon/ and the wrapper uses it

# z0-free (roughness cancels between the bracketing levels).


def interp_levels_to_height(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I appreciate the design note to clarify intent/choices. I agree given the dependence on elevation it makes sense to attach this to the model. Now that we have it in a commit I think we can drop the lengthy comment explaining rationale.

As for the method itself, is there a reason to make it a standalone, separate public helper rather than a method of the class (possible just an internal helper method too, i.e. _interp_levels_to_height)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trimmed the comment and made it a private module helper function _interp_levels_to_height


@check_optional_dependencies()
class CosmoDownscaling(torch.nn.Module, AutoModelMixin):
"""COSMO-REA downscaling model: ERA5 -> high-resolution COSMO-REA.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should eventually get filled out with details of the model, use-cases, intended usage, etc

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added details about the two modes, use cases and intended usage

background channel order.
channel_transforms : dict | None
Per-output-channel nonlinear transform spec (from the training
``channel_transforms`` zarr attr), used to invert after de-normalizing.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

training zarr attrs shouldn't be referenced

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed this reference

solar gate), applied in physical space in postprocess for both modes.
number_of_samples : int
Number of samples (diffusion); the ``sample`` dim is kept (size 1) for
``mode="mean"`` so both modes share an output contract.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this a constructor arg? Shouldn't it be dynamic?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constructor arg just provides the default but it is dynamic and can be set between calls. This matches CorrDiff behavior and the diagnostic __call__(x, coords) signature has no per-call kwargs, so sample count is carried as instance state. I added a note to the docstring making this behavior more explicit.

Diffusion noise-schedule bounds. Default to 0.002 / 800.0.
rho : float
Karras noise-schedule exponent. Defaults to 7.0.
solver : str

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the rho and solver need to be configurable via constructor args? Can they just be fixed or do they change?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They're kept configurable on purpose. We expose them because these settings (sigma_max, rho, solver, steps) can shape the sampled output distribution e.g. ensemble spread and how extremes are represented. So leaving them tunable is useful for inference experimentation and optimization. I've added a docstring note on this.


Badges
------
region:eu class:ds product:wind product:precip product:temp product:atmos year:2024 gpu:80gb

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Release year 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed -> 2026

regression_model: torch.nn.Module | None,
diffusion_model: torch.nn.Module | None,
resolution: str,
mode: str,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Literal[...] for string args with a fixed set of options

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

) / t_next
x_next = x + (t_next - t_cur) * 0.5 * (d_cur + d_prime)
x = x_next
return x.float()

@pzharrington pzharrington Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to use a custom sampler over the standard physicsnemo.diffusion offerings?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched _denoise to the standard EDMNoiseScheduler + get_denoiser + sample()

``halo`` (px, default 0 = off): run on a block expanded by ``halo`` real
cells per side and trim it off the output, keeping the returned bbox
interior clear of the DiT's boundary artifact (~32 px); clamps + warns at
the grid edge. Both models are DiT-RoPE (resolution-agnostic), so any size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same resolution-agnostic comment applies here

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed, renamed

# Build the DiT with upstream physicsnemo's axial-2D-RoPE NATTEN backend
# (attention_backend="natten2d_rope" in _build_*_dit). RoPE adds no
# parameters, so a non-RoPE DiT would load the same weights 0/0 yet run a
# wrong (un-RoPE'd) forward -- building with the backend avoids that.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This approach of loading checkpoints with manual helpers extracting out of a zip file seems brittle and error-prone. If anything in upstream PNM changes about the internal zip format, it could break here, for example. I strongly prefer the standard DiT.from_checkpoint(...) approach.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to the standard from_checkpoint

checkpoint keys are ``model.model.*`` plus ``sigma_data``."""
from physicsnemo.diffusion.preconditioners import EDMPreconditioner
from physicsnemo.diffusion.utils import ConcatConditionWrapper
from physicsnemo.models.dit import DiT

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move imports to the top in a standard protected import like other models (see e.g. StormScope)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"rope_theta": dp.get("rope_theta", 10000.0),
"qk_norm": dp.get("qk_norm", False),
}
backend = dp.get("natten_backend", "cutlass-fna")

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why hardcode the backend to cutlass-fna (or manually impose it)? This could block e.g. the Hopper bf16 kernels getting used on H100. Natten should be able to select the appropriate one.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am no longer forcing a specifc backend

resolution: str = "rea6",
hub_heights: Sequence[float] | None = None,
hub_interp: str = "linear",
) -> DiagnosticModel:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Literal[...] for all string args with a fixed set of options

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


@classmethod
@check_optional_dependencies()
def load_model(

@pzharrington pzharrington Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General comment. Add a minimal config.json in the top level of the model package directory (can just be super simple like this) if there aren't any files already named that (or config.yaml) . Then, add a line in the load_model to resolve/load that file from the package. That will enable automatic download tracking in HuggingFace (I know it's a weird system but it comes from HF rules, not ours)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a minimal config.json at the package root

…, standard sampler

- Add CosmoLexicon (earth2studio/lexicon/cosmo.py); the wrapper maps COSMO output
  names to Earth2Studio names via it, replacing the in-module mapping dicts.
- Load checkpoints with the standard DiT.from_checkpoint /
  EDMPreconditioner.from_checkpoint; drop the manual .mdlus zip extraction, the
  hand-built DiT, and the forced NATTEN backend (native modules let NATTEN
  auto-select the kernel per GPU).
- Replace the inline EDM/Karras sampler loop with physicsnemo's EDMNoiseScheduler
  + sample().
- Literal[...] for fixed-option string args; move physicsnemo/json/xarray imports
  to the top (physicsnemo in the protected optional-dependency block); resolve a
  package-root config.json in load_model for HuggingFace download tracking.
- Pin nvidia-physicsnemo>=2.2.0 (RoPE); drop cosmo from the `all` extra (it
  conflicts with da-healda's <=2.1.1 cap until 2.2.0 ships).
- Docstrings/comments: "crop-size agnostic at the fixed resolution" wording,
  remove internal/dangling references, expand the class docstring, TODO(cosmo)
  markers, badge year 2026.
- Tests: add CosmoLexicon tests + an euler-vs-heun solver test; route the
  output-name test through CosmoLexicon; drop the obsolete DiT-builder test.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Signed-off-by: gertl <gertl@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🚀[FEA]: Add CosmoDownscaling diagnostic model (ERA5 → COSMO-REA6/REA2)

2 participants