Skip to content

HealDA v2 Architecture#1758

Open
aayushg55 wants to merge 33 commits into
NVIDIA:mainfrom
aayushg55:ag/healda-v2-arch
Open

HealDA v2 Architecture#1758
aayushg55 wants to merge 33 commits into
NVIDIA:mainfrom
aayushg55:ag/healda-v2-arch

Conversation

@aayushg55

Copy link
Copy Markdown
Contributor

PhysicsNeMo Pull Request

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

aayushg55 added 30 commits June 24, 2026 15:34
Bring the ragged grouped-query pixel cross-attention (pixel latents attend
to per-pixel packed observation tokens) into the experimental healda package
as a first step toward a video/observation DiT block.

Layout follows PNM's optional-dependency conventions:
- triton is referenced only via OptionalImport (never a bare import), so the
  modules import without triton and the import-linter external-import contract
  stays satisfied.
- the compiled Triton kernels live in a private _pixel_attn_kernels.py backend
  (mirroring the warp _warp_impl pattern), imported lazily by the public
  pixel_cross_attention.py only when triton is available -- no per-chunk guards.
- triton_autotune_cache.py: standalone autotune-config persistence util,
  likewise OptionalImport-based.

Headers use the current PNM template; ruff + license-header checks pass.

Commit message authored by AI
A 4D (b, t, x, c) analog of physicsnemo.nn.DiTBlock for field sequences: keeps
the DiT template (pluggable spatial-attention backend via get_attention,
adaLN-Zero conditioning, drop-path, gated MLP) and adds two optional gated
sub-layers:
- temporal (video) attention across time, with a pluggable time<->space reshard
  for context parallelism (manual all-to-all or ShardTensor.redistribute);
- observation cross-attention (pixel latents attend to packed per-pixel obs
  tokens via the vendored PixelCrossAttention).

New modules: temporal_attention.py (TemporalAttention + RoPE + causal/window
mask), sharding.py (shard_x/shard_t all-to-all + ShardTensor variants),
obs_packing.py (AttentionPacking/PixelGroupMap, copied as-is). Tests cover the
CPU plain + temporal paths and the CUDA full spatial+obs+temporal path
(forward/backward + grad flow).

Commit message authored by AI
Replace the _pixel_attn_kernels() helper + _k alias with an inline
`from . import _pixel_attn_kernels as kernels` inside the kernel-launch
wrappers, matching PNM's lazy optional-backend pattern (cf.
mesh/visualization/draw_mesh.py importing _matplotlib_impl/_pyvista_impl).

Commit message authored by AI
Bundle the observation tokens and ragged packing metadata into one
ObsCrossAttention object (tokens + cu_seqlens_k + max_seqlen_k + group_map)
so VideoDiTBlock's obs sub-layer takes a single argument instead of separate
obs_tokens + packing args. Drop the redundant data-pipeline AttentionPacking
struct (its counts/npix/hpx_level/pixel_order/is_packed fields are unused by
the model). Add jaxtyping Float/Int shape hints across the block, temporal
attention, sharding, and obs-packing modules.

Commit message authored by AI
Use the released physicsnemo.models.dit shape names: the token feature axis is
'hidden_size' (not 'channels'/'dim') and conditioning is 'condition_embed_dim',
matching DiTBlock's annotations and our own hidden_size constructor arg.

Commit message authored by AI
Add the ragged pixel cross-attention test suite (16 tests) validating the
Triton kernel forward/backward against a readable PyTorch GQA reference
(_ragged_gqa_reference), plus packed-grid, small-pixel grouping (grouped ==
ungrouped bit-for-bit), nn.Module wiring, empty-tokens DDP-safety, and config
validation. Port build_pixel_group_map (the CSR small-pixel grouping helper,
pure function of cu_seqlens_k) into obs_packing.py to support the grouping test.

Commit message authored by AI
A HEALPix field-sequence diffusion transformer over (B, C, T, npix): reuses the
existing HEALPixPatchTokenizer/HEALPixPatchDetokenizer (which already fold the
time axis and add the calendar embedding) and EDM conditioning, reshapes the
flat token sequence to (B, T, X, hidden) for the VideoDiTBlock stack (spatial +
optional factorized temporal + optional observation cross-attention), then back
for the detokenizer. Observations enter as a prebuilt ObsCrossAttention bundle.

Tests: CPU dense+temporal and CUDA dense+temporal+obs forward/backward.

Commit message authored by AI
- VideoDiT no longer hardcodes HEALPix: it takes a pluggable tokenizer /
  detokenizer (the grid lives only in tokenization; the backbone is grid-
  agnostic) plus time_length, threading tokenizer_kwargs and reshaping the flat
  time-major token sequence to (B, T, X, hidden) for the blocks.
- VideoDiTBlock's spatial+MLP path now mirrors physicsnemo.nn.DiTBlock: renamed
  emb_channels -> condition_embed_dim and added the DiTBlock dropout args
  (attn_drop_rate, proj_drop_rate, mlp_drop_rate, final_mlp_dropout), wired into
  get_attention and the MLP. Its args are now a superset of DiTBlock's.

Commit message authored by AI
…pose VideoDiT

- VideoDiTBlock subclasses physicsnemo.nn.DiTBlock, reusing its spatial
  attention, gated MLP, pre-norms, 6-chunk adaLN-Zero and drop-path; adds
  optional gated temporal + observation cross-attention sub-layers.
- is_causal and obs/temporal config moved to __init__; obs args behind
  obs_kwargs, temporal behind temporal_kwargs (explicit Dicts, no **kwargs).
- VideoDiT inherits physicsnemo.Module and composes the conditioning embedder +
  pluggable tokenizer/detokenizer + blocks (no production DiT change).
- sharding: @torch._dynamo.disable on the ShardTensor reshard.

Commit message authored by AI
…-attention

Generalize the experimental video DiT into a grid-agnostic DiT with a time axis:
- add a shared ndim-agnostic AdaLayerNormZero with a zero_init toggle (using
  get_layer_norm) and SiLU inside the module;
- compose VideoDiTBlock from shared building blocks instead of subclassing
  DiTBlock, moving modulation into per-sub-layer AdaLayerNormZero;
- replace the obs-specific cross-attention with a generic pluggable
  CrossAttentionModuleBase slot + opaque context; PixelCrossAttention is the
  reference impl and now owns the fold/ragged-unpack;
- make VideoDiT a kwarg-superset of DiT (drop_path_rates, conditioning_embedder
  choice, attn_kwargs, block_kwargs, dit_initialization).

Commit message authored by AI
…oss-attn, adaLN naming

- VideoDiT/VideoDiTBlock: cross_attention is a per-block factory (no deepcopy);
  forward context typed Optional[Any]; conditioning resolution inlined.
- Time axis first-class: HEALPix tokenizer gains separate_time_axis; detokenizer
  infers flat-vs-time-first from input rank (backwards-compatible with the v1
  flat + time_length path).
- adaLN-Zero attrs renamed role-first: attn_norm / temporal_attn_norm /
  cross_attn_norm / mlp_norm.

Commit message authored by AI
… Triton kernel

Port the v2 FiLM-conditioned observation tokenizer as an initial drop-in:
- obs_film_tokenizer.py: ObsTokenizerFiLM module + pure-PyTorch reference,
  custom-op wrappers, and the fused_film_tokenizer_triton entry point. Forward
  dispatches to the Triton kernel on CUDA when triton is available, else the
  reference path.
- _film_kernels.py: private fused FiLM forward/backward Triton kernels, guarded
  by OptionalImport("triton") (no bare import) and imported lazily.
- test_obs_film_tokenizer.py: CPU reference smoke test plus CUDA Triton-vs-
  reference parity tests.

Settings are kept as-is for now (TODO(polish) markers on unused ones).

Commit message authored by AI
DiT.__init__ indexed input_size[1] unconditionally, which IndexErrors for
non-2D tokenizers (e.g. HEALPix, where HealDA passes a 1-tuple input_size).
The latent grid is only consumed by the NATTEN backends, so guard it.

Commit message authored by AI
Match the public module stem (obs_film_tokenizer.py), mirroring
pixel_cross_attention.py / _pixel_attn_kernels.py.

Commit message authored by AI
…kenizer

obs_film_tokenizer.py -> obs_tokenizer.py, _obs_film_kernels.py ->
_obs_tokenizer_kernels.py (class stays ObsTokenizerFiLM; "FiLM" distinguishes
the impl from the existing ObsTokenizer and leaves room to evolve).

Commit message authored by AI
Compose the upstreamed VideoDiT backbone with the FiLM ObsTokenizerFiLM and
per-block PixelCrossAttention into the production v2 video+obs data-assimilation
architecture (hidden 1536, 16 heads, 32 layers, time_length 8, linear causal
temporal attention, drop-path 0.0 for the first 4 blocks then 0.1), so the
existing healda checkpoint can be loaded. HealDAv2 hosts the FiLM obs tokenizer
and assembles the per-pixel ObsCrossAttention context the backbone consumes.

Commit message authored by AI
…tidy wording

- VideoDiT.set_context_parallel(mode, target) fans the temporal time<->space
  reshard config out to every block (the per-block setter had no caller, so CP
  was never actually enabled through the model).
- VideoDiT.forward asserts the tokenizer emits 4D (B,T,X,hidden).
- Drop the coined "field sequence(s)" phrasing from docstrings/comments.

Commit message authored by AI
…ntract

Make the base a general cross-attention sub-layer rather than one "injected
into a video DiT block": forward takes (*batch, hidden_size) latents and an
opaque context: Any. Drops the field-sequence wording and trims the docstring.

Commit message authored by AI
…ocks=2)

Replace the separate attn_norm + mlp_norm AdaLayerNormZero(n_blocks=1) pair with
a single norm1 = AdaLayerNormZero(n_blocks=2) that emits the spatial-attention
modulation plus the raw MLP shift/scale/gate, and a parameter-free LayerNorm
MLP pre-norm modulated by those. This matches the DiT/diffusers layout and the
production checkpoint's single norm1.linear. temporal_attn_norm/cross_attn_norm
stay one-block adaLNs; initialize_weights now zeroes norm1.

Commit message authored by AI
Rename the packed-observation container to ObsContext and make tokens optional
(unset until the observation tokenizer fills it). PixelCrossAttention.forward
now consumes an ObsContext, raising if tokens is unset; a None group_map keeps
using the ungrouped ragged path (the model layer does not build it). Updates
all references and tests.

Commit message authored by AI
…orward

ObsContext now also carries the raw per-observation arrays (values,
float_metadata, obs_type, channel, platform) alongside the ragged packing, so
HealDAv2.forward takes one obs: ObsContext instead of loose per-obs and packing
args. The forward runs the tokenizer, fills tokens via dataclasses.replace, and
passes the context through; it no longer builds the pixel group map (a None
group_map uses the ungrouped ragged path). build_pixel_group_map stays for
callers that precompute it.

Commit message authored by AI
…mporal RoPE

Replace the hand-rolled RotaryPositionEmbedding with PhysicsNeMo's
nn.module.rope.RotaryPositionEmbedding1D (math-verified equivalent: same
interleaved-pair rotation and theta^(-2k/d) schedule). The temporal q/k are
transposed so the time axis is the -2 dim the module rotates, then restored.
Its cos/sin are non-persistent buffers recomputed at init, so the old
persistent rope.freqs_cos/sin keys no longer exist.

Commit message authored by AI
…rim docstring

Remove the unused nchannel, nplatform, and use_global_channel_platform_ids
constructor params (the embedding tables are always sized GLOBAL_MAX_* and id
spaces are the caller's responsibility). Trim the editorial rationale and
conv/sat first-linear essays in the class docstring to terse facts.

Commit message authored by AI
…meter-free

HealDAv2 passes attn_kwargs={"qk_norm_type":"RMSNorm","qk_norm_affine":False}.
PhysicsNeMo's TimmSelfAttention declares those exact kwargs (translating them to
timm's qk_norm + RmsNorm norm_layer), so affine-free RMSNorm actually engages and
the names are not silently swallowed. Add a regression test that the spatial q/k
norm modules exist (not Identity) and have no learnable parameters.

Commit message authored by AI
Lightweight audit cleanups: jaxtyping annotations on the two HEALPix hpx
tokenizer forwards, a descriptive CalendarEmbedding ValueError message, an
accurate uneven-shards comment in sharding.py, a Literal["apex","torch"] hint on
AdaLayerNormZero.layernorm_backend, and removal of the dead VideoDiTBlock
self.hidden_size attribute. The hpx __init__ **kwargs safety nets are kept.

Commit message authored by AI
…n; tidy HealDAv2 wiring

Drop the per-block hot-path validation from PixelCrossAttention.forward (tokens-set
and cu_seqlens-vs-pixel-count) and instead validate the packing's structural shape
once in ObsContext.__post_init__; tokens-set is guaranteed by HealDAv2.forward.
Build the HEALPix (de)tokenizers as locals instead of inline in the VideoDiT call,
and note the drop-path zero-first stability rationale in the docstring.

Commit message authored by AI
…packing->obs_context

Extract the kernel-companion packing primitives into a single
pixel_attention_utils.py: sort_and_pack (inlined Triton counting-sort kernel,
guarded by triton availability, argsort fallback), counts_to_cu_seqlens, and
build_pixel_group_map. These operate on plain index/count tensors, so the data
pipeline builds the packing the model only consumes. obs_packing.py is now purely
the ObsContext + PixelGroupMap contract, renamed to obs_context.py. Also build
HealDAv2's cross-attention via functools.partial and keep the cross-attention
base docstring generic. Adds CPU parity tests for the packing utils.

Commit message authored by AI
@copy-pr-bot

copy-pr-bot Bot commented Jun 26, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

…Attention

Add pixel_attention_reference (readable per-pixel ragged GQA) and dispatch to it
in PixelCrossAttention when triton/CUDA is unavailable, so the module runs on CPU.
Merge the CPU reference tests into test_pixel_cross_attention.py (dedupe the
reference, gate only the Triton-kernel tests via OptionalImport, drop the
module-level importorskip).

Commit message authored by AI
@aayushg55 aayushg55 marked this pull request as ready for review June 26, 2026 17:06
@greptile-apps

greptile-apps Bot commented Jun 26, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces the HealDAv2 v2 architecture for video-based weather data assimilation. It adds a full transformer pipeline: a FiLM observation tokenizer, ragged pixel cross-attention, factorized temporal attention, adaLN-Zero modulation, and context-parallel resharding — all backed by hand-written Triton kernels with autograd custom_op wrappers. Two existing files are updated: HEALPixPatchTokenizer gains a separate_time_axis flag and DiT guards NATTEN latent-grid computation.

  • New model stack (healda_v2.pyvideo_dit.pyvideo_dit_block.py): HealDAv2 composes a FiLM tokenizer, a per-block PixelCrossAttention, optional temporal attention, and context-parallel resharding inside VideoDiTBlock; the backbone uses adaLN-Zero with zero-initialization for training stability.
  • Triton kernels (_obs_tokenizer_kernels.py, _pixel_attn_kernels.py): fused forward/backward kernels for the FiLM tokenizer and ragged GQA pixel cross-attention, registered as torch.library.custom_op with fake-tensor implementations for torch.compile compatibility.
  • Autotune cache (triton_autotune_cache.py, pixel_cross_attention.py): per-GPU/rank JSON cache that persists Triton autotune configs across runs; the cache-invalidation hash needs to cover _pixel_attn_kernels.py (not just pixel_cross_attention.py) to catch kernel edits.

Important Files Changed

Filename Overview
physicsnemo/experimental/models/healda/healda_v2.py Top-level HealDAv2 model; HealDAv2MetaData sets the deprecated name field, triggering a DeprecationWarning on every instantiation.
physicsnemo/experimental/models/healda/pixel_cross_attention.py Triton ragged GQA attention + autotune cache; autotune hash only covers this file, not _pixel_attn_kernels.py, so kernel changes silently reuse stale configs; dKV is passed to two kernel argument slots without in-code documentation.
physicsnemo/experimental/models/healda/video_dit.py VideoDiT backbone; MetaData name is too generic and a blank line before class VideoDiT is missing per PEP 8.
physicsnemo/experimental/models/healda/obs_tokenizer.py FiLM observation tokenizer with Triton fused kernel + autograd wrappers; logic is sound, custom-op/fake registration follows correct patterns.
physicsnemo/experimental/models/healda/pixel_attention_utils.py Ragged packing utilities (counting sort, cu_seqlens, group map); int32 position cast in counting_sort_and_pack silently overflows for N > ~2 billion observations.
physicsnemo/experimental/models/healda/video_dit_block.py DiT block with spatial/temporal/cross-attention; adaLN-Zero wiring, drop-path, and context-parallel resharding look correct.
physicsnemo/experimental/models/healda/obs_context.py ObsContext/PixelGroupMap dataclasses with device-transfer helpers; straightforward and correct.
physicsnemo/experimental/models/healda/adaln.py adaLN-Zero modulation, ndim-agnostic broadcast; logic is correct for both n_blocks=1 and n_blocks=2 cases.
physicsnemo/nn/module/hpx/tokenizer.py HEALPixPatchTokenizer gains separate_time_axis flag; detokenizer handles both 3D and 4D inputs; CalendarEmbedding error message improved. Changes are clean.
physicsnemo/models/dit/dit.py DiT guards latent-grid computation behind is_natten; safe and minimal change.
physicsnemo/experimental/models/healda/temporal_attention.py Temporal self-attention with RoPE, linear/softmax variants, and causal windowing; @torch.compile on forward looks correct.
physicsnemo/experimental/models/healda/sharding.py Context-parallel t<->x resharding via all_to_all and ShardTensor; both paths are autograd-aware and logically consistent.

Comments Outside Diff (3)

  1. physicsnemo/experimental/models/healda/pixel_cross_attention.py, line 1465-1481 (link)

    P1 Cache hash reads pixel_cross_attention.py, not the actual Triton kernel file

    The cache-busting digest is computed over __file__ (i.e., pixel_cross_attention.py), but the kernels that actually get tuned live in _pixel_attn_kernels.py. A change to those kernels without touching this file will silently reuse stale autotune configs that were optimized for the old kernel, potentially applying wrong launch parameters to new code.

  2. physicsnemo/experimental/models/healda/video_dit.py, line 1349-1365 (link)

    P2 Generic class name MetaData risks import collisions and lacks a blank line

    MetaData is a very common identifier — any from physicsnemo.experimental.models.healda.video_dit import MetaData (or star import) will shadow other MetaData symbols in the same namespace. A more specific name like VideoDiTMetaData avoids this. Also, PEP 8 requires two blank lines before a top-level class definition, but there is only one blank line separating MetaData from class VideoDiT.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

  3. physicsnemo/experimental/models/healda/pixel_attention_utils.py, line 273-307 (link)

    P2 counting_sort_and_pack: atomic-claimed position cast to int32 can overflow for very large observation counts

    The Triton kernel computes pos = tl.atomic_add(bucket_offsets_ptr + key, 1, ...) and then writes sorted_order_ptr + pos.to(tl.int32). If the total number of observations N exceeds ~2.1 billion, the int32 cast truncates the position, writing two source indices to the same slot and silently corrupting the permutation. The sorted_order output tensor is also torch.int32. For typical atmospheric observation batch sizes this is unlikely, but it is an undocumented limit worth noting in the docstring.

Reviews (1): Last reviewed commit: "test(experimental/healda): rename crypti..." | Re-trigger Greptile

Comment on lines +47 to +60
cuda_graphs: bool = False
amp_cpu: bool = False
amp_gpu: bool = True
torch_fx: bool = False
bf16: bool = True
onnx: bool = False
func_torch: bool = False
auto_grad: bool = False


class HealDAv2(Module):
r"""
Video data-assimilation model combining HEALPix tokenizers, a FiLM observation
tokenizer, and a VideoDiT backbone.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Deprecated name field triggers DeprecationWarning on every instantiation

ModelMetaData.__post_init__ fires a DeprecationWarning whenever name is not _DEPRECATED_SENTINEL. Because HealDAv2MetaData sets name: str = "HealDAv2", every HealDAv2(...) call emits this warning. The name field was removed from the base class and is replaced by the model's class name automatically — the field should be dropped from HealDAv2MetaData entirely.

Suggested change
cuda_graphs: bool = False
amp_cpu: bool = False
amp_gpu: bool = True
torch_fx: bool = False
bf16: bool = True
onnx: bool = False
func_torch: bool = False
auto_grad: bool = False
class HealDAv2(Module):
r"""
Video data-assimilation model combining HEALPix tokenizers, a FiLM observation
tokenizer, and a VideoDiT backbone.
@dataclass
class HealDAv2MetaData(ModelMetaData):
"""Metadata for HealDAv2 model."""
jit: bool = False

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Comment on lines +750 to +781
wk_slice,
wv_slice,
None,
bv_slice,
cu_seqlens_k,
prog_ptr,
prog_pix,
max_seqlen_k,
2,
scale,
force_fp32=force_fp32,
)
)
return torch.cat(outs, dim=1)


def pixel_attention_reference(
Q: torch.Tensor,
tokens: torch.Tensor,
W_k: torch.Tensor,
W_v: torch.Tensor,
cu_seqlens_k: torch.Tensor,
n_kv_heads: int,
scale: float,
B_v: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""Pure-PyTorch ragged grouped-query attention; reference for :func:`pixel_attention`.

Projects each pixel's token slice to keys/values, applies softmax attention
from that pixel's query heads, and writes the per-pixel output. The key bias is
omitted (softmax cancels it); ``group_map`` does not apply (grouping is a
kernel-launch optimization that does not change the result).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Same dKV tensor passed to two separate kernel pointer slots

dKV appears at two consecutive positions in the _pixel_attn_gqa_bwd call. With COMBINED_DKV=True the kernel presumably writes the combined [dK | dV] rows to only the first argument and ignores the second, but this aliased pointer pattern is not documented in this call site. If a future kernel change removes the COMBINED_DKV guard and begins writing to both pointers independently, the two writes will race on the same backing storage, silently corrupting gradients. At a minimum, add a comment here explaining why the same tensor is safe to pass twice.

@aayushg55

Copy link
Copy Markdown
Contributor Author

@pzharrington Let me know what you think about the VideoDiTBlock/VideoDiT classes. I can restructure the code accordingly.

In particular, would something like this work better?

class VideoDiTBlock(DiTBlock):
    def __init__(self, ..., temporal_attention=False, cross_attention=None, is_causal=False, ...):
        super().__init__(hidden_size, num_heads, attention_backend=..., condition_embed_dim=..., mlp_ratio=..., drop_path=..., **attn_kwargs)
        self._is_causal = is_causal
        if temporal_attention: self.temporal_attention = ...; self.temporal_attn_norm = ...
        if cross_attention:    self.cross_attention = cross_attention(); self.cross_attn_norm = ...
    def forward(self, hidden_states, c, cross_attention_context=None, attn_kwargs=None):
        # reuse self.adaptive_modulation(c).chunk(6), self.pre_attention_norm/self.pre_mlp_norm,
        # self.attention, self.linear, self.drop_path; insert obs + temporal sub-layers between.

return param.view(shape)


class AdaLayerNormZero(nn.Module):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The video block has four gated sub-layers (spatial attention, MLP, temporal attention, obs cross-attention). So rather than repeat DiTBlock's inline adaLN pattern — adaptive_modulation = Sequential(SiLU, Linear(cond, 6*hidden)) + separate pre_attention_norm/pre_mlp_norm + a modulation helper — four times, I felt it would be cleaner to wrap the whole adaLN-Zero operation (SiLU + Linear → shift/scale/gate, the affine-free LayerNorm, and applying them) into a single reusable AdaLayerNormZero module.

MODEL_DIM = 1


def shard_x(

@aayushg55 aayushg55 Jun 26, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Context parallelism: manual all-to-all vs ShardTensor.redistribute

Factorized temporal attention needs a time↔space reshard each block (2 per block:
enter t-sharded → spatial attn → reshard to x-sharded → temporal attn → reshard
back). We implemented it both ways: a manual all_to_all_single, and
ShardTensor.redistribute (Shard(time)Shard(space) on a 1D mesh).

Per-block — 8×H100, dit-5B (C=1536 (16×96), X=12288, T=8, bf16 autocast, 2 reshards/block:

path eager ms/block compiled ms/block
manual all-to-all 16.5 14.5
ShardTensor.redistribute 20.5 19.3

Just the reshard primitives (t↔x round trip = 2 reshards, same shape, eager):

path fwd fwd+bwd
manual all-to-all 0.29 ms 0.97 ms
ShardTensor.redistribute 2.48 ms 5.45 ms

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark code — reshard primitives (reshard_microbench.py)

Run on 8×H100, dit-5B shape:
T_FULL=8 X_FULL=12288 EMBED_DIM=1536 torchrun --nproc_per_node=8 reshard_microbench.py

import os
import time

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.placement_types import Shard

from physicsnemo.domain_parallel import ShardTensor
from physicsnemo.experimental.models.healda.sharding import (
    shard_t,
    shard_t_shardtensor,
    shard_x,
    shard_x_shardtensor,
)

dist.init_process_group("nccl")
rank, world = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)
dev = torch.device("cuda", rank)
mesh = init_device_mesh("cuda", (world,), mesh_dim_names=["cp"])
group = mesh.get_group()

b = 1
t_full = int(os.environ.get("T_FULL", 16))
x_full = int(os.environ.get("X_FULL", 12288))
c = int(os.environ.get("EMBED_DIM", 1536))
t_local = t_full // world
dtype = torch.bfloat16


def manual(local):
    return shard_t(shard_x(local, group), group)


def st(local):
    return shard_t_shardtensor(shard_x_shardtensor(local, mesh), mesh)


def bench(fn, name, backward, iters=30, warmup=10):
    local = torch.randn(
        b, t_local, x_full, c, device=dev, dtype=dtype, requires_grad=backward
    )
    for _ in range(warmup):
        out = fn(local)
        if backward:
            out.float().pow(2).mean().backward()
            local.grad = None
    torch.cuda.synchronize()
    dist.barrier()
    t0 = time.perf_counter()
    for _ in range(iters):
        out = fn(local)
        if backward:
            out.float().pow(2).mean().backward()
            local.grad = None
    torch.cuda.synchronize()
    dist.barrier()
    ms = (time.perf_counter() - t0) / iters * 1e3
    if rank == 0:
        print(f"  {name:12s} {'fwd+bwd' if backward else 'fwd':8s}: {ms:8.3f} ms/iter")


def bench_stage(name, fn, iters=30, warmup=10):
    local = torch.randn(b, t_local, x_full, c, device=dev, dtype=dtype)
    for _ in range(warmup):
        fn(local)
    torch.cuda.synchronize()
    dist.barrier()
    t0 = time.perf_counter()
    for _ in range(iters):
        fn(local)
    torch.cuda.synchronize()
    dist.barrier()
    if rank == 0:
        print(f"    stage {name:24s}: {(time.perf_counter() - t0) / iters * 1e3:8.3f} ms")


if rank == 0:
    print(f"world={world} local=({b},{t_local},{x_full},{c}) {dtype}")
for bwd in (False, True):
    bench(manual, "manual a2a", bwd)
    bench(st, "shardtensor", bwd)

_st = ShardTensor.from_local(
    torch.randn(b, t_local, x_full, c, device=dev, dtype=dtype), mesh, [Shard(1)]
)
bench_stage("from_local", lambda x: ShardTensor.from_local(x, mesh, [Shard(1)]))
bench_stage("redistribute", lambda x: _st.redistribute(placements=[Shard(2)]))
dist.destroy_process_group()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@negin513 for viz

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect the ShardTensor dispatch overhead is responsible to the differences in eager mode. ShardTensor does not yet support torch compilation, but will very very soon (like a week or two). Could be interesting to look at this again when it lands. It'd be nice to have a fully ShardTensor-based implementation but if this PR is looking ready before then, perhaps we can merge it with both implementations and defer the decision (which should not change the user-facing APIs/functionality

@aayushg55

aayushg55 commented Jun 26, 2026

Copy link
Copy Markdown
Contributor Author

Our model needs QK-normalization with affine=False in the spatial attention, but TE's MultiheadAttention only supports affine=True. Options I see:

  • (a) TE: Keep the affine params but exclude them from the optimizer.
  • (b) Build a small attention module: affine-free QK-norm + TE DotProductAttention.
  • (c) Use the existing timm backend, which already supports affine=False.

Currently, I went with c (this is also what the v1 ckpt uses), but just wanted to flag this incompatibility with the TE backend as set up now.

self._latent_w = self.input_size[1] // self.patch_size[1]
latent_hw = (self._latent_h, self._latent_w)
# Only NATTEN uses the latent grid; other backends may pass a non-2D input_size
if is_natten:

@aayushg55 aayushg55 Jun 26, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a patch to fix a recent DiT / HealDA v1 bug (unrelated to my PR introducing v2)

Comment thread physicsnemo/nn/module/hpx/tokenizer.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants