HealDA v2 Architecture#1758
Conversation
Bring the ragged grouped-query pixel cross-attention (pixel latents attend to per-pixel packed observation tokens) into the experimental healda package as a first step toward a video/observation DiT block. Layout follows PNM's optional-dependency conventions: - triton is referenced only via OptionalImport (never a bare import), so the modules import without triton and the import-linter external-import contract stays satisfied. - the compiled Triton kernels live in a private _pixel_attn_kernels.py backend (mirroring the warp _warp_impl pattern), imported lazily by the public pixel_cross_attention.py only when triton is available -- no per-chunk guards. - triton_autotune_cache.py: standalone autotune-config persistence util, likewise OptionalImport-based. Headers use the current PNM template; ruff + license-header checks pass. Commit message authored by AI
A 4D (b, t, x, c) analog of physicsnemo.nn.DiTBlock for field sequences: keeps the DiT template (pluggable spatial-attention backend via get_attention, adaLN-Zero conditioning, drop-path, gated MLP) and adds two optional gated sub-layers: - temporal (video) attention across time, with a pluggable time<->space reshard for context parallelism (manual all-to-all or ShardTensor.redistribute); - observation cross-attention (pixel latents attend to packed per-pixel obs tokens via the vendored PixelCrossAttention). New modules: temporal_attention.py (TemporalAttention + RoPE + causal/window mask), sharding.py (shard_x/shard_t all-to-all + ShardTensor variants), obs_packing.py (AttentionPacking/PixelGroupMap, copied as-is). Tests cover the CPU plain + temporal paths and the CUDA full spatial+obs+temporal path (forward/backward + grad flow). Commit message authored by AI
Replace the _pixel_attn_kernels() helper + _k alias with an inline `from . import _pixel_attn_kernels as kernels` inside the kernel-launch wrappers, matching PNM's lazy optional-backend pattern (cf. mesh/visualization/draw_mesh.py importing _matplotlib_impl/_pyvista_impl). Commit message authored by AI
Bundle the observation tokens and ragged packing metadata into one ObsCrossAttention object (tokens + cu_seqlens_k + max_seqlen_k + group_map) so VideoDiTBlock's obs sub-layer takes a single argument instead of separate obs_tokens + packing args. Drop the redundant data-pipeline AttentionPacking struct (its counts/npix/hpx_level/pixel_order/is_packed fields are unused by the model). Add jaxtyping Float/Int shape hints across the block, temporal attention, sharding, and obs-packing modules. Commit message authored by AI
Use the released physicsnemo.models.dit shape names: the token feature axis is 'hidden_size' (not 'channels'/'dim') and conditioning is 'condition_embed_dim', matching DiTBlock's annotations and our own hidden_size constructor arg. Commit message authored by AI
Add the ragged pixel cross-attention test suite (16 tests) validating the Triton kernel forward/backward against a readable PyTorch GQA reference (_ragged_gqa_reference), plus packed-grid, small-pixel grouping (grouped == ungrouped bit-for-bit), nn.Module wiring, empty-tokens DDP-safety, and config validation. Port build_pixel_group_map (the CSR small-pixel grouping helper, pure function of cu_seqlens_k) into obs_packing.py to support the grouping test. Commit message authored by AI
A HEALPix field-sequence diffusion transformer over (B, C, T, npix): reuses the existing HEALPixPatchTokenizer/HEALPixPatchDetokenizer (which already fold the time axis and add the calendar embedding) and EDM conditioning, reshapes the flat token sequence to (B, T, X, hidden) for the VideoDiTBlock stack (spatial + optional factorized temporal + optional observation cross-attention), then back for the detokenizer. Observations enter as a prebuilt ObsCrossAttention bundle. Tests: CPU dense+temporal and CUDA dense+temporal+obs forward/backward. Commit message authored by AI
- VideoDiT no longer hardcodes HEALPix: it takes a pluggable tokenizer / detokenizer (the grid lives only in tokenization; the backbone is grid- agnostic) plus time_length, threading tokenizer_kwargs and reshaping the flat time-major token sequence to (B, T, X, hidden) for the blocks. - VideoDiTBlock's spatial+MLP path now mirrors physicsnemo.nn.DiTBlock: renamed emb_channels -> condition_embed_dim and added the DiTBlock dropout args (attn_drop_rate, proj_drop_rate, mlp_drop_rate, final_mlp_dropout), wired into get_attention and the MLP. Its args are now a superset of DiTBlock's. Commit message authored by AI
…pose VideoDiT - VideoDiTBlock subclasses physicsnemo.nn.DiTBlock, reusing its spatial attention, gated MLP, pre-norms, 6-chunk adaLN-Zero and drop-path; adds optional gated temporal + observation cross-attention sub-layers. - is_causal and obs/temporal config moved to __init__; obs args behind obs_kwargs, temporal behind temporal_kwargs (explicit Dicts, no **kwargs). - VideoDiT inherits physicsnemo.Module and composes the conditioning embedder + pluggable tokenizer/detokenizer + blocks (no production DiT change). - sharding: @torch._dynamo.disable on the ShardTensor reshard. Commit message authored by AI
…-attention Generalize the experimental video DiT into a grid-agnostic DiT with a time axis: - add a shared ndim-agnostic AdaLayerNormZero with a zero_init toggle (using get_layer_norm) and SiLU inside the module; - compose VideoDiTBlock from shared building blocks instead of subclassing DiTBlock, moving modulation into per-sub-layer AdaLayerNormZero; - replace the obs-specific cross-attention with a generic pluggable CrossAttentionModuleBase slot + opaque context; PixelCrossAttention is the reference impl and now owns the fold/ragged-unpack; - make VideoDiT a kwarg-superset of DiT (drop_path_rates, conditioning_embedder choice, attn_kwargs, block_kwargs, dit_initialization). Commit message authored by AI
…oss-attn, adaLN naming - VideoDiT/VideoDiTBlock: cross_attention is a per-block factory (no deepcopy); forward context typed Optional[Any]; conditioning resolution inlined. - Time axis first-class: HEALPix tokenizer gains separate_time_axis; detokenizer infers flat-vs-time-first from input rank (backwards-compatible with the v1 flat + time_length path). - adaLN-Zero attrs renamed role-first: attn_norm / temporal_attn_norm / cross_attn_norm / mlp_norm. Commit message authored by AI
… Triton kernel
Port the v2 FiLM-conditioned observation tokenizer as an initial drop-in:
- obs_film_tokenizer.py: ObsTokenizerFiLM module + pure-PyTorch reference,
custom-op wrappers, and the fused_film_tokenizer_triton entry point. Forward
dispatches to the Triton kernel on CUDA when triton is available, else the
reference path.
- _film_kernels.py: private fused FiLM forward/backward Triton kernels, guarded
by OptionalImport("triton") (no bare import) and imported lazily.
- test_obs_film_tokenizer.py: CPU reference smoke test plus CUDA Triton-vs-
reference parity tests.
Settings are kept as-is for now (TODO(polish) markers on unused ones).
Commit message authored by AI
DiT.__init__ indexed input_size[1] unconditionally, which IndexErrors for non-2D tokenizers (e.g. HEALPix, where HealDA passes a 1-tuple input_size). The latent grid is only consumed by the NATTEN backends, so guard it. Commit message authored by AI
Match the public module stem (obs_film_tokenizer.py), mirroring pixel_cross_attention.py / _pixel_attn_kernels.py. Commit message authored by AI
…kenizer obs_film_tokenizer.py -> obs_tokenizer.py, _obs_film_kernels.py -> _obs_tokenizer_kernels.py (class stays ObsTokenizerFiLM; "FiLM" distinguishes the impl from the existing ObsTokenizer and leaves room to evolve). Commit message authored by AI
Compose the upstreamed VideoDiT backbone with the FiLM ObsTokenizerFiLM and per-block PixelCrossAttention into the production v2 video+obs data-assimilation architecture (hidden 1536, 16 heads, 32 layers, time_length 8, linear causal temporal attention, drop-path 0.0 for the first 4 blocks then 0.1), so the existing healda checkpoint can be loaded. HealDAv2 hosts the FiLM obs tokenizer and assembles the per-pixel ObsCrossAttention context the backbone consumes. Commit message authored by AI
…tidy wording - VideoDiT.set_context_parallel(mode, target) fans the temporal time<->space reshard config out to every block (the per-block setter had no caller, so CP was never actually enabled through the model). - VideoDiT.forward asserts the tokenizer emits 4D (B,T,X,hidden). - Drop the coined "field sequence(s)" phrasing from docstrings/comments. Commit message authored by AI
…ntract Make the base a general cross-attention sub-layer rather than one "injected into a video DiT block": forward takes (*batch, hidden_size) latents and an opaque context: Any. Drops the field-sequence wording and trims the docstring. Commit message authored by AI
…ocks=2) Replace the separate attn_norm + mlp_norm AdaLayerNormZero(n_blocks=1) pair with a single norm1 = AdaLayerNormZero(n_blocks=2) that emits the spatial-attention modulation plus the raw MLP shift/scale/gate, and a parameter-free LayerNorm MLP pre-norm modulated by those. This matches the DiT/diffusers layout and the production checkpoint's single norm1.linear. temporal_attn_norm/cross_attn_norm stay one-block adaLNs; initialize_weights now zeroes norm1. Commit message authored by AI
Rename the packed-observation container to ObsContext and make tokens optional (unset until the observation tokenizer fills it). PixelCrossAttention.forward now consumes an ObsContext, raising if tokens is unset; a None group_map keeps using the ungrouped ragged path (the model layer does not build it). Updates all references and tests. Commit message authored by AI
…orward ObsContext now also carries the raw per-observation arrays (values, float_metadata, obs_type, channel, platform) alongside the ragged packing, so HealDAv2.forward takes one obs: ObsContext instead of loose per-obs and packing args. The forward runs the tokenizer, fills tokens via dataclasses.replace, and passes the context through; it no longer builds the pixel group map (a None group_map uses the ungrouped ragged path). build_pixel_group_map stays for callers that precompute it. Commit message authored by AI
…mporal RoPE Replace the hand-rolled RotaryPositionEmbedding with PhysicsNeMo's nn.module.rope.RotaryPositionEmbedding1D (math-verified equivalent: same interleaved-pair rotation and theta^(-2k/d) schedule). The temporal q/k are transposed so the time axis is the -2 dim the module rotates, then restored. Its cos/sin are non-persistent buffers recomputed at init, so the old persistent rope.freqs_cos/sin keys no longer exist. Commit message authored by AI
…rim docstring Remove the unused nchannel, nplatform, and use_global_channel_platform_ids constructor params (the embedding tables are always sized GLOBAL_MAX_* and id spaces are the caller's responsibility). Trim the editorial rationale and conv/sat first-linear essays in the class docstring to terse facts. Commit message authored by AI
…meter-free
HealDAv2 passes attn_kwargs={"qk_norm_type":"RMSNorm","qk_norm_affine":False}.
PhysicsNeMo's TimmSelfAttention declares those exact kwargs (translating them to
timm's qk_norm + RmsNorm norm_layer), so affine-free RMSNorm actually engages and
the names are not silently swallowed. Add a regression test that the spatial q/k
norm modules exist (not Identity) and have no learnable parameters.
Commit message authored by AI
Lightweight audit cleanups: jaxtyping annotations on the two HEALPix hpx tokenizer forwards, a descriptive CalendarEmbedding ValueError message, an accurate uneven-shards comment in sharding.py, a Literal["apex","torch"] hint on AdaLayerNormZero.layernorm_backend, and removal of the dead VideoDiTBlock self.hidden_size attribute. The hpx __init__ **kwargs safety nets are kept. Commit message authored by AI
…n; tidy HealDAv2 wiring Drop the per-block hot-path validation from PixelCrossAttention.forward (tokens-set and cu_seqlens-vs-pixel-count) and instead validate the packing's structural shape once in ObsContext.__post_init__; tokens-set is guaranteed by HealDAv2.forward. Build the HEALPix (de)tokenizers as locals instead of inline in the VideoDiT call, and note the drop-path zero-first stability rationale in the docstring. Commit message authored by AI
…packing->obs_context Extract the kernel-companion packing primitives into a single pixel_attention_utils.py: sort_and_pack (inlined Triton counting-sort kernel, guarded by triton availability, argsort fallback), counts_to_cu_seqlens, and build_pixel_group_map. These operate on plain index/count tensors, so the data pipeline builds the packing the model only consumes. obs_packing.py is now purely the ObsContext + PixelGroupMap contract, renamed to obs_context.py. Also build HealDAv2's cross-attention via functools.partial and keep the cross-attention base docstring generic. Adds CPU parity tests for the packing utils. Commit message authored by AI
…Attention Add pixel_attention_reference (readable per-pixel ragged GQA) and dispatch to it in PixelCrossAttention when triton/CUDA is unavailable, so the module runs on CPU. Merge the CPU reference tests into test_pixel_cross_attention.py (dedupe the reference, gate only the Triton-kernel tests via OptionalImport, drop the module-level importorskip). Commit message authored by AI
Commit message authored by AI
Greptile SummaryThis PR introduces the HealDAv2 v2 architecture for video-based weather data assimilation. It adds a full transformer pipeline: a FiLM observation tokenizer, ragged pixel cross-attention, factorized temporal attention, adaLN-Zero modulation, and context-parallel resharding — all backed by hand-written Triton kernels with autograd
Important Files Changed
|
| cuda_graphs: bool = False | ||
| amp_cpu: bool = False | ||
| amp_gpu: bool = True | ||
| torch_fx: bool = False | ||
| bf16: bool = True | ||
| onnx: bool = False | ||
| func_torch: bool = False | ||
| auto_grad: bool = False | ||
|
|
||
|
|
||
| class HealDAv2(Module): | ||
| r""" | ||
| Video data-assimilation model combining HEALPix tokenizers, a FiLM observation | ||
| tokenizer, and a VideoDiT backbone. |
There was a problem hiding this comment.
Deprecated
name field triggers DeprecationWarning on every instantiation
ModelMetaData.__post_init__ fires a DeprecationWarning whenever name is not _DEPRECATED_SENTINEL. Because HealDAv2MetaData sets name: str = "HealDAv2", every HealDAv2(...) call emits this warning. The name field was removed from the base class and is replaced by the model's class name automatically — the field should be dropped from HealDAv2MetaData entirely.
| cuda_graphs: bool = False | |
| amp_cpu: bool = False | |
| amp_gpu: bool = True | |
| torch_fx: bool = False | |
| bf16: bool = True | |
| onnx: bool = False | |
| func_torch: bool = False | |
| auto_grad: bool = False | |
| class HealDAv2(Module): | |
| r""" | |
| Video data-assimilation model combining HEALPix tokenizers, a FiLM observation | |
| tokenizer, and a VideoDiT backbone. | |
| @dataclass | |
| class HealDAv2MetaData(ModelMetaData): | |
| """Metadata for HealDAv2 model.""" | |
| jit: bool = False |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| wk_slice, | ||
| wv_slice, | ||
| None, | ||
| bv_slice, | ||
| cu_seqlens_k, | ||
| prog_ptr, | ||
| prog_pix, | ||
| max_seqlen_k, | ||
| 2, | ||
| scale, | ||
| force_fp32=force_fp32, | ||
| ) | ||
| ) | ||
| return torch.cat(outs, dim=1) | ||
|
|
||
|
|
||
| def pixel_attention_reference( | ||
| Q: torch.Tensor, | ||
| tokens: torch.Tensor, | ||
| W_k: torch.Tensor, | ||
| W_v: torch.Tensor, | ||
| cu_seqlens_k: torch.Tensor, | ||
| n_kv_heads: int, | ||
| scale: float, | ||
| B_v: Optional[torch.Tensor] = None, | ||
| ) -> torch.Tensor: | ||
| r"""Pure-PyTorch ragged grouped-query attention; reference for :func:`pixel_attention`. | ||
|
|
||
| Projects each pixel's token slice to keys/values, applies softmax attention | ||
| from that pixel's query heads, and writes the per-pixel output. The key bias is | ||
| omitted (softmax cancels it); ``group_map`` does not apply (grouping is a | ||
| kernel-launch optimization that does not change the result). |
There was a problem hiding this comment.
Same
dKV tensor passed to two separate kernel pointer slots
dKV appears at two consecutive positions in the _pixel_attn_gqa_bwd call. With COMBINED_DKV=True the kernel presumably writes the combined [dK | dV] rows to only the first argument and ignores the second, but this aliased pointer pattern is not documented in this call site. If a future kernel change removes the COMBINED_DKV guard and begins writing to both pointers independently, the two writes will race on the same backing storage, silently corrupting gradients. At a minimum, add a comment here explaining why the same tensor is safe to pass twice.
|
@pzharrington Let me know what you think about the VideoDiTBlock/VideoDiT classes. I can restructure the code accordingly. In particular, would something like this work better? |
| return param.view(shape) | ||
|
|
||
|
|
||
| class AdaLayerNormZero(nn.Module): |
There was a problem hiding this comment.
The video block has four gated sub-layers (spatial attention, MLP, temporal attention, obs cross-attention). So rather than repeat DiTBlock's inline adaLN pattern — adaptive_modulation = Sequential(SiLU, Linear(cond, 6*hidden)) + separate pre_attention_norm/pre_mlp_norm + a modulation helper — four times, I felt it would be cleaner to wrap the whole adaLN-Zero operation (SiLU + Linear → shift/scale/gate, the affine-free LayerNorm, and applying them) into a single reusable AdaLayerNormZero module.
| MODEL_DIM = 1 | ||
|
|
||
|
|
||
| def shard_x( |
There was a problem hiding this comment.
Context parallelism: manual all-to-all vs ShardTensor.redistribute
Factorized temporal attention needs a time↔space reshard each block (2 per block:
enter t-sharded → spatial attn → reshard to x-sharded → temporal attn → reshard
back). We implemented it both ways: a manual all_to_all_single, and
ShardTensor.redistribute (Shard(time) ↔ Shard(space) on a 1D mesh).
Per-block — 8×H100, dit-5B (C=1536 (16×96), X=12288, T=8, bf16 autocast, 2 reshards/block:
| path | eager ms/block | compiled ms/block |
|---|---|---|
| manual all-to-all | 16.5 | 14.5 |
ShardTensor.redistribute |
20.5 | 19.3 |
Just the reshard primitives (t↔x round trip = 2 reshards, same shape, eager):
| path | fwd | fwd+bwd |
|---|---|---|
| manual all-to-all | 0.29 ms | 0.97 ms |
ShardTensor.redistribute |
2.48 ms | 5.45 ms |
There was a problem hiding this comment.
Benchmark code — reshard primitives (reshard_microbench.py)
Run on 8×H100, dit-5B shape:
T_FULL=8 X_FULL=12288 EMBED_DIM=1536 torchrun --nproc_per_node=8 reshard_microbench.py
import os
import time
import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.placement_types import Shard
from physicsnemo.domain_parallel import ShardTensor
from physicsnemo.experimental.models.healda.sharding import (
shard_t,
shard_t_shardtensor,
shard_x,
shard_x_shardtensor,
)
dist.init_process_group("nccl")
rank, world = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)
dev = torch.device("cuda", rank)
mesh = init_device_mesh("cuda", (world,), mesh_dim_names=["cp"])
group = mesh.get_group()
b = 1
t_full = int(os.environ.get("T_FULL", 16))
x_full = int(os.environ.get("X_FULL", 12288))
c = int(os.environ.get("EMBED_DIM", 1536))
t_local = t_full // world
dtype = torch.bfloat16
def manual(local):
return shard_t(shard_x(local, group), group)
def st(local):
return shard_t_shardtensor(shard_x_shardtensor(local, mesh), mesh)
def bench(fn, name, backward, iters=30, warmup=10):
local = torch.randn(
b, t_local, x_full, c, device=dev, dtype=dtype, requires_grad=backward
)
for _ in range(warmup):
out = fn(local)
if backward:
out.float().pow(2).mean().backward()
local.grad = None
torch.cuda.synchronize()
dist.barrier()
t0 = time.perf_counter()
for _ in range(iters):
out = fn(local)
if backward:
out.float().pow(2).mean().backward()
local.grad = None
torch.cuda.synchronize()
dist.barrier()
ms = (time.perf_counter() - t0) / iters * 1e3
if rank == 0:
print(f" {name:12s} {'fwd+bwd' if backward else 'fwd':8s}: {ms:8.3f} ms/iter")
def bench_stage(name, fn, iters=30, warmup=10):
local = torch.randn(b, t_local, x_full, c, device=dev, dtype=dtype)
for _ in range(warmup):
fn(local)
torch.cuda.synchronize()
dist.barrier()
t0 = time.perf_counter()
for _ in range(iters):
fn(local)
torch.cuda.synchronize()
dist.barrier()
if rank == 0:
print(f" stage {name:24s}: {(time.perf_counter() - t0) / iters * 1e3:8.3f} ms")
if rank == 0:
print(f"world={world} local=({b},{t_local},{x_full},{c}) {dtype}")
for bwd in (False, True):
bench(manual, "manual a2a", bwd)
bench(st, "shardtensor", bwd)
_st = ShardTensor.from_local(
torch.randn(b, t_local, x_full, c, device=dev, dtype=dtype), mesh, [Shard(1)]
)
bench_stage("from_local", lambda x: ShardTensor.from_local(x, mesh, [Shard(1)]))
bench_stage("redistribute", lambda x: _st.redistribute(placements=[Shard(2)]))
dist.destroy_process_group()There was a problem hiding this comment.
I suspect the ShardTensor dispatch overhead is responsible to the differences in eager mode. ShardTensor does not yet support torch compilation, but will very very soon (like a week or two). Could be interesting to look at this again when it lands. It'd be nice to have a fully ShardTensor-based implementation but if this PR is looking ready before then, perhaps we can merge it with both implementations and defer the decision (which should not change the user-facing APIs/functionality
|
Our model needs QK-normalization with
Currently, I went with c (this is also what the v1 ckpt uses), but just wanted to flag this incompatibility with the TE backend as set up now. |
| self._latent_w = self.input_size[1] // self.patch_size[1] | ||
| latent_hw = (self._latent_h, self._latent_w) | ||
| # Only NATTEN uses the latent grid; other backends may pass a non-2D input_size | ||
| if is_natten: |
There was a problem hiding this comment.
This is a patch to fix a recent DiT / HealDA v1 bug (unrelated to my PR introducing v2)
PhysicsNeMo Pull Request
Description
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.