diff --git a/benchmarks/physicsnemo/nn/functional/registry.py b/benchmarks/physicsnemo/nn/functional/registry.py index 3e62c775c1..6c170851c7 100644 --- a/benchmarks/physicsnemo/nn/functional/registry.py +++ b/benchmarks/physicsnemo/nn/functional/registry.py @@ -23,7 +23,10 @@ MeshLSQGradient, RectilinearGridGradient, SpectralGridGradient, + UniformGridCurl, + UniformGridDivergence, UniformGridGradient, + UniformGridLaplacian, ) from physicsnemo.nn.functional.fourier_spectral import ( IRFFT, @@ -36,6 +39,8 @@ ) from physicsnemo.nn.functional.geometry import ( FarthestPointSampling, + MeshPoissonDiskSample, + MeshToVoxelFraction, SignedDistanceField, ) from physicsnemo.nn.functional.interpolation import ( @@ -64,8 +69,13 @@ MeshGreenGaussGradient, SpectralGridGradient, MeshlessFDDerivatives, + UniformGridDivergence, + UniformGridCurl, + UniformGridLaplacian, # Geometry. FarthestPointSampling, + MeshPoissonDiskSample, + MeshToVoxelFraction, SignedDistanceField, # Interpolation. GridToPointInterpolation, diff --git a/docs/api/nn/functionals/derivatives.rst b/docs/api/nn/functionals/derivatives.rst index 0e132c76a7..ef543934cc 100644 --- a/docs/api/nn/functionals/derivatives.rst +++ b/docs/api/nn/functionals/derivatives.rst @@ -32,3 +32,24 @@ Derivative Functionals :width: 100% .. autofunction:: physicsnemo.nn.functional.meshless_fd_derivatives + +Uniform Grid Vector Calculus +---------------------------- + +.. autofunction:: physicsnemo.nn.functional.uniform_grid_divergence + +.. figure:: /img/nn/functional/derivatives/uniform_grid_divergence.png + :alt: Uniform grid divergence example + :width: 100% + +.. autofunction:: physicsnemo.nn.functional.uniform_grid_curl + +.. figure:: /img/nn/functional/derivatives/uniform_grid_curl.png + :alt: Uniform grid curl example + :width: 100% + +.. autofunction:: physicsnemo.nn.functional.uniform_grid_laplacian + +.. figure:: /img/nn/functional/derivatives/uniform_grid_laplacian.png + :alt: Uniform grid Laplacian example + :width: 100% diff --git a/docs/img/nn/functional/derivatives/uniform_grid_curl.png b/docs/img/nn/functional/derivatives/uniform_grid_curl.png new file mode 100644 index 0000000000..78db06b1e3 Binary files /dev/null and b/docs/img/nn/functional/derivatives/uniform_grid_curl.png differ diff --git a/docs/img/nn/functional/derivatives/uniform_grid_divergence.png b/docs/img/nn/functional/derivatives/uniform_grid_divergence.png new file mode 100644 index 0000000000..dd39c22354 Binary files /dev/null and b/docs/img/nn/functional/derivatives/uniform_grid_divergence.png differ diff --git a/docs/img/nn/functional/derivatives/uniform_grid_laplacian.png b/docs/img/nn/functional/derivatives/uniform_grid_laplacian.png new file mode 100644 index 0000000000..6e554d651f Binary files /dev/null and b/docs/img/nn/functional/derivatives/uniform_grid_laplacian.png differ diff --git a/physicsnemo/nn/functional/__init__.py b/physicsnemo/nn/functional/__init__.py index 4d952d8fba..00affd78ae 100644 --- a/physicsnemo/nn/functional/__init__.py +++ b/physicsnemo/nn/functional/__init__.py @@ -20,7 +20,10 @@ meshless_fd_derivatives, rectilinear_grid_gradient, spectral_grid_gradient, + uniform_grid_curl, + uniform_grid_divergence, uniform_grid_gradient, + uniform_grid_laplacian, ) from .equivariant_ops import ( legendre_polynomials, @@ -60,6 +63,9 @@ "irfft2", "drop_path", "farthest_point_sampling", + "uniform_grid_curl", + "uniform_grid_divergence", + "uniform_grid_laplacian", "grid_to_point_interpolation", "imag", "interpolation", diff --git a/physicsnemo/nn/functional/derivatives/__init__.py b/physicsnemo/nn/functional/derivatives/__init__.py index 0306f0e6c8..de046588e0 100644 --- a/physicsnemo/nn/functional/derivatives/__init__.py +++ b/physicsnemo/nn/functional/derivatives/__init__.py @@ -25,15 +25,24 @@ rectilinear_grid_gradient, ) from .spectral_grid_gradient import SpectralGridGradient, spectral_grid_gradient +from .uniform_grid_curl import UniformGridCurl, uniform_grid_curl +from .uniform_grid_divergence import UniformGridDivergence, uniform_grid_divergence from .uniform_grid_gradient import UniformGridGradient, uniform_grid_gradient +from .uniform_grid_laplacian import UniformGridLaplacian, uniform_grid_laplacian __all__ = [ + "UniformGridCurl", + "UniformGridDivergence", + "UniformGridLaplacian", "MeshGreenGaussGradient", "MeshlessFDDerivatives", "MeshLSQGradient", "RectilinearGridGradient", "SpectralGridGradient", "UniformGridGradient", + "uniform_grid_curl", + "uniform_grid_divergence", + "uniform_grid_laplacian", "mesh_green_gauss_gradient", "meshless_fd_derivatives", "mesh_lsq_gradient", diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_curl/__init__.py b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/__init__.py new file mode 100644 index 0000000000..5dc6907ce8 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .uniform_grid_curl import UniformGridCurl, uniform_grid_curl + +__all__ = ["UniformGridCurl", "uniform_grid_curl"] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_curl/_torch_impl.py b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/_torch_impl.py new file mode 100644 index 0000000000..9c7ccbd22b --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/_torch_impl.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from .utils import validate_vector_field + +_SUPPORTED_ORDERS = (2, 4) + + +def _normalize_spacing( + spacing: float | Sequence[float], ndim: int +) -> tuple[float, ...]: + if isinstance(spacing, (float, int)): + return tuple(float(spacing) for _ in range(ndim)) + spacing_tuple = tuple(float(x) for x in spacing) + if len(spacing_tuple) != ndim: + raise ValueError( + f"spacing must have {ndim} entries for a {ndim}D field, got {len(spacing_tuple)}" + ) + return spacing_tuple + + +def _validate_order(order: int) -> int: + if not isinstance(order, int): + raise TypeError(f"order must be an integer, got {type(order)}") + if order not in _SUPPORTED_ORDERS: + raise ValueError( + "uniform_grid_curl supports central orders " + f"{list(_SUPPORTED_ORDERS)}, got order={order}" + ) + return order + + +def _central_derivative_order2( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + return ( + torch.roll(field, shifts=-1, dims=axis) - torch.roll(field, shifts=1, dims=axis) + ) / (2.0 * dx) + + +def _central_derivative_order4( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + return ( + -torch.roll(field, shifts=-2, dims=axis) + + 8.0 * torch.roll(field, shifts=-1, dims=axis) + - 8.0 * torch.roll(field, shifts=1, dims=axis) + + torch.roll(field, shifts=2, dims=axis) + ) / (12.0 * dx) + + +_DERIVATIVE_DISPATCH = { + 2: _central_derivative_order2, + 4: _central_derivative_order4, +} + + +def uniform_grid_curl_torch( + vector_field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, +) -> torch.Tensor: + """Compute periodic uniform-grid curl with PyTorch tensor ops.""" + grid_ndim = validate_vector_field(vector_field) + spacing_tuple = _normalize_spacing(spacing, grid_ndim) + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + derivative_fn = _DERIVATIVE_DISPATCH[_validate_order(order)] + + def derivative(component: int, axis: int) -> torch.Tensor: + return derivative_fn(vector_field[component], axis, spacing_tuple[axis]) + + if grid_ndim == 2: + return derivative(1, 0) - derivative(0, 1) + + curl_x = derivative(2, 1) - derivative(1, 2) + curl_y = derivative(0, 2) - derivative(2, 0) + curl_z = derivative(1, 0) - derivative(0, 1) + return torch.stack((curl_x, curl_y, curl_z), dim=0) diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_curl/_warp_impl.py b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/_warp_impl.py new file mode 100644 index 0000000000..fca86813b0 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/_warp_impl.py @@ -0,0 +1,426 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import warp as wp + +from ..uniform_grid_gradient._warp_impl.utils import ( + _launch_dim, + _normalize_spacing, + _to_wp_tensor, + _warp_launch_context, + _wp_launch, + _wrap_minus1, + _wrap_minus2, + _wrap_plus1, + _wrap_plus2, +) +from .utils import validate_vector_field + +_SUPPORTED_ORDERS = (2, 4) + + +def _validate_order(order: int) -> int: + if not isinstance(order, int): + raise TypeError(f"order must be an integer, got {type(order)}") + if order not in _SUPPORTED_ORDERS: + raise ValueError( + "uniform_grid_curl supports central orders " + f"{list(_SUPPORTED_ORDERS)}, got order={order}" + ) + return order + + +def _validate_positive_spacing(spacing_tuple: tuple[float, ...]) -> None: + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + + +def _to_fp32_contiguous(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.float32 and tensor.is_contiguous(): + return tensor + return tensor.to(dtype=torch.float32).contiguous() + + +def _restore_dtype(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + if tensor.dtype == target_dtype: + return tensor + return tensor.to(dtype=target_dtype) + + +@wp.kernel +def _curl_2d_order2_kernel( + vector_field: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + output: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = output.shape[0] + n1 = output.shape[1] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + d_v1_dx = (vector_field[1, ip, j] - vector_field[1, im, j]) * (0.5 * inv_dx0) + d_v0_dy = (vector_field[0, i, jp] - vector_field[0, i, jm]) * (0.5 * inv_dx1) + output[i, j] = d_v1_dx - d_v0_dy + + +@wp.kernel +def _curl_2d_order4_kernel( + vector_field: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + output: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = output.shape[0] + n1 = output.shape[1] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + d_v1_dx = ( + -vector_field[1, ip2, j] + + 8.0 * vector_field[1, ip1, j] + - 8.0 * vector_field[1, im1, j] + + vector_field[1, im2, j] + ) * (inv_dx0 / 12.0) + d_v0_dy = ( + -vector_field[0, i, jp2] + + 8.0 * vector_field[0, i, jp1] + - 8.0 * vector_field[0, i, jm1] + + vector_field[0, i, jm2] + ) * (inv_dx1 / 12.0) + output[i, j] = d_v1_dx - d_v0_dy + + +@wp.kernel +def _curl_3d_order2_kernel( + vector_field: wp.array4d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + output: wp.array4d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = output.shape[1] + n1 = output.shape[2] + n2 = output.shape[3] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + output[0, i, j, k] = (vector_field[2, i, jp, k] - vector_field[2, i, jm, k]) * ( + 0.5 * inv_dx1 + ) - (vector_field[1, i, j, kp] - vector_field[1, i, j, km]) * (0.5 * inv_dx2) + output[1, i, j, k] = (vector_field[0, i, j, kp] - vector_field[0, i, j, km]) * ( + 0.5 * inv_dx2 + ) - (vector_field[2, ip, j, k] - vector_field[2, im, j, k]) * (0.5 * inv_dx0) + output[2, i, j, k] = (vector_field[1, ip, j, k] - vector_field[1, im, j, k]) * ( + 0.5 * inv_dx0 + ) - (vector_field[0, i, jp, k] - vector_field[0, i, jm, k]) * (0.5 * inv_dx1) + + +@wp.kernel +def _curl_3d_order4_kernel( + vector_field: wp.array4d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + output: wp.array4d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = output.shape[1] + n1 = output.shape[2] + n2 = output.shape[3] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + km1 = _wrap_minus1(k, n2) + kp1 = _wrap_plus1(k, n2) + km2 = _wrap_minus2(k, n2) + kp2 = _wrap_plus2(k, n2) + d_v2_dy = ( + -vector_field[2, i, jp2, k] + + 8.0 * vector_field[2, i, jp1, k] + - 8.0 * vector_field[2, i, jm1, k] + + vector_field[2, i, jm2, k] + ) * (inv_dx1 / 12.0) + d_v1_dz = ( + -vector_field[1, i, j, kp2] + + 8.0 * vector_field[1, i, j, kp1] + - 8.0 * vector_field[1, i, j, km1] + + vector_field[1, i, j, km2] + ) * (inv_dx2 / 12.0) + d_v0_dz = ( + -vector_field[0, i, j, kp2] + + 8.0 * vector_field[0, i, j, kp1] + - 8.0 * vector_field[0, i, j, km1] + + vector_field[0, i, j, km2] + ) * (inv_dx2 / 12.0) + d_v2_dx = ( + -vector_field[2, ip2, j, k] + + 8.0 * vector_field[2, ip1, j, k] + - 8.0 * vector_field[2, im1, j, k] + + vector_field[2, im2, j, k] + ) * (inv_dx0 / 12.0) + d_v1_dx = ( + -vector_field[1, ip2, j, k] + + 8.0 * vector_field[1, ip1, j, k] + - 8.0 * vector_field[1, im1, j, k] + + vector_field[1, im2, j, k] + ) * (inv_dx0 / 12.0) + d_v0_dy = ( + -vector_field[0, i, jp2, k] + + 8.0 * vector_field[0, i, jp1, k] + - 8.0 * vector_field[0, i, jm1, k] + + vector_field[0, i, jm2, k] + ) * (inv_dx1 / 12.0) + output[0, i, j, k] = d_v2_dy - d_v1_dz + output[1, i, j, k] = d_v0_dz - d_v2_dx + output[2, i, j, k] = d_v1_dx - d_v0_dy + + +@wp.kernel +def _curl_backward_2d_order2_kernel( + grad_output: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + grad_vector: wp.array3d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad_output.shape[0] + n1 = grad_output.shape[1] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + grad_vector[0, i, j] = (grad_output[i, jp] - grad_output[i, jm]) * (0.5 * inv_dx1) + grad_vector[1, i, j] = (grad_output[im, j] - grad_output[ip, j]) * (0.5 * inv_dx0) + + +@wp.kernel +def _curl_backward_2d_order4_kernel( + grad_output: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + grad_vector: wp.array3d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad_output.shape[0] + n1 = grad_output.shape[1] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + grad_vector[0, i, j] = ( + -grad_output[i, jp2] + + 8.0 * grad_output[i, jp1] + - 8.0 * grad_output[i, jm1] + + grad_output[i, jm2] + ) * (inv_dx1 / 12.0) + grad_vector[1, i, j] = ( + grad_output[ip2, j] + - 8.0 * grad_output[ip1, j] + + 8.0 * grad_output[im1, j] + - grad_output[im2, j] + ) * (inv_dx0 / 12.0) + + +_FORWARD_KERNELS = { + (2, 2): _curl_2d_order2_kernel, + (2, 4): _curl_2d_order4_kernel, + (3, 2): _curl_3d_order2_kernel, + (3, 4): _curl_3d_order4_kernel, +} +_BACKWARD_2D_KERNELS = { + 2: _curl_backward_2d_order2_kernel, + 4: _curl_backward_2d_order4_kernel, +} + + +def _launch_curl_forward( + *, + vector_field_fp32: torch.Tensor, + spacing_tuple: tuple[float, ...], + order: int, + output_fp32: torch.Tensor, +) -> None: + wp_device, wp_stream = _warp_launch_context(vector_field_fp32) + launch_shape = output_fp32.shape if output_fp32.ndim == 2 else output_fp32.shape[1:] + _wp_launch( + kernel=_FORWARD_KERNELS[(vector_field_fp32.ndim - 1, order)], + dim=_launch_dim(launch_shape), + inputs=[ + _to_wp_tensor(vector_field_fp32), + *[1.0 / float(dx) for dx in spacing_tuple], + _to_wp_tensor(output_fp32), + ], + device=wp_device, + stream=wp_stream, + ) + + +def _launch_curl_backward_2d( + *, + grad_output_fp32: torch.Tensor, + spacing_tuple: tuple[float, ...], + order: int, + grad_vector_fp32: torch.Tensor, +) -> None: + wp_device, wp_stream = _warp_launch_context(grad_output_fp32) + _wp_launch( + kernel=_BACKWARD_2D_KERNELS[order], + dim=_launch_dim(grad_output_fp32.shape), + inputs=[ + _to_wp_tensor(grad_output_fp32), + *[1.0 / float(dx) for dx in spacing_tuple], + _to_wp_tensor(grad_vector_fp32), + ], + device=wp_device, + stream=wp_stream, + ) + + +@torch.library.custom_op("physicsnemo::uniform_grid_curl_warp_impl", mutates_args=()) +def uniform_grid_curl_impl( + vector_field: torch.Tensor, + spacing_meta: torch.Tensor, + order: int, +) -> torch.Tensor: + """Evaluate uniform-grid curl with fused Warp kernels.""" + grid_ndim = validate_vector_field(vector_field) + spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + _validate_positive_spacing(spacing_tuple) + order = _validate_order(int(order)) + orig_dtype = vector_field.dtype + vector_field_fp32 = _to_fp32_contiguous(vector_field) + output_shape = ( + vector_field_fp32.shape[1:] if grid_ndim == 2 else vector_field_fp32.shape + ) + output_fp32 = torch.empty( + output_shape, + device=vector_field_fp32.device, + dtype=torch.float32, + ) + _launch_curl_forward( + vector_field_fp32=vector_field_fp32, + spacing_tuple=spacing_tuple[:grid_ndim], + order=order, + output_fp32=output_fp32, + ) + return _restore_dtype(output_fp32, orig_dtype) + + +@uniform_grid_curl_impl.register_fake +def _uniform_grid_curl_impl_fake( + vector_field: torch.Tensor, + spacing_meta: torch.Tensor, + order: int, +) -> torch.Tensor: + _ = (spacing_meta, order) + output_shape = ( + vector_field.shape[1:] if vector_field.ndim == 3 else vector_field.shape + ) + return torch.empty( + output_shape, + device=vector_field.device, + dtype=vector_field.dtype, + ) + + +def setup_uniform_grid_curl_context( + ctx: torch.autograd.function.FunctionCtx, + inputs: tuple, + output: torch.Tensor, +) -> None: + """Save uniform-grid curl metadata for the backward pass.""" + vector_field, spacing_meta, order = inputs + _ = output + ctx.spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + ctx.order = int(order) + ctx.orig_dtype = vector_field.dtype + ctx.grid_ndim = vector_field.ndim - 1 + + +def backward_uniform_grid_curl( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None]: + if grad_output is None or not ctx.needs_input_grad[0]: + return None, None, None + grad_output_fp32 = _to_fp32_contiguous(grad_output) + if ctx.grid_ndim == 2: + grad_vector_fp32 = torch.empty( + (2, *grad_output_fp32.shape), + device=grad_output_fp32.device, + dtype=torch.float32, + ) + _launch_curl_backward_2d( + grad_output_fp32=grad_output_fp32, + spacing_tuple=ctx.spacing_tuple[:2], + order=ctx.order, + grad_vector_fp32=grad_vector_fp32, + ) + else: + grad_vector_fp32 = torch.empty_like(grad_output_fp32) + _launch_curl_forward( + vector_field_fp32=grad_output_fp32, + spacing_tuple=ctx.spacing_tuple[:3], + order=ctx.order, + output_fp32=grad_vector_fp32, + ) + return _restore_dtype(grad_vector_fp32, ctx.orig_dtype), None, None + + +uniform_grid_curl_impl.register_autograd( + backward_uniform_grid_curl, + setup_context=setup_uniform_grid_curl_context, +) + + +def uniform_grid_curl_warp( + vector_field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, +) -> torch.Tensor: + """Compute periodic uniform-grid curl with a fused Warp custom op.""" + grid_ndim = vector_field.ndim - 1 + spacing_tuple = _normalize_spacing(spacing, grid_ndim) + spacing_meta = torch.tensor(spacing_tuple, dtype=torch.float32, device="cpu") + return uniform_grid_curl_impl(vector_field, spacing_meta, int(order)) diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_curl/uniform_grid_curl.py b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/uniform_grid_curl.py new file mode 100644 index 0000000000..48d6b2a923 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/uniform_grid_curl.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from ._torch_impl import uniform_grid_curl_torch +from ._warp_impl import uniform_grid_curl_warp + + +class UniformGridCurl(FunctionSpec): + r"""Compute periodic curl on a uniform grid. + + This functional accepts channel-first vector fields with shape + ``(dim, *grid_shape)`` for 2D or 3D uniform grids. For 2D inputs, it + returns scalar vorticity. For 3D inputs, it returns the channel-first + vector curl. + + Parameters + ---------- + vector_field : torch.Tensor + Channel-first vector field with shape ``(2, n0, n1)`` or + ``(3, n0, n1, n2)``. + spacing : float | Sequence[float], optional + Uniform spacing per grid axis. A scalar applies the same spacing to + every axis. + order : int, optional + Central-difference accuracy order. Supported values match + :func:`physicsnemo.nn.functional.uniform_grid_gradient`. + implementation : {"warp", "torch"} or None + Explicit backend selection. When ``None``, rank-based backend dispatch + is used. + + Returns + ------- + torch.Tensor + Scalar curl with shape ``grid_shape`` for 2D inputs, or vector curl + with shape ``(3, *grid_shape)`` for 3D inputs. + """ + + _BENCHMARK_CASES = ( + ("2d-512x512-o2", (512, 512), (0.01, 0.02), 2), + ("2d-384x384-o4", (384, 384), (0.01, 0.02), 4), + ("3d-96x96x96-o2", (96, 96, 96), 0.02, 2), + ("3d-80x80x80-o4", (80, 80, 80), 0.02, 4), + ) + + _BACKWARD_CASES = ( + ("2d-grad-256x256-o2", (256, 256), (0.01, 0.02), 2), + ("2d-grad-192x192-o4", (192, 192), (0.01, 0.02), 4), + ("3d-grad-64x64x64-o2", (64, 64, 64), 0.02, 2), + ) + + _COMPARE_ATOL = 1e-5 + _COMPARE_RTOL = 1e-5 + + @FunctionSpec.register(name="warp", required_imports=("warp>=0.6.0",), rank=0) + def warp_forward( + vector_field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + ) -> torch.Tensor: + """Dispatch uniform-grid curl to the Warp backend.""" + return uniform_grid_curl_warp( + vector_field=vector_field, + spacing=spacing, + order=order, + ) + + @FunctionSpec.register(name="torch", rank=1, baseline=True) + def torch_forward( + vector_field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + ) -> torch.Tensor: + """Dispatch uniform-grid curl to eager PyTorch.""" + return uniform_grid_curl_torch( + vector_field=vector_field, + spacing=spacing, + order=order, + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield representative forward benchmark and parity input cases.""" + device = torch.device(device) + for label, shape, spacing, order in cls._BENCHMARK_CASES: + vector_field = _make_periodic_vector_field(shape, device=device) + yield label, (vector_field,), {"spacing": spacing, "order": order} + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield representative backward benchmark and parity input cases.""" + device = torch.device(device) + for label, shape, spacing, order in cls._BACKWARD_CASES: + vector_field = ( + _make_periodic_vector_field(shape, device=device) + .detach() + .clone() + .requires_grad_(True) + ) + yield label, (vector_field,), {"spacing": spacing, "order": order} + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare forward outputs across implementations.""" + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_ATOL, + rtol=cls._COMPARE_RTOL, + ) + + @classmethod + def compare_backward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare backward gradients across implementations.""" + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_ATOL, + rtol=cls._COMPARE_RTOL, + ) + + +def _make_periodic_vector_field( + shape: tuple[int, ...], + *, + device: torch.device, +) -> torch.Tensor: + """Construct smooth periodic vector fields for benchmark cases.""" + axes = tuple( + torch.arange(n, device=device, dtype=torch.float32) / float(n) for n in shape + ) + if len(shape) == 2: + x0, x1 = axes + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + return torch.stack( + ( + torch.sin(2.0 * torch.pi * yy), + torch.cos(2.0 * torch.pi * xx), + ), + dim=0, + ) + + x0, x1, x2 = axes + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + return torch.stack( + ( + torch.sin(2.0 * torch.pi * yy), + torch.cos(2.0 * torch.pi * zz), + torch.sin(2.0 * torch.pi * xx), + ), + dim=0, + ) + + +uniform_grid_curl = UniformGridCurl.make_function("uniform_grid_curl") + +__all__ = ["UniformGridCurl", "uniform_grid_curl"] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_curl/utils.py b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/utils.py new file mode 100644 index 0000000000..464eceaac4 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_curl/utils.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def validate_vector_field(vector_field: torch.Tensor) -> int: + """Validate a channel-first 2D/3D vector field.""" + if not torch.is_floating_point(vector_field): + raise TypeError("vector_field must be a floating-point tensor") + if vector_field.ndim not in (3, 4): + raise ValueError( + "uniform_grid_curl expects a channel-first 2D or 3D vector field with " + "shape (dim, *grid_shape)" + ) + + grid_ndim = vector_field.ndim - 1 + if vector_field.shape[0] != grid_ndim: + raise ValueError( + "vector_field.shape[0] must match grid dimensionality " + f"({grid_ndim}), got {vector_field.shape[0]}" + ) + return grid_ndim diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/__init__.py b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/__init__.py new file mode 100644 index 0000000000..75f417e589 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .uniform_grid_divergence import UniformGridDivergence, uniform_grid_divergence + +__all__ = ["UniformGridDivergence", "uniform_grid_divergence"] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/_torch_impl.py b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/_torch_impl.py new file mode 100644 index 0000000000..a3f2176a6f --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/_torch_impl.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from .utils import validate_vector_field + +_SUPPORTED_ORDERS = (2, 4) + + +def _normalize_spacing( + spacing: float | Sequence[float], ndim: int +) -> tuple[float, ...]: + if isinstance(spacing, (float, int)): + return tuple(float(spacing) for _ in range(ndim)) + spacing_tuple = tuple(float(x) for x in spacing) + if len(spacing_tuple) != ndim: + raise ValueError( + f"spacing must have {ndim} entries for a {ndim}D field, got {len(spacing_tuple)}" + ) + return spacing_tuple + + +def _validate_order(order: int) -> int: + if not isinstance(order, int): + raise TypeError(f"order must be an integer, got {type(order)}") + if order not in _SUPPORTED_ORDERS: + raise ValueError( + "uniform_grid_divergence supports central orders " + f"{list(_SUPPORTED_ORDERS)}, got order={order}" + ) + return order + + +def _central_derivative_order2( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + return ( + torch.roll(field, shifts=-1, dims=axis) - torch.roll(field, shifts=1, dims=axis) + ) / (2.0 * dx) + + +def _central_derivative_order4( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + return ( + -torch.roll(field, shifts=-2, dims=axis) + + 8.0 * torch.roll(field, shifts=-1, dims=axis) + - 8.0 * torch.roll(field, shifts=1, dims=axis) + + torch.roll(field, shifts=2, dims=axis) + ) / (12.0 * dx) + + +_DERIVATIVE_DISPATCH = { + 2: _central_derivative_order2, + 4: _central_derivative_order4, +} + + +def uniform_grid_divergence_torch( + vector_field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, +) -> torch.Tensor: + """Compute periodic uniform-grid divergence with PyTorch tensor ops.""" + grid_ndim = validate_vector_field(vector_field) + spacing_tuple = _normalize_spacing(spacing, grid_ndim) + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + derivative_fn = _DERIVATIVE_DISPATCH[_validate_order(order)] + + divergence = torch.zeros_like(vector_field[0]) + for axis, dx in enumerate(spacing_tuple): + divergence = divergence + derivative_fn(vector_field[axis], axis, dx) + return divergence diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/_warp_impl.py b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/_warp_impl.py new file mode 100644 index 0000000000..d42b7d9d46 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/_warp_impl.py @@ -0,0 +1,531 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import warp as wp + +from ..uniform_grid_gradient._warp_impl.utils import ( + _launch_dim, + _normalize_spacing, + _to_wp_tensor, + _warp_launch_context, + _wp_launch, + _wrap_minus1, + _wrap_minus2, + _wrap_plus1, + _wrap_plus2, +) +from .utils import validate_vector_field + +_SUPPORTED_ORDERS = (2, 4) + + +def _validate_order(order: int) -> int: + if not isinstance(order, int): + raise TypeError(f"order must be an integer, got {type(order)}") + if order not in _SUPPORTED_ORDERS: + raise ValueError( + "uniform_grid_divergence supports central orders " + f"{list(_SUPPORTED_ORDERS)}, got order={order}" + ) + return order + + +def _validate_positive_spacing(spacing_tuple: tuple[float, ...]) -> None: + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + + +def _to_fp32_contiguous(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.float32 and tensor.is_contiguous(): + return tensor + return tensor.to(dtype=torch.float32).contiguous() + + +def _restore_dtype(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + if tensor.dtype == target_dtype: + return tensor + return tensor.to(dtype=target_dtype) + + +@wp.kernel +def _divergence_1d_order2_kernel( + vector_field: wp.array2d(dtype=wp.float32), + inv_dx0: float, + output: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = output.shape[0] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + output[i] = (vector_field[0, ip] - vector_field[0, im]) * (0.5 * inv_dx0) + + +@wp.kernel +def _divergence_1d_order4_kernel( + vector_field: wp.array2d(dtype=wp.float32), + inv_dx0: float, + output: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = output.shape[0] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + output[i] = ( + -vector_field[0, ip2] + + 8.0 * vector_field[0, ip1] + - 8.0 * vector_field[0, im1] + + vector_field[0, im2] + ) * (inv_dx0 / 12.0) + + +@wp.kernel +def _divergence_2d_order2_kernel( + vector_field: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + output: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = output.shape[0] + n1 = output.shape[1] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + div_x = (vector_field[0, ip, j] - vector_field[0, im, j]) * (0.5 * inv_dx0) + div_y = (vector_field[1, i, jp] - vector_field[1, i, jm]) * (0.5 * inv_dx1) + output[i, j] = div_x + div_y + + +@wp.kernel +def _divergence_2d_order4_kernel( + vector_field: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + output: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = output.shape[0] + n1 = output.shape[1] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + div_x = ( + -vector_field[0, ip2, j] + + 8.0 * vector_field[0, ip1, j] + - 8.0 * vector_field[0, im1, j] + + vector_field[0, im2, j] + ) * (inv_dx0 / 12.0) + div_y = ( + -vector_field[1, i, jp2] + + 8.0 * vector_field[1, i, jp1] + - 8.0 * vector_field[1, i, jm1] + + vector_field[1, i, jm2] + ) * (inv_dx1 / 12.0) + output[i, j] = div_x + div_y + + +@wp.kernel +def _divergence_3d_order2_kernel( + vector_field: wp.array4d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + output: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = output.shape[0] + n1 = output.shape[1] + n2 = output.shape[2] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + div_x = (vector_field[0, ip, j, k] - vector_field[0, im, j, k]) * (0.5 * inv_dx0) + div_y = (vector_field[1, i, jp, k] - vector_field[1, i, jm, k]) * (0.5 * inv_dx1) + div_z = (vector_field[2, i, j, kp] - vector_field[2, i, j, km]) * (0.5 * inv_dx2) + output[i, j, k] = div_x + div_y + div_z + + +@wp.kernel +def _divergence_3d_order4_kernel( + vector_field: wp.array4d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + output: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = output.shape[0] + n1 = output.shape[1] + n2 = output.shape[2] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + km1 = _wrap_minus1(k, n2) + kp1 = _wrap_plus1(k, n2) + km2 = _wrap_minus2(k, n2) + kp2 = _wrap_plus2(k, n2) + div_x = ( + -vector_field[0, ip2, j, k] + + 8.0 * vector_field[0, ip1, j, k] + - 8.0 * vector_field[0, im1, j, k] + + vector_field[0, im2, j, k] + ) * (inv_dx0 / 12.0) + div_y = ( + -vector_field[1, i, jp2, k] + + 8.0 * vector_field[1, i, jp1, k] + - 8.0 * vector_field[1, i, jm1, k] + + vector_field[1, i, jm2, k] + ) * (inv_dx1 / 12.0) + div_z = ( + -vector_field[2, i, j, kp2] + + 8.0 * vector_field[2, i, j, kp1] + - 8.0 * vector_field[2, i, j, km1] + + vector_field[2, i, j, km2] + ) * (inv_dx2 / 12.0) + output[i, j, k] = div_x + div_y + div_z + + +@wp.kernel +def _divergence_backward_1d_order2_kernel( + grad_output: wp.array(dtype=wp.float32), + inv_dx0: float, + grad_vector: wp.array2d(dtype=wp.float32), +): + i = wp.tid() + n0 = grad_output.shape[0] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + grad_vector[0, i] = (grad_output[im] - grad_output[ip]) * (0.5 * inv_dx0) + + +@wp.kernel +def _divergence_backward_1d_order4_kernel( + grad_output: wp.array(dtype=wp.float32), + inv_dx0: float, + grad_vector: wp.array2d(dtype=wp.float32), +): + i = wp.tid() + n0 = grad_output.shape[0] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + grad_vector[0, i] = ( + grad_output[ip2] + - 8.0 * grad_output[ip1] + + 8.0 * grad_output[im1] + - grad_output[im2] + ) * (inv_dx0 / 12.0) + + +@wp.kernel +def _divergence_backward_2d_order2_kernel( + grad_output: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + grad_vector: wp.array3d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad_output.shape[0] + n1 = grad_output.shape[1] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + grad_vector[0, i, j] = (grad_output[im, j] - grad_output[ip, j]) * (0.5 * inv_dx0) + grad_vector[1, i, j] = (grad_output[i, jm] - grad_output[i, jp]) * (0.5 * inv_dx1) + + +@wp.kernel +def _divergence_backward_2d_order4_kernel( + grad_output: wp.array2d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + grad_vector: wp.array3d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = grad_output.shape[0] + n1 = grad_output.shape[1] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + grad_vector[0, i, j] = ( + grad_output[ip2, j] + - 8.0 * grad_output[ip1, j] + + 8.0 * grad_output[im1, j] + - grad_output[im2, j] + ) * (inv_dx0 / 12.0) + grad_vector[1, i, j] = ( + grad_output[i, jp2] + - 8.0 * grad_output[i, jp1] + + 8.0 * grad_output[i, jm1] + - grad_output[i, jm2] + ) * (inv_dx1 / 12.0) + + +@wp.kernel +def _divergence_backward_3d_order2_kernel( + grad_output: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + grad_vector: wp.array4d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad_output.shape[0] + n1 = grad_output.shape[1] + n2 = grad_output.shape[2] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + grad_vector[0, i, j, k] = (grad_output[im, j, k] - grad_output[ip, j, k]) * ( + 0.5 * inv_dx0 + ) + grad_vector[1, i, j, k] = (grad_output[i, jm, k] - grad_output[i, jp, k]) * ( + 0.5 * inv_dx1 + ) + grad_vector[2, i, j, k] = (grad_output[i, j, km] - grad_output[i, j, kp]) * ( + 0.5 * inv_dx2 + ) + + +@wp.kernel +def _divergence_backward_3d_order4_kernel( + grad_output: wp.array3d(dtype=wp.float32), + inv_dx0: float, + inv_dx1: float, + inv_dx2: float, + grad_vector: wp.array4d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = grad_output.shape[0] + n1 = grad_output.shape[1] + n2 = grad_output.shape[2] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + km1 = _wrap_minus1(k, n2) + kp1 = _wrap_plus1(k, n2) + km2 = _wrap_minus2(k, n2) + kp2 = _wrap_plus2(k, n2) + grad_vector[0, i, j, k] = ( + grad_output[ip2, j, k] + - 8.0 * grad_output[ip1, j, k] + + 8.0 * grad_output[im1, j, k] + - grad_output[im2, j, k] + ) * (inv_dx0 / 12.0) + grad_vector[1, i, j, k] = ( + grad_output[i, jp2, k] + - 8.0 * grad_output[i, jp1, k] + + 8.0 * grad_output[i, jm1, k] + - grad_output[i, jm2, k] + ) * (inv_dx1 / 12.0) + grad_vector[2, i, j, k] = ( + grad_output[i, j, kp2] + - 8.0 * grad_output[i, j, kp1] + + 8.0 * grad_output[i, j, km1] + - grad_output[i, j, km2] + ) * (inv_dx2 / 12.0) + + +_FORWARD_KERNELS = { + (1, 2): _divergence_1d_order2_kernel, + (1, 4): _divergence_1d_order4_kernel, + (2, 2): _divergence_2d_order2_kernel, + (2, 4): _divergence_2d_order4_kernel, + (3, 2): _divergence_3d_order2_kernel, + (3, 4): _divergence_3d_order4_kernel, +} +_BACKWARD_KERNELS = { + (1, 2): _divergence_backward_1d_order2_kernel, + (1, 4): _divergence_backward_1d_order4_kernel, + (2, 2): _divergence_backward_2d_order2_kernel, + (2, 4): _divergence_backward_2d_order4_kernel, + (3, 2): _divergence_backward_3d_order2_kernel, + (3, 4): _divergence_backward_3d_order4_kernel, +} + + +def _launch_divergence_forward( + *, + vector_field_fp32: torch.Tensor, + spacing_tuple: tuple[float, ...], + order: int, + output_fp32: torch.Tensor, +) -> None: + wp_device, wp_stream = _warp_launch_context(vector_field_fp32) + _wp_launch( + kernel=_FORWARD_KERNELS[(vector_field_fp32.ndim - 1, order)], + dim=_launch_dim(output_fp32.shape), + inputs=[ + _to_wp_tensor(vector_field_fp32), + *[1.0 / float(dx) for dx in spacing_tuple], + _to_wp_tensor(output_fp32), + ], + device=wp_device, + stream=wp_stream, + ) + + +def _launch_divergence_backward( + *, + grad_output_fp32: torch.Tensor, + spacing_tuple: tuple[float, ...], + order: int, + grad_vector_fp32: torch.Tensor, +) -> None: + wp_device, wp_stream = _warp_launch_context(grad_output_fp32) + _wp_launch( + kernel=_BACKWARD_KERNELS[(grad_output_fp32.ndim, order)], + dim=_launch_dim(grad_output_fp32.shape), + inputs=[ + _to_wp_tensor(grad_output_fp32), + *[1.0 / float(dx) for dx in spacing_tuple], + _to_wp_tensor(grad_vector_fp32), + ], + device=wp_device, + stream=wp_stream, + ) + + +@torch.library.custom_op( + "physicsnemo::uniform_grid_divergence_warp_impl", mutates_args=() +) +def uniform_grid_divergence_impl( + vector_field: torch.Tensor, + spacing_meta: torch.Tensor, + order: int, +) -> torch.Tensor: + """Evaluate uniform-grid divergence with fused Warp kernels.""" + grid_ndim = validate_vector_field(vector_field) + spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + _validate_positive_spacing(spacing_tuple) + order = _validate_order(int(order)) + orig_dtype = vector_field.dtype + vector_field_fp32 = _to_fp32_contiguous(vector_field) + output_fp32 = torch.empty( + vector_field_fp32.shape[1:], + device=vector_field_fp32.device, + dtype=torch.float32, + ) + _launch_divergence_forward( + vector_field_fp32=vector_field_fp32, + spacing_tuple=spacing_tuple[:grid_ndim], + order=order, + output_fp32=output_fp32, + ) + return _restore_dtype(output_fp32, orig_dtype) + + +@uniform_grid_divergence_impl.register_fake +def _uniform_grid_divergence_impl_fake( + vector_field: torch.Tensor, + spacing_meta: torch.Tensor, + order: int, +) -> torch.Tensor: + _ = (spacing_meta, order) + return torch.empty( + vector_field.shape[1:], + device=vector_field.device, + dtype=vector_field.dtype, + ) + + +def setup_uniform_grid_divergence_context( + ctx: torch.autograd.function.FunctionCtx, + inputs: tuple, + output: torch.Tensor, +) -> None: + """Save uniform-grid divergence metadata for the backward pass.""" + vector_field, spacing_meta, order = inputs + _ = output + ctx.spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + ctx.order = int(order) + ctx.orig_dtype = vector_field.dtype + + +def backward_uniform_grid_divergence( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None]: + if grad_output is None or not ctx.needs_input_grad[0]: + return None, None, None + grad_output_fp32 = _to_fp32_contiguous(grad_output) + grad_vector_fp32 = torch.empty( + (grad_output_fp32.ndim, *grad_output_fp32.shape), + device=grad_output_fp32.device, + dtype=torch.float32, + ) + _launch_divergence_backward( + grad_output_fp32=grad_output_fp32, + spacing_tuple=ctx.spacing_tuple[: grad_output_fp32.ndim], + order=ctx.order, + grad_vector_fp32=grad_vector_fp32, + ) + return _restore_dtype(grad_vector_fp32, ctx.orig_dtype), None, None + + +uniform_grid_divergence_impl.register_autograd( + backward_uniform_grid_divergence, + setup_context=setup_uniform_grid_divergence_context, +) + + +def uniform_grid_divergence_warp( + vector_field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, +) -> torch.Tensor: + """Compute periodic uniform-grid divergence with a fused Warp custom op.""" + grid_ndim = vector_field.ndim - 1 + spacing_tuple = _normalize_spacing(spacing, grid_ndim) + spacing_meta = torch.tensor(spacing_tuple, dtype=torch.float32, device="cpu") + return uniform_grid_divergence_impl(vector_field, spacing_meta, int(order)) diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/uniform_grid_divergence.py b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/uniform_grid_divergence.py new file mode 100644 index 0000000000..08c21f21cd --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/uniform_grid_divergence.py @@ -0,0 +1,184 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from ._torch_impl import uniform_grid_divergence_torch +from ._warp_impl import uniform_grid_divergence_warp + + +class UniformGridDivergence(FunctionSpec): + r"""Compute periodic divergence on a uniform grid. + + This functional accepts channel-first vector fields with shape + ``(dim, *grid_shape)`` where ``dim`` matches the 1D/2D/3D grid + dimensionality. Divergence is computed as the trace of the Jacobian, + + .. math:: + + \nabla \cdot u = \sum_i \partial_i u_i. + + Parameters + ---------- + vector_field : torch.Tensor + Channel-first vector field with shape ``(dim, *grid_shape)``. + spacing : float | Sequence[float], optional + Uniform spacing per grid axis. A scalar applies the same spacing to + every axis. + order : int, optional + Central-difference accuracy order. Supported values match + :func:`physicsnemo.nn.functional.uniform_grid_gradient`. + implementation : {"warp", "torch"} or None + Explicit backend selection. When ``None``, rank-based backend dispatch + is used. + + Returns + ------- + torch.Tensor + Scalar divergence field with shape ``grid_shape``. + """ + + _BENCHMARK_CASES = ( + ("1d-n8192-o2", (8192,), 0.01, 2), + ("1d-n8192-o4", (8192,), 0.01, 4), + ("2d-512x512-o2", (512, 512), (0.01, 0.02), 2), + ("2d-384x384-o4", (384, 384), (0.01, 0.02), 4), + ("3d-96x96x96-o2", (96, 96, 96), 0.02, 2), + ) + + _BACKWARD_CASES = ( + ("1d-grad-n4096-o2", (4096,), 0.01, 2), + ("2d-grad-256x256-o2", (256, 256), (0.01, 0.02), 2), + ("2d-grad-192x192-o4", (192, 192), (0.01, 0.02), 4), + ("3d-grad-64x64x64-o2", (64, 64, 64), 0.02, 2), + ) + + _COMPARE_ATOL = 1e-5 + _COMPARE_RTOL = 1e-5 + + @FunctionSpec.register(name="warp", required_imports=("warp>=0.6.0",), rank=0) + def warp_forward( + vector_field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + ) -> torch.Tensor: + """Dispatch uniform-grid divergence to the Warp backend.""" + return uniform_grid_divergence_warp( + vector_field=vector_field, + spacing=spacing, + order=order, + ) + + @FunctionSpec.register(name="torch", rank=1, baseline=True) + def torch_forward( + vector_field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + ) -> torch.Tensor: + """Dispatch uniform-grid divergence to eager PyTorch.""" + return uniform_grid_divergence_torch( + vector_field=vector_field, + spacing=spacing, + order=order, + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield representative forward benchmark and parity input cases.""" + device = torch.device(device) + for label, shape, spacing, order in cls._BENCHMARK_CASES: + vector_field = _make_periodic_vector_field(shape, device=device) + yield label, (vector_field,), {"spacing": spacing, "order": order} + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield representative backward benchmark and parity input cases.""" + device = torch.device(device) + for label, shape, spacing, order in cls._BACKWARD_CASES: + vector_field = ( + _make_periodic_vector_field(shape, device=device) + .detach() + .clone() + .requires_grad_(True) + ) + yield label, (vector_field,), {"spacing": spacing, "order": order} + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare forward outputs across implementations.""" + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_ATOL, + rtol=cls._COMPARE_RTOL, + ) + + @classmethod + def compare_backward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare backward gradients across implementations.""" + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_ATOL, + rtol=cls._COMPARE_RTOL, + ) + + +def _make_periodic_vector_field( + shape: tuple[int, ...], + *, + device: torch.device, +) -> torch.Tensor: + """Construct smooth periodic vector fields for benchmark cases.""" + axes = tuple( + torch.arange(n, device=device, dtype=torch.float32) / float(n) for n in shape + ) + if len(shape) == 1: + (x0,) = axes + return torch.sin(2.0 * torch.pi * x0).unsqueeze(0) + + if len(shape) == 2: + x0, x1 = axes + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + return torch.stack( + ( + torch.sin(2.0 * torch.pi * xx), + torch.cos(2.0 * torch.pi * yy), + ), + dim=0, + ) + + x0, x1, x2 = axes + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + return torch.stack( + ( + torch.sin(2.0 * torch.pi * xx), + torch.cos(2.0 * torch.pi * yy), + 0.5 * torch.sin(2.0 * torch.pi * zz), + ), + dim=0, + ) + + +uniform_grid_divergence = UniformGridDivergence.make_function("uniform_grid_divergence") + +__all__ = ["UniformGridDivergence", "uniform_grid_divergence"] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/utils.py b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/utils.py new file mode 100644 index 0000000000..6d1065fbbe --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_divergence/utils.py @@ -0,0 +1,36 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def validate_vector_field(vector_field: torch.Tensor) -> int: + """Validate a channel-first vector field and return grid dimensionality.""" + if not torch.is_floating_point(vector_field): + raise TypeError("vector_field must be a floating-point tensor") + if vector_field.ndim < 2 or vector_field.ndim > 4: + raise ValueError( + "uniform_grid_divergence expects a channel-first 1D-3D vector field with " + "shape (dim, *grid_shape)" + ) + + grid_ndim = vector_field.ndim - 1 + if vector_field.shape[0] != grid_ndim: + raise ValueError( + "vector_field.shape[0] must match grid dimensionality " + f"({grid_ndim}), got {vector_field.shape[0]}" + ) + return grid_ndim diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/utils.py b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/utils.py index b92dc2ebc4..bd814b8550 100644 --- a/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/utils.py +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_gradient/_warp_impl/utils.py @@ -138,6 +138,30 @@ def _launch_dim(shape: torch.Size) -> int | tuple[int, ...]: return shape[0] if len(shape) == 1 else tuple(shape) +@wp.func +def _wrap_plus1(i: int, n: int) -> int: + """Wrap a grid index one cell forward for periodic stencils.""" + return (i + 1) % n + + +@wp.func +def _wrap_minus1(i: int, n: int) -> int: + """Wrap a grid index one cell backward for periodic stencils.""" + return (i + n - 1) % n + + +@wp.func +def _wrap_plus2(i: int, n: int) -> int: + """Wrap a grid index two cells forward for periodic stencils.""" + return (i + 2) % n + + +@wp.func +def _wrap_minus2(i: int, n: int) -> int: + """Wrap a grid index two cells backward for periodic stencils.""" + return (i + n - 2) % n + + def _inverse_spacings( spacing_tuple: tuple[float, ...], *, diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/__init__.py b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/__init__.py new file mode 100644 index 0000000000..e909faf0ae --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/__init__.py @@ -0,0 +1,19 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .uniform_grid_laplacian import UniformGridLaplacian, uniform_grid_laplacian + +__all__ = ["UniformGridLaplacian", "uniform_grid_laplacian"] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/_torch_impl.py b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/_torch_impl.py new file mode 100644 index 0000000000..a86acc45f1 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/_torch_impl.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from .utils import validate_scalar_field + +_SUPPORTED_ORDERS = (2, 4) + + +def _normalize_spacing( + spacing: float | Sequence[float], ndim: int +) -> tuple[float, ...]: + if isinstance(spacing, (float, int)): + return tuple(float(spacing) for _ in range(ndim)) + spacing_tuple = tuple(float(x) for x in spacing) + if len(spacing_tuple) != ndim: + raise ValueError( + f"spacing must have {ndim} entries for a {ndim}D field, got {len(spacing_tuple)}" + ) + return spacing_tuple + + +def _validate_order(order: int) -> int: + if not isinstance(order, int): + raise TypeError(f"order must be an integer, got {type(order)}") + if order not in _SUPPORTED_ORDERS: + raise ValueError( + "uniform_grid_laplacian supports central orders " + f"{list(_SUPPORTED_ORDERS)}, got order={order}" + ) + return order + + +def _second_derivative_order2( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + return ( + torch.roll(field, shifts=-1, dims=axis) + - 2.0 * field + + torch.roll(field, shifts=1, dims=axis) + ) / (dx * dx) + + +def _second_derivative_order4( + field: torch.Tensor, axis: int, dx: float +) -> torch.Tensor: + return ( + -torch.roll(field, shifts=-2, dims=axis) + + 16.0 * torch.roll(field, shifts=-1, dims=axis) + - 30.0 * field + + 16.0 * torch.roll(field, shifts=1, dims=axis) + - torch.roll(field, shifts=2, dims=axis) + ) / (12.0 * dx * dx) + + +_DERIVATIVE_DISPATCH = { + 2: _second_derivative_order2, + 4: _second_derivative_order4, +} + + +def uniform_grid_laplacian_torch( + field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, +) -> torch.Tensor: + """Compute periodic uniform-grid Laplacian with PyTorch tensor ops.""" + validate_scalar_field(field) + spacing_tuple = _normalize_spacing(spacing, field.ndim) + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + derivative_fn = _DERIVATIVE_DISPATCH[_validate_order(order)] + + laplacian = torch.zeros_like(field) + for axis, dx in enumerate(spacing_tuple): + laplacian = laplacian + derivative_fn(field, axis, dx) + return laplacian diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/_warp_impl.py b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/_warp_impl.py new file mode 100644 index 0000000000..215e8ba524 --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/_warp_impl.py @@ -0,0 +1,340 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch +import warp as wp + +from ..uniform_grid_gradient._warp_impl.utils import ( + _launch_dim, + _normalize_spacing, + _to_wp_tensor, + _warp_launch_context, + _wp_launch, + _wrap_minus1, + _wrap_minus2, + _wrap_plus1, + _wrap_plus2, +) +from .utils import validate_scalar_field + +_SUPPORTED_ORDERS = (2, 4) + + +def _validate_order(order: int) -> int: + if not isinstance(order, int): + raise TypeError(f"order must be an integer, got {type(order)}") + if order not in _SUPPORTED_ORDERS: + raise ValueError( + "uniform_grid_laplacian supports central orders " + f"{list(_SUPPORTED_ORDERS)}, got order={order}" + ) + return order + + +def _validate_positive_spacing(spacing_tuple: tuple[float, ...]) -> None: + for dx in spacing_tuple: + if dx <= 0.0: + raise ValueError("all spacing entries must be strictly positive") + + +def _to_fp32_contiguous(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype == torch.float32 and tensor.is_contiguous(): + return tensor + return tensor.to(dtype=torch.float32).contiguous() + + +def _restore_dtype(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + if tensor.dtype == target_dtype: + return tensor + return tensor.to(dtype=target_dtype) + + +@wp.kernel +def _laplacian_1d_order2_kernel( + field: wp.array(dtype=wp.float32), + inv_dx0_sq: float, + output: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + output[i] = (field[ip] - 2.0 * field[i] + field[im]) * inv_dx0_sq + + +@wp.kernel +def _laplacian_1d_order4_kernel( + field: wp.array(dtype=wp.float32), + inv_dx0_sq: float, + output: wp.array(dtype=wp.float32), +): + i = wp.tid() + n0 = field.shape[0] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + output[i] = ( + -field[ip2] + + 16.0 * field[ip1] + - 30.0 * field[i] + + 16.0 * field[im1] + - field[im2] + ) * (inv_dx0_sq / 12.0) + + +@wp.kernel +def _laplacian_2d_order2_kernel( + field: wp.array2d(dtype=wp.float32), + inv_dx0_sq: float, + inv_dx1_sq: float, + output: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + d2x = (field[ip, j] - 2.0 * field[i, j] + field[im, j]) * inv_dx0_sq + d2y = (field[i, jp] - 2.0 * field[i, j] + field[i, jm]) * inv_dx1_sq + output[i, j] = d2x + d2y + + +@wp.kernel +def _laplacian_2d_order4_kernel( + field: wp.array2d(dtype=wp.float32), + inv_dx0_sq: float, + inv_dx1_sq: float, + output: wp.array2d(dtype=wp.float32), +): + i, j = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + d2x = ( + -field[ip2, j] + + 16.0 * field[ip1, j] + - 30.0 * field[i, j] + + 16.0 * field[im1, j] + - field[im2, j] + ) * (inv_dx0_sq / 12.0) + d2y = ( + -field[i, jp2] + + 16.0 * field[i, jp1] + - 30.0 * field[i, j] + + 16.0 * field[i, jm1] + - field[i, jm2] + ) * (inv_dx1_sq / 12.0) + output[i, j] = d2x + d2y + + +@wp.kernel +def _laplacian_3d_order2_kernel( + field: wp.array3d(dtype=wp.float32), + inv_dx0_sq: float, + inv_dx1_sq: float, + inv_dx2_sq: float, + output: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + im = _wrap_minus1(i, n0) + ip = _wrap_plus1(i, n0) + jm = _wrap_minus1(j, n1) + jp = _wrap_plus1(j, n1) + km = _wrap_minus1(k, n2) + kp = _wrap_plus1(k, n2) + d2x = (field[ip, j, k] - 2.0 * field[i, j, k] + field[im, j, k]) * inv_dx0_sq + d2y = (field[i, jp, k] - 2.0 * field[i, j, k] + field[i, jm, k]) * inv_dx1_sq + d2z = (field[i, j, kp] - 2.0 * field[i, j, k] + field[i, j, km]) * inv_dx2_sq + output[i, j, k] = d2x + d2y + d2z + + +@wp.kernel +def _laplacian_3d_order4_kernel( + field: wp.array3d(dtype=wp.float32), + inv_dx0_sq: float, + inv_dx1_sq: float, + inv_dx2_sq: float, + output: wp.array3d(dtype=wp.float32), +): + i, j, k = wp.tid() + n0 = field.shape[0] + n1 = field.shape[1] + n2 = field.shape[2] + im1 = _wrap_minus1(i, n0) + ip1 = _wrap_plus1(i, n0) + im2 = _wrap_minus2(i, n0) + ip2 = _wrap_plus2(i, n0) + jm1 = _wrap_minus1(j, n1) + jp1 = _wrap_plus1(j, n1) + jm2 = _wrap_minus2(j, n1) + jp2 = _wrap_plus2(j, n1) + km1 = _wrap_minus1(k, n2) + kp1 = _wrap_plus1(k, n2) + km2 = _wrap_minus2(k, n2) + kp2 = _wrap_plus2(k, n2) + d2x = ( + -field[ip2, j, k] + + 16.0 * field[ip1, j, k] + - 30.0 * field[i, j, k] + + 16.0 * field[im1, j, k] + - field[im2, j, k] + ) * (inv_dx0_sq / 12.0) + d2y = ( + -field[i, jp2, k] + + 16.0 * field[i, jp1, k] + - 30.0 * field[i, j, k] + + 16.0 * field[i, jm1, k] + - field[i, jm2, k] + ) * (inv_dx1_sq / 12.0) + d2z = ( + -field[i, j, kp2] + + 16.0 * field[i, j, kp1] + - 30.0 * field[i, j, k] + + 16.0 * field[i, j, km1] + - field[i, j, km2] + ) * (inv_dx2_sq / 12.0) + output[i, j, k] = d2x + d2y + d2z + + +_LAPLACIAN_KERNELS = { + (1, 2): _laplacian_1d_order2_kernel, + (1, 4): _laplacian_1d_order4_kernel, + (2, 2): _laplacian_2d_order2_kernel, + (2, 4): _laplacian_2d_order4_kernel, + (3, 2): _laplacian_3d_order2_kernel, + (3, 4): _laplacian_3d_order4_kernel, +} + + +def _launch_laplacian( + *, + field_fp32: torch.Tensor, + spacing_tuple: tuple[float, ...], + order: int, + output_fp32: torch.Tensor, +) -> None: + inv_sq = [1.0 / float(dx * dx) for dx in spacing_tuple] + wp_device, wp_stream = _warp_launch_context(field_fp32) + _wp_launch( + kernel=_LAPLACIAN_KERNELS[(field_fp32.ndim, order)], + dim=_launch_dim(field_fp32.shape), + inputs=[ + _to_wp_tensor(field_fp32), + *inv_sq, + _to_wp_tensor(output_fp32), + ], + device=wp_device, + stream=wp_stream, + ) + + +@torch.library.custom_op( + "physicsnemo::uniform_grid_laplacian_warp_impl", mutates_args=() +) +def uniform_grid_laplacian_impl( + field: torch.Tensor, + spacing_meta: torch.Tensor, + order: int, +) -> torch.Tensor: + """Evaluate uniform-grid Laplacian with fused Warp kernels.""" + validate_scalar_field(field) + spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + _validate_positive_spacing(spacing_tuple) + order = _validate_order(int(order)) + orig_dtype = field.dtype + field_fp32 = _to_fp32_contiguous(field) + output_fp32 = torch.empty_like(field_fp32) + _launch_laplacian( + field_fp32=field_fp32, + spacing_tuple=spacing_tuple, + order=order, + output_fp32=output_fp32, + ) + return _restore_dtype(output_fp32, orig_dtype) + + +@uniform_grid_laplacian_impl.register_fake +def _uniform_grid_laplacian_impl_fake( + field: torch.Tensor, + spacing_meta: torch.Tensor, + order: int, +) -> torch.Tensor: + _ = (spacing_meta, order) + return torch.empty_like(field) + + +def setup_uniform_grid_laplacian_context( + ctx: torch.autograd.function.FunctionCtx, + inputs: tuple, + output: torch.Tensor, +) -> None: + """Save uniform-grid Laplacian metadata for the backward pass.""" + field, spacing_meta, order = inputs + _ = output + ctx.spacing_tuple = tuple(float(v) for v in spacing_meta.tolist()) + ctx.order = int(order) + ctx.orig_dtype = field.dtype + + +def backward_uniform_grid_laplacian( + ctx: torch.autograd.function.FunctionCtx, + grad_output: torch.Tensor, +) -> tuple[torch.Tensor | None, None, None]: + if grad_output is None or not ctx.needs_input_grad[0]: + return None, None, None + grad_output_fp32 = _to_fp32_contiguous(grad_output) + grad_field = torch.empty_like(grad_output_fp32) + _launch_laplacian( + field_fp32=grad_output_fp32, + spacing_tuple=ctx.spacing_tuple, + order=ctx.order, + output_fp32=grad_field, + ) + return _restore_dtype(grad_field, ctx.orig_dtype), None, None + + +uniform_grid_laplacian_impl.register_autograd( + backward_uniform_grid_laplacian, + setup_context=setup_uniform_grid_laplacian_context, +) + + +def uniform_grid_laplacian_warp( + field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, +) -> torch.Tensor: + """Compute periodic uniform-grid Laplacian with a fused Warp custom op.""" + spacing_tuple = _normalize_spacing(spacing, field.ndim) + spacing_meta = torch.tensor(spacing_tuple, dtype=torch.float32, device="cpu") + return uniform_grid_laplacian_impl(field, spacing_meta, int(order)) diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/uniform_grid_laplacian.py b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/uniform_grid_laplacian.py new file mode 100644 index 0000000000..13c3da6eab --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/uniform_grid_laplacian.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections.abc import Sequence + +import torch + +from physicsnemo.core.function_spec import FunctionSpec + +from ._torch_impl import uniform_grid_laplacian_torch +from ._warp_impl import uniform_grid_laplacian_warp + + +class UniformGridLaplacian(FunctionSpec): + r"""Compute periodic Laplacians on a uniform grid. + + This functional accepts scalar fields defined on a 1D/2D/3D uniform + Cartesian grid and computes the trace of the Hessian, + + .. math:: + + \nabla^2 f = \sum_i \partial_{ii} f. + + Parameters + ---------- + field : torch.Tensor + Scalar grid field with shape ``(n0,)``, ``(n0,n1)``, or ``(n0,n1,n2)``. + spacing : float | Sequence[float], optional + Uniform spacing per grid axis. A scalar applies the same spacing to + every axis. + order : int, optional + Central-difference accuracy order. Supported values match + :func:`physicsnemo.nn.functional.uniform_grid_gradient`. + implementation : {"warp", "torch"} or None + Explicit backend selection. When ``None``, rank-based backend dispatch + is used. + + Returns + ------- + torch.Tensor + Scalar Laplacian field with the same shape as ``field``. + """ + + _BENCHMARK_CASES = ( + ("1d-n8192-o2", (8192,), 0.01, 2), + ("1d-n8192-o4", (8192,), 0.01, 4), + ("2d-512x512-o2", (512, 512), (0.01, 0.02), 2), + ("2d-384x384-o4", (384, 384), (0.01, 0.02), 4), + ("3d-96x96x96-o2", (96, 96, 96), 0.02, 2), + ) + + _BACKWARD_CASES = ( + ("1d-grad-n4096-o2", (4096,), 0.01, 2), + ("2d-grad-256x256-o2", (256, 256), (0.01, 0.02), 2), + ("2d-grad-192x192-o4", (192, 192), (0.01, 0.02), 4), + ("3d-grad-64x64x64-o2", (64, 64, 64), 0.02, 2), + ) + + # Fourth-order second-derivative stencils subtract nearly equal values. + # The fused Warp kernel and composed Torch reference take different + # float32 rounding paths on CUDA, so Laplacian parity needs a small + # absolute tolerance while the analytic tests keep physical accuracy + # coverage separate. + _COMPARE_ATOL = 5e-3 + _COMPARE_RTOL = 1e-5 + _COMPARE_BACKWARD_ATOL = 1e-2 + _COMPARE_BACKWARD_RTOL = 1e-5 + + @FunctionSpec.register(name="warp", required_imports=("warp>=0.6.0",), rank=0) + def warp_forward( + field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + ) -> torch.Tensor: + """Dispatch uniform-grid Laplacian to the Warp backend.""" + return uniform_grid_laplacian_warp( + field=field, + spacing=spacing, + order=order, + ) + + @FunctionSpec.register(name="torch", rank=1, baseline=True) + def torch_forward( + field: torch.Tensor, + spacing: float | Sequence[float] = 1.0, + order: int = 2, + ) -> torch.Tensor: + """Dispatch uniform-grid Laplacian to eager PyTorch.""" + return uniform_grid_laplacian_torch( + field=field, + spacing=spacing, + order=order, + ) + + @classmethod + def make_inputs_forward(cls, device: torch.device | str = "cpu"): + """Yield representative forward benchmark and parity input cases.""" + device = torch.device(device) + for label, shape, spacing, order in cls._BENCHMARK_CASES: + field = _make_periodic_scalar_field(shape, device=device) + yield label, (field,), {"spacing": spacing, "order": order} + + @classmethod + def make_inputs_backward(cls, device: torch.device | str = "cpu"): + """Yield representative backward benchmark and parity input cases.""" + device = torch.device(device) + for label, shape, spacing, order in cls._BACKWARD_CASES: + field = ( + _make_periodic_scalar_field(shape, device=device) + .detach() + .clone() + .requires_grad_(True) + ) + yield label, (field,), {"spacing": spacing, "order": order} + + @classmethod + def compare_forward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare forward outputs across implementations.""" + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_ATOL, + rtol=cls._COMPARE_RTOL, + ) + + @classmethod + def compare_backward(cls, output: torch.Tensor, reference: torch.Tensor) -> None: + """Compare backward gradients across implementations.""" + torch.testing.assert_close( + output, + reference, + atol=cls._COMPARE_BACKWARD_ATOL, + rtol=cls._COMPARE_BACKWARD_RTOL, + ) + + +def _make_periodic_scalar_field( + shape: tuple[int, ...], + *, + device: torch.device, +) -> torch.Tensor: + """Construct smooth periodic scalar fields for benchmark cases.""" + axes = tuple( + torch.arange(n, device=device, dtype=torch.float32) / float(n) for n in shape + ) + if len(shape) == 1: + (x0,) = axes + return torch.sin(2.0 * torch.pi * x0) + + if len(shape) == 2: + x0, x1 = axes + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + return torch.sin(2.0 * torch.pi * xx) + 0.5 * torch.cos(2.0 * torch.pi * yy) + + x0, x1, x2 = axes + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + return ( + torch.sin(2.0 * torch.pi * xx) + + 0.5 * torch.cos(2.0 * torch.pi * yy) + + 0.25 * torch.sin(2.0 * torch.pi * zz) + ) + + +uniform_grid_laplacian = UniformGridLaplacian.make_function("uniform_grid_laplacian") + +__all__ = ["UniformGridLaplacian", "uniform_grid_laplacian"] diff --git a/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/utils.py b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/utils.py new file mode 100644 index 0000000000..00e9770b1e --- /dev/null +++ b/physicsnemo/nn/functional/derivatives/uniform_grid_laplacian/utils.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def validate_scalar_field(field: torch.Tensor) -> None: + """Validate a scalar 1D/2D/3D grid field.""" + if not torch.is_floating_point(field): + raise TypeError("field must be a floating-point tensor") + if field.ndim < 1 or field.ndim > 3: + raise ValueError( + "uniform_grid_laplacian expects a scalar 1D-3D grid field with shape " + "(n0,), (n0,n1), or (n0,n1,n2)" + ) diff --git a/test/nn/functional/derivatives/test_uniform_grid_curl.py b/test/nn/functional/derivatives/test_uniform_grid_curl.py new file mode 100644 index 0000000000..3b7791febd --- /dev/null +++ b/test/nn/functional/derivatives/test_uniform_grid_curl.py @@ -0,0 +1,225 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from physicsnemo.nn.functional import uniform_grid_curl +from physicsnemo.nn.functional.derivatives import UniformGridCurl +from test.conftest import requires_module +from test.nn.functional._parity_utils import clone_case + + +def _make_periodic_vector_field(device: str, dims: int): + torch_device = torch.device(device) + wave_number = 2.0 * torch.pi + + if dims == 2: + n0, n1 = 160, 144 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float32) / float(n1) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + vector_field = torch.stack( + ( + torch.sin(wave_number * yy), + torch.cos(wave_number * xx), + ), + dim=0, + ) + spacing = (1.0 / float(n0), 1.0 / float(n1)) + expected = -wave_number * torch.sin(wave_number * xx) - wave_number * torch.cos( + wave_number * yy + ) + return vector_field, spacing, expected + + n0, n1, n2 = 64, 56, 48 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float32) / float(n1) + x2 = torch.arange(n2, device=torch_device, dtype=torch.float32) / float(n2) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + vector_field = torch.stack( + ( + torch.sin(wave_number * yy), + torch.cos(wave_number * zz), + torch.sin(wave_number * xx), + ), + dim=0, + ) + spacing = (1.0 / float(n0), 1.0 / float(n1), 1.0 / float(n2)) + expected = torch.stack( + ( + wave_number * torch.sin(wave_number * zz), + -wave_number * torch.cos(wave_number * xx), + -wave_number * torch.cos(wave_number * yy), + ), + dim=0, + ) + return vector_field, spacing, expected + + +@pytest.mark.parametrize("dims", [2, 3]) +@pytest.mark.parametrize("order", [2, 4]) +def test_uniform_grid_curl_torch(device: str, dims: int, order: int): + vector_field, spacing, expected = _make_periodic_vector_field(device, dims) + output = UniformGridCurl.dispatch( + vector_field, + spacing=spacing, + order=order, + implementation="torch", + ) + torch.testing.assert_close(output, expected, atol=8e-2, rtol=8e-2) + + +def test_uniform_grid_curl_public_function(device: str): + vector_field, spacing, expected = _make_periodic_vector_field(device, dims=2) + output = uniform_grid_curl( + vector_field, + spacing=spacing, + order=2, + implementation="torch", + ) + torch.testing.assert_close(output, expected, atol=8e-2, rtol=8e-2) + + +@requires_module("warp") +def test_uniform_grid_curl_backend_forward_parity(device: str): + for _label, args, kwargs in UniformGridCurl.make_inputs_forward(device=device): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = UniformGridCurl.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_warp = UniformGridCurl.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + UniformGridCurl.compare_forward(out_warp, out_torch) + + +def test_uniform_grid_curl_compare_forward_contract(device: str): + vector_field, spacing, _expected = _make_periodic_vector_field(device, dims=2) + output = UniformGridCurl.dispatch( + vector_field, + spacing=spacing, + order=2, + implementation="torch", + ) + UniformGridCurl.compare_forward(output, output.detach().clone()) + + +@requires_module("warp") +def test_uniform_grid_curl_backend_backward_parity(device: str): + for _label, args, kwargs in UniformGridCurl.make_inputs_backward(device=device): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = UniformGridCurl.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_torch.square().mean().backward() + grad_torch = args_torch[0].grad + assert grad_torch is not None + + out_warp = UniformGridCurl.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + out_warp.square().mean().backward() + grad_warp = args_warp[0].grad + assert grad_warp is not None + + UniformGridCurl.compare_backward(grad_warp, grad_torch) + + +def test_uniform_grid_curl_compare_backward_contract(device: str): + vector_field, spacing, _expected = _make_periodic_vector_field(device, dims=2) + vector_field = vector_field.detach().clone().requires_grad_(True) + output = UniformGridCurl.dispatch( + vector_field, + spacing=spacing, + order=2, + implementation="torch", + ) + output.square().mean().backward() + assert vector_field.grad is not None + UniformGridCurl.compare_backward( + vector_field.grad, + vector_field.grad.detach().clone(), + ) + + +def test_uniform_grid_curl_error_handling(device: str): + with pytest.raises(TypeError, match="floating-point"): + UniformGridCurl.dispatch( + torch.ones((2, 8, 8), device=device, dtype=torch.int64), + implementation="torch", + ) + + with pytest.raises(ValueError, match="2D or 3D"): + UniformGridCurl.dispatch( + torch.ones((1, 8), device=device), + implementation="torch", + ) + + with pytest.raises(ValueError, match="shape\\[0\\]"): + UniformGridCurl.dispatch( + torch.ones((3, 8, 8), device=device), + implementation="torch", + ) + + +def test_uniform_grid_curl_make_inputs_forward(device: str): + label, args, kwargs = next(iter(UniformGridCurl.make_inputs_forward(device=device))) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + vector_field = args[0] + assert vector_field.ndim in (3, 4) + assert vector_field.shape[0] == vector_field.ndim - 1 + + output = UniformGridCurl.dispatch( + *args, + implementation="torch", + **kwargs, + ) + if vector_field.ndim == 3: + assert output.shape == vector_field.shape[1:] + else: + assert output.shape == vector_field.shape + + +def test_uniform_grid_curl_make_inputs_backward(device: str): + _label, args, kwargs = next( + iter(UniformGridCurl.make_inputs_backward(device=device)) + ) + vector_field = args[0] + assert vector_field.requires_grad + + output = UniformGridCurl.dispatch( + *args, + implementation="torch", + **kwargs, + ) + output.square().mean().backward() + assert vector_field.grad is not None diff --git a/test/nn/functional/derivatives/test_uniform_grid_divergence.py b/test/nn/functional/derivatives/test_uniform_grid_divergence.py new file mode 100644 index 0000000000..e4c770b983 --- /dev/null +++ b/test/nn/functional/derivatives/test_uniform_grid_divergence.py @@ -0,0 +1,227 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from physicsnemo.nn.functional import uniform_grid_divergence +from physicsnemo.nn.functional.derivatives import UniformGridDivergence +from test.conftest import requires_module +from test.nn.functional._parity_utils import clone_case + + +def _make_periodic_vector_field(device: str, dims: int): + torch_device = torch.device(device) + wave_number = 2.0 * torch.pi + + if dims == 1: + n0 = 384 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + vector_field = torch.sin(wave_number * x0).unsqueeze(0) + spacing = (1.0 / float(n0),) + expected = wave_number * torch.cos(wave_number * x0) + return vector_field, spacing, expected + + if dims == 2: + n0, n1 = 160, 144 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float32) / float(n1) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + vector_field = torch.stack( + ( + torch.sin(wave_number * xx), + torch.cos(wave_number * yy), + ), + dim=0, + ) + spacing = (1.0 / float(n0), 1.0 / float(n1)) + expected = wave_number * torch.cos(wave_number * xx) - wave_number * torch.sin( + wave_number * yy + ) + return vector_field, spacing, expected + + n0, n1, n2 = 64, 56, 48 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float32) / float(n1) + x2 = torch.arange(n2, device=torch_device, dtype=torch.float32) / float(n2) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + vector_field = torch.stack( + ( + torch.sin(wave_number * xx), + torch.cos(wave_number * yy), + 0.5 * torch.sin(wave_number * zz), + ), + dim=0, + ) + spacing = (1.0 / float(n0), 1.0 / float(n1), 1.0 / float(n2)) + expected = ( + wave_number * torch.cos(wave_number * xx) + - wave_number * torch.sin(wave_number * yy) + + 0.5 * wave_number * torch.cos(wave_number * zz) + ) + return vector_field, spacing, expected + + +@pytest.mark.parametrize("dims", [1, 2, 3]) +@pytest.mark.parametrize("order", [2, 4]) +def test_uniform_grid_divergence_torch(device: str, dims: int, order: int): + vector_field, spacing, expected = _make_periodic_vector_field(device, dims) + output = UniformGridDivergence.dispatch( + vector_field, + spacing=spacing, + order=order, + implementation="torch", + ) + torch.testing.assert_close(output, expected, atol=8e-2, rtol=8e-2) + + +def test_uniform_grid_divergence_public_function(device: str): + vector_field, spacing, expected = _make_periodic_vector_field(device, dims=2) + output = uniform_grid_divergence( + vector_field, + spacing=spacing, + order=2, + implementation="torch", + ) + torch.testing.assert_close(output, expected, atol=8e-2, rtol=8e-2) + + +@requires_module("warp") +def test_uniform_grid_divergence_backend_forward_parity(device: str): + for _label, args, kwargs in UniformGridDivergence.make_inputs_forward( + device=device + ): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = UniformGridDivergence.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_warp = UniformGridDivergence.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + UniformGridDivergence.compare_forward(out_warp, out_torch) + + +def test_uniform_grid_divergence_compare_forward_contract(device: str): + vector_field, spacing, _expected = _make_periodic_vector_field(device, dims=2) + output = UniformGridDivergence.dispatch( + vector_field, + spacing=spacing, + order=2, + implementation="torch", + ) + UniformGridDivergence.compare_forward(output, output.detach().clone()) + + +@requires_module("warp") +def test_uniform_grid_divergence_backend_backward_parity(device: str): + for _label, args, kwargs in UniformGridDivergence.make_inputs_backward( + device=device + ): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = UniformGridDivergence.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_torch.square().mean().backward() + grad_torch = args_torch[0].grad + assert grad_torch is not None + + out_warp = UniformGridDivergence.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + out_warp.square().mean().backward() + grad_warp = args_warp[0].grad + assert grad_warp is not None + + UniformGridDivergence.compare_backward(grad_warp, grad_torch) + + +def test_uniform_grid_divergence_compare_backward_contract(device: str): + vector_field, spacing, _expected = _make_periodic_vector_field(device, dims=2) + vector_field = vector_field.detach().clone().requires_grad_(True) + output = UniformGridDivergence.dispatch( + vector_field, + spacing=spacing, + order=2, + implementation="torch", + ) + output.square().mean().backward() + assert vector_field.grad is not None + UniformGridDivergence.compare_backward( + vector_field.grad, + vector_field.grad.detach().clone(), + ) + + +def test_uniform_grid_divergence_error_handling(device: str): + with pytest.raises(TypeError, match="floating-point"): + UniformGridDivergence.dispatch( + torch.ones((2, 8, 8), device=device, dtype=torch.int64), + implementation="torch", + ) + + with pytest.raises(ValueError, match="shape\\[0\\]"): + UniformGridDivergence.dispatch( + torch.ones((3, 8, 8), device=device), + implementation="torch", + ) + + +def test_uniform_grid_divergence_make_inputs_forward(device: str): + label, args, kwargs = next( + iter(UniformGridDivergence.make_inputs_forward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + vector_field = args[0] + assert vector_field.ndim in (2, 3, 4) + assert vector_field.shape[0] == vector_field.ndim - 1 + + output = UniformGridDivergence.dispatch( + *args, + implementation="torch", + **kwargs, + ) + assert output.shape == vector_field.shape[1:] + + +def test_uniform_grid_divergence_make_inputs_backward(device: str): + _label, args, kwargs = next( + iter(UniformGridDivergence.make_inputs_backward(device=device)) + ) + vector_field = args[0] + assert vector_field.requires_grad + + output = UniformGridDivergence.dispatch( + *args, + implementation="torch", + **kwargs, + ) + output.square().mean().backward() + assert vector_field.grad is not None diff --git a/test/nn/functional/derivatives/test_uniform_grid_laplacian.py b/test/nn/functional/derivatives/test_uniform_grid_laplacian.py new file mode 100644 index 0000000000..5a3d3fa75b --- /dev/null +++ b/test/nn/functional/derivatives/test_uniform_grid_laplacian.py @@ -0,0 +1,211 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from physicsnemo.nn.functional import uniform_grid_laplacian +from physicsnemo.nn.functional.derivatives import UniformGridLaplacian +from test.conftest import requires_module +from test.nn.functional._parity_utils import clone_case + + +def _make_periodic_scalar_field(device: str, dims: int): + torch_device = torch.device(device) + wave_number = 2.0 * torch.pi + + if dims == 1: + n0 = 384 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + field = torch.sin(wave_number * x0) + spacing = (1.0 / float(n0),) + expected = -(wave_number**2) * torch.sin(wave_number * x0) + return field, spacing, expected + + if dims == 2: + n0, n1 = 128, 112 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float32) / float(n1) + xx, yy = torch.meshgrid(x0, x1, indexing="ij") + field = torch.sin(wave_number * xx) + 0.5 * torch.cos(wave_number * yy) + spacing = (1.0 / float(n0), 1.0 / float(n1)) + expected = -(wave_number**2) * ( + torch.sin(wave_number * xx) + 0.5 * torch.cos(wave_number * yy) + ) + return field, spacing, expected + + n0, n1, n2 = 56, 48, 40 + x0 = torch.arange(n0, device=torch_device, dtype=torch.float32) / float(n0) + x1 = torch.arange(n1, device=torch_device, dtype=torch.float32) / float(n1) + x2 = torch.arange(n2, device=torch_device, dtype=torch.float32) / float(n2) + xx, yy, zz = torch.meshgrid(x0, x1, x2, indexing="ij") + field = ( + torch.sin(wave_number * xx) + + 0.5 * torch.cos(wave_number * yy) + + 0.25 * torch.sin(wave_number * zz) + ) + spacing = (1.0 / float(n0), 1.0 / float(n1), 1.0 / float(n2)) + expected = -(wave_number**2) * field + return field, spacing, expected + + +@pytest.mark.parametrize("dims", [1, 2, 3]) +@pytest.mark.parametrize("order", [2, 4]) +def test_uniform_grid_laplacian_torch(device: str, dims: int, order: int): + field, spacing, expected = _make_periodic_scalar_field(device, dims) + output = UniformGridLaplacian.dispatch( + field, + spacing=spacing, + order=order, + implementation="torch", + ) + torch.testing.assert_close(output, expected, atol=2e-1, rtol=8e-2) + + +def test_uniform_grid_laplacian_public_function(device: str): + field, spacing, expected = _make_periodic_scalar_field(device, dims=2) + output = uniform_grid_laplacian( + field, + spacing=spacing, + order=2, + implementation="torch", + ) + torch.testing.assert_close(output, expected, atol=2e-1, rtol=8e-2) + + +@requires_module("warp") +def test_uniform_grid_laplacian_backend_forward_parity(device: str): + for _label, args, kwargs in UniformGridLaplacian.make_inputs_forward(device=device): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = UniformGridLaplacian.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_warp = UniformGridLaplacian.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + UniformGridLaplacian.compare_forward(out_warp, out_torch) + + +def test_uniform_grid_laplacian_compare_forward_contract(device: str): + field, spacing, _expected = _make_periodic_scalar_field(device, dims=2) + output = UniformGridLaplacian.dispatch( + field, + spacing=spacing, + order=2, + implementation="torch", + ) + UniformGridLaplacian.compare_forward(output, output.detach().clone()) + + +@requires_module("warp") +def test_uniform_grid_laplacian_backend_backward_parity(device: str): + for _label, args, kwargs in UniformGridLaplacian.make_inputs_backward( + device=device + ): + args_torch, kwargs_torch = clone_case(args, kwargs) + args_warp, kwargs_warp = clone_case(args, kwargs) + + out_torch = UniformGridLaplacian.dispatch( + *args_torch, + implementation="torch", + **kwargs_torch, + ) + out_torch.square().mean().backward() + grad_torch = args_torch[0].grad + assert grad_torch is not None + + out_warp = UniformGridLaplacian.dispatch( + *args_warp, + implementation="warp", + **kwargs_warp, + ) + out_warp.square().mean().backward() + grad_warp = args_warp[0].grad + assert grad_warp is not None + + UniformGridLaplacian.compare_backward(grad_warp, grad_torch) + + +def test_uniform_grid_laplacian_compare_backward_contract(device: str): + field, spacing, _expected = _make_periodic_scalar_field(device, dims=2) + field = field.detach().clone().requires_grad_(True) + output = UniformGridLaplacian.dispatch( + field, + spacing=spacing, + order=2, + implementation="torch", + ) + output.square().mean().backward() + assert field.grad is not None + UniformGridLaplacian.compare_backward( + field.grad, + field.grad.detach().clone(), + ) + + +def test_uniform_grid_laplacian_error_handling(device: str): + with pytest.raises(TypeError, match="floating-point"): + UniformGridLaplacian.dispatch( + torch.ones((8, 8), device=device, dtype=torch.int64), + implementation="torch", + ) + + with pytest.raises(ValueError, match="1D-3D"): + UniformGridLaplacian.dispatch( + torch.ones((2, 2, 2, 2), device=device), + implementation="torch", + ) + + +def test_uniform_grid_laplacian_make_inputs_forward(device: str): + label, args, kwargs = next( + iter(UniformGridLaplacian.make_inputs_forward(device=device)) + ) + assert isinstance(label, str) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + + field = args[0] + assert field.ndim in (1, 2, 3) + + output = UniformGridLaplacian.dispatch( + *args, + implementation="torch", + **kwargs, + ) + assert output.shape == field.shape + + +def test_uniform_grid_laplacian_make_inputs_backward(device: str): + _label, args, kwargs = next( + iter(UniformGridLaplacian.make_inputs_backward(device=device)) + ) + field = args[0] + assert field.requires_grad + + output = UniformGridLaplacian.dispatch( + *args, + implementation="torch", + **kwargs, + ) + output.square().mean().backward() + assert field.grad is not None