-
Notifications
You must be signed in to change notification settings - Fork 716
HealDA v2 Architecture #1758
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
HealDA v2 Architecture #1758
Changes from all commits
d5b3144
3f5f204
abae3e0
8b73b2a
6395210
98e206b
edbb77b
393cc8d
f564363
74b61da
4214749
8895229
d820ca3
2ffe3ec
ddf94fb
f337346
6951927
2a9c785
4a17617
0301011
88bbd91
b100be0
ed91599
bfaf14f
38d3ad9
099cfdd
4e69b56
eae239b
64c8c7e
ea78849
803409b
4b6af05
1d56d82
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """ndim-agnostic adaptive layer norm zero (adaLN-Zero) modulation.""" | ||
|
|
||
| from typing import Literal, Tuple | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from jaxtyping import Float | ||
|
|
||
| from physicsnemo.nn.module.dit_layers import get_layer_norm | ||
|
|
||
|
|
||
| def _broadcast(param: Float[torch.Tensor, "batch channels"], ndim: int) -> torch.Tensor: | ||
| # (B, C) -> (B, 1, ..., 1, C) so a per-sample modulation broadcasts over a | ||
| # hidden-state tensor of arbitrary rank (3D (B, L, C), 4D (B, T, X, C), ...). | ||
| shape = (param.shape[0],) + (1,) * (ndim - 2) + (param.shape[1],) | ||
| return param.view(shape) | ||
|
|
||
|
|
||
| class AdaLayerNormZero(nn.Module): | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The video block has four gated sub-layers (spatial attention, MLP, temporal attention, obs cross-attention). So rather than repeat DiTBlock's inline adaLN pattern — adaptive_modulation = Sequential(SiLU, Linear(cond, 6*hidden)) + separate pre_attention_norm/pre_mlp_norm + a modulation helper — four times, I felt it would be cleaner to wrap the whole adaLN-Zero operation (SiLU + Linear → shift/scale/gate, the affine-free LayerNorm, and applying them) into a single reusable AdaLayerNormZero module.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I like this consolidation, but I find it a bit odd to have the variadic return from and then some simple reusable helpers, either as classmethods or standalone functinos: Thoughts? |
||
| r"""Adaptive layer norm zero (adaLN-Zero) modulation, agnostic to input rank. | ||
|
|
||
| Emits ``n_blocks`` ``(shift, scale, gate)`` triples from the conditioning | ||
| embedding via ``SiLU + Linear``. The first block's ``shift``/``scale`` are | ||
| applied to the affine-free layer-normed ``x`` here; its ``gate`` is returned | ||
| for the caller's gated residual. For ``n_blocks > 1`` the remaining blocks' | ||
| ``(shift, scale, gate)`` are returned unapplied, so one projection can drive | ||
| several sub-layers (e.g. grouped attention + feed-forward, as in the standard | ||
| DiT block). Modulation vectors are broadcast to match ``x.ndim``, so the same | ||
| module serves 3D :math:`(B, L, C)` and 4D :math:`(B, T, X, C)` states. | ||
|
|
||
| The ``SiLU`` is applied inside this module, so the conditioning embedder must | ||
| emit a pre-activation embedding. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| embedding_dim : int | ||
| Channel dimension :math:`C` of the hidden states. | ||
| condition_embed_dim : int | ||
| Channel dimension of the conditioning embedding. | ||
| n_blocks : int, optional, default=1 | ||
| Number of ``(shift, scale, gate)`` triples to emit. | ||
| zero_init : bool, optional, default=True | ||
| If ``True``, zero the modulation ``Linear`` (adaLN-Zero) at construction | ||
| and in :meth:`initialize_weights`, so each residual branch starts as | ||
| identity. If ``False``, the modulation keeps its default initialization. | ||
| layernorm_backend : Literal["apex", "torch"], optional, default="torch" | ||
| Backend for the affine-free :func:`~physicsnemo.nn.module.dit_layers.get_layer_norm`. | ||
| norm_eps : float, optional, default=1e-6 | ||
| Epsilon for the layer norm. | ||
|
|
||
| Forward | ||
| ------- | ||
| x : torch.Tensor | ||
| Hidden states of shape :math:`(B, \dots, C)` (any rank :math:`\geq 2`). | ||
| c : torch.Tensor | ||
| Conditioning embedding of shape :math:`(B, D_c)`. | ||
|
|
||
| Outputs | ||
| ------- | ||
| Tuple[torch.Tensor, ...] | ||
| ``(normed, gate)`` for ``n_blocks == 1``, where ``normed`` is the | ||
| modulated layer-normed ``x`` and ``gate`` is broadcast to ``x.ndim``. For | ||
| ``n_blocks > 1``, the remaining ``(shift, scale, gate)`` of each later | ||
| block follow, each broadcast to ``x.ndim``. | ||
|
|
||
| Examples | ||
| -------- | ||
| >>> import torch | ||
| >>> from physicsnemo.experimental.models.healda.adaln import AdaLayerNormZero | ||
| >>> norm = AdaLayerNormZero(embedding_dim=64, condition_embed_dim=32) | ||
| >>> x = torch.randn(2, 5, 64) | ||
| >>> c = torch.randn(2, 32) | ||
| >>> normed, gate = norm(x, c) | ||
| >>> normed.shape, gate.shape | ||
| (torch.Size([2, 5, 64]), torch.Size([2, 1, 64])) | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| embedding_dim: int, | ||
| condition_embed_dim: int, | ||
| n_blocks: int = 1, | ||
| zero_init: bool = True, | ||
| layernorm_backend: Literal["apex", "torch"] = "torch", | ||
| norm_eps: float = 1e-6, | ||
| ): | ||
| super().__init__() | ||
| self.n_blocks = n_blocks | ||
| self.zero_init = zero_init | ||
| self.modulation = nn.Sequential( | ||
| nn.SiLU(), | ||
| nn.Linear(condition_embed_dim, 3 * n_blocks * embedding_dim, bias=True), | ||
| ) | ||
| self.norm = get_layer_norm( | ||
| embedding_dim, layernorm_backend, elementwise_affine=False, eps=norm_eps | ||
| ) | ||
| if zero_init: | ||
| self.initialize_weights() | ||
|
|
||
| def initialize_weights(self) -> None: | ||
| r"""Zero the modulation linear when ``zero_init`` is set (adaLN-Zero). | ||
|
|
||
| Returns | ||
| ------- | ||
| None | ||
| Modifies parameters in-place; a no-op when ``zero_init`` is ``False``. | ||
| """ | ||
| if self.zero_init: | ||
| nn.init.zeros_(self.modulation[-1].weight) | ||
| nn.init.zeros_(self.modulation[-1].bias) | ||
|
|
||
| def forward( | ||
| self, | ||
| x: Float[torch.Tensor, "batch ... hidden_size"], | ||
| c: Float[torch.Tensor, "batch condition_embed_dim"], | ||
| ) -> Tuple[torch.Tensor, ...]: | ||
| chunks = self.modulation(c).chunk(3 * self.n_blocks, dim=-1) | ||
| shift, scale, gate = chunks[0], chunks[1], chunks[2] | ||
| normed = self.norm(x) * (1 + _broadcast(scale, x.ndim)) + _broadcast( | ||
| shift, x.ndim | ||
| ) | ||
| outputs = [normed, _broadcast(gate, x.ndim)] | ||
| outputs.extend(_broadcast(extra, x.ndim) for extra in chunks[3:]) | ||
| return tuple(outputs) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. | ||
| # SPDX-FileCopyrightText: All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Pluggable cross-attention sub-layer contract.""" | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| from jaxtyping import Float | ||
|
|
||
| from physicsnemo.core import Module | ||
|
|
||
|
|
||
| class CrossAttentionModuleBase(Module, ABC): | ||
| r"""Abstract base for a cross-attention sub-layer. | ||
|
|
||
| A concrete module attends from ``hidden_states`` to an arbitrary external | ||
| ``context`` that the module fully owns (its type, layout, and any folding / | ||
| packing). The caller treats ``context`` as opaque. | ||
|
|
||
| Forward | ||
| ------- | ||
| hidden_states : torch.Tensor | ||
| Latents of shape :math:`(*B, C)`. | ||
| context : Any | ||
| Module-defined conditioning source, opaque to the caller. | ||
|
|
||
| Outputs | ||
| ------- | ||
| torch.Tensor | ||
| Updated latents of shape :math:`(*B, C)`. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def forward( | ||
| self, | ||
| hidden_states: Float[torch.Tensor, "*batch hidden_size"], | ||
| context: Any, | ||
| ) -> Float[torch.Tensor, "*batch hidden_size"]: | ||
| pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest trying to consolidate all triton kernels and related utils into one file, there's a lot of different files here