From 501512f041918c7de8ba6236367c2b15c51f76d9 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 22 Jun 2026 17:26:03 +0000 Subject: [PATCH 1/4] GeoTransolver: single-pass context+local extraction Add a sync-free aliasing check (_same_coords) so the context projector detects when spatial_coords and geometry are views of the same storage (common after collate unsqueeze) and runs each per-scale processor once via extract_context_and_local instead of separately in extract_context_features and extract_local_features, halving radius_search calls. Falls back to the two-pass path when inputs differ. Enable include_local_features in the surface recipe config and add coverage. --- .../conf/model/geotransolver_surface.yaml | 2 +- .../models/geotransolver/context_projector.py | 110 +++++++++++++++-- .../geotransolver/test_context_projector.py | 115 ++++++++++++++++++ 3 files changed, 217 insertions(+), 10 deletions(-) diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/model/geotransolver_surface.yaml b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/model/geotransolver_surface.yaml index c61c2306e9..1e98b43fa9 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/model/geotransolver_surface.yaml +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/conf/model/geotransolver_surface.yaml @@ -55,7 +55,7 @@ model: slice_num: 256 use_te: false plus: false - include_local_features: false + include_local_features: true radii: [0.1, 0.5, 2.0] neighbors_in_radius: [16, 32, 64] n_hidden_local: 32 diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index 15f7d67ffc..6bdd683269 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -767,6 +767,99 @@ def extract_local_features( dim=-1, ) + @staticmethod + def _same_coords( + a: Float[torch.Tensor, "batch points spatial_dim"], + b: Float[torch.Tensor, "batch points spatial_dim"], + ) -> bool: + r"""Whether two tensors are the same data (so a ball query over either is identical). + + :meth:`extract_context_features` and :meth:`extract_local_features` + pass ``spatial_coords`` and ``geometry`` to the per-scale processor in + **swapped** order (``processor(spatial_coords, geometry)`` vs + ``processor(geometry, spatial_coords)``). Those two calls compute the + same ball query and the same processor output **iff the two inputs hold + the same coordinates**. When they do, :meth:`extract_context_and_local` + runs each processor once instead of twice. + + This is a *sufficient*, sync-free aliasing test: an ``is`` check is not + enough because the recipe's collate ``unsqueeze``-es ``geometry`` and + ``local_positions`` into distinct view objects that still alias the same + storage. A value comparison would force a host sync, so we instead + confirm the two tensors are identical views over identical storage + (shape, dtype, stride, offset, and base pointer all match). When this + returns ``False`` the caller falls back to the exact two-pass behavior, + so correctness never depends on it. + + Parameters + ---------- + a, b : torch.Tensor + Candidate coordinate tensors of shape :math:`(B, N, 3)`. + + Returns + ------- + bool + ``True`` when *a* and *b* are guaranteed element-for-element equal. + """ + return a is b or ( + a.shape == b.shape + and a.dtype == b.dtype + and a.stride() == b.stride() + and a.storage_offset() == b.storage_offset() + and a.data_ptr() == b.data_ptr() + ) + + def extract_context_and_local( + self, + spatial_coords: Float[torch.Tensor, "batch points spatial_dim"], + geometry: Float[torch.Tensor, "batch points geometry_dim"], + ) -> tuple[ + list[Float[torch.Tensor, "batch heads slices dim"]], + Float[torch.Tensor, "batch points total_hidden"], + ]: + r"""Extract context and local features in one pass, reusing the ball query when possible. + + Combines :meth:`extract_context_features` and + :meth:`extract_local_features`. When ``spatial_coords`` and ``geometry`` + are the same coordinates (see :meth:`_same_coords`), the swapped-argument + processor calls in those two methods are identical, so each per-scale + processor (ball query + MLP, the model's dominant ``radius_search`` + kernel) is evaluated **once** and fed to both the context tokenizer and + the concatenated local features. Otherwise this falls back to the exact + two-pass behavior, preserving the deliberate asymmetry for configs where + geometry and positions differ. + + Parameters + ---------- + spatial_coords : torch.Tensor + Spatial coordinates of shape :math:`(B, N, 3)`. + geometry : torch.Tensor + Geometry features of shape :math:`(B, N, C_{geo})`. + + Returns + ------- + tuple[list[torch.Tensor], torch.Tensor] + ``(context_features, local_features)`` matching the outputs of + :meth:`extract_context_features` and :meth:`extract_local_features` + respectively. + """ + if self._same_coords(spatial_coords, geometry): + context_features = [] + local_parts = [] + for processor, tokenizer in zip(self.processors, self.tokenizers): + # processor(spatial_coords, geometry) == processor(geometry, spatial_coords) + # when the two inputs alias, so one pass feeds both paths. + feat = processor(spatial_coords, geometry) + context_features.append(tokenizer(feat)) + local_parts.append(feat) + return context_features, torch.cat(local_parts, dim=-1) + + # General (asymmetric) fallback: distinct query/search sets per path. + return ( + self.extract_context_features(spatial_coords, geometry), + self.extract_local_features(spatial_coords, geometry), + ) + class GlobalContextBuilder(nn.Module): r"""Orchestrates all context construction with a clean, simple interface. @@ -1029,16 +1122,15 @@ def build_context( for i, embedding in enumerate(local_embeddings): spatial_coords = local_positions[i] # Extract coordinates - # Get tokenized context features from multi-scale extractor - context_feats = self.local_extractors[i].extract_context_features( - spatial_coords, geometry - ) + # Get tokenized context features and concatenated local + # features in one pass. When spatial_coords and geometry alias + # the same coordinates (the common case), the extractor reuses a + # single ball query per scale instead of computing the same + # radius_search twice; otherwise it falls back to two passes. + context_feats, local_feats = self.local_extractors[ + i + ].extract_context_and_local(spatial_coords, geometry) context_parts.extend(context_feats) - - # Get concatenated local features for skip connection - local_feats = self.local_extractors[i].extract_local_features( - spatial_coords, geometry - ) local_features.append(local_feats) # Tokenize geometry features diff --git a/test/models/geotransolver/test_context_projector.py b/test/models/geotransolver/test_context_projector.py index 49affc1a13..def0f4a5d5 100644 --- a/test/models/geotransolver/test_context_projector.py +++ b/test/models/geotransolver/test_context_projector.py @@ -16,8 +16,13 @@ import torch +# Import datapipes first: it resolves a pre-existing nn.functional <-> datapipes +# circular import that otherwise fails when context_projector is the first +# physicsnemo import in the process (e.g. running this file standalone). +import physicsnemo.datapipes # noqa: F401 from physicsnemo.experimental.models.geotransolver.context_projector import ( ContextProjector, + MultiScaleFeatureExtractor, ) # ============================================================================= @@ -53,3 +58,113 @@ def test_context_projector_forward(device): # Output shape: [Batch, Heads, Slice_num, dim_head] assert slice_tokens.shape == (batch_size, heads, slice_num, dim_head) assert not torch.isnan(slice_tokens).any() + + +# ============================================================================= +# MultiScaleFeatureExtractor consolidation tests +# ============================================================================= + + +def _make_extractor(device): + """Build a small MultiScaleFeatureExtractor for the consolidation tests.""" + extractor = MultiScaleFeatureExtractor( + geometry_dim=3, + radii=[0.5, 1.0], + neighbors_in_radius=[4, 8], + hidden_dim=16, + n_head=4, + dim_head=8, + dropout=0.0, + slice_num=8, + use_te=False, + plus=False, + ).to(device) + extractor.eval() + return extractor + + +def test_same_coords_guard(device): + """``_same_coords`` detects aliasing views but not equal-valued copies.""" + base = torch.randn(64, 3, device=device) + + # Identical object. + assert MultiScaleFeatureExtractor._same_coords(base, base) + + # Distinct view objects over the same storage (mirrors the recipe collate, + # which unsqueezes geometry and local_positions separately). + a = base.unsqueeze(0) + b = base.unsqueeze(0) + assert a is not b + assert MultiScaleFeatureExtractor._same_coords(a, b) + + # Equal values but distinct storage -> must NOT be treated as aliased. + assert not MultiScaleFeatureExtractor._same_coords(base, base.clone()) + + +def test_extract_context_and_local_aliased_matches_two_pass(device): + """Fast path (aliased inputs) equals the separate context/local methods.""" + torch.manual_seed(0) + extractor = _make_extractor(device) + x = torch.randn(2, 64, 3, device=device) + + with torch.no_grad(): + context, local = extractor.extract_context_and_local(x, x) + context_ref = extractor.extract_context_features(x, x) + local_ref = extractor.extract_local_features(x, x) + + assert len(context) == len(context_ref) == extractor.num_scales + for got, ref in zip(context, context_ref): + assert torch.equal(got, ref) + assert torch.equal(local, local_ref) + + +def test_extract_context_and_local_distinct_matches_two_pass(device): + """Fallback path (distinct inputs) preserves the asymmetric semantics.""" + torch.manual_seed(1) + extractor = _make_extractor(device) + spatial = torch.randn(2, 64, 3, device=device) + geometry = torch.randn(2, 64, 3, device=device) + + # Sanity: the guard must reject these so the fallback path is taken. + assert not extractor._same_coords(spatial, geometry) + + with torch.no_grad(): + context, local = extractor.extract_context_and_local(spatial, geometry) + context_ref = extractor.extract_context_features(spatial, geometry) + local_ref = extractor.extract_local_features(spatial, geometry) + + for got, ref in zip(context, context_ref): + assert torch.equal(got, ref) + assert torch.equal(local, local_ref) + + +def test_extract_context_and_local_reuses_radius_search(device, monkeypatch): + """Aliased inputs issue one radius_search per scale; distinct inputs issue two.""" + # Imported lazily: importing ball_query before the model package is fully + # initialized trips a pre-existing nn.functional <-> datapipes import cycle. + import physicsnemo.nn.module.ball_query as ball_query_mod + + extractor = _make_extractor(device) + + call_count = {"n": 0} + original = ball_query_mod.radius_search + + def counting_radius_search(*args, **kwargs): + call_count["n"] += 1 + return original(*args, **kwargs) + + monkeypatch.setattr(ball_query_mod, "radius_search", counting_radius_search) + + x = torch.randn(2, 64, 3, device=device) + with torch.no_grad(): + extractor.extract_context_and_local(x, x) + # One ball query per scale (consolidated), not two. + assert call_count["n"] == extractor.num_scales + + call_count["n"] = 0 + spatial = torch.randn(2, 64, 3, device=device) + geometry = torch.randn(2, 64, 3, device=device) + with torch.no_grad(): + extractor.extract_context_and_local(spatial, geometry) + # Fallback runs context + local separately: two ball queries per scale. + assert call_count["n"] == 2 * extractor.num_scales From 12948328f2681aeaaba8f55bc40435f3e106b5e1 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 29 Jun 2026 15:37:48 -0500 Subject: [PATCH 2/4] Cleaning up the fast path radius search --- .../models/geotransolver/context_projector.py | 111 +++--------------- 1 file changed, 15 insertions(+), 96 deletions(-) diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py index 6bdd683269..41e3cfbc29 100644 --- a/physicsnemo/experimental/models/geotransolver/context_projector.py +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -50,69 +50,14 @@ from physicsnemo.nn import ConcreteDropout +from .utils import structured_grid_to_conv_input, tensors_alias + # Check optional dependency availability TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) if TE_AVAILABLE: import transformer_engine.pytorch as te -def _structured_grid_to_conv_input( - x: Float[torch.Tensor, "batch tokens channels"], - batch: int, - tokens: int, - channels: int, - ndim: int, - spatial_shape: tuple[int, ...], -) -> Float[torch.Tensor, "batch channels ..."]: - r"""Reshape flat token tensor to spatial layout for Conv2d/Conv3d. - - Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` for 2D or - :math:`(B, C, H, W, D)` for 3D so that structured context projectors - can apply spatial convolutions. Validates that :math:`N` matches the - grid size. - - Parameters - ---------- - x : torch.Tensor - Input tensor of shape :math:`(B, N, C)` (batch, tokens, channels). - batch : int - Batch size :math:`B`. - tokens : int - Number of tokens :math:`N` (must equal :math:`H \\times W` or - :math:`H \\times W \\times D`). - channels : int - Channel dimension :math:`C`. - ndim : int - Number of spatial dimensions; must be 2 or 3. - spatial_shape : tuple[int, ...] - :math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D. - - Returns - ------- - torch.Tensor - Reshaped tensor of shape :math:`(B, C, H, W)` or - :math:`(B, C, H, W, D)` for use as conv input. - - Raises - ------ - ValueError - If ``tokens`` does not match the product of ``spatial_shape``. - """ - if ndim == 2: - H, W = spatial_shape - if tokens != H * W: - raise ValueError( - f"Expected N={H * W} tokens for 2D grid, got N={tokens}" - ) - return x.view(batch, H, W, channels).permute(0, 3, 1, 2) - H, W, D = spatial_shape - if tokens != H * W * D: - raise ValueError( - f"Expected N={H * W * D} tokens for 3D grid, got N={tokens}" - ) - return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3) - - class _SliceToContextMixin: r"""Internal mixin providing shared slice-to-context init and slice aggregation. @@ -454,10 +399,7 @@ def _grid_project( Float[torch.Tensor, "batch tokens heads dim"], ] ): - B, N, C = x.shape - grid = _structured_grid_to_conv_input( - x, B, N, C, self._nd, self.spatial_shape - ) + grid = structured_grid_to_conv_input(x, self.spatial_shape) pattern = ( "B (H D) h w -> B (h w) H D" if self._nd == 2 @@ -772,24 +714,11 @@ def _same_coords( a: Float[torch.Tensor, "batch points spatial_dim"], b: Float[torch.Tensor, "batch points spatial_dim"], ) -> bool: - r"""Whether two tensors are the same data (so a ball query over either is identical). - - :meth:`extract_context_features` and :meth:`extract_local_features` - pass ``spatial_coords`` and ``geometry`` to the per-scale processor in - **swapped** order (``processor(spatial_coords, geometry)`` vs - ``processor(geometry, spatial_coords)``). Those two calls compute the - same ball query and the same processor output **iff the two inputs hold - the same coordinates**. When they do, :meth:`extract_context_and_local` - runs each processor once instead of twice. - - This is a *sufficient*, sync-free aliasing test: an ``is`` check is not - enough because the recipe's collate ``unsqueeze``-es ``geometry`` and - ``local_positions`` into distinct view objects that still alias the same - storage. A value comparison would force a host sync, so we instead - confirm the two tensors are identical views over identical storage - (shape, dtype, stride, offset, and base pointer all match). When this - returns ``False`` the caller falls back to the exact two-pass behavior, - so correctness never depends on it. + r"""Whether ``a`` and ``b`` alias the same coordinates. + + Thin wrapper around :func:`~physicsnemo.experimental.models.geotransolver.utils.tensors_alias`. + When ``True``, a ball query over either tensor is identical, so + :meth:`extract_context_and_local` can take its single-pass fast path. Parameters ---------- @@ -801,13 +730,7 @@ def _same_coords( bool ``True`` when *a* and *b* are guaranteed element-for-element equal. """ - return a is b or ( - a.shape == b.shape - and a.dtype == b.dtype - and a.stride() == b.stride() - and a.storage_offset() == b.storage_offset() - and a.data_ptr() == b.data_ptr() - ) + return tensors_alias(a, b) def extract_context_and_local( self, @@ -821,13 +744,10 @@ def extract_context_and_local( Combines :meth:`extract_context_features` and :meth:`extract_local_features`. When ``spatial_coords`` and ``geometry`` - are the same coordinates (see :meth:`_same_coords`), the swapped-argument - processor calls in those two methods are identical, so each per-scale - processor (ball query + MLP, the model's dominant ``radius_search`` - kernel) is evaluated **once** and fed to both the context tokenizer and - the concatenated local features. Otherwise this falls back to the exact - two-pass behavior, preserving the deliberate asymmetry for configs where - geometry and positions differ. + alias the same coordinates (see :meth:`_same_coords`), each per-scale + processor (ball query + MLP) runs once and feeds both paths. Otherwise + it falls back to the two-pass behavior, preserving the asymmetry for + configs where geometry and positions differ. Parameters ---------- @@ -844,17 +764,16 @@ def extract_context_and_local( respectively. """ if self._same_coords(spatial_coords, geometry): + # Aliased inputs: one processor pass per scale feeds both paths. context_features = [] local_parts = [] for processor, tokenizer in zip(self.processors, self.tokenizers): - # processor(spatial_coords, geometry) == processor(geometry, spatial_coords) - # when the two inputs alias, so one pass feeds both paths. feat = processor(spatial_coords, geometry) context_features.append(tokenizer(feat)) local_parts.append(feat) return context_features, torch.cat(local_parts, dim=-1) - # General (asymmetric) fallback: distinct query/search sets per path. + # Asymmetric fallback: distinct query/search sets per path. return ( self.extract_context_features(spatial_coords, geometry), self.extract_local_features(spatial_coords, geometry), From 33cabb74a360dcaa4b5b9b77ea5456c7c76ad50b Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 29 Jun 2026 15:38:07 -0500 Subject: [PATCH 3/4] Cleaning up the fast path radius search --- .../models/geotransolver/utils.py | 103 ++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 physicsnemo/experimental/models/geotransolver/utils.py diff --git a/physicsnemo/experimental/models/geotransolver/utils.py b/physicsnemo/experimental/models/geotransolver/utils.py new file mode 100644 index 0000000000..da10be1064 --- /dev/null +++ b/physicsnemo/experimental/models/geotransolver/utils.py @@ -0,0 +1,103 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Small tensor helpers shared across GeoTransolver context projectors.""" + +from __future__ import annotations + +import torch +from jaxtyping import Float + + +def structured_grid_to_conv_input( + x: Float[torch.Tensor, "batch tokens channels"], + spatial_shape: tuple[int, ...], +) -> Float[torch.Tensor, "batch channels ..."]: + r"""Reshape a flat token tensor to spatial layout for Conv2d/Conv3d. + + Converts :math:`(B, N, C)` to :math:`(B, C, H, W)` (2D) or + :math:`(B, C, H, W, D)` (3D) so structured projectors can apply spatial + convolutions. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)`. + spatial_shape : tuple[int, ...] + :math:`(H, W)` for 2D or :math:`(H, W, D)` for 3D. The product must + equal :math:`N`. + + Returns + ------- + torch.Tensor + Tensor of shape :math:`(B, C, H, W)` or :math:`(B, C, H, W, D)`. + + Raises + ------ + ValueError + If ``spatial_shape`` is not length 2 or 3, or its product does not + match the token dimension :math:`N`. + """ + batch, tokens, channels = x.shape + expected = 1 + for s in spatial_shape: + expected *= s + if tokens != expected: + raise ValueError( + f"Expected N={expected} tokens for grid {tuple(spatial_shape)}, " + f"got N={tokens}" + ) + + if len(spatial_shape) == 2: + H, W = spatial_shape + return x.view(batch, H, W, channels).permute(0, 3, 1, 2) + if len(spatial_shape) == 3: + H, W, D = spatial_shape + return x.view(batch, H, W, D, channels).permute(0, 4, 1, 2, 3) + raise ValueError( + f"spatial_shape must have length 2 or 3, got {tuple(spatial_shape)}" + ) + + +def tensors_alias( + a: Float[torch.Tensor, "..."], + b: Float[torch.Tensor, "..."], +) -> bool: + r"""Return ``True`` when ``a`` and ``b`` are guaranteed to hold identical data. + + This is a sync-free, *sufficient* aliasing test: it confirms the two tensors + are the same object, or distinct views over the same storage with matching + shape, dtype, stride, and offset. A plain ``is`` check is not enough because + callers may pass separately-created views of the same storage; a value + comparison is avoided because it would force a host sync. + + Parameters + ---------- + a, b : torch.Tensor + Candidate tensors to compare. + + Returns + ------- + bool + ``True`` if ``a`` and ``b`` are element-for-element equal. + """ + return a is b or ( + a.shape == b.shape + and a.dtype == b.dtype + and a.stride() == b.stride() + and a.storage_offset() == b.storage_offset() + and a.data_ptr() == b.data_ptr() + ) From b0b9b80f37fc941deb69dcacad67dc7a6c0a1e57 Mon Sep 17 00:00:00 2001 From: Corey Adams <6619961+coreyjadams@users.noreply.github.com> Date: Mon, 29 Jun 2026 15:46:36 -0500 Subject: [PATCH 4/4] Use is_set_to instead of hand-rolled comparison --- .../models/geotransolver/utils.py | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/physicsnemo/experimental/models/geotransolver/utils.py b/physicsnemo/experimental/models/geotransolver/utils.py index da10be1064..cd70b29eae 100644 --- a/physicsnemo/experimental/models/geotransolver/utils.py +++ b/physicsnemo/experimental/models/geotransolver/utils.py @@ -78,11 +78,9 @@ def tensors_alias( ) -> bool: r"""Return ``True`` when ``a`` and ``b`` are guaranteed to hold identical data. - This is a sync-free, *sufficient* aliasing test: it confirms the two tensors - are the same object, or distinct views over the same storage with matching - shape, dtype, stride, and offset. A plain ``is`` check is not enough because - callers may pass separately-created views of the same storage; a value - comparison is avoided because it would force a host sync. + Sync-free, *sufficient* aliasing test: ``True`` iff the tensors alias the + same storage with matching shape, stride, offset, and dtype. Avoids a value + comparison, which would force a host sync. Parameters ---------- @@ -92,12 +90,8 @@ def tensors_alias( Returns ------- bool - ``True`` if ``a`` and ``b`` are element-for-element equal. + ``True`` if ``a`` and ``b`` alias the same storage with matching shape, + stride, offset, and dtype. """ - return a is b or ( - a.shape == b.shape - and a.dtype == b.dtype - and a.stride() == b.stride() - and a.storage_offset() == b.storage_offset() - and a.data_ptr() == b.data_ptr() - ) + # ``is_set_to`` covers storage, offset, size, and stride but ignores dtype. + return a.is_set_to(b) and a.dtype == b.dtype