Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d5b3144
feat(healda): vendor Triton pixel/observation cross-attention
aayushg55 Jun 24, 2026
3f5f204
feat(healda): add VideoDiTBlock with temporal + observation attention
aayushg55 Jun 24, 2026
abae3e0
refactor(healda): lazy-import pixel-attn kernels via plain local import
aayushg55 Jun 24, 2026
8b73b2a
refactor(healda): single ObsCrossAttention obs input + jaxtyping hints
aayushg55 Jun 24, 2026
6395210
style(healda): match released DiT jaxtyping vocabulary
aayushg55 Jun 24, 2026
98e206b
test(healda): port pixel cross-attention tests + build_pixel_group_map
aayushg55 Jun 24, 2026
edbb77b
feat(healda): add VideoDiT model composing the video DiT blocks
aayushg55 Jun 24, 2026
393cc8d
refactor(healda): grid-agnostic VideoDiT + DiTBlock-aligned block args
aayushg55 Jun 24, 2026
f564363
refactor(experimental/healda): VideoDiTBlock subclasses DiTBlock; com…
aayushg55 Jun 26, 2026
74b61da
refactor(experimental/healda): compose VideoDiT block + generic cross…
aayushg55 Jun 26, 2026
4214749
Merge branch 'main' of github.com:NVIDIA/physicsnemo into pnm-integra…
aayushg55 Jun 26, 2026
8895229
refactor(experimental/healda): time-first-class tokenizer, factory cr…
aayushg55 Jun 26, 2026
d820ca3
feat(experimental/healda): add FiLM obs tokenizer with optional fused…
aayushg55 Jun 26, 2026
2ffe3ec
fix(models/dit): only compute NATTEN latent grid for NATTEN backends
aayushg55 Jun 26, 2026
ddf94fb
refactor(experimental/healda): rename _film_kernels -> _obs_film_kernels
aayushg55 Jun 26, 2026
f337346
refactor(experimental/healda): rename FiLM tokenizer module to obs_to…
aayushg55 Jun 26, 2026
6951927
feat(experimental/healda): add HealDAv2 video+obs DA model
aayushg55 Jun 26, 2026
2a9c785
refactor(experimental/healda): expose VideoDiT.set_context_parallel; …
aayushg55 Jun 26, 2026
4a17617
refactor(experimental/healda): generalize CrossAttentionModuleBase co…
aayushg55 Jun 26, 2026
0301011
refactor(experimental/healda): group block adaLN into one norm1 (n_bl…
aayushg55 Jun 26, 2026
88bbd91
refactor(experimental/healda): rename ObsCrossAttention to ObsContext
aayushg55 Jun 26, 2026
b100be0
refactor(experimental/healda): take a single ObsContext in HealDAv2.f…
aayushg55 Jun 26, 2026
ed91599
refactor(experimental/healda): reuse RotaryPositionEmbedding1D for te…
aayushg55 Jun 26, 2026
bfaf14f
refactor(experimental/healda): drop unused ObsTokenizerFiLM params, t…
aayushg55 Jun 26, 2026
38d3ad9
test(experimental/healda): assert spatial qk-norm is engaged and para…
aayushg55 Jun 26, 2026
099cfdd
refactor(experimental/healda): audit typing + message fixes
aayushg55 Jun 26, 2026
4e69b56
refactor(experimental/healda): move obs-context checks to constructio…
aayushg55 Jun 26, 2026
eae239b
refactor(experimental/healda): add pixel_attention_utils; rename obs_…
aayushg55 Jun 26, 2026
64c8c7e
update pixel attn utils docstrings
aayushg55 Jun 26, 2026
ea78849
lint, various docstring cleanup
aayushg55 Jun 26, 2026
803409b
Merge branch 'main' of github.com:NVIDIA/physicsnemo into ag/healda-v…
aayushg55 Jun 26, 2026
4b6af05
feat(experimental/healda): pure-PyTorch reference path for PixelCross…
aayushg55 Jun 26, 2026
1d56d82
test(experimental/healda): rename cryptic pcak alias to kernels
aayushg55 Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
493 changes: 493 additions & 0 deletions physicsnemo/experimental/models/healda/_obs_tokenizer_kernels.py

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest trying to consolidate all triton kernels and related utils into one file, there's a lot of different files here

Large diffs are not rendered by default.

679 changes: 679 additions & 0 deletions physicsnemo/experimental/models/healda/_pixel_attn_kernels.py

Large diffs are not rendered by default.

139 changes: 139 additions & 0 deletions physicsnemo/experimental/models/healda/adaln.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ndim-agnostic adaptive layer norm zero (adaLN-Zero) modulation."""

from typing import Literal, Tuple

import torch
import torch.nn as nn
from jaxtyping import Float

from physicsnemo.nn.module.dit_layers import get_layer_norm


def _broadcast(param: Float[torch.Tensor, "batch channels"], ndim: int) -> torch.Tensor:
# (B, C) -> (B, 1, ..., 1, C) so a per-sample modulation broadcasts over a
# hidden-state tensor of arbitrary rank (3D (B, L, C), 4D (B, T, X, C), ...).
shape = (param.shape[0],) + (1,) * (ndim - 2) + (param.shape[1],)
return param.view(shape)


class AdaLayerNormZero(nn.Module):

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The video block has four gated sub-layers (spatial attention, MLP, temporal attention, obs cross-attention). So rather than repeat DiTBlock's inline adaLN pattern — adaptive_modulation = Sequential(SiLU, Linear(cond, 6*hidden)) + separate pre_attention_norm/pre_mlp_norm + a modulation helper — four times, I felt it would be cleaner to wrap the whole adaLN-Zero operation (SiLU + Linear → shift/scale/gate, the affine-free LayerNorm, and applying them) into a single reusable AdaLayerNormZero module.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this consolidation, but I find it a bit odd to have the variadic return from forward and the fact that e norm/application behavior is applied differently for n_blocks=1 vs n_blocks > 1. What about having a simpler class that just returns the shift, scale, gate for n_blocks:

class AdaLNModulation(nn.Module):
    """adaLN-Zero projection: c -> 3*n_blocks modulation tensors, in regular
    [shift, scale, gate] * n_blocks order. Norms live at the call sites."""
    def forward(self, c) -> tuple[torch.Tensor, ...]:
        return self.proj(c).chunk(3 * self.n_blocks, dim=-1)

and then some simple reusable helpers, either as classmethods or standalone functinos:

def modulate(x_normed, shift, scale):                 # x * (1 + scale) + shift
    return torch.addcmul(_broadcast(shift, x_normed.ndim),
                         x_normed, 1 + _broadcast(scale, x_normed.ndim))

def gated_residual(residual, branch_out, gate, drop_path=None):
    gate = _broadcast(gate, residual.ndim)
    if drop_path is not None:
        gate = drop_path(gate)
    return torch.addcmul(residual, gate, branch_out)

Thoughts?

r"""Adaptive layer norm zero (adaLN-Zero) modulation, agnostic to input rank.

Emits ``n_blocks`` ``(shift, scale, gate)`` triples from the conditioning
embedding via ``SiLU + Linear``. The first block's ``shift``/``scale`` are
applied to the affine-free layer-normed ``x`` here; its ``gate`` is returned
for the caller's gated residual. For ``n_blocks > 1`` the remaining blocks'
``(shift, scale, gate)`` are returned unapplied, so one projection can drive
several sub-layers (e.g. grouped attention + feed-forward, as in the standard
DiT block). Modulation vectors are broadcast to match ``x.ndim``, so the same
module serves 3D :math:`(B, L, C)` and 4D :math:`(B, T, X, C)` states.

The ``SiLU`` is applied inside this module, so the conditioning embedder must
emit a pre-activation embedding.

Parameters
----------
embedding_dim : int
Channel dimension :math:`C` of the hidden states.
condition_embed_dim : int
Channel dimension of the conditioning embedding.
n_blocks : int, optional, default=1
Number of ``(shift, scale, gate)`` triples to emit.
zero_init : bool, optional, default=True
If ``True``, zero the modulation ``Linear`` (adaLN-Zero) at construction
and in :meth:`initialize_weights`, so each residual branch starts as
identity. If ``False``, the modulation keeps its default initialization.
layernorm_backend : Literal["apex", "torch"], optional, default="torch"
Backend for the affine-free :func:`~physicsnemo.nn.module.dit_layers.get_layer_norm`.
norm_eps : float, optional, default=1e-6
Epsilon for the layer norm.

Forward
-------
x : torch.Tensor
Hidden states of shape :math:`(B, \dots, C)` (any rank :math:`\geq 2`).
c : torch.Tensor
Conditioning embedding of shape :math:`(B, D_c)`.

Outputs
-------
Tuple[torch.Tensor, ...]
``(normed, gate)`` for ``n_blocks == 1``, where ``normed`` is the
modulated layer-normed ``x`` and ``gate`` is broadcast to ``x.ndim``. For
``n_blocks > 1``, the remaining ``(shift, scale, gate)`` of each later
block follow, each broadcast to ``x.ndim``.

Examples
--------
>>> import torch
>>> from physicsnemo.experimental.models.healda.adaln import AdaLayerNormZero
>>> norm = AdaLayerNormZero(embedding_dim=64, condition_embed_dim=32)
>>> x = torch.randn(2, 5, 64)
>>> c = torch.randn(2, 32)
>>> normed, gate = norm(x, c)
>>> normed.shape, gate.shape
(torch.Size([2, 5, 64]), torch.Size([2, 1, 64]))
"""

def __init__(
self,
embedding_dim: int,
condition_embed_dim: int,
n_blocks: int = 1,
zero_init: bool = True,
layernorm_backend: Literal["apex", "torch"] = "torch",
norm_eps: float = 1e-6,
):
super().__init__()
self.n_blocks = n_blocks
self.zero_init = zero_init
self.modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(condition_embed_dim, 3 * n_blocks * embedding_dim, bias=True),
)
self.norm = get_layer_norm(
embedding_dim, layernorm_backend, elementwise_affine=False, eps=norm_eps
)
if zero_init:
self.initialize_weights()

def initialize_weights(self) -> None:
r"""Zero the modulation linear when ``zero_init`` is set (adaLN-Zero).

Returns
-------
None
Modifies parameters in-place; a no-op when ``zero_init`` is ``False``.
"""
if self.zero_init:
nn.init.zeros_(self.modulation[-1].weight)
nn.init.zeros_(self.modulation[-1].bias)

def forward(
self,
x: Float[torch.Tensor, "batch ... hidden_size"],
c: Float[torch.Tensor, "batch condition_embed_dim"],
) -> Tuple[torch.Tensor, ...]:
chunks = self.modulation(c).chunk(3 * self.n_blocks, dim=-1)
shift, scale, gate = chunks[0], chunks[1], chunks[2]
normed = self.norm(x) * (1 + _broadcast(scale, x.ndim)) + _broadcast(
shift, x.ndim
)
outputs = [normed, _broadcast(gate, x.ndim)]
outputs.extend(_broadcast(extra, x.ndim) for extra in chunks[3:])
return tuple(outputs)
53 changes: 53 additions & 0 deletions physicsnemo/experimental/models/healda/cross_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pluggable cross-attention sub-layer contract."""

from abc import ABC, abstractmethod
from typing import Any

import torch
from jaxtyping import Float

from physicsnemo.core import Module


class CrossAttentionModuleBase(Module, ABC):
r"""Abstract base for a cross-attention sub-layer.

A concrete module attends from ``hidden_states`` to an arbitrary external
``context`` that the module fully owns (its type, layout, and any folding /
packing). The caller treats ``context`` as opaque.

Forward
-------
hidden_states : torch.Tensor
Latents of shape :math:`(*B, C)`.
context : Any
Module-defined conditioning source, opaque to the caller.

Outputs
-------
torch.Tensor
Updated latents of shape :math:`(*B, C)`.
"""

@abstractmethod
def forward(
self,
hidden_states: Float[torch.Tensor, "*batch hidden_size"],
context: Any,
) -> Float[torch.Tensor, "*batch hidden_size"]:
pass
Loading
Loading