Skip to content

HealDA v2 Architecture#1758

Open
aayushg55 wants to merge 56 commits into
NVIDIA:mainfrom
aayushg55:ag/healda-v2-arch
Open

HealDA v2 Architecture#1758
aayushg55 wants to merge 56 commits into
NVIDIA:mainfrom
aayushg55:ag/healda-v2-arch

Conversation

@aayushg55

Copy link
Copy Markdown
Contributor

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

aayushg55 added 30 commits June 24, 2026 15:34
Bring the ragged grouped-query pixel cross-attention (pixel latents attend
to per-pixel packed observation tokens) into the experimental healda package
as a first step toward a video/observation DiT block.

Layout follows PNM's optional-dependency conventions:
- triton is referenced only via OptionalImport (never a bare import), so the
  modules import without triton and the import-linter external-import contract
  stays satisfied.
- the compiled Triton kernels live in a private _pixel_attn_kernels.py backend
  (mirroring the warp _warp_impl pattern), imported lazily by the public
  pixel_cross_attention.py only when triton is available -- no per-chunk guards.
- triton_autotune_cache.py: standalone autotune-config persistence util,
  likewise OptionalImport-based.

Headers use the current PNM template; ruff + license-header checks pass.

Commit message authored by AI
A 4D (b, t, x, c) analog of physicsnemo.nn.DiTBlock for field sequences: keeps
the DiT template (pluggable spatial-attention backend via get_attention,
adaLN-Zero conditioning, drop-path, gated MLP) and adds two optional gated
sub-layers:
- temporal (video) attention across time, with a pluggable time<->space reshard
  for context parallelism (manual all-to-all or ShardTensor.redistribute);
- observation cross-attention (pixel latents attend to packed per-pixel obs
  tokens via the vendored PixelCrossAttention).

New modules: temporal_attention.py (TemporalAttention + RoPE + causal/window
mask), sharding.py (shard_x/shard_t all-to-all + ShardTensor variants),
obs_packing.py (AttentionPacking/PixelGroupMap, copied as-is). Tests cover the
CPU plain + temporal paths and the CUDA full spatial+obs+temporal path
(forward/backward + grad flow).

Commit message authored by AI
Replace the _pixel_attn_kernels() helper + _k alias with an inline
`from . import _pixel_attn_kernels as kernels` inside the kernel-launch
wrappers, matching PNM's lazy optional-backend pattern (cf.
mesh/visualization/draw_mesh.py importing _matplotlib_impl/_pyvista_impl).

Commit message authored by AI
Bundle the observation tokens and ragged packing metadata into one
ObsCrossAttention object (tokens + cu_seqlens_k + max_seqlen_k + group_map)
so VideoDiTBlock's obs sub-layer takes a single argument instead of separate
obs_tokens + packing args. Drop the redundant data-pipeline AttentionPacking
struct (its counts/npix/hpx_level/pixel_order/is_packed fields are unused by
the model). Add jaxtyping Float/Int shape hints across the block, temporal
attention, sharding, and obs-packing modules.

Commit message authored by AI
Use the released physicsnemo.models.dit shape names: the token feature axis is
'hidden_size' (not 'channels'/'dim') and conditioning is 'condition_embed_dim',
matching DiTBlock's annotations and our own hidden_size constructor arg.

Commit message authored by AI
Add the ragged pixel cross-attention test suite (16 tests) validating the
Triton kernel forward/backward against a readable PyTorch GQA reference
(_ragged_gqa_reference), plus packed-grid, small-pixel grouping (grouped ==
ungrouped bit-for-bit), nn.Module wiring, empty-tokens DDP-safety, and config
validation. Port build_pixel_group_map (the CSR small-pixel grouping helper,
pure function of cu_seqlens_k) into obs_packing.py to support the grouping test.

Commit message authored by AI
A HEALPix field-sequence diffusion transformer over (B, C, T, npix): reuses the
existing HEALPixPatchTokenizer/HEALPixPatchDetokenizer (which already fold the
time axis and add the calendar embedding) and EDM conditioning, reshapes the
flat token sequence to (B, T, X, hidden) for the VideoDiTBlock stack (spatial +
optional factorized temporal + optional observation cross-attention), then back
for the detokenizer. Observations enter as a prebuilt ObsCrossAttention bundle.

Tests: CPU dense+temporal and CUDA dense+temporal+obs forward/backward.

Commit message authored by AI
- VideoDiT no longer hardcodes HEALPix: it takes a pluggable tokenizer /
  detokenizer (the grid lives only in tokenization; the backbone is grid-
  agnostic) plus time_length, threading tokenizer_kwargs and reshaping the flat
  time-major token sequence to (B, T, X, hidden) for the blocks.
- VideoDiTBlock's spatial+MLP path now mirrors physicsnemo.nn.DiTBlock: renamed
  emb_channels -> condition_embed_dim and added the DiTBlock dropout args
  (attn_drop_rate, proj_drop_rate, mlp_drop_rate, final_mlp_dropout), wired into
  get_attention and the MLP. Its args are now a superset of DiTBlock's.

Commit message authored by AI
…pose VideoDiT

- VideoDiTBlock subclasses physicsnemo.nn.DiTBlock, reusing its spatial
  attention, gated MLP, pre-norms, 6-chunk adaLN-Zero and drop-path; adds
  optional gated temporal + observation cross-attention sub-layers.
- is_causal and obs/temporal config moved to __init__; obs args behind
  obs_kwargs, temporal behind temporal_kwargs (explicit Dicts, no **kwargs).
- VideoDiT inherits physicsnemo.Module and composes the conditioning embedder +
  pluggable tokenizer/detokenizer + blocks (no production DiT change).
- sharding: @torch._dynamo.disable on the ShardTensor reshard.

Commit message authored by AI
…-attention

Generalize the experimental video DiT into a grid-agnostic DiT with a time axis:
- add a shared ndim-agnostic AdaLayerNormZero with a zero_init toggle (using
  get_layer_norm) and SiLU inside the module;
- compose VideoDiTBlock from shared building blocks instead of subclassing
  DiTBlock, moving modulation into per-sub-layer AdaLayerNormZero;
- replace the obs-specific cross-attention with a generic pluggable
  CrossAttentionModuleBase slot + opaque context; PixelCrossAttention is the
  reference impl and now owns the fold/ragged-unpack;
- make VideoDiT a kwarg-superset of DiT (drop_path_rates, conditioning_embedder
  choice, attn_kwargs, block_kwargs, dit_initialization).

Commit message authored by AI
…oss-attn, adaLN naming

- VideoDiT/VideoDiTBlock: cross_attention is a per-block factory (no deepcopy);
  forward context typed Optional[Any]; conditioning resolution inlined.
- Time axis first-class: HEALPix tokenizer gains separate_time_axis; detokenizer
  infers flat-vs-time-first from input rank (backwards-compatible with the v1
  flat + time_length path).
- adaLN-Zero attrs renamed role-first: attn_norm / temporal_attn_norm /
  cross_attn_norm / mlp_norm.

Commit message authored by AI
… Triton kernel

Port the v2 FiLM-conditioned observation tokenizer as an initial drop-in:
- obs_film_tokenizer.py: ObsTokenizerFiLM module + pure-PyTorch reference,
  custom-op wrappers, and the fused_film_tokenizer_triton entry point. Forward
  dispatches to the Triton kernel on CUDA when triton is available, else the
  reference path.
- _film_kernels.py: private fused FiLM forward/backward Triton kernels, guarded
  by OptionalImport("triton") (no bare import) and imported lazily.
- test_obs_film_tokenizer.py: CPU reference smoke test plus CUDA Triton-vs-
  reference parity tests.

Settings are kept as-is for now (TODO(polish) markers on unused ones).

Commit message authored by AI
DiT.__init__ indexed input_size[1] unconditionally, which IndexErrors for
non-2D tokenizers (e.g. HEALPix, where HealDA passes a 1-tuple input_size).
The latent grid is only consumed by the NATTEN backends, so guard it.

Commit message authored by AI
Match the public module stem (obs_film_tokenizer.py), mirroring
pixel_cross_attention.py / _pixel_attn_kernels.py.

Commit message authored by AI
…kenizer

obs_film_tokenizer.py -> obs_tokenizer.py, _obs_film_kernels.py ->
_obs_tokenizer_kernels.py (class stays ObsTokenizerFiLM; "FiLM" distinguishes
the impl from the existing ObsTokenizer and leaves room to evolve).

Commit message authored by AI
Compose the upstreamed VideoDiT backbone with the FiLM ObsTokenizerFiLM and
per-block PixelCrossAttention into the production v2 video+obs data-assimilation
architecture (hidden 1536, 16 heads, 32 layers, time_length 8, linear causal
temporal attention, drop-path 0.0 for the first 4 blocks then 0.1), so the
existing healda checkpoint can be loaded. HealDAv2 hosts the FiLM obs tokenizer
and assembles the per-pixel ObsCrossAttention context the backbone consumes.

Commit message authored by AI
…tidy wording

- VideoDiT.set_context_parallel(mode, target) fans the temporal time<->space
  reshard config out to every block (the per-block setter had no caller, so CP
  was never actually enabled through the model).
- VideoDiT.forward asserts the tokenizer emits 4D (B,T,X,hidden).
- Drop the coined "field sequence(s)" phrasing from docstrings/comments.

Commit message authored by AI
…ntract

Make the base a general cross-attention sub-layer rather than one "injected
into a video DiT block": forward takes (*batch, hidden_size) latents and an
opaque context: Any. Drops the field-sequence wording and trims the docstring.

Commit message authored by AI
…ocks=2)

Replace the separate attn_norm + mlp_norm AdaLayerNormZero(n_blocks=1) pair with
a single norm1 = AdaLayerNormZero(n_blocks=2) that emits the spatial-attention
modulation plus the raw MLP shift/scale/gate, and a parameter-free LayerNorm
MLP pre-norm modulated by those. This matches the DiT/diffusers layout and the
production checkpoint's single norm1.linear. temporal_attn_norm/cross_attn_norm
stay one-block adaLNs; initialize_weights now zeroes norm1.

Commit message authored by AI
Rename the packed-observation container to ObsContext and make tokens optional
(unset until the observation tokenizer fills it). PixelCrossAttention.forward
now consumes an ObsContext, raising if tokens is unset; a None group_map keeps
using the ungrouped ragged path (the model layer does not build it). Updates
all references and tests.

Commit message authored by AI
…orward

ObsContext now also carries the raw per-observation arrays (values,
float_metadata, obs_type, channel, platform) alongside the ragged packing, so
HealDAv2.forward takes one obs: ObsContext instead of loose per-obs and packing
args. The forward runs the tokenizer, fills tokens via dataclasses.replace, and
passes the context through; it no longer builds the pixel group map (a None
group_map uses the ungrouped ragged path). build_pixel_group_map stays for
callers that precompute it.

Commit message authored by AI
…mporal RoPE

Replace the hand-rolled RotaryPositionEmbedding with PhysicsNeMo's
nn.module.rope.RotaryPositionEmbedding1D (math-verified equivalent: same
interleaved-pair rotation and theta^(-2k/d) schedule). The temporal q/k are
transposed so the time axis is the -2 dim the module rotates, then restored.
Its cos/sin are non-persistent buffers recomputed at init, so the old
persistent rope.freqs_cos/sin keys no longer exist.

Commit message authored by AI
…rim docstring

Remove the unused nchannel, nplatform, and use_global_channel_platform_ids
constructor params (the embedding tables are always sized GLOBAL_MAX_* and id
spaces are the caller's responsibility). Trim the editorial rationale and
conv/sat first-linear essays in the class docstring to terse facts.

Commit message authored by AI
…meter-free

HealDAv2 passes attn_kwargs={"qk_norm_type":"RMSNorm","qk_norm_affine":False}.
PhysicsNeMo's TimmSelfAttention declares those exact kwargs (translating them to
timm's qk_norm + RmsNorm norm_layer), so affine-free RMSNorm actually engages and
the names are not silently swallowed. Add a regression test that the spatial q/k
norm modules exist (not Identity) and have no learnable parameters.

Commit message authored by AI
Lightweight audit cleanups: jaxtyping annotations on the two HEALPix hpx
tokenizer forwards, a descriptive CalendarEmbedding ValueError message, an
accurate uneven-shards comment in sharding.py, a Literal["apex","torch"] hint on
AdaLayerNormZero.layernorm_backend, and removal of the dead VideoDiTBlock
self.hidden_size attribute. The hpx __init__ **kwargs safety nets are kept.

Commit message authored by AI
…n; tidy HealDAv2 wiring

Drop the per-block hot-path validation from PixelCrossAttention.forward (tokens-set
and cu_seqlens-vs-pixel-count) and instead validate the packing's structural shape
once in ObsContext.__post_init__; tokens-set is guaranteed by HealDAv2.forward.
Build the HEALPix (de)tokenizers as locals instead of inline in the VideoDiT call,
and note the drop-path zero-first stability rationale in the docstring.

Commit message authored by AI
…packing->obs_context

Extract the kernel-companion packing primitives into a single
pixel_attention_utils.py: sort_and_pack (inlined Triton counting-sort kernel,
guarded by triton availability, argsort fallback), counts_to_cu_seqlens, and
build_pixel_group_map. These operate on plain index/count tensors, so the data
pipeline builds the packing the model only consumes. obs_packing.py is now purely
the ObsContext + PixelGroupMap contract, renamed to obs_context.py. Also build
HealDAv2's cross-attention via functools.partial and keep the cross-attention
base docstring generic. Adds CPU parity tests for the packing utils.

Commit message authored by AI
aayushg55 and others added 5 commits June 30, 2026 17:05
…pers

Split the adaLN-Zero op into a projection module and two apply helpers so
every sub-layer composes the same way and the pieces can be reused across
DiT/video blocks (3D and 4D states):

- adaln.py: AdaLNModulation (c -> 3*n_blocks shift/scale/gate chunks) plus
  standalone modulate() and gated_residual(); the affine-free pre-norm now
  lives at the call site instead of inside the module.
- video_dit_block.py: rewire to the uniform norm -> modulate -> gated_residual
  pattern (no block-0 special case); projections named norm1_modulation /
  temporal_attn_modulation / cross_attn_modulation, each attention sub-layer
  owning its own parameter-free pre-norm (attn_norm / mlp_norm /
  temporal_attn_norm / cross_attn_norm).
- video_dit.py: update the adaLN docstring reference.
- test: track the norm1_modulation rename.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
…te _kernels modules

Consolidate each op's Triton device code and its host-side glue (launch
dispatch, autograd Function, custom-op registration) into the op's private
_*_kernels module, leaving the model files as just the nn.Module + public API:

- _pixel_attn_kernels.py / _obs_tokenizer_kernels.py: now own the full Triton
  op stack instead of only the @triton.jit kernels.
- pixel_cross_attention.py: slimmed to the PixelCrossAttention module;
  pixel_attention_reference -> _pixel_attention_reference (now private).
- obs_tokenizer.py: slimmed to the ObsTokenizerFiLM module.
- __init__.py: add package docstring + __all__; export PixelCrossAttention and
  ObsTokenizerFiLM alongside the existing public classes.
- pixel_attention_utils.py: docstring trim.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The packing helpers (sort_and_pack, counts_to_cu_seqlens, build_pixel_group_map,
counting_sort_and_pack + its counting-sort Triton kernel) build the ragged
layout that pixel_attention consumes and are used only by pixel cross-attention,
so co-locate them in pixel_cross_attention.py and drop the separate module.
Repoint the two tests and update the obs_context doc reference.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Comment thread physicsnemo/experimental/models/healda/pixel_cross_attention.py Outdated
root and others added 2 commits July 1, 2026 09:54
- video_dit_block: set_context_parallel now type-checks target (ProcessGroup for
  "all_to_all", DeviceMesh for "shardtensor") with a clear TypeError.
- healda_v2: add a set_context_parallel passthrough to the backbone; expand the
  class docstring (data-flow stages, grid-agnostic boundary, context-parallel
  constraints) and Notes (obs packing, and that only the timm backend supports
  the RMSNorm / qk_norm_affine=False QK-norm used for stable training).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

@negin513 negin513 left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job @aayushg55!

I’ll add a few comments inline.

One higher-level thing I noticed: this PR does not add an end-to-end training recipe or example showing how the new HealDAv2, VideoDiT, ObsContext, and observation/pixel-attention pieces are intended to be wired together...

Does it make sense to add a minimal example in this PR under examples/weather/healda/ ?

@aayushg55

Copy link
Copy Markdown
Contributor Author

One higher-level thing I noticed: this PR does not add an end-to-end training recipe or example showing how the new HealDAv2, VideoDiT, ObsContext, and observation/pixel-attention pieces are intended to be wired together...

Does it make sense to add a minimal example in this PR under examples/weather/healda/ ?

I think the idea was for this PR to only introduce the architecture to get it in earlier rather than later and enable adding inference capability in Earth2Studio before the next release. Further integrating the training loop and making sure it is up-to-standard would likely take significant engineering effort (and closer collaboration with the PNM team), as our internal training loops/dataloaders are not using PNM in any form at the moment and do not follow how existing PNM training loops are set up. Once the dataloader pieces are also added, we can follow up with the training loop.

Collapse PixelCrossAttention's input_dim/output_dim into one hidden_size (it
is a residual sub-layer that always used equal widths) and fix the output
reshape. Rename TemporalAttention embed_dim -> hidden_size and update callers.

Commit message authored by AI
@negin513

negin513 commented Jul 1, 2026

Copy link
Copy Markdown
Member

This PR adds a substantial new v2 model stack but doesn't surface any of the building blocks through the package's public API. I think it might worth exposing these similar to v1 through the public API.

What would this look like?

I was thinking we could update physicsnemo/experimental/models/healda/__init__.py to export the main user-facing v2 pieces, similar to how v1 exports HealDA.

At minimum, probably:

from .healda_v2 import HealDAv2, HealDAv2MetaData
from .obs_context import ObsContext, PixelGroupMap

Similar to HealDA-v1 stuff. But I would keep the Triton kernel modules private. The main goal is that a recipe or user can do from physicsnemo.experimental.models.healda import HealDAv2, ObsContext without importing deep module paths.

@aayushg55

Copy link
Copy Markdown
Contributor Author

This PR adds a substantial new v2 model stack but doesn't surface any of the building blocks through the package's public API. I think it might worth exposing these similar to v1 through the public API.

What would this look like?

I was thinking we could update physicsnemo/experimental/models/healda/__init__.py to export the main user-facing v2 pieces, similar to how v1 exports HealDA.

At minimum, probably:

from .healda_v2 import HealDAv2, HealDAv2MetaData
from .obs_context import ObsContext, PixelGroupMap

Similar to HealDA-v1 stuff. But I would keep the Triton kernel modules private. The main goal is that a recipe or user can do from physicsnemo.experimental.models.healda import HealDAv2, ObsContext without importing deep module paths.

Thanks, I have updated the physicsnemo/experimental/models/healda/init.py to export the v2 pieces

"UniformFusion",
"ScatterAggregator",
"scatter_mean",
]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We definitely don't need to publicly export all of these (e.g. MetaData classes are not typically exported), and I would actually push for the majority of these to not be exported. Advanced users who want to go in and access them can use a direct import but otherwise we only need to export the HealDA architecture and maybe any components that could be useful for people working with custom obs data.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And a more major change, but one I think is worth doing, is to outright remove the healda v1 architecture and replace it with this newer/better version. We decided we are not going to support the legacy v1 model (users can still access it by installing an older version of phyiscsnemo+earth2studio), and planned for it by version-capping the physicsnemo source for healda in last earth2studio release. So we should be free to remove/replace it here.

Ultimately the "v2" is not very meaningful for most external users and I think it's more clear to have one class for the architecture. We don't need to wait a release cycle to deprecate the older one since it's all in experimental

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 @pzharrington, thanks for the context. The MetaData and HealDAv2 was my suggestion. I was mostly thinking about discoverability from the recipe/user side on useful utilities such as ObsContext, but I agree we should keep the public API tighter and not export internals unnecessarily.

Given your point that we do not plan to support the legacy v1 model, I agree that replacing the public HealDA export with the newer implementation is cleaner than exposing a parallel HealDAv2. Then we can keep the API focused on the architecture itself, plus only the obs-data pieces that recipes/custom data users actually need, e.g. ObsContext (?) if users are expected to construct those directly. @pzharrington, I 100% follow your judgment on this. + @aayushg55 maybe you can comment around this what utils from HEALDAv2 will be helpful for future to expose through API.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the cross attention module (PixelCrossAttention), tokenizer (ObsTokenizerFiLM), the obs wrapper (ObsContext), the combined architecture (HealDAv2 renamed to HealDA), and probably the VideoDiT would be things users might be interested in using.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, let's go with that then

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, just asking for a deferral on the cleanup so that we can at least test the two implementations against one another in the same code base. I don't see why it has to be an immediate change if the concern isn't maintenance.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, users of earth2studio will not experience any breaking changes. The previous release and current version on github main cap the physicsnemo dependency for healda so users can still run the v1 checkpoint for inference (that's where I assume the downloads are coming from) backed by the previous/existing architecture in PNM. We also use versioning in our HF packages so the e2studio model wrapper wont be broken when we upload newer checkpoints (and later on a user could still choose to download the older checkpoints if desired by pulling down a previous commit). When we do update the package, we'll update the model wrapper to use the newer tag and users would get seamlessly upgraded to the new version.

For context, we had a call a week or two ago with product folks and Mike/Ayush and decided this in-place upgrade approach of v1 was the move, I'm not just raising this out of nowhere 🙂 I do think there is a need for E2 to come up with a more robust model versioning practice/strategy in general though

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also for context on this:

I think this insistence on "one DiT" may be the issue.

PNM has gotten feedback from multiple sources (external users, devtech, SAs, ...) and across domains/applications that duplicated or multiple highly similar model architecture implementations degrades the package quality and user experience. Hence the pivot towards more lower-level reusable ops/layers/blocks, and reusing larger pieces where possible -- it allows existing test coverage and feature completeness on those components to spread better across the package. I would not describe our approach as "everything in one DiT" but rather "use existing components where possible, extend and modify as needed".

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aayushg55 after some more iteration on Slack, Noah and I have arrived at the following action items:

  • You or your agents come up with a more descriptive identifier than "v2" for the HealDAv2 architecture class and related items in this PR here
  • Add an explicit deprecation warning to the v1 HealDA architecture and a TODO to remove it, but otherwise keep it in the repo for now
  • Remove all public exports of old components except the top-level architecture
  • Make sure those older components related to v1 have some reasonable documentation that they will be dropped in the future.

This should unblock progress on this particular PR while other questions on the release are ironed out

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will go with VideoHealDA

:class:`physicsnemo.nn.module.hpx.tokenizer.HEALPixPatchTokenizer`.
2. A :class:`.video_dit.VideoDiT` backbone processes the token sequence with
spatial attention, factorized temporal attention, and adaLN-Zero
conditioning built from the EDM noise embedding and the calendar

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDM noise embedding? This isn't a diffusion model so that's a bit confusing

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The architecture uses all of the DiT architecture conditioning components (AdaLN + noise/condition label). We trained it as a regression model (huber loss) setting noise to always be 0, but it could be trained as a diffusion model too.

@pzharrington pzharrington Jul 1, 2026

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I suppose that's true. I'm just not sure about advertising that here quite yet -- partly since no one has evaluated/experimented with the model in that realm as far as I know, but mainly because as of now making it a diffusion model would have to be a manual implementation. The forward signature expected by all physicsnemo.diffusion components is (x, t, condition: torch.Tensor | TensorDict) (see here) so it would take a bit of work to massage the the current Healdav2 into compliance, or one would have to write their own diffusion loop from scratch (which we don't necessarily want to encourage for obvious reasons)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I can revise the docstring

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

aayushg55 added 12 commits July 2, 2026 14:21
…ntext fields

- Consolidate Triton backends (pixel-attention GQA + counting sort, FiLM
  tokenizer, autotune cache) into a new kernels/ subpackage. Move
  pixel_attention() dispatch and the counting-sort kernel out of
  pixel_cross_attention.py into kernels/pixel_attention.py; move the
  sort_and_pack/counts_to_cu_seqlens/build_pixel_group_map packing utilities
  into obs_context.py.
- Fix ObsContext's raw observation fields (obs, float_metadata, obs_type,
  channel, platform) to be required, matching the original healda
  UnifiedObservation contract, instead of incorrectly Optional; rename
  values -> obs to match upstream. Update HealDAv2.forward's obs param to
  obs_ctx to avoid an obs.obs collision.
- Add full named_parameters() gradient-coverage checks to the HealDAv2 tests
  (previously only spot-checked a single parameter), matching the coverage
  already used for PixelCrossAttention/ObsTokenizerFiLM.
- Simplify test ObsContext builders: drop unnecessary dataclasses.replace
  indirection (tokens can be passed directly, per ObsContext's own contract)
  and inline single-use cross-attention factories.

Commit message authored by AI
Guard HealDAv2, VideoDiT, ObsTokenizerFiLM, PixelCrossAttention,
VideoDiTBlock, and TemporalAttention forwards with torch.compiler
checks. Rename fused Triton helper to _fused_film_tokenizer_triton.

Commit message authored by AI
Add CI-tested Examples to HealDAv2 and VideoDiT (correct obs packing
order and group_map), fix healda cross-refs, export VideoDiT, store
meta_dim on ObsTokenizerFiLM, and restore the fused Triton helper doc.

Commit message authored by AI
…d DiT naming

- Rename forward's noise_labels -> t and add the missing class_labels
  parameter (threaded through to VideoDiT/DiT as condition), matching
  v1 HealDA.forward and physicsnemo.models.dit.DiT.forward naming.
- Rename HealDAv2's emb_channels -> condition_embed_dim to match v1's
  naming and stop colliding in meaning with the detokenizer's own
  condition_dim.
- Add prepare_obs_context() to build a packed ObsContext directly from
  raw per-observation arrays, and export it (replacing the internal
  PixelGroupMap export) from the healda package.
- Tidy HealDAv2/VideoDiT docstrings for conciseness and accuracy.

Commit message authored by AI
Consolidate cross_attention, temporal_attention, and pixel_cross_attention
into one module and update imports/doc refs. Tidy obs_context module docs.

Commit message authored by AI
Reorder so TemporalAttention leads, then a section separator, then the
cross-attention contract and PixelCrossAttention.

Commit message authored by AI
…ts tests

- Move VideoDiTBlock into video_dit.py and delete video_dit_block.py, per
  review feedback to reduce file count.
- Merge test_video_dit_block.py's tests into test_video_dit.py and delete
  the standalone file.
- Replace the real HEALPixPatchTokenizer/Detokenizer in test_video_dit.py
  with minimal linear-projection stubs, dropping the earth2grid test
  dependency entirely so VideoDiT's attention/temporal/cross-attention/
  conditioning wiring can be tested without it.

Commit message authored by AI
Persist the winning Triton @autotune tile config only when
HEALDA_PIXEL_ATTN_AUTOTUNE_CACHE_DIR points at a directory (unset = off, no
disk I/O). Drop the PHYSICSNEMO_CACHE_DIR default and HEALDA_PIXEL_ATTN_PREWARM
opt-out; the cache is orthogonal to Triton's compiled-kernel cache. Document the
behavior at the model/layer level so users need not read the kernel module.

Commit message authored by AI

from .healda import HealDA, HealDAMetaData
from .healda_v2 import HealDAv2, HealDAv2MetaData
from .obs_context import ObsContext, prepare_obs_context

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While making the HealDAv2 example, I realized constructing ObsContext directly required callers to manually call sort_and_pack, counts_to_cu_seqlens, per-observation tensor reordering, and build_pixel_group_map. So I instead added prepare_obs_context as the exported public helper for the required preprocessing, so callers can pass raw observation tensors and get a valid ObsContext with consistent packing metadata.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants