Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ model:
slice_num: 256
use_te: false
plus: false
include_local_features: false
include_local_features: true
radii: [0.1, 0.5, 2.0]
neighbors_in_radius: [16, 32, 64]
n_hidden_local: 32
Expand Down
151 changes: 81 additions & 70 deletions physicsnemo/experimental/models/geotransolver/context_projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,69 +50,14 @@

from physicsnemo.nn import ConcreteDropout

from .utils import structured_grid_to_conv_input, tensors_alias

# Check optional dependency availability
TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False)
if TE_AVAILABLE:
import transformer_engine.pytorch as te


def _structured_grid_to_conv_input(
x: Float[torch.Tensor, "batch tokens channels"],
batch: int,
tokens: int,
channels: int,
ndim: int,
spatial_shape: tuple[int, ...],
) -> Float[torch.Tensor, "batch channels ..."]:
r"""Reshape flat token tensor to spatial layout for Conv2d/Conv3d.

Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` for 2D or
:math:`(B, C, H, W, D)` for 3D so that structured context projectors
can apply spatial convolutions. Validates that :math:`N` matches the
grid size.

Parameters
----------
x : torch.Tensor
Input tensor of shape :math:`(B, N, C)` (batch, tokens, channels).
batch : int
Batch size :math:`B`.
tokens : int
Number of tokens :math:`N` (must equal :math:`H \\times W` or
:math:`H \\times W \\times D`).
channels : int
Channel dimension :math:`C`.
ndim : int
Number of spatial dimensions; must be 2 or 3.
spatial_shape : tuple[int, ...]
:math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D.

Returns
-------
torch.Tensor
Reshaped tensor of shape :math:`(B, C, H, W)` or
:math:`(B, C, H, W, D)` for use as conv input.

Raises
------
ValueError
If ``tokens`` does not match the product of ``spatial_shape``.
"""
if ndim == 2:
H, W = spatial_shape
if tokens != H * W:
raise ValueError(
f"Expected N={H * W} tokens for 2D grid, got N={tokens}"
)
return x.view(batch, H, W, channels).permute(0, 3, 1, 2)
H, W, D = spatial_shape
if tokens != H * W * D:
raise ValueError(
f"Expected N={H * W * D} tokens for 3D grid, got N={tokens}"
)
return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3)


class _SliceToContextMixin:
r"""Internal mixin providing shared slice-to-context init and slice aggregation.

Expand Down Expand Up @@ -454,10 +399,7 @@ def _grid_project(
Float[torch.Tensor, "batch tokens heads dim"],
]
):
B, N, C = x.shape
grid = _structured_grid_to_conv_input(
x, B, N, C, self._nd, self.spatial_shape
)
grid = structured_grid_to_conv_input(x, self.spatial_shape)
pattern = (
"B (H D) h w -> B (h w) H D"
if self._nd == 2
Expand Down Expand Up @@ -767,6 +709,76 @@ def extract_local_features(
dim=-1,
)

@staticmethod
def _same_coords(
a: Float[torch.Tensor, "batch points spatial_dim"],
b: Float[torch.Tensor, "batch points spatial_dim"],
) -> bool:
r"""Whether ``a`` and ``b`` alias the same coordinates.

Thin wrapper around :func:`~physicsnemo.experimental.models.geotransolver.utils.tensors_alias`.
When ``True``, a ball query over either tensor is identical, so
:meth:`extract_context_and_local` can take its single-pass fast path.

Parameters
----------
a, b : torch.Tensor
Candidate coordinate tensors of shape :math:`(B, N, 3)`.

Returns
-------
bool
``True`` when *a* and *b* are guaranteed element-for-element equal.
"""
return tensors_alias(a, b)

def extract_context_and_local(
self,
spatial_coords: Float[torch.Tensor, "batch points spatial_dim"],
geometry: Float[torch.Tensor, "batch points geometry_dim"],
) -> tuple[
list[Float[torch.Tensor, "batch heads slices dim"]],
Float[torch.Tensor, "batch points total_hidden"],
]:
r"""Extract context and local features in one pass, reusing the ball query when possible.

Combines :meth:`extract_context_features` and
:meth:`extract_local_features`. When ``spatial_coords`` and ``geometry``
alias the same coordinates (see :meth:`_same_coords`), each per-scale
processor (ball query + MLP) runs once and feeds both paths. Otherwise
it falls back to the two-pass behavior, preserving the asymmetry for
configs where geometry and positions differ.

Parameters
----------
spatial_coords : torch.Tensor
Spatial coordinates of shape :math:`(B, N, 3)`.
geometry : torch.Tensor
Geometry features of shape :math:`(B, N, C_{geo})`.

Returns
-------
tuple[list[torch.Tensor], torch.Tensor]
``(context_features, local_features)`` matching the outputs of
:meth:`extract_context_features` and :meth:`extract_local_features`
respectively.
"""
if self._same_coords(spatial_coords, geometry):
# Aliased inputs: one processor pass per scale feeds both paths.
context_features = []
local_parts = []
for processor, tokenizer in zip(self.processors, self.tokenizers):
feat = processor(spatial_coords, geometry)
context_features.append(tokenizer(feat))
local_parts.append(feat)
return context_features, torch.cat(local_parts, dim=-1)

# Asymmetric fallback: distinct query/search sets per path.
return (
self.extract_context_features(spatial_coords, geometry),
self.extract_local_features(spatial_coords, geometry),
)


class GlobalContextBuilder(nn.Module):
r"""Orchestrates all context construction with a clean, simple interface.
Expand Down Expand Up @@ -1029,16 +1041,15 @@ def build_context(
for i, embedding in enumerate(local_embeddings):
spatial_coords = local_positions[i] # Extract coordinates

# Get tokenized context features from multi-scale extractor
context_feats = self.local_extractors[i].extract_context_features(
spatial_coords, geometry
)
# Get tokenized context features and concatenated local
# features in one pass. When spatial_coords and geometry alias
# the same coordinates (the common case), the extractor reuses a
# single ball query per scale instead of computing the same
# radius_search twice; otherwise it falls back to two passes.
context_feats, local_feats = self.local_extractors[
i
].extract_context_and_local(spatial_coords, geometry)
context_parts.extend(context_feats)

# Get concatenated local features for skip connection
local_feats = self.local_extractors[i].extract_local_features(
spatial_coords, geometry
)
local_features.append(local_feats)

# Tokenize geometry features
Expand Down
97 changes: 97 additions & 0 deletions physicsnemo/experimental/models/geotransolver/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Small tensor helpers shared across GeoTransolver context projectors."""

from __future__ import annotations

import torch
from jaxtyping import Float


def structured_grid_to_conv_input(
x: Float[torch.Tensor, "batch tokens channels"],
spatial_shape: tuple[int, ...],
) -> Float[torch.Tensor, "batch channels ..."]:
r"""Reshape a flat token tensor to spatial layout for Conv2d/Conv3d.

Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` (2D) or
:math:`(B, C, H, W, D)` (3D) so structured projectors can apply spatial
convolutions.

Parameters
----------
x : torch.Tensor
Input tensor of shape :math:`(B, N, C)`.
spatial_shape : tuple[int, ...]
:math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D. The product must
equal :math:`N`.

Returns
-------
torch.Tensor
Tensor of shape :math:`(B, C, H, W)` or :math:`(B, C, H, W, D)`.

Raises
------
ValueError
If ``spatial_shape`` is not length 2 or 3, or its product does not
match the token dimension :math:`N`.
"""
batch, tokens, channels = x.shape
expected = 1
for s in spatial_shape:
expected *= s
if tokens != expected:
raise ValueError(
f"Expected N={expected} tokens for grid {tuple(spatial_shape)}, "
f"got N={tokens}"
)

if len(spatial_shape) == 2:
H, W = spatial_shape
return x.view(batch, H, W, channels).permute(0, 3, 1, 2)
if len(spatial_shape) == 3:
H, W, D = spatial_shape
return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3)
raise ValueError(
f"spatial_shape must have length 2 or 3, got {tuple(spatial_shape)}"
)


def tensors_alias(
a: Float[torch.Tensor, "..."],
b: Float[torch.Tensor, "..."],
) -> bool:
r"""Return ``True`` when ``a`` and ``b`` are guaranteed to hold identical data.

Sync-free, *sufficient* aliasing test: ``True`` iff the tensors alias the
same storage with matching shape, stride, offset, and dtype. Avoids a value
comparison, which would force a host sync.

Parameters
----------
a, b : torch.Tensor
Candidate tensors to compare.

Returns
-------
bool
``True`` if ``a`` and ``b`` alias the same storage with matching shape,
stride, offset, and dtype.
"""
# ``is_set_to`` covers storage, offset, size, and stride but ignores dtype.
return a.is_set_to(b) and a.dtype == b.dtype
Loading
Loading