diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/model/geotransolver_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/model/geotransolver_surface.yaml index c61c2306e9..1e98b43fa9 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/model/geotransolver_surface.yaml +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/model/geotransolver_surface.yaml @@ -55,7 +55,7 @@ model: slice_num: 256 use_te: false plus: false - include_local_features: false + include_local_features: true radii: [0.1, 0.5, 2.0] neighbors_in_radius: [16, 32, 64] n_hidden_local: 32 diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index 15f7d67ffc..41e3cfbc29 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -50,69 +50,14 @@ from physicsnemo.nn import ConcreteDropout +from .utils import structured_grid_to_conv_input, tensors_alias + # Check optional dependency availability TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) if TE_AVAILABLE: import transformer_engine.pytorch as te -def _structured_grid_to_conv_input( - x: Float[torch.Tensor, "batch tokens channels"], - batch: int, - tokens: int, - channels: int, - ndim: int, - spatial_shape: tuple[int, ...], -) -> Float[torch.Tensor, "batch channels ..."]: - r"""Reshape flat token tensor to spatial layout for Conv2d/Conv3d. - - Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` for 2D or - :math:`(B, C, H, W, D)` for 3D so that structured context projectors - can apply spatial convolutions. Validates that :math:`N` matches the - grid size. - - Parameters - ---------- - x : torch.Tensor - Input tensor of shape :math:`(B, N, C)` (batch, tokens, channels). - batch : int - Batch size :math:`B`. - tokens : int - Number of tokens :math:`N` (must equal :math:`H \\times W` or - :math:`H \\times W \\times D`). - channels : int - Channel dimension :math:`C`. - ndim : int - Number of spatial dimensions; must be 2 or 3. - spatial_shape : tuple[int, ...] - :math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D. - - Returns - ------- - torch.Tensor - Reshaped tensor of shape :math:`(B, C, H, W)` or - :math:`(B, C, H, W, D)` for use as conv input. - - Raises - ------ - ValueError - If ``tokens`` does not match the product of ``spatial_shape``. - """ - if ndim == 2: - H, W = spatial_shape - if tokens != H * W: - raise ValueError( - f"Expected N={H * W} tokens for 2D grid, got N={tokens}" - ) - return x.view(batch, H, W, channels).permute(0, 3, 1, 2) - H, W, D = spatial_shape - if tokens != H * W * D: - raise ValueError( - f"Expected N={H * W * D} tokens for 3D grid, got N={tokens}" - ) - return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3) - - class _SliceToContextMixin: r"""Internal mixin providing shared slice-to-context init and slice aggregation. @@ -454,10 +399,7 @@ def _grid_project( Float[torch.Tensor, "batch tokens heads dim"], ] ): - B, N, C = x.shape - grid = _structured_grid_to_conv_input( - x, B, N, C, self._nd, self.spatial_shape - ) + grid = structured_grid_to_conv_input(x, self.spatial_shape) pattern = ( "B (H D) h w -> B (h w) H D" if self._nd == 2 @@ -767,6 +709,76 @@ def extract_local_features( dim=-1, ) + @staticmethod + def _same_coords( + a: Float[torch.Tensor, "batch points spatial_dim"], + b: Float[torch.Tensor, "batch points spatial_dim"], + ) -> bool: + r"""Whether ``a`` and ``b`` alias the same coordinates. + + Thin wrapper around :func:`~physicsnemo.experimental.models.geotransolver.utils.tensors_alias`. + When ``True``, a ball query over either tensor is identical, so + :meth:`extract_context_and_local` can take its single-pass fast path. + + Parameters + ---------- + a, b : torch.Tensor + Candidate coordinate tensors of shape :math:`(B, N, 3)`. + + Returns + ------- + bool + ``True`` when *a* and *b* are guaranteed element-for-element equal. + """ + return tensors_alias(a, b) + + def extract_context_and_local( + self, + spatial_coords: Float[torch.Tensor, "batch points spatial_dim"], + geometry: Float[torch.Tensor, "batch points geometry_dim"], + ) -> tuple[ + list[Float[torch.Tensor, "batch heads slices dim"]], + Float[torch.Tensor, "batch points total_hidden"], + ]: + r"""Extract context and local features in one pass, reusing the ball query when possible. + + Combines :meth:`extract_context_features` and + :meth:`extract_local_features`. When ``spatial_coords`` and ``geometry`` + alias the same coordinates (see :meth:`_same_coords`), each per-scale + processor (ball query + MLP) runs once and feeds both paths. Otherwise + it falls back to the two-pass behavior, preserving the asymmetry for + configs where geometry and positions differ. + + Parameters + ---------- + spatial_coords : torch.Tensor + Spatial coordinates of shape :math:`(B, N, 3)`. + geometry : torch.Tensor + Geometry features of shape :math:`(B, N, C_{geo})`. + + Returns + ------- + tuple[list[torch.Tensor], torch.Tensor] + ``(context_features, local_features)`` matching the outputs of + :meth:`extract_context_features` and :meth:`extract_local_features` + respectively. + """ + if self._same_coords(spatial_coords, geometry): + # Aliased inputs: one processor pass per scale feeds both paths. + context_features = [] + local_parts = [] + for processor, tokenizer in zip(self.processors, self.tokenizers): + feat = processor(spatial_coords, geometry) + context_features.append(tokenizer(feat)) + local_parts.append(feat) + return context_features, torch.cat(local_parts, dim=-1) + + # Asymmetric fallback: distinct query/search sets per path. + return ( + self.extract_context_features(spatial_coords, geometry), + self.extract_local_features(spatial_coords, geometry), + ) + class GlobalContextBuilder(nn.Module): r"""Orchestrates all context construction with a clean, simple interface. @@ -1029,16 +1041,15 @@ def build_context( for i, embedding in enumerate(local_embeddings): spatial_coords = local_positions[i] # Extract coordinates - # Get tokenized context features from multi-scale extractor - context_feats = self.local_extractors[i].extract_context_features( - spatial_coords, geometry - ) + # Get tokenized context features and concatenated local + # features in one pass. When spatial_coords and geometry alias + # the same coordinates (the common case), the extractor reuses a + # single ball query per scale instead of computing the same + # radius_search twice; otherwise it falls back to two passes. + context_feats, local_feats = self.local_extractors[ + i + ].extract_context_and_local(spatial_coords, geometry) context_parts.extend(context_feats) - - # Get concatenated local features for skip connection - local_feats = self.local_extractors[i].extract_local_features( - spatial_coords, geometry - ) local_features.append(local_feats) # Tokenize geometry features diff --git a/physicsnemo/experimental/models/geotransolver/utils.py b/physicsnemo/experimental/models/geotransolver/utils.py new file mode 100644 index 0000000000..cd70b29eae --- /dev/null +++ b/physicsnemo/experimental/models/geotransolver/utils.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Small tensor helpers shared across GeoTransolver context projectors.""" + +from __future__ import annotations + +import torch +from jaxtyping import Float + + +def structured_grid_to_conv_input( + x: Float[torch.Tensor, "batch tokens channels"], + spatial_shape: tuple[int, ...], +) -> Float[torch.Tensor, "batch channels ..."]: + r"""Reshape a flat token tensor to spatial layout for Conv2d/Conv3d. + + Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` (2D) or + :math:`(B, C, H, W, D)` (3D) so structured projectors can apply spatial + convolutions. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)`. + spatial_shape : tuple[int, ...] + :math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D. The product must + equal :math:`N`. + + Returns + ------- + torch.Tensor + Tensor of shape :math:`(B, C, H, W)` or :math:`(B, C, H, W, D)`. + + Raises + ------ + ValueError + If ``spatial_shape`` is not length 2 or 3, or its product does not + match the token dimension :math:`N`. + """ + batch, tokens, channels = x.shape + expected = 1 + for s in spatial_shape: + expected *= s + if tokens != expected: + raise ValueError( + f"Expected N={expected} tokens for grid {tuple(spatial_shape)}, " + f"got N={tokens}" + ) + + if len(spatial_shape) == 2: + H, W = spatial_shape + return x.view(batch, H, W, channels).permute(0, 3, 1, 2) + if len(spatial_shape) == 3: + H, W, D = spatial_shape + return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3) + raise ValueError( + f"spatial_shape must have length 2 or 3, got {tuple(spatial_shape)}" + ) + + +def tensors_alias( + a: Float[torch.Tensor, "..."], + b: Float[torch.Tensor, "..."], +) -> bool: + r"""Return ``True`` when ``a`` and ``b`` are guaranteed to hold identical data. + + Sync-free, *sufficient* aliasing test: ``True`` iff the tensors alias the + same storage with matching shape, stride, offset, and dtype. Avoids a value + comparison, which would force a host sync. + + Parameters + ---------- + a, b : torch.Tensor + Candidate tensors to compare. + + Returns + ------- + bool + ``True`` if ``a`` and ``b`` alias the same storage with matching shape, + stride, offset, and dtype. + """ + # ``is_set_to`` covers storage, offset, size, and stride but ignores dtype. + return a.is_set_to(b) and a.dtype == b.dtype diff --git a/test/models/geotransolver/test_context_projector.py b/test/models/geotransolver/test_context_projector.py index 49affc1a13..def0f4a5d5 100644 --- a/test/models/geotransolver/test_context_projector.py +++ b/test/models/geotransolver/test_context_projector.py @@ -16,8 +16,13 @@ import torch +# Import datapipes first: it resolves a pre-existing nn.functional <-> datapipes +# circular import that otherwise fails when context_projector is the first +# physicsnemo import in the process (e.g. running this file standalone). +import physicsnemo.datapipes # noqa: F401 from physicsnemo.experimental.models.geotransolver.context_projector import ( ContextProjector, + MultiScaleFeatureExtractor, ) # ============================================================================= @@ -53,3 +58,113 @@ def test_context_projector_forward(device): # Output shape: [Batch, Heads, Slice_num, dim_head] assert slice_tokens.shape == (batch_size, heads, slice_num, dim_head) assert not torch.isnan(slice_tokens).any() + + +# ============================================================================= +# MultiScaleFeatureExtractor consolidation tests +# ============================================================================= + + +def _make_extractor(device): + """Build a small MultiScaleFeatureExtractor for the consolidation tests.""" + extractor = MultiScaleFeatureExtractor( + geometry_dim=3, + radii=[0.5, 1.0], + neighbors_in_radius=[4, 8], + hidden_dim=16, + n_head=4, + dim_head=8, + dropout=0.0, + slice_num=8, + use_te=False, + plus=False, + ).to(device) + extractor.eval() + return extractor + + +def test_same_coords_guard(device): + """``_same_coords`` detects aliasing views but not equal-valued copies.""" + base = torch.randn(64, 3, device=device) + + # Identical object. + assert MultiScaleFeatureExtractor._same_coords(base, base) + + # Distinct view objects over the same storage (mirrors the recipe collate, + # which unsqueezes geometry and local_positions separately). + a = base.unsqueeze(0) + b = base.unsqueeze(0) + assert a is not b + assert MultiScaleFeatureExtractor._same_coords(a, b) + + # Equal values but distinct storage -> must NOT be treated as aliased. + assert not MultiScaleFeatureExtractor._same_coords(base, base.clone()) + + +def test_extract_context_and_local_aliased_matches_two_pass(device): + """Fast path (aliased inputs) equals the separate context/local methods.""" + torch.manual_seed(0) + extractor = _make_extractor(device) + x = torch.randn(2, 64, 3, device=device) + + with torch.no_grad(): + context, local = extractor.extract_context_and_local(x, x) + context_ref = extractor.extract_context_features(x, x) + local_ref = extractor.extract_local_features(x, x) + + assert len(context) == len(context_ref) == extractor.num_scales + for got, ref in zip(context, context_ref): + assert torch.equal(got, ref) + assert torch.equal(local, local_ref) + + +def test_extract_context_and_local_distinct_matches_two_pass(device): + """Fallback path (distinct inputs) preserves the asymmetric semantics.""" + torch.manual_seed(1) + extractor = _make_extractor(device) + spatial = torch.randn(2, 64, 3, device=device) + geometry = torch.randn(2, 64, 3, device=device) + + # Sanity: the guard must reject these so the fallback path is taken. + assert not extractor._same_coords(spatial, geometry) + + with torch.no_grad(): + context, local = extractor.extract_context_and_local(spatial, geometry) + context_ref = extractor.extract_context_features(spatial, geometry) + local_ref = extractor.extract_local_features(spatial, geometry) + + for got, ref in zip(context, context_ref): + assert torch.equal(got, ref) + assert torch.equal(local, local_ref) + + +def test_extract_context_and_local_reuses_radius_search(device, monkeypatch): + """Aliased inputs issue one radius_search per scale; distinct inputs issue two.""" + # Imported lazily: importing ball_query before the model package is fully + # initialized trips a pre-existing nn.functional <-> datapipes import cycle. + import physicsnemo.nn.module.ball_query as ball_query_mod + + extractor = _make_extractor(device) + + call_count = {"n": 0} + original = ball_query_mod.radius_search + + def counting_radius_search(*args, **kwargs): + call_count["n"] += 1 + return original(*args, **kwargs) + + monkeypatch.setattr(ball_query_mod, "radius_search", counting_radius_search) + + x = torch.randn(2, 64, 3, device=device) + with torch.no_grad(): + extractor.extract_context_and_local(x, x) + # One ball query per scale (consolidated), not two. + assert call_count["n"] == extractor.num_scales + + call_count["n"] = 0 + spatial = torch.randn(2, 64, 3, device=device) + geometry = torch.randn(2, 64, 3, device=device) + with torch.no_grad(): + extractor.extract_context_and_local(spatial, geometry) + # Fallback runs context + local separately: two ball queries per scale. + assert call_count["n"] == 2 * extractor.num_scales