diff --git a/python/metatomic_ase/tests/symmetrized.py b/python/metatomic_ase/tests/symmetrized.py index 79aac2eca..90239a236 100644 --- a/python/metatomic_ase/tests/symmetrized.py +++ b/python/metatomic_ase/tests/symmetrized.py @@ -1,4 +1,6 @@ +from pathlib import Path from typing import Dict, List, Optional, Tuple, Union +import warnings import numpy as np import pytest @@ -26,6 +28,13 @@ ) +REAL_CHECKPOINT = ( + Path(__file__).resolve().parents[3] + / "SYMMOD_EXAMPLE" + / "pet-mad-xs-v1.5.0.ckpt" +) + + def _body_axis_from_system(system: System) -> torch.Tensor: """ Return the normalized vector connecting the two farthest atoms. @@ -593,3 +602,142 @@ def test_space_group_average_non_periodic(): # Forces must be unchanged F_pg = out["forces"] assert np.allclose(F_pg, forces) + + +@pytest.mark.skipif( + not REAL_CHECKPOINT.is_file(), + reason="requires local SYMMOD_EXAMPLE checkpoint", +) +def test_real_checkpoint_matches_symmetrized_model_with_aligned_quadrature(capfd): + pytest.importorskip("metatrain.utils.io") + pytest.importorskip("metatrain.utils.neighbor_lists") + + from metatrain.utils.io import load_model + from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists + from metatomic.torch import systems_to_torch + from metatomic.torch.symmetrized_model import SymmetrizedModel + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=r"the 'features' output name is deprecated.*", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"the 'non_conservative_forces' output name is deprecated.*", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"`per_atom` is deprecated, please use `sample_kind` instead.*", + category=DeprecationWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"Lebedev order may be insufficient for character projections\.", + category=UserWarning, + ) + + model = load_model(REAL_CHECKPOINT) + model.eval() + model = model.to(dtype=torch.float64, device="cpu") + exported = model.export() + + atoms = bulk("Si", cubic=True) + atoms.rattle(0.1, seed=0) + + symm_model = SymmetrizedModel( + model, + max_o3_lambda_grid=3, + max_o3_lambda_target=2, + max_o3_lambda_character=2, + batch_size=1, + ).to(device="cpu", dtype=torch.float64) + outputs = { + "energy": ModelOutput(sample_kind="system"), + "non_conservative_forces": ModelOutput(sample_kind="atom"), + "non_conservative_stress": ModelOutput(sample_kind="system"), + } + systems = systems_to_torch([atoms], device="cpu", dtype=torch.float64) + systems = [ + get_system_with_neighbor_lists( + system, model.model.requested_neighbor_lists() + ) + for system in systems + ] + low_level = symm_model(systems, outputs) + + base_calculator = MetatomicCalculator( + exported, + non_conservative=True, + do_gradients_with_energy=False, + ) + ase_calculator = SymmetrizedCalculator( + base_calculator, + l_max=3, + batch_size=1, + include_inversion=True, + store_rotational_std=True, + ) + ase_calculator.quadrature_rotations = np.concatenate( + [symm_model.so3_rotations.numpy(), (-symm_model.so3_rotations).numpy()], + axis=0, + ) + ase_calculator.quadrature_weights = np.concatenate( + [ + 0.5 * symm_model.so3_weights.numpy(), + 0.5 * symm_model.so3_weights.numpy(), + ], + axis=0, + ) + ase_calculator.batch_size = 1 + + atoms.calc = ase_calculator + ase_energy = atoms.get_potential_energy() + ase_forces = atoms.get_forces() + ase_stress = atoms.get_stress(voigt=False) + + captured = capfd.readouterr() + assert captured.out == "" + allowed_stderr_fragments = [ + "`per_atom` is deprecated, please use `sample_kind` instead", + "output 'energy' has an empty unit. Consider adding a unit to ensure correct unit conversion.", + "ModelOutput.quantity is deprecated and will be removed in a future version", + "the 'features' quantity is deprecated, please update this code to use 'feature' instead.", + ] + unexpected_stderr = [ + line + for line in captured.err.splitlines() + if line != "" + and not any(fragment in line for fragment in allowed_stderr_fragments) + ] + assert unexpected_stderr == [] + + low_energy = low_level["energy_l0_mean"].block().values.squeeze().item() + low_forces = ( + low_level["non_conservative_forces_l1_mean"] + .block() + .values.roll(1, 1) + .squeeze(-1) + .numpy() + ) + low_forces = low_forces - low_forces.mean(axis=0, keepdims=True) + + l0 = low_level["non_conservative_stress_l0_mean"].block().values.squeeze().item() + l2 = low_level["non_conservative_stress_l2_mean"].block().values.squeeze().numpy() + low_stress = np.zeros((3, 3), dtype=np.float64) + low_stress[0, 1] = l2[0] + low_stress[1, 0] = l2[0] + low_stress[1, 2] = l2[1] + low_stress[2, 1] = l2[1] + low_stress[0, 2] = l2[3] + low_stress[2, 0] = l2[3] + low_stress[0, 0] = l2[4] - l2[2] * np.sqrt(3.0) / 3.0 + low_stress[1, 1] = -l2[4] - l2[2] * np.sqrt(3.0) / 3.0 + low_stress[2, 2] = l2[2] * 2.0 * np.sqrt(3.0) / 3.0 + low_stress += np.eye(3) * (l0 / 3.0) + + assert np.isclose(ase_energy, low_energy, atol=1e-12) + assert np.allclose(ase_forces, low_forces, atol=1e-12) + assert np.allclose(ase_stress, low_stress, atol=1e-12) diff --git a/python/metatomic_torch/metatomic/torch/__init__.py b/python/metatomic_torch/metatomic/torch/__init__.py index a8bf363aa..f41c4e5ee 100644 --- a/python/metatomic_torch/metatomic/torch/__init__.py +++ b/python/metatomic_torch/metatomic/torch/__init__.py @@ -65,3 +65,15 @@ save_buffer, ) from .systems_to_torch import systems_to_torch # noqa: F401 + + +def __getattr__(name): + # lazy import for ase_calculator, making it accessible as + # ``metatomic.torch.ase_calculator`` without requiring a separate import from + # ``metatomic.torch``, but only importing the code when actually required. + if name == "ase_calculator": + import metatomic.torch.ase_calculator + + return metatomic.torch.ase_calculator + else: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/metatomic_torch/metatomic/torch/_augmentation.py b/python/metatomic_torch/metatomic/torch/_augmentation.py new file mode 100644 index 000000000..bdb9a2fce --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/_augmentation.py @@ -0,0 +1,313 @@ +from typing import Dict, List, Optional, Tuple + +import metatensor.torch as mts +import torch +from metatensor.torch import TensorBlock, TensorMap + +from . import System, register_autograd_neighbors + + +def _block_row_indices_by_system( + block: TensorBlock, + n_systems: int, +) -> List[torch.Tensor]: + if "system" not in block.samples.names: + if n_systems == 1: + return [torch.arange(block.values.shape[0], device=block.values.device)] + raise ValueError( + "Rotational augmentation expects output samples to include a 'system' " + "dimension when transforming multiple systems." + ) + + system_ids = block.samples.column("system").to(dtype=torch.long) + if len(system_ids) != 0: + min_system_id = int(torch.min(system_ids).item()) + max_system_id = int(torch.max(system_ids).item()) + if min_system_id < 0 or max_system_id >= n_systems: + raise ValueError( + "Encountered output samples with out-of-range system indices." + ) + + return [ + torch.nonzero(system_ids == system_index, as_tuple=False).reshape(-1) + for system_index in range(n_systems) + ] + + +def _apply_wigner_D_matrices( + systems: List[System], + target_tmap: TensorMap, + transformations: List[torch.Tensor], + wigner_D_matrices: Dict[int, List[torch.Tensor]], +) -> TensorMap: + new_blocks: List[TensorBlock] = [] + for key, block in target_tmap.items(): + values = block.values + row_indices = _block_row_indices_by_system(block, len(systems)) + + new_values = values.clone() + rank = len(block.components) + if rank == 1: + ell, sigma = int(key["o3_lambda"]), int(key["o3_sigma"]) + for rows, transformation, wigner_D_matrix in zip( + row_indices, transformations, wigner_D_matrices[ell], strict=True + ): + is_inverted = torch.det(transformation) < 0 + new_v = values[rows].clone() + if is_inverted: + new_v = new_v * (-1) ** ell * sigma + new_v = new_v.transpose(1, 2) + new_v = new_v @ wigner_D_matrix.T + new_v = new_v.transpose(1, 2) + new_values[rows] = new_v + elif rank == 2: + ell1, ell2, sigma1, sigma2 = ( + int(key["o3_lambda_1"]), + int(key["o3_lambda_2"]), + int(key["o3_sigma_1"]), + int(key["o3_sigma_2"]), + ) + for rows, transformation, wigner_D_matrix1, wigner_D_matrix2 in zip( + row_indices, + transformations, + wigner_D_matrices[ell1], + wigner_D_matrices[ell2], + strict=True, + ): + is_inverted = torch.det(transformation) < 0 + new_v = values[rows].clone() + if is_inverted: + new_v = new_v * (-1) ** ell1 * sigma1 * (-1) ** ell2 * sigma2 + new_v = torch.einsum( + "Aa,iabp,bB->iABp", wigner_D_matrix1, new_v, wigner_D_matrix2.T + ) + new_values[rows] = new_v + else: + raise ValueError( + f"Unsupported spherical tensor rank {rank} in augmentation helper." + ) + new_blocks.append( + TensorBlock( + values=new_values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ) + + return TensorMap(keys=target_tmap.keys, blocks=new_blocks) + + +def _apply_augmentations( + systems: List[System], + targets: Dict[str, TensorMap], + transformations: List[torch.Tensor], + wigner_D_matrices: Dict[int, List[torch.Tensor]], + extra_data: Optional[Dict[str, TensorMap]] = None, +) -> Tuple[List[System], Dict[str, TensorMap], Dict[str, TensorMap]]: + new_systems: List[System] = [] + for system, transformation in zip(systems, transformations, strict=True): + new_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for data_name in system.known_data(): + data = system.get_data(data_name) + if len(data) != 1: + raise ValueError( + f"System data '{data_name}' has {len(data)} blocks, which is not " + "supported. Only scalar and vector data are supported." + ) + if len(data.block().components) == 0: + new_system.add_data(data_name, data) + elif len(data.block().components) == 1 and data.block().components[ + 0 + ].names == ["xyz"]: + new_system.add_data( + data_name, + TensorMap( + keys=data.keys, + blocks=[ + TensorBlock( + values=( + data.block().values.swapaxes(-1, -2) + @ transformation.T + ).swapaxes(-1, -2), + samples=data.block().samples, + components=data.block().components, + properties=data.block().properties, + ) + ], + ), + ) + else: + raise ValueError( + f"System data '{data_name}' has components {data.block().components}, " + "which are not supported. Only scalar and vector data are supported." + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + register_autograd_neighbors(new_system, neighbors) + new_system.add_neighbor_list(options, neighbors) + new_systems.append(new_system) + + new_targets: Dict[str, TensorMap] = {} + new_extra_data: Dict[str, TensorMap] = {} + + # Build a non-mask view of extra_data without mutating the caller's dict; + # mask entries are passed through unchanged. + remaining_extra_data: Optional[Dict[str, TensorMap]] = None + if extra_data is not None: + remaining_extra_data = {} + for key, value in extra_data.items(): + if key.endswith("_mask"): + new_extra_data[key] = value + else: + remaining_extra_data[key] = value + + for tensormap_dict, new_dict in zip( + [targets, remaining_extra_data], [new_targets, new_extra_data], strict=True + ): + if tensormap_dict is None: + continue + for name, original_tmap in tensormap_dict.items(): + is_scalar = False + if ( + len(original_tmap.blocks()) == 1 + and len(original_tmap.block().components) == 0 + ): + is_scalar = True + + cartesian_component_names = {"xyz", "xyz_1", "xyz_2"} + is_cartesian = len(original_tmap.blocks()) > 0 and all( + len(block.components) > 0 + and block.components[0].names[0] in cartesian_component_names + for block in original_tmap.blocks() + ) + + is_spherical = len(original_tmap.blocks()) > 0 and ( + all( + len(block.components) == 1 + and block.components[0].names == ["o3_mu"] + for block in original_tmap.blocks() + ) + or all( + len(block.components) == 2 + and block.components[0].names == ["o3_mu_1"] + and block.components[1].names == ["o3_mu_2"] + for block in original_tmap.blocks() + ) + ) + + if is_scalar: + energy_block = TensorBlock( + values=original_tmap.block().values, + samples=original_tmap.block().samples, + components=original_tmap.block().components, + properties=original_tmap.block().properties, + ) + if original_tmap.block().has_gradient("positions"): + block = original_tmap.block().gradient("positions") + position_gradients = block.values.squeeze(-1) + split_sizes = [system.positions.shape[0] for system in systems] + split_position_gradients = torch.split( + position_gradients, split_sizes + ) + position_gradients = torch.cat( + [ + split_position_gradients[i] @ transformations[i].T + for i in range(len(systems)) + ] + ) + energy_block.add_gradient( + "positions", + TensorBlock( + values=position_gradients.unsqueeze(-1), + samples=block.samples, + components=block.components, + properties=block.properties, + ), + ) + if original_tmap.block().has_gradient("strain"): + block = original_tmap.block().gradient("strain") + strain_values = block.values.squeeze(-1) # (n_rows, 3, 3) + row_indices = _block_row_indices_by_system(block, len(systems)) + new_strain_values = strain_values.clone() + for i, rows in enumerate(row_indices): + new_strain_values[rows] = torch.einsum( + "Aa,iab,bB->iAB", + transformations[i], + strain_values[rows], + transformations[i].T, + ) + energy_block.add_gradient( + "strain", + TensorBlock( + values=new_strain_values.unsqueeze(-1), + samples=block.samples, + components=block.components, + properties=block.properties, + ), + ) + new_dict[name] = TensorMap( + keys=original_tmap.keys, blocks=[energy_block] + ) + + elif is_spherical: + new_dict[name] = _apply_wigner_D_matrices( + systems, original_tmap, transformations, wigner_D_matrices + ) + + elif is_cartesian: + new_blocks = [] + for block in original_tmap.blocks(): + rank = len(block.components) + row_indices = _block_row_indices_by_system(block, len(systems)) + if rank == 1: + new_values = block.values.clone() + for rows, transformation in zip( + row_indices, transformations, strict=True + ): + v = block.values[rows].clone() + new_v = v.transpose(1, 2) + new_v = new_v @ transformation.T + new_v = new_v.transpose(1, 2) + new_values[rows] = new_v + elif rank == 2: + new_values = block.values.clone() + for rows, transformation in zip( + row_indices, transformations, strict=True + ): + tensor_i = block.values[rows].clone() + new_values[rows] = torch.einsum( + "Aa,iabp,bB->iABp", + transformation, + tensor_i, + transformation.T, + ) + else: + raise ValueError( + f"Unsupported Cartesian tensor rank {rank} in augmentation helper." + ) + new_blocks.append( + TensorBlock( + values=new_values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ) + new_dict[name] = TensorMap( + keys=original_tmap.keys, blocks=new_blocks + ) + else: + raise ValueError( + f"TensorMap '{name}' is neither scalar, Cartesian, nor spherical in the supported format." + ) + + return new_systems, new_targets, new_extra_data diff --git a/python/metatomic_torch/metatomic/torch/_jit_compat.py b/python/metatomic_torch/metatomic/torch/_jit_compat.py new file mode 100644 index 000000000..d6466203f --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/_jit_compat.py @@ -0,0 +1,13 @@ +import functools + + +def _identity_decorator(func): + return func + + +try: + import numba as _numba +except ImportError: # pragma: no cover + jit = _identity_decorator +else: + jit = functools.partial(_numba.njit, cache=True) diff --git a/python/metatomic_torch/metatomic/torch/_wigner.py b/python/metatomic_torch/metatomic/torch/_wigner.py new file mode 100644 index 000000000..f4c842413 --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/_wigner.py @@ -0,0 +1,484 @@ +"""Private Wigner-d/Wigner-D helpers for symmetry operations. + +Adapted from the `spherical` project (MIT license), primarily from +`spherical/recursions/wignerH.py`, `spherical/utilities/indexing.py`, and the +Wigner-D assembly logic in `spherical/wigner.py`. + +This reduced metatomic copy keeps only the recurrence-based pieces needed to build +real Wigner-D matrices from ZYZ Euler angles. It intentionally does not depend on +`spinsfast`, `quaternionic`, or the public `spherical` package. +""" + +import functools +from typing import Dict, Tuple + +import numpy as np +import torch + +from ._jit_compat import jit + + +@jit +def _epsilon(m: int) -> int: + if m <= 0: + return 1 + if m % 2: + return -1 + return 1 + + +@jit +def _nm_index(n: int, m: int) -> int: + return m + n * (n + 1) + + +@jit +def _nabsm_index(n: int, absm: int) -> int: + return absm + (n * (n + 1)) // 2 + + +@jit +def _wigner_h_size(mp_max: int, ell_max: int) -> int: + if ell_max < 0: + return 0 + if mp_max >= ell_max: + return (ell_max + 1) * (ell_max + 2) * (2 * ell_max + 3) // 6 + + return ( + (ell_max + 1) * (ell_max + 2) * (2 * ell_max + 3) + - 2 * (ell_max - mp_max) * (ell_max - mp_max + 1) * (ell_max - mp_max + 2) + ) // 6 + + +@jit +def _wigner_d_size(ell_min: int, mp_max: int, ell_max: int) -> int: + if mp_max >= ell_max: + return ( + ell_max * (ell_max * (4 * ell_max + 12) + 11) + + ell_min * (1 - 4 * ell_min**2) + + 3 + ) // 3 + if mp_max > ell_min: + return ( + 3 * ell_max * (ell_max + 2) + + ell_min * (1 - 4 * ell_min**2) + + mp_max + * (3 * ell_max * (2 * ell_max + 4) + mp_max * (-2 * mp_max - 3) + 5) + + 3 + ) // 3 + + return (ell_max * (ell_max + 2) - ell_min**2) * (1 + 2 * mp_max) + 2 * mp_max + 1 + + +@jit +def _wigner_h_index_base(ell: int, mp: int, m: int, mp_max: int) -> int: + local_mp_max = mp_max + if local_mp_max > ell: + local_mp_max = ell + idx = _wigner_h_size(local_mp_max, ell - 1) + if mp < 1: + idx += (local_mp_max + mp) * (2 * ell - local_mp_max + mp + 1) // 2 + else: + idx += (local_mp_max + 1) * (2 * ell - local_mp_max + 2) // 2 + idx += (mp - 1) * (2 * ell - mp + 2) // 2 + idx += m - abs(mp) + return idx + + +@jit +def _wigner_h_index(ell: int, mp: int, m: int, mp_max: int) -> int: + if ell == 0: + return 0 + + local_mp_max = mp_max + if local_mp_max > ell: + local_mp_max = ell + + if m < -mp: + if m < mp: + return _wigner_h_index_base(ell, -mp, -m, local_mp_max) + return _wigner_h_index_base(ell, -m, -mp, local_mp_max) + + if m < mp: + return _wigner_h_index_base(ell, m, mp, local_mp_max) + return _wigner_h_index_base(ell, mp, m, local_mp_max) + + +@jit +def _wigner_d_index(ell: int, mp: int, m: int, ell_min: int, mp_max: int) -> int: + idx = 0 + for ell_prev in range(ell_min, ell): + local_mp_max = mp_max if mp_max < ell_prev else ell_prev + idx += (2 * local_mp_max + 1) * (2 * ell_prev + 1) + + local_mp_max = mp_max if mp_max < ell else ell + idx += (mp + local_mp_max) * (2 * ell + 1) + idx += m + ell + return idx + + +@jit +def _step_1(hwedge): + hwedge[0] = 1.0 + + +@jit +def _step_2(g, h, n_max, mp_max, hwedge, hextra, hv, expi_beta): + cos_beta = expi_beta.real + sin_beta = expi_beta.imag + sqrt3 = np.sqrt(3.0) + inverse_sqrt2 = 1.0 / np.sqrt(2.0) + if n_max > 0: + n0n_index = _wigner_h_index(1, 0, 1, mp_max) + nn_index = _nm_index(1, 1) + hwedge[n0n_index] = sqrt3 + hwedge[n0n_index - 1] = (g[nn_index - 1] * cos_beta) * inverse_sqrt2 + for n in range(2, n_max + 2): + if n <= n_max: + n0n_index = _wigner_h_index(n, 0, n, mp_max) + out = hwedge + else: + n0n_index = n + out = hextra + prev_index = _wigner_h_index(n - 1, 0, n - 1, mp_max) + nn_index = _nm_index(n, n) + const = np.sqrt(1.0 + 0.5 / n) + g_i = g[nn_index - 1] + out[n0n_index] = const * hwedge[prev_index] + out[n0n_index - 1] = g_i * cos_beta * out[n0n_index] + for i in range(2, n): + g_i = g[nn_index - i] + h_i = h[nn_index - i] + out[n0n_index - i] = ( + g_i * cos_beta * out[n0n_index - i + 1] + - h_i * sin_beta**2 * out[n0n_index - i + 2] + ) + const = 1.0 / np.sqrt(4 * n + 2) + g_i = g[nn_index - n] + h_i = h[nn_index - n] + out[n0n_index - n] = ( + g_i * cos_beta * out[n0n_index - n + 1] + - h_i * sin_beta**2 * out[n0n_index - n + 2] + ) * const + prefactor = const + for i in range(1, n): + prefactor *= sin_beta + out[n0n_index - n + i] *= prefactor + if n <= n_max: + hv[_nm_index(n, 1)] = hwedge[_wigner_h_index(n, 0, 1, mp_max)] + hv[_nm_index(n, 0)] = hwedge[_wigner_h_index(n, 0, 1, mp_max)] + prefactor = 1.0 + for n in range(1, n_max + 1): + prefactor *= sin_beta + hwedge[_wigner_h_index(n, 0, n, mp_max)] *= prefactor / np.sqrt(4 * n + 2) + prefactor *= sin_beta + hextra[n_max + 1] *= prefactor / np.sqrt(4 * (n_max + 1) + 2) + hv[_nm_index(1, 1)] = hwedge[_wigner_h_index(1, 0, 1, mp_max)] + hv[_nm_index(1, 0)] = hwedge[_wigner_h_index(1, 0, 1, mp_max)] + + +@jit +def _step_3(a, b, n_max, mp_max, hwedge, hextra, expi_beta): + cos_beta = expi_beta.real + sin_beta = expi_beta.imag + if n_max > 0 and mp_max > 0: + for n in range(1, n_max + 1): + i1 = _wigner_h_index(n, 1, 1, mp_max) + if n + 1 <= n_max: + i2 = _wigner_h_index(n + 1, 0, 0, mp_max) + h2 = hwedge + else: + i2 = 0 + h2 = hextra + i3 = _nm_index(n + 1, 0) + i4 = _nabsm_index(n, 1) + inverse_b5 = 1.0 / b[i3] + for i in range(n): + b6 = b[-i + i3 - 2] + b7 = b[i + i3] + a8 = a[i + i4] + hwedge[i + i1] = inverse_b5 * ( + 0.5 + * ( + b6 * (1 - cos_beta) * h2[i + i2 + 2] + - b7 * (1 + cos_beta) * h2[i + i2] + ) + - a8 * sin_beta * h2[i + i2 + 1] + ) + + +@jit +def _step_4(d, n_max, mp_max, hwedge, hv): + if n_max > 0 and mp_max > 0: + for n in range(2, n_max + 1): + for mp in range(1, min(n, mp_max)): + i1 = _wigner_h_index(n, mp + 1, mp + 1, mp_max) - 1 + i2 = _wigner_h_index(n, mp - 1, mp, mp_max) + i3 = _wigner_h_index(n, mp, mp, mp_max) - 1 + i4 = _wigner_h_index(n, mp, mp + 1, mp_max) + i5 = _nm_index(n, mp) + i6 = _nm_index(n, mp - 1) + inverse_d5 = 1.0 / d[i5] + d6 = d[i6] + hv[_nm_index(n, mp + 1)] = inverse_d5 * ( + d6 * hwedge[i2] - d[i6] * hv[_nm_index(n, mp)] + d[i5] * hwedge[i4] + ) + for i in range(1, n - mp): + d7 = d[i + i6] + d8 = d[i + i5] + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] - d7 * hwedge[i + i3] + d8 * hwedge[i + i4] + ) + i = n - mp + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] - d[i + i6] * hwedge[i + i3] + ) + + +@jit +def _step_5(d, n_max, mp_max, hwedge, hv): + if n_max > 0 and mp_max > 0: + for n in range(0, n_max + 1): + for mp in range(0, -min(n, mp_max), -1): + i1 = _wigner_h_index(n, mp - 1, -mp + 1, mp_max) - 1 + i2 = _wigner_h_index(n, mp + 1, -mp + 1, mp_max) - 1 + i3 = _wigner_h_index(n, mp, -mp, mp_max) - 1 + i4 = _wigner_h_index(n, mp, -mp + 1, mp_max) + i5 = _nm_index(n, mp - 1) + i6 = _nm_index(n, mp) + i7 = _nm_index(n, -mp - 1) + i8 = _nm_index(n, -mp) + inverse_d5 = 1.0 / d[i5] + d6 = d[i6] + d7 = d[i7] + d8 = d[i8] + if mp == 0: + hv[_nm_index(n, mp - 1)] = inverse_d5 * ( + d6 * hv[_nm_index(n, mp + 1)] + + d7 * hv[_nm_index(n, mp)] + - d8 * hwedge[i4] + ) + else: + hv[_nm_index(n, mp - 1)] = inverse_d5 * ( + d6 * hwedge[i2] + d7 * hv[_nm_index(n, mp)] - d8 * hwedge[i4] + ) + for i in range(1, n + mp): + d7 = d[i + i7] + d8 = d[i + i8] + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] + d7 * hwedge[i + i3] - d8 * hwedge[i + i4] + ) + i = n + mp + hwedge[i + i1] = inverse_d5 * ( + d6 * hwedge[i + i2] + d[i + i7] * hwedge[i + i3] + ) + + +def _create_wigner_coefficients(ell_max: int): + n = np.array([n for n in range(ell_max + 2) for _ in range(-n, n + 1)]) + m = np.array([m for n in range(ell_max + 2) for m in range(-n, n + 1)]) + absn = np.array([n for n in range(ell_max + 2) for _ in range(n + 1)]) + absm = np.array([m for n in range(ell_max + 2) for m in range(n + 1)]) + + a = np.sqrt( + (absn + 1 + absm) * (absn + 1 - absm) / ((2 * absn + 1) * (2 * absn + 3)) + ) + b = np.sqrt((n - m - 1) * (n - m) / ((2 * n - 1) * (2 * n + 1))) + b[m < 0] *= -1 + d = 0.5 * np.sqrt((n - m) * (n + m + 1)) + d[m < 0] *= -1 + with np.errstate(divide="ignore", invalid="ignore"): + g = 2 * (m + 1) / np.sqrt((n - m) * (n + m + 1)) + h = np.sqrt((n + m + 2) * (n - m - 1) / ((n - m) * (n + m + 1))) + return a, b, d, g, h + + +def _complex_powers(z: complex, ell_max: int) -> np.ndarray: + powers = np.empty(ell_max + 1, dtype=np.complex128) + powers[0] = 1.0 + 0.0j + for idx in range(1, ell_max + 1): + powers[idx] = powers[idx - 1] * z + return powers + + +def _to_euler_phases( + alpha: float, beta: float, gamma: float +) -> Tuple[complex, complex, complex]: + # Match spherical.Wigner's convention after converting scipy's ZYZ Euler angles + # into the phases used by the recurrence. + z_alpha = np.exp(-1j * alpha) + expi_beta = np.exp(1j * beta) + z_gamma = np.exp(-1j * gamma) + return z_alpha, expi_beta, z_gamma + + +def _compute_wigner_d_complex( + ell_max: int, alpha: np.ndarray, beta: np.ndarray, gamma: np.ndarray +) -> np.ndarray: + if not (alpha.shape == beta.shape == gamma.shape): + raise ValueError("alpha, beta, and gamma must have identical shapes") + + mp_max = ell_max + a, b, d, g, h = _create_wigner_coefficients(ell_max) + hsize = _wigner_h_size(mp_max, ell_max) + dsize = _wigner_d_size(0, mp_max, ell_max) + result = np.zeros(alpha.shape + (dsize,), dtype=np.complex128) + + for index in np.ndindex(alpha.shape): + z_alpha, expi_beta, z_gamma = _to_euler_phases( + float(alpha[index]), float(beta[index]), float(gamma[index]) + ) + hwedge = np.zeros(hsize, dtype=np.float64) + hv = np.zeros((ell_max + 1) ** 2, dtype=np.float64) + hextra = np.zeros(ell_max + 2, dtype=np.float64) + + _step_1(hwedge) + _step_2(g, h, ell_max, mp_max, hwedge, hextra, hv, expi_beta) + _step_3(a, b, ell_max, mp_max, hwedge, hextra, expi_beta) + _step_4(d, ell_max, mp_max, hwedge, hv) + _step_5(d, ell_max, mp_max, hwedge, hv) + + z_alpha_powers = _complex_powers(z_alpha, ell_max) + z_gamma_powers = _complex_powers(z_gamma, ell_max) + out = result[index] + for ell in range(0, ell_max + 1): + for mp in range(-ell, 0): + i_d = _wigner_d_index(ell, mp, -ell, 0, mp_max) + for m in range(-ell, 0): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[-m].conjugate() + * z_alpha_powers[-mp].conjugate() + ) + i_d += 1 + for m in range(0, ell + 1): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[m] + * z_alpha_powers[-mp].conjugate() + ) + i_d += 1 + for mp in range(0, ell + 1): + i_d = _wigner_d_index(ell, mp, -ell, 0, mp_max) + for m in range(-ell, 0): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[-m].conjugate() + * z_alpha_powers[mp] + ) + i_d += 1 + for m in range(0, ell + 1): + i_h = _wigner_h_index(ell, mp, m, mp_max) + out[i_d] = ( + _epsilon(mp) + * _epsilon(-m) + * hwedge[i_h] + * z_gamma_powers[m] + * z_alpha_powers[mp] + ) + i_d += 1 + + return result + + +def compute_complex_wigner_d_matrices( + ell_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], +) -> Dict[int, np.ndarray]: + alpha, beta, gamma = angles + raw = _compute_wigner_d_complex(ell_max, alpha, beta, gamma) + matrices: Dict[int, np.ndarray] = {} + for ell in range(ell_max + 1): + shape = alpha.shape + (2 * ell + 1, 2 * ell + 1) + block = np.zeros(shape, dtype=np.complex128) + for mp in range(-ell, ell + 1): + for m in range(-ell, ell + 1): + block[..., mp + ell, m + ell] = raw[ + ..., _wigner_d_index(ell, mp, m, 0, ell_max) + ] + matrices[ell] = block + return matrices + + +def compute_real_wigner_d_matrices( + ell_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], + complex_to_real: Dict[int, np.ndarray], +) -> Dict[int, torch.Tensor]: + complex_matrices = compute_complex_wigner_d_matrices(ell_max, angles) + real_matrices: Dict[int, torch.Tensor] = {} + for ell, matrix in complex_matrices.items(): + transform = complex_to_real[ell] + matrix = np.einsum("ij,...jk,kl->...il", transform.conj(), matrix, transform.T) + # Recursion accumulates floating-point noise that grows with ell, so scale + # the tolerance by the magnitude of the matrix instead of using a fixed atol. + scale = float(np.max(np.abs(matrix.real))) if matrix.size else 1.0 + atol = max(1e-9, scale * 1e-10) + if not np.allclose(matrix.imag, 0.0, atol=atol): + raise ValueError("real Wigner matrix conversion produced complex values") + real_matrices[ell] = torch.from_numpy(matrix.real) + return real_matrices + + +@functools.lru_cache(maxsize=None) +def _complex_to_real_spherical_harmonics_transform(ell: int) -> np.ndarray: + """ + Generate the transformation matrix from complex spherical harmonics + to real spherical harmonics for a given l. + Returns a transformation matrix of shape ``(2l+1, 2l+1)``. + """ + if ell < 0 or not isinstance(ell, int): + raise ValueError("l must be a non-negative integer.") + + size = 2 * ell + 1 + T = np.zeros((size, size), dtype=complex) + + for m in range(-ell, ell + 1): + m_index = m + ell + if m > 0: + T[m_index, ell + m] = 1 / np.sqrt(2) * (-1) ** m + T[m_index, ell - m] = 1 / np.sqrt(2) + elif m < 0: + T[m_index, ell + abs(m)] = -1j / np.sqrt(2) * (-1) ** m + T[m_index, ell - abs(m)] = 1j / np.sqrt(2) + else: + T[m_index, ell] = 1 + + return T + + +def compute_real_wigner_matrices( + o3_lambda_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], +) -> Dict[int, torch.Tensor]: + """Build the real Wigner-D matrices for ``ell = 0..o3_lambda_max`` at the given + ZYZ Euler angles, using the cached complex-to-real transform per ell.""" + complex_to_real = { + ell: _complex_to_real_spherical_harmonics_transform(ell) + for ell in range(o3_lambda_max + 1) + } + return compute_real_wigner_d_matrices(o3_lambda_max, angles, complex_to_real) + + +def compute_wigner_batch( + ell_max: int, + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], + *, + device: torch.device, + dtype: torch.dtype, +) -> Dict[int, torch.Tensor]: + """Real Wigner-D matrices for ``ell = 0..ell_max`` at the given angles, cast to + the requested device and dtype.""" + return { + ell: tensor.to(device=device, dtype=dtype) + for ell, tensor in compute_real_wigner_matrices(ell_max, angles).items() + } diff --git a/python/metatomic_torch/metatomic/torch/symmetrized_model.py b/python/metatomic_torch/metatomic/torch/symmetrized_model.py new file mode 100644 index 000000000..711a7649d --- /dev/null +++ b/python/metatomic_torch/metatomic/torch/symmetrized_model.py @@ -0,0 +1,1470 @@ +import logging +import time +import warnings +from typing import Dict, List, Optional, Tuple + +import metatensor.torch as mts +import numpy as np +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatomic.torch import ( + ModelInterface, + ModelOutput, + System, + register_autograd_neighbors, +) +from metatomic.torch._augmentation import _apply_augmentations +from metatomic.torch._wigner import compute_wigner_batch as _compute_wigner_batch + + +LOGGER = logging.getLogger(__name__) + + +def _import_scipy(): + # deferred: scipy is only needed for quadrature/rotation construction at + # __init__ time; loading symmetrized_model.py for other helpers should not + # require it. + try: + from scipy.integrate import lebedev_rule + from scipy.spatial.transform import Rotation + except ImportError as e: + raise ImportError( + "scipy is required for SymmetrizedModel quadrature construction; " + "install it with `pip install scipy`." + ) from e + return lebedev_rule, Rotation + + +def _choose_quadrature(L_max: int) -> Tuple[int, int]: + """ + Choose a Lebedev quadrature order and number of in-plane rotations to integrate + spherical harmonics up to degree ``L_max``. + + :param L_max: maximum spherical harmonic degree + :return: (lebedev_order, n_inplane_rotations) + """ + available = [ + 3, + 5, + 7, + 9, + 11, + 13, + 15, + 17, + 19, + 21, + 23, + 25, + 27, + 29, + 31, + 35, + 41, + 47, + 53, + 59, + 65, + 71, + 77, + 83, + 89, + 95, + 101, + 107, + 113, + 119, + 125, + 131, + ] + # pick smallest order >= L_max + n = min(o for o in available if o >= L_max) + # minimal gamma count + K = L_max + 1 + return n, K + + +def get_euler_angles_quadrature(lebedev_order: int, n_rotations: int): + """ + Get the Euler angles and weights for a Lebedev quadrature combined with in-plane + rotations for SO(3) integration. + + :param lebedev_order: order of the Lebedev quadrature on the unit sphere + :param n_rotations: number of in-plane rotations per Lebedev node + :return: alpha, beta, gamma, w arrays of shape (M,), (M,), (K,), (M,) + respectively, where M is the number of Lebedev nodes and K is the number of + in-plane rotations. + """ + + lebedev_rule, _ = _import_scipy() + # Lebedev nodes (X: (3, M)) + X, w = lebedev_rule(lebedev_order) # w sums to 4*pi + x, y, z = X + alpha = np.arctan2(y, x) # (M,) + beta = np.arccos(np.clip(z, -1.0, 1.0)) # (M,) + gamma = np.linspace(0.0, 2 * np.pi, n_rotations, endpoint=False) # (K,) + + w_so3 = np.repeat(w / (4 * np.pi * n_rotations), repeats=gamma.size) # (M*K,) + + A = np.repeat(alpha, gamma.size) # (N,) + B = np.repeat(beta, gamma.size) # (N,) + G = np.tile(gamma, alpha.size) # (N,) + + return A, B, G, w_so3 + + +def _rotations_from_angles( + alpha: np.ndarray, beta: np.ndarray, gamma: np.ndarray +) -> "Rotation": # noqa: F821 — scipy is imported lazily + """ + Compose rotations from ZYZ Euler angles. + + :param alpha: array of alpha angles (M,) + :param beta: array of beta angles (M,) + :param gamma: array of gamma angles (K,) + :return: Rotation object containing all (M*K,) rotations + """ + + _, Rotation = _import_scipy() + # Compose ZYZ rotations in SO(3) + Rot = ( + Rotation.from_euler("z", alpha.reshape(-1, 1)) + * Rotation.from_euler("y", beta.reshape(-1, 1)) + * Rotation.from_euler("z", gamma.reshape(-1, 1)) + ) + + return Rot + + +def _transform_system(system: System, transformation: torch.Tensor) -> System: + transformed_system = System( + positions=system.positions @ transformation.T, + types=system.types, + cell=system.cell @ transformation.T, + pbc=system.pbc, + ) + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + + neighbors.values[:] = ( + neighbors.values.squeeze(-1) @ transformation.T + ).unsqueeze(-1) + + register_autograd_neighbors(transformed_system, neighbors) + transformed_system.add_neighbor_list(options, neighbors) + return transformed_system + + +def _l0_components_from_matrices(A: torch.Tensor) -> torch.Tensor: + """ + Extract the L=0 component (trace) from rank-2 Cartesian blocks. + + Expects ``A`` with shape ``(a, 3, 3, b)``; returns ``(a, 1, b)``. + """ + # move (3, 3) axes to the end for the assert and indexing below + A = A.permute(0, 3, 1, 2) + assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." + + # Trace as L=0 component; unsqueeze preserves the autograd graph + l0_A = (A[..., 0, 0] + A[..., 1, 1] + A[..., 2, 2]).unsqueeze(-1) + + l0_A = l0_A.permute(0, 2, 1) + return l0_A + + +def _l2_components_from_matrices(A: torch.Tensor) -> torch.Tensor: + """ + Extract the L=2 components (symmetric traceless part) from rank-2 Cartesian blocks. + + Expects ``A`` with shape ``(a, 3, 3, b)``; returns ``(a, 5, b)``. + """ + # move (3, 3) axes to the end for the assert and indexing below + A = A.permute(0, 3, 1, 2) + assert A.shape[-2:] == (3, 3), "The last two dimensions of A must be (3, 3)." + + # Use torch.stack to preserve the autograd graph + l2_A = torch.stack( + [ + (A[..., 0, 1] + A[..., 1, 0]) / 2.0, + (A[..., 1, 2] + A[..., 2, 1]) / 2.0, + (2.0 * A[..., 2, 2] - A[..., 0, 0] - A[..., 1, 1]) / (2.0 * np.sqrt(3.0)), + (A[..., 0, 2] + A[..., 2, 0]) / 2.0, + (A[..., 0, 0] - A[..., 1, 1]) / 2.0, + ], + dim=-1, + ) + + l2_A = l2_A.permute(0, 2, 1) + + return l2_A + + +def decompose_energy_tensor( + tensor_dict: Dict[str, TensorMap], +) -> Dict[str, TensorMap]: + """ + Decompose energy tensor into its L=0 irreducible representation. + + Energy is a scalar, so it lives entirely in the L=0 sector. This function + adds an ``o3_mu`` component axis with a single m=0 entry to make the format + consistent with higher-order decompositions. + + :param tensor_dict: dictionary of TensorMaps (modified in place) + :return: the same dictionary with ``"energy"`` replaced by ``"energy_l0"`` + """ + for key in ["energy", "energy_total"]: + if key not in tensor_dict: + continue + + tensor = tensor_dict[key] + tensor_dict[key + "_l0"] = TensorMap( + tensor.keys, + [ + TensorBlock( + values=block.values.unsqueeze(1), + samples=block.samples, + components=[ + Labels( + names=["o3_mu"], + values=torch.tensor( + [[0]], device=block.values.device, dtype=torch.int32 + ), + ) + ], + properties=block.properties, + ) + for block in tensor + ], + ) + tensor_dict.pop(key) + + return tensor_dict + + +def decompose_forces_tensor( + tensor_dict: Dict[str, TensorMap], +) -> Dict[str, TensorMap]: + """ + Decompose forces tensors into L=1 irreducible representations. + + Forces are Cartesian vectors (x, y, z). This reorders them to spherical + component order (y, z, x) → (m=-1, m=0, m=1) via a cyclic roll, and + labels the component axis as ``o3_mu``. + + Handles both ``"forces"`` (conservative) and ``"non_conservative_forces"`` keys. + + :param tensor_dict: dictionary of TensorMaps (modified in place) + :return: the same dictionary with forces keys replaced by ``"..._l1"`` variants + """ + for key in ["forces", "non_conservative_forces"]: + if key not in tensor_dict: + continue + + tensor = tensor_dict[key] + tensor_dict[key + "_l1"] = TensorMap( + tensor.keys, + [ + TensorBlock( + values=block.values.roll(-1, 1), + samples=block.samples, + components=[ + Labels( + names="o3_mu", + values=torch.tensor( + [[mu] for mu in range(-1, 2)], + device=block.values.device, + dtype=torch.int32, + ), + ) + ], + properties=block.properties, + ) + for block in tensor + ], + ) + tensor_dict.pop(key) + return tensor_dict + + +def decompose_stress_tensor( + tensor_dict: Dict[str, TensorMap], +) -> Dict[str, TensorMap]: + """ + Decompose stress tensors into L=0 (trace) and L=2 (symmetric traceless) parts. + + The 3x3 stress tensor decomposes as: trace (L=0 scalar) + symmetric traceless + (L=2, 5 components). The antisymmetric part (L=1) is zero for physical stress. + + Handles both ``"stress"`` (conservative) and ``"non_conservative_stress"`` keys. + + :param tensor_dict: dictionary of TensorMaps (modified in place) + :return: the same dictionary with stress keys replaced by ``"..._l0"`` and + ``"..._l2"`` variants + """ + for key in ["stress", "non_conservative_stress"]: + if key not in tensor_dict: + continue + + tensor = tensor_dict[key] + blocks_l0 = [] + blocks_l2 = [] + for block in tensor.blocks(): + trace_values = _l0_components_from_matrices(block.values) + block_l0 = TensorBlock( + values=trace_values, + samples=block.samples, + components=[ + Labels( + names=["o3_mu"], + values=torch.tensor( + [[0]], device=block.values.device, dtype=torch.int32 + ), + ) + ], + properties=block.properties, + ) + blocks_l0.append(block_l0) + + block_l2 = TensorBlock( + values=_l2_components_from_matrices(block.values), + samples=block.samples, + components=[ + Labels( + names="o3_mu", + values=torch.tensor( + [[mu] for mu in range(-2, 3)], + device=block.values.device, + dtype=torch.int32, + ), + ) + ], + properties=block.properties, + ) + blocks_l2.append(block_l2) + + tensor_dict[key + "_l0"] = TensorMap(tensor.keys, blocks_l0) + tensor_dict[key + "_l2"] = TensorMap(tensor.keys, blocks_l2) + tensor_dict.pop(key) + + return tensor_dict + + +def decompose_tensors( + tensor_dict: Dict[str, TensorMap], +) -> Dict[str, TensorMap]: + """ + Decompose all tensors in the dictionary into irreducible representations of O(3). + + :param tensor_dict: dictionary of TensorMaps to decompose + :return: dictionary of TensorMaps with decomposed tensors + """ + tensor_dict = decompose_energy_tensor(tensor_dict) + tensor_dict = decompose_forces_tensor(tensor_dict) + tensor_dict = decompose_stress_tensor(tensor_dict) + return tensor_dict + + +def _maybe_add_energy_total(tensor_dict: Dict[str, TensorMap]) -> Dict[str, TensorMap]: + tensor_dict = dict(tensor_dict) + if ( + "energy" in tensor_dict + and "atom" in tensor_dict["energy"].block().samples.names + ): + tensor_dict["energy_total"] = mts.sum_over_samples( + tensor_dict["energy"], ["atom"] + ) + return tensor_dict + + +def _normalize_output_tensors( + name: str, + tensor: TensorMap, +) -> Dict[str, TensorMap]: + return decompose_tensors(_maybe_add_energy_total({name: tensor})) + + +def _tensor_map_dtype(tensor: TensorMap) -> torch.dtype: + return tensor.block().values.dtype + + +def _key_to_tuple(key_entry) -> Tuple[int, ...]: + return tuple(int(v) for v in key_entry.values.tolist()) + + +def _prepend_system_to_samples( + sample_names: List[str], + sample_values: torch.Tensor, + system_index: int, + *, + device: torch.device, +) -> Labels: + system_values = torch.full( + (sample_values.shape[0], 1), + system_index, + dtype=torch.int32, + device=device, + ) + if len(sample_names) == 0: + return Labels(["system"], system_values) + + return Labels( + ["system"] + sample_names, + torch.cat( + [system_values, sample_values.to(device=device, dtype=torch.int32)], dim=1 + ), + ) + + +def _selected_atoms_for_local_systems( + selected_atoms: Optional[Labels], + system_index: int, + n_local_systems: int, +) -> Optional[Labels]: + if selected_atoms is None: + return None + + system_mask = selected_atoms.column("system").to(dtype=torch.long) == system_index + system_selected_atoms = selected_atoms.values[system_mask] + if system_selected_atoms.shape[0] == 0: + return Labels( + list(selected_atoms.names), + selected_atoms.values.new_empty((0, len(selected_atoms.names))), + ) + + local_selected_atoms: List[torch.Tensor] = [] + for local_system_index in range(n_local_systems): + local_values = system_selected_atoms.clone() + local_values[:, 0] = local_system_index + local_selected_atoms.append(local_values) + + return Labels( + list(selected_atoms.names), + torch.cat(local_selected_atoms, dim=0), + ) + + +def _reshape_block_by_local_system( + block: TensorBlock, n_local_systems: int +) -> Tuple[torch.Tensor, List[str], torch.Tensor]: + local_ids = block.samples.column("system").to(dtype=torch.long) + if len(local_ids) != 0: + min_local_id = int(torch.min(local_ids).item()) + max_local_id = int(torch.max(local_ids).item()) + if min_local_id < 0 or max_local_id >= n_local_systems: + raise ValueError( + "Encountered output samples with out-of-range system indices." + ) + + split_values: List[torch.Tensor] = [] + base_sample_values: Optional[torch.Tensor] = None + for local_system_index in range(n_local_systems): + local_mask = local_ids == local_system_index + local_values = block.values[local_mask] + local_sample_values = block.samples.values[local_mask][:, 1:] + if base_sample_values is None: + base_sample_values = local_sample_values + elif not torch.equal(local_sample_values, base_sample_values): + raise ValueError( + "Streaming SymmetrizedModel expects each rotated copy of a system to " + "produce the same sample labels in the same order." + ) + split_values.append(local_values) + + assert base_sample_values is not None + stacked_values = torch.stack(split_values, dim=0) + return stacked_values, list(block.samples.names[1:]), base_sample_values + + +def _reduce_weighted_batch_tensor( + tensor: TensorMap, + weights: torch.Tensor, + system_index: int, + *, + component_norm: bool = False, +) -> TensorMap: + # Weights are halved because the outer pipeline averages over the two-element + # inversion subgroup by summing the +1 and -1 contributions. + n_local_systems = weights.numel() + reduced_blocks: List[TensorBlock] = [] + for block in tensor.blocks(): + values, sample_names, sample_values = _reshape_block_by_local_system( + block, n_local_systems + ) + + components = block.components + if component_norm: + component_dims = tuple(range(2, 2 + len(block.components))) + if len(component_dims) == 0: + values = values**2 + else: + values = torch.sum(values**2, dim=component_dims) + components = [] + + weight = weights.to(dtype=values.dtype, device=values.device) + view = [values.shape[0]] + [1] * (values.ndim - 1) + reduced_values = torch.sum(0.5 * weight.view(view) * values, dim=0) + if reduced_values.ndim == 1: + reduced_values = reduced_values.unsqueeze(0) + + reduced_blocks.append( + TensorBlock( + values=reduced_values, + samples=_prepend_system_to_samples( + sample_names, + sample_values, + system_index, + device=block.samples.values.device, + ), + components=components, + properties=block.properties, + ) + ) + + return TensorMap(tensor.keys, reduced_blocks) + + +def _accumulate_tensormap( + accumulators: Dict[str, TensorMap], name: str, contribution: TensorMap +) -> None: + if name in accumulators: + accumulators[name] = mts.add(accumulators[name], contribution) + else: + accumulators[name] = contribution + + +def _append_tensormap( + accumulators: Dict[str, List[TensorMap]], name: str, contribution: TensorMap +) -> None: + accumulators.setdefault(name, []).append(contribution) + + +def _join_tensormap_list(tensors: List[TensorMap]) -> TensorMap: + if len(tensors) == 1: + return tensors[0] + return mts.join(tensors, "samples", different_keys="union") + + +def _mean_norm_squared_tensor(tensor: TensorMap) -> TensorMap: + blocks: List[TensorBlock] = [] + for block in tensor.blocks(): + if block.values.ndim > 2: + values = torch.sum( + block.values**2, dim=tuple(range(1, block.values.ndim - 1)) + ) + else: + values = block.values**2 + if values.ndim == 1: + values = values.unsqueeze(0) + blocks.append( + TensorBlock( + values=values, + samples=block.samples, + components=[], + properties=block.properties, + ) + ) + return TensorMap(tensor.keys, blocks) + + +def _finalize_variance( + second_moment: TensorMap, + mean: TensorMap, +) -> TensorMap: + mean_norm_sq = _mean_norm_squared_tensor(mean) + return mts.subtract(second_moment, mean_norm_sq) + + +def _compute_batch_projection_contributions( + tensor: TensorMap, + weights: torch.Tensor, + wigner_matrices: Dict[int, torch.Tensor], + max_o3_lambda_character: int, + *, + storage_device: Optional[torch.device] = None, +) -> Dict[Tuple[int, ...], Dict[str, object]]: + n_local_systems = weights.numel() + block_contributions: Dict[Tuple[int, ...], Dict[str, object]] = {} + for key, block in tensor.items(): + key_tuple = _key_to_tuple(key) + values, sample_names, sample_values = _reshape_block_by_local_system( + block, n_local_systems + ) + weight = weights.to(dtype=values.dtype, device=values.device) + weighted_values = ( + weight.view([weight.shape[0]] + [1] * (values.ndim - 1)) * values + ) + + coefficients: Dict[int, torch.Tensor] = {} + for ell in range(max_o3_lambda_character + 1): + D = wigner_matrices[ell].to(dtype=values.dtype, device=values.device) + coefficient = torch.einsum("imn,i...->mn...", D, weighted_values) + if storage_device is not None and coefficient.device != storage_device: + coefficient = coefficient.to(device=storage_device) + coefficients[ell] = coefficient + + key_values = key.values.clone() + sample_values_out = sample_values.clone() + components = list(block.components) + properties = block.properties + if storage_device is not None: + key_values = key_values.to(device=storage_device) + sample_values_out = sample_values_out.to(device=storage_device) + components = [component.to(device=storage_device) for component in block.components] + properties = block.properties.to(device=storage_device) + + block_contributions[key_tuple] = { + "key_names": list(tensor.keys.names), + "key_values": key_values, + "sample_names": sample_names, + "sample_values": sample_values_out, + "components": components, + "properties": properties, + "coefficients": coefficients, + } + + return block_contributions + + +def _merge_projection_contributions( + accumulator: Dict[Tuple[int, ...], Dict[str, object]], + contribution: Dict[Tuple[int, ...], Dict[str, object]], +) -> None: + for key_tuple, entry in contribution.items(): + if key_tuple not in accumulator: + accumulator[key_tuple] = entry + continue + existing = accumulator[key_tuple] + existing_coefficients = existing["coefficients"] + contribution_coefficients = entry["coefficients"] + assert isinstance(existing_coefficients, dict) + assert isinstance(contribution_coefficients, dict) + for ell, tensor in contribution_coefficients.items(): + if ell in existing_coefficients: + existing_coefficients[ell] = existing_coefficients[ell] + tensor + else: + existing_coefficients[ell] = tensor + + +def _finalize_projection_tensor( + positive: Dict[Tuple[int, ...], Dict[str, object]], + negative: Dict[Tuple[int, ...], Dict[str, object]], + system_index: int, + max_o3_lambda_character: int, +) -> Optional[TensorMap]: + all_keys = list(positive.keys()) + for key in negative.keys(): + if key not in positive: + all_keys.append(key) + + if len(all_keys) == 0: + return None + + blocks: List[TensorBlock] = [] + key_values: List[torch.Tensor] = [] + key_names: Optional[List[str]] = None + for key_tuple in all_keys: + plus_entry = positive.get(key_tuple) + minus_entry = negative.get(key_tuple) + meta = plus_entry if plus_entry is not None else minus_entry + assert meta is not None + + key_names = list(meta["key_names"]) + key_tensor = meta["key_values"] + sample_names = meta["sample_names"] + sample_values = meta["sample_values"] + components = meta["components"] + properties = meta["properties"] + plus_coeffs = plus_entry["coefficients"] if plus_entry is not None else {} + minus_coeffs = minus_entry["coefficients"] if minus_entry is not None else {} + + for ell in range(max_o3_lambda_character + 1): + plus_tensor = plus_coeffs.get(ell) + minus_tensor = minus_coeffs.get(ell) + if plus_tensor is None and minus_tensor is None: + continue + if plus_tensor is None: + plus_tensor = torch.zeros_like(minus_tensor) + if minus_tensor is None: + minus_tensor = torch.zeros_like(plus_tensor) + + parity = (-1) ** ell + for sigma in [1, -1]: + combined = plus_tensor + sigma * parity * minus_tensor + values = ( + 0.25 * (2 * ell + 1) * torch.sum(combined * combined, dim=(0, 1)) + ) + blocks.append( + TensorBlock( + values=values, + samples=_prepend_system_to_samples( + sample_names, + sample_values, + system_index, + device=values.device, + ), + components=components, + properties=properties, + ) + ) + key_values.append( + torch.cat( + [ + key_tensor, + torch.tensor( + [ell, sigma], + dtype=key_tensor.dtype, + device=key_tensor.device, + ), + ] + ) + ) + + assert key_names is not None + tensor = TensorMap( + Labels(key_names + ["chi_lambda", "chi_sigma"], torch.stack(key_values)), + blocks, + ) + if "_" in tensor.keys.names: + tensor = mts.remove_dimension(tensor, "keys", "_") + return tensor + + +def _slice_angles( + angles: Tuple[np.ndarray, np.ndarray, np.ndarray], + start: int, + stop: int, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + return tuple(angle[start:stop] for angle in angles) + + +class SymmetrizedModel(torch.nn.Module): + r""" + Wrapper around an atomistic model that symmetrizes its outputs over :math:`O(3)` + and computes equivariance metrics. + + The model is evaluated over a quadrature grid on :math:`O(3)`, constructed from a + Lebedev grid supplemented by in-plane rotations. For each sampled group element, the + model outputs are "back-rotated" according to the known :math:`O(3)` action + appropriate for their tensorial type (scalar, vector, tensor, etc.). Averaging these + back-rotated predictions over the quadrature grid yields fully + :math:`O(3)`-symmetrized outputs. In addition, two complementary equivariance + metrics are computed: + + 1. Variance under :math:`O(3)` of the back-rotated outputs. + + For a perfectly equivariant model, the back-rotated output :math:`x(g)` is + independent of the group element :math:`g`. Deviations from perfect equivariance + are quantified by the difference between the average squared norm over + :math:`O(3)` and the squared norm of the :math:`O(3)`-averaged output: + + .. math:: + + \mathrm{Var}_{O(3)}[x] + = + \left\langle \,\| x(g) \|^{2} \,\right\rangle_{O(3)} + - + \left\| \left\langle x(g) \right\rangle_{O(3)} \right\|^{2} . + + Here, :math:`\|\cdot\|` denotes the Euclidean norm over the ``component`` axis, + and :math:`\langle \cdot \rangle_{O(3)}` denotes averaging over the quadrature + grid. This quantity is the squared norm of the component orthogonal to the + perfectly equivariant subspace and therefore provides a scalar measure of the + deviation from exact equivariance. + + 2. Decomposition into isotypical components of :math:`O(3)`. + + Each output component may be viewed as a scalar function on :math:`O(3)`, + which can be decomposed into isotypical components labeled by the irreducible + representations :math:`\ell,\sigma` of :math:`O(3)`. The projection onto the + :math:`(\ell,\sigma)`-th isotypical subspace is computed as a convolution with + the corresponding character :math:`\chi_{\ell}`: + + .. math:: + + (P_{\ell,\sigma} x)(g) + = + \int_{O(3)} \chi_{\ell,\sigma}(h^{-1} g)\, x(h)\, \mathrm{d}\mu(h). + + Its squared :math:`L^{2}` norm over :math:`O(3)` is + + .. math:: + + \| P_{\ell,\sigma} x \|^{2} + = + \left\langle \, | (P_{\ell,\sigma} x)(g) |^{2} \, \right\rangle_{O(3)} . + + These quantities describe how the model output is distributed across the + different :math:`O(3)` irreducible sectors. The complementary component, + orthogonal to all isotypical subspaces, is given by + + .. math:: + + \| x \|^{2} + - + \sum_{\ell,\sigma} \| P_{\ell,\sigma} x \|^{2} , + + and provides a refined measure of the deviation from lying entirely within any + prescribed set of :math:`O(3)` irreducible representations. + + :param base_model: atomistic model to symmetrize + :param max_o3_lambda: maximum O(3) angular momentum the grid integrates exactly + :param batch_size: number of rotations to evaluate in a single batch + :param max_o3_lambda_character: maximum O(3) angular momentum for character + projections. If None, set to ``max_o3_lambda``. + :param offload_to_cpu: controls whether base-model outputs are moved to CPU + immediately after each forward pass. If ``True``, outputs are moved to CPU + before back-rotation and accumulation, trading bandwidth for GPU memory. + If ``None`` (default), the policy is chosen automatically per forward call: + offload when ``compute_gradients=False`` (no need to keep the graph alive, + memory savings dominate), do not offload when ``compute_gradients=True`` + (the gradient graph must stay on the model device). + :param storage_device: device on which long-lived accumulators are kept and on + which the returned TensorMaps are placed. If ``None``, follow the device + used during back-rotation (CPU when offloaded, model device otherwise). + When set explicitly, this takes precedence over ``offload_to_cpu`` for the + final output placement. + """ + + def __init__( + self, + base_model, + max_o3_lambda_character: int, + max_o3_lambda_target: int, + batch_size: int = 32, + max_o3_lambda_grid: Optional[int] = None, + offload_to_cpu: Optional[bool] = None, + storage_device: Optional[str] = None, + ): + super().__init__() + self.base_model = base_model + + try: + ref_param = next(base_model.parameters()) + device = ref_param.device + dtype = ref_param.dtype + except StopIteration: + device = torch.device("cpu") + dtype = torch.get_default_dtype() + + self.max_o3_lambda_target = max_o3_lambda_target + self.batch_size = batch_size + self.offload_to_cpu = offload_to_cpu + self.storage_device = ( + None if storage_device is None else torch.device(storage_device) + ) + if max_o3_lambda_grid is None: + max_o3_lambda_grid = int(2 * max_o3_lambda_character + 1) + self.max_o3_lambda_grid = max_o3_lambda_grid + self.max_o3_lambda_character = max_o3_lambda_character + + # Compute grid (unchanged) + lebedev_order, n_inplane_rotations = _choose_quadrature(self.max_o3_lambda_grid) + if lebedev_order < 2 * self.max_o3_lambda_character: + warnings.warn( + "Lebedev order may be insufficient for character projections.", + stacklevel=2, + ) + alpha, beta, gamma, w_so3 = get_euler_angles_quadrature( + lebedev_order, n_inplane_rotations + ) + so3_weights = torch.from_numpy(w_so3).to(device=device, dtype=dtype) + self.register_buffer("so3_weights", so3_weights) + + so3_rotations = torch.from_numpy( + _rotations_from_angles(alpha, beta, gamma).as_matrix() + ).to(device=device, dtype=dtype) + self.register_buffer("so3_rotations", so3_rotations) + + angles_inverse_rotations = (np.pi - gamma, beta, np.pi - alpha) + so3_inverse_rotations = torch.from_numpy( + _rotations_from_angles(*angles_inverse_rotations).as_matrix() + ).to(device=device, dtype=dtype) + self.register_buffer("so3_inverse_rotations", so3_inverse_rotations) + self._inverse_quadrature_angles = angles_inverse_rotations + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + project_tokens: bool = False, + compute_gradients: bool = False, + ) -> Dict[str, TensorMap]: + """ + Symmetrize the model outputs over :math:`O(3)` and compute equivariance + metrics. + + :param systems: list of systems to evaluate + :param outputs: dictionary of model outputs to symmetrize + :param selected_atoms: optional Labels specifying which atoms to consider + :param project_tokens: if True, also compute character projections + :param compute_gradients: if True, compute conservative forces and stress + via autograd. When False (default), the grid evaluation runs under + ``torch.no_grad()`` to save memory. + :return: dictionary with symmetrized outputs and equivariance metrics + """ + device = self.so3_weights.device + if self.offload_to_cpu is None: + offload = not compute_gradients + else: + offload = self.offload_to_cpu + if self.storage_device is not None: + result_device = self.storage_device + else: + result_device = torch.device("cpu") if offload else device + + with torch.enable_grad() if compute_gradients else torch.no_grad(): + transformed_outputs, backtransformed_outputs = self._eval_over_grid( + systems, + outputs, + selected_atoms, + return_transformed=project_tokens, + compute_gradients=compute_gradients, + offload_to_cpu=offload, + ) + + out_dict: Dict[str, TensorMap] = { + name: tensor.to(device=result_device) + for name, tensor in backtransformed_outputs.items() + } + + if not project_tokens: + return out_dict + + for name, tensor in transformed_outputs.items(): + out_dict[name] = tensor.to(device=result_device) + + return out_dict + + def _eval_over_grid( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + return_transformed: bool, + compute_gradients: bool = False, + offload_to_cpu: bool = False, + ) -> Tuple[Dict[str, TensorMap], Dict[str, TensorMap]]: + """ + Stream the model over the O(3) quadrature, accumulating mean, variance, and + character projections without ever materializing the full-grid TensorMap. + + :param systems: list of systems to evaluate + :param outputs: dictionary of model outputs to symmetrize + :param selected_atoms: optional Labels specifying which atoms to consider + :param return_transformed: if True, also return un-back-rotated outputs + :param compute_gradients: if True, compute forces/stress via autograd + :return: (transformed_outputs, backtransformed_outputs) dictionaries + """ + eval_start = time.perf_counter() + accumulator_device = self.storage_device + n_rotations = self.so3_rotations.size(0) + requested_output_names = list(outputs.keys()) + if compute_gradients: + if "energy" not in outputs: + raise ValueError("compute_gradients=True requires 'energy' in outputs") + + requested_output_names = list( + dict.fromkeys(requested_output_names + ["forces"]) + ) + if any(bool(torch.any(s.pbc).item()) for s in systems): + requested_output_names = list( + dict.fromkeys(requested_output_names + ["stress"]) + ) + + mean_accumulators: Dict[str, List[TensorMap]] = {} + second_moment_accumulators: Dict[str, List[TensorMap]] = {} + character_projection_accumulators: Dict[str, List[TensorMap]] = {} + proj_pos_accumulators: Dict[str, List[TensorMap]] = {} + proj_neg_accumulators: Dict[str, List[TensorMap]] = {} + + for i_sys, system in enumerate(systems): + system_start = time.perf_counter() + system_mean_accumulators: Dict[str, TensorMap] = {} + system_second_moment_accumulators: Dict[str, TensorMap] = {} + system_proj_pos_accumulators: Dict[ + str, Dict[Tuple[int, ...], Dict[str, object]] + ] = {} + system_proj_neg_accumulators: Dict[ + str, Dict[Tuple[int, ...], Dict[str, object]] + ] = {} + + work_device = system.positions.device + work_dtype = system.positions.dtype + + batch_starts = range(0, n_rotations, self.batch_size) + + # Loop order: outer = batch_start, inner = inversion. This lets us + # compute the (expensive) batch Wigner matrices exactly once per + # batch and reuse them across both inversion passes — Wigner is + # parity-invariant; inversion enters only via _apply_wigner_D_matrices. + for batch_start in batch_starts: + batch_stop = min(batch_start + self.batch_size, n_rotations) + n_local_systems = batch_stop - batch_start + weights = self.so3_weights[batch_start:batch_stop] + batch_rotations = self.so3_rotations[batch_start:batch_stop].to( + device=work_device, dtype=work_dtype + ) + local_selected_atoms = _selected_atoms_for_local_systems( + selected_atoms, + i_sys, + n_local_systems, + ) + + # cache for parity-invariant batch data (Wigner + augmentation system), + # keyed by output device/dtype; survives both inversion passes for this + # batch_start, then is discarded — peak memory unchanged + batch_wigner_cache: Dict[ + Tuple[str, str], + Tuple[System, Dict[int, torch.Tensor], Dict[int, List[torch.Tensor]]], + ] = {} + + for inversion in (1, -1): + inversion_batch = inversion * batch_rotations + + if compute_gradients: + out = _evaluate_with_gradients( + self.base_model, + system, + inversion_batch, + outputs, + local_selected_atoms, + work_device, + system.positions.dtype, + ) + else: + transformed_systems = [ + _transform_system( + system, + R.to(device=work_device), + ) + for R in inversion_batch + ] + out = self.base_model( + transformed_systems, + outputs, + local_selected_atoms, + ) + if offload_to_cpu: + out = {k: v.to(device="cpu") for k, v in out.items()} + + present_output_names = [ + name for name in requested_output_names if name in out + ] + if len(present_output_names) == 0: + continue + + # per-inversion cache for inverse_rotations (parity-dependent) + batch_inverse_rotations_cache: Dict[ + Tuple[str, str], List[torch.Tensor] + ] = {} + + for name in present_output_names: + tensor = out[name] + tensor_dtype = _tensor_map_dtype(tensor) + tensor_device = tensor.block().values.device + cache_key = (str(tensor_device), str(tensor_dtype)) + + if cache_key not in batch_wigner_cache: + augmentation_system = ( + system + if system.positions.device == tensor_device + else system.to( + device=tensor_device, + dtype=system.positions.dtype, + ) + ) + batch_wigner = _compute_wigner_batch( + max( + self.max_o3_lambda_target, + self.max_o3_lambda_character, + ), + _slice_angles( + self._inverse_quadrature_angles, + batch_start, + batch_stop, + ), + device=tensor_device, + dtype=tensor_dtype, + ) + wigner_dict: Dict[int, List[torch.Tensor]] = { + ell: list(mat.unbind(0)) + for ell, mat in batch_wigner.items() + } + batch_wigner_cache[cache_key] = ( + augmentation_system, + batch_wigner, + wigner_dict, + ) + + ( + augmentation_system, + batch_wigner, + wigner_dict, + ) = batch_wigner_cache[cache_key] + + if cache_key not in batch_inverse_rotations_cache: + inverse_mats = ( + inversion + * self.so3_inverse_rotations[batch_start:batch_stop] + ).to(device=tensor_device, dtype=tensor_dtype) + batch_inverse_rotations_cache[cache_key] = list( + inverse_mats.unbind(0) + ) + inverse_rotations = batch_inverse_rotations_cache[cache_key] + + _, backtransformed_batch, _ = _apply_augmentations( + [augmentation_system] * n_local_systems, + {name: tensor}, + inverse_rotations, + wigner_dict, + ) + # only needed for character projections, so skip the + # allocation entirely when return_transformed is off + direct_normalized: Optional[Dict[str, TensorMap]] = ( + _normalize_output_tensors(name, tensor) + if return_transformed + else None + ) + for ( + final_name, + backtransformed_tensor, + ) in _normalize_output_tensors( + name, + backtransformed_batch[name], + ).items(): + mean_batch = _reduce_weighted_batch_tensor( + backtransformed_tensor, + weights, + i_sys, + component_norm=False, + ) + if ( + accumulator_device is not None + and mean_batch.block().values.device != accumulator_device + ): + mean_batch = mean_batch.to(device=accumulator_device) + _accumulate_tensormap( + system_mean_accumulators, final_name, mean_batch + ) + + second_moment_batch = _reduce_weighted_batch_tensor( + backtransformed_tensor, + weights, + i_sys, + component_norm=True, + ) + if ( + accumulator_device is not None + and second_moment_batch.block().values.device + != accumulator_device + ): + second_moment_batch = second_moment_batch.to( + device=accumulator_device + ) + _accumulate_tensormap( + system_second_moment_accumulators, + final_name, + second_moment_batch, + ) + + if return_transformed: + assert direct_normalized is not None + projection_storage_device = accumulator_device + # direct_normalized and the iteration above are + # both built from `name`, so the key sets must + # match — KeyError here means a real bug + direct_tensor = direct_normalized[final_name] + block_contribution = ( + _compute_batch_projection_contributions( + direct_tensor, + weights, + batch_wigner, + self.max_o3_lambda_character, + storage_device=projection_storage_device, + ) + ) + accumulators = ( + system_proj_pos_accumulators + if inversion == 1 + else system_proj_neg_accumulators + ) + accumulators.setdefault(final_name, {}) + _merge_projection_contributions( + accumulators[final_name], block_contribution + ) + + if return_transformed: + projection_names = set(system_proj_pos_accumulators) | set( + system_proj_neg_accumulators + ) + for name in projection_names: + char_proj = _finalize_projection_tensor( + system_proj_pos_accumulators.get(name, {}), + system_proj_neg_accumulators.get(name, {}), + i_sys, + self.max_o3_lambda_character, + ) + if char_proj is not None: + character_projection_accumulators.setdefault(name, []).append( + char_proj + ) + + proj_pos_final = _finalize_projection_tensor( + system_proj_pos_accumulators.get(name, {}), + {}, + i_sys, + self.max_o3_lambda_character, + ) + if proj_pos_final is not None: + proj_pos_accumulators.setdefault(name, []).append( + proj_pos_final + ) + + proj_neg_final = _finalize_projection_tensor( + {}, + system_proj_neg_accumulators.get(name, {}), + i_sys, + self.max_o3_lambda_character, + ) + if proj_neg_final is not None: + proj_neg_accumulators.setdefault(name, []).append( + proj_neg_final + ) + + for name, tensor in system_mean_accumulators.items(): + _append_tensormap(mean_accumulators, name, tensor) + + for name, tensor in system_second_moment_accumulators.items(): + _append_tensormap(second_moment_accumulators, name, tensor) + + LOGGER.debug( + "SymmetrizedModel progress: system %s/%s finished in %.2fs " + "(project_tokens=%s, outputs=%s, batch_size=%s, offload_to_cpu=%s)", + i_sys + 1, + len(systems), + time.perf_counter() - system_start, + return_transformed, + len(requested_output_names), + self.batch_size, + offload_to_cpu, + ) + + mean_results: Dict[str, TensorMap] = {} + for name, mean_tensors in mean_accumulators.items(): + mean_tensor = _join_tensormap_list(mean_tensors) + # expose under both keys for downstream compat: callers that look up + # the symmetrized output by its bare name and callers that expect the + # "_mean" suffix both get the same TensorMap object. + mean_results[name] = mean_tensor + mean_results[name + "_mean"] = mean_tensor + + norm_squared = _join_tensormap_list(second_moment_accumulators[name]) + mean_results[name + "_var"] = _finalize_variance( + norm_squared, + mean_tensor, + ) + mean_results[name + "_norm_squared"] = norm_squared + + if not return_transformed: + return {}, mean_results + + transformed_results: Dict[str, TensorMap] = {} + backtransformed_results = dict(mean_results) + + for name, tensors in character_projection_accumulators.items(): + if tensors: + backtransformed_results[name + "_character_projection"] = ( + _join_tensormap_list(tensors) + ) + + for name, tensors in proj_pos_accumulators.items(): + if tensors: + transformed_results[name + "_character_projection_plus"] = ( + tensors[0] if len(tensors) == 1 else _join_tensormap_list(tensors) + ) + + for name, tensors in proj_neg_accumulators.items(): + if tensors: + transformed_results[name + "_character_projection_minus"] = ( + tensors[0] if len(tensors) == 1 else _join_tensormap_list(tensors) + ) + + LOGGER.debug( + "SymmetrizedModel finished %s systems in %.2fs " + "(project_tokens=%s, outputs=%s, batch_size=%s, offload_to_cpu=%s)", + len(systems), + time.perf_counter() - eval_start, + return_transformed, + len(requested_output_names), + self.batch_size, + offload_to_cpu, + ) + + return transformed_results, backtransformed_results + + +def _evaluate_with_gradients( + model: ModelInterface, + system: System, + rotations: torch.Tensor, + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels], + device: torch.device, + dtype: torch.dtype, +) -> Dict[str, TensorMap]: + """ + Evaluate model on a batch of rotated copies of one system and compute conservative + forces/stress via autograd. + + Forces are ``-dE/d(positions)`` in each rotated frame; stress is ``(1/V) dE/d(strain)`` + via the strain trick. Both are packaged as Cartesian TensorMaps with one entry per + rotation (sample axis = ``[system, atom]`` for forces, ``[system]`` for stress) so the + downstream back-rotation pipeline can treat them like any other per-system output. + + :param model: atomistic model to evaluate + :param system: input system (original frame) + :param rotations: ``(N, 3, 3)`` rotation matrices (each may include inversion) + :param outputs: model output specifications + :param selected_atoms: optional atom selection (in the local batch index space) + :param device: device for tensors + :param dtype: dtype for tensors + :return: model output dict with added ``"forces"`` and (if periodic) ``"stress"`` + """ + if rotations.dim() != 3 or rotations.shape[-2:] != (3, 3): + raise ValueError( + f"rotations must have shape (N, 3, 3), got {tuple(rotations.shape)}" + ) + + n_rot = rotations.shape[0] + n_atoms = system.positions.shape[0] + has_cell = bool(torch.any(system.pbc).item()) + + rotated_positions_list: List[torch.Tensor] = [] + strain_list: List[torch.Tensor] = [] + transformed_systems: List[System] = [] + + detached_positions = system.positions.detach() + detached_cell = system.cell.detach() + # hoist device/dtype cast out of the per-rotation loop + rotations = rotations.to(device=device, dtype=dtype) + + for i in range(n_rot): + R = rotations[i] + + rotated_positions = (detached_positions @ R.T).requires_grad_(True) + rotated_cell = detached_cell @ R.T + rotated_positions_list.append(rotated_positions) + + if has_cell: + strain = torch.eye(3, requires_grad=True, device=device, dtype=dtype) + final_positions = rotated_positions @ strain + final_cell = rotated_cell @ strain + strain_list.append(strain) + else: + final_positions = rotated_positions + final_cell = rotated_cell + + transformed = System( + types=system.types, + positions=final_positions, + cell=final_cell, + pbc=system.pbc, + ) + + # each rotated copy needs its own neighbor list block so autograd can + # flow through the rotated positions independently per system + for options in system.known_neighbor_lists(): + neighbors = mts.detach_block(system.get_neighbor_list(options)) + neighbors.values[:] = (neighbors.values.squeeze(-1) @ R.T).unsqueeze(-1) + register_autograd_neighbors(transformed, neighbors) + transformed.add_neighbor_list(options, neighbors) + + transformed_systems.append(transformed) + + out = model(transformed_systems, outputs, selected_atoms) + + if "energy" not in out: + raise ValueError("compute_gradients=True requires the model to output 'energy'") + + # The model treats the N systems independently, so d(sum)/d(rotated_positions[i]) + # equals dE_i/d(rotated_positions[i]) — no cross-system contamination. + energy_sum = out["energy"].block().values.sum() + + grad_targets: List[torch.Tensor] = list(rotated_positions_list) + if has_cell: + grad_targets.extend(strain_list) + grads = torch.autograd.grad(energy_sum, grad_targets, create_graph=False) + + position_grads = grads[:n_rot] + strain_grads = grads[n_rot:] if has_cell else [] + + forces_values = torch.cat([-g for g in position_grads], dim=0) # (n_rot*n_atoms, 3) + + atom_range = torch.arange(n_atoms, dtype=torch.int64, device=device) + system_indices = torch.arange( + n_rot, dtype=torch.int64, device=device + ).repeat_interleave(n_atoms) + atom_indices = atom_range.repeat(n_rot) + forces_samples = Labels( + names=["system", "atom"], + values=torch.stack([system_indices, atom_indices], dim=1), + ) + + key_labels = Labels( + names=["_"], + values=torch.tensor([[0]], dtype=torch.int64, device=device), + ) + + forces_block = TensorBlock( + values=forces_values.unsqueeze(-1), # (n_rot*n_atoms, 3, 1) + samples=forces_samples, + components=[ + Labels( + "xyz", + torch.arange(3, dtype=torch.int64, device=device).reshape(-1, 1), + ) + ], + properties=Labels( + names=["force"], + values=torch.tensor([[0]], dtype=torch.int64, device=device), + ), + ) + forces_tmap = TensorMap(key_labels, [forces_block]) + if selected_atoms is not None: + forces_tmap = mts.slice(forces_tmap, axis="samples", selection=selected_atoms) + out["forces"] = forces_tmap + + if has_cell: + # volume is rotation-invariant, so the original cell volume is correct for + # every rotated copy + volume = torch.abs(torch.linalg.det(detached_cell)) + stress_values = torch.stack(strain_grads, dim=0) / volume # (n_rot, 3, 3) + + stress_block = TensorBlock( + values=stress_values.unsqueeze(-1), # (n_rot, 3, 3, 1) + samples=Labels( + names=["system"], + values=torch.arange( + n_rot, dtype=torch.int64, device=device + ).reshape(-1, 1), + ), + components=[ + Labels( + "xyz_1", + torch.arange(3, dtype=torch.int64, device=device).reshape(-1, 1), + ), + Labels( + "xyz_2", + torch.arange(3, dtype=torch.int64, device=device).reshape(-1, 1), + ), + ], + properties=Labels( + names=["stress"], + values=torch.tensor([[0]], dtype=torch.int64, device=device), + ), + ) + out["stress"] = TensorMap(key_labels, [stress_block]) + + return out diff --git a/python/metatomic_torch/tests/augmentation.py b/python/metatomic_torch/tests/augmentation.py new file mode 100644 index 000000000..248e9bf4e --- /dev/null +++ b/python/metatomic_torch/tests/augmentation.py @@ -0,0 +1,275 @@ +import numpy as np +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap + +from metatomic.torch import System +from metatomic.torch._augmentation import _apply_augmentations +from metatomic.torch._wigner import compute_wigner_batch + + +def _make_system(types, positions=None, cell=None, pbc=None): + n_atoms = len(types) + if positions is None: + positions = torch.zeros((n_atoms, 3), dtype=torch.float64) + if cell is None: + cell = torch.zeros((3, 3), dtype=torch.float64) + if pbc is None: + pbc = torch.tensor([False, False, False]) + return System( + types=torch.tensor(types, dtype=torch.int32), + positions=positions, + cell=cell, + pbc=pbc, + ) + + +def _rotation_batch(alphas): + transformations = [] + for alpha in alphas: + transformations.append( + torch.tensor( + [ + [np.cos(alpha), -np.sin(alpha), 0.0], + [np.sin(alpha), np.cos(alpha), 0.0], + [0.0, 0.0, 1.0], + ], + dtype=torch.float64, + ) + ) + + zeros = np.zeros(len(alphas)) + wigner_D_matrices = { + ell: list(matrix.unbind(0)) + for ell, matrix in compute_wigner_batch( + 1, + (np.asarray(alphas), zeros, zeros), + device=torch.device("cpu"), + dtype=torch.float64, + ).items() + } + return transformations, wigner_D_matrices + + +def _row_indices(samples, n_systems): + system_ids = samples.column("system").to(dtype=torch.long) + return [ + torch.nonzero(system_ids == system_index, as_tuple=False).reshape(-1) + for system_index in range(n_systems) + ] + + +def test_sparse_atomic_basis_rank1_augmentation(): + systems = [_make_system([1, 8]), _make_system([1, 8])] + transformations, wigner_D_matrices = _rotation_batch([np.pi / 2, np.pi]) + + component = Labels( + ["o3_mu"], + torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1), + ) + property_labels = Labels(["n"], torch.tensor([[0]], dtype=torch.int32)) + tensor = TensorMap( + Labels( + ["o3_lambda", "o3_sigma", "atom_type"], + torch.tensor([[1, 1, 1], [1, 1, 8]], dtype=torch.int32), + ), + [ + TensorBlock( + values=torch.tensor( + [[[1.0], [2.0], [3.0]], [[4.0], [5.0], [6.0]]], + dtype=torch.float64, + ), + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [1, 0]], dtype=torch.int32), + ), + components=[component], + properties=property_labels, + ), + TensorBlock( + values=torch.tensor( + [[[7.0], [8.0], [9.0]], [[10.0], [11.0], [12.0]]], + dtype=torch.float64, + ), + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 1], [1, 1]], dtype=torch.int32), + ), + components=[component], + properties=property_labels, + ), + ], + ) + + _, augmented_targets, _ = _apply_augmentations( + systems, + {"target": tensor}, + transformations, + wigner_D_matrices, + ) + augmented = augmented_targets["target"] + + expected_blocks = [] + for block in tensor.blocks(): + expected_values = block.values.clone() + for rows, wigner_D_matrix in zip( + _row_indices(block.samples, len(systems)), + wigner_D_matrices[1], + strict=True, + ): + rotated = block.values[rows].clone().transpose(1, 2) + rotated = rotated @ wigner_D_matrix.T + expected_values[rows] = rotated.transpose(1, 2) + + expected_blocks.append( + TensorBlock( + values=expected_values, + samples=block.samples, + components=block.components, + properties=block.properties, + ) + ) + + expected = TensorMap(tensor.keys, expected_blocks) + assert augmented.keys == expected.keys + for block_id in range(len(expected.keys)): + assert ( + augmented.block_by_id(block_id).samples + == expected.block_by_id(block_id).samples + ) + assert torch.allclose( + augmented.block_by_id(block_id).values, + expected.block_by_id(block_id).values, + atol=1e-12, + ) + + +def test_sparse_atomic_basis_rank2_augmentation_with_missing_system_rows(): + systems = [_make_system([1, 1]), _make_system([8, 8])] + transformations, wigner_D_matrices = _rotation_batch([np.pi / 2, np.pi]) + + components = [ + Labels( + ["o3_mu_1"], + torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1), + ), + Labels( + ["o3_mu_2"], + torch.arange(-1, 2, dtype=torch.int32).reshape(-1, 1), + ), + ] + property_labels = Labels( + ["n_1", "n_2"], + torch.tensor([[0, 0]], dtype=torch.int32), + ) + values = torch.arange(18, dtype=torch.float64).reshape(2, 3, 3, 1) + tensor = TensorMap( + Labels( + ["o3_lambda_1", "o3_lambda_2", "o3_sigma_1", "o3_sigma_2", "atom_type"], + torch.tensor([[1, 1, 1, 1, 1]], dtype=torch.int32), + ), + [ + TensorBlock( + values=values, + samples=Labels( + ["system", "atom"], + torch.tensor([[0, 0], [0, 1]], dtype=torch.int32), + ), + components=components, + properties=property_labels, + ) + ], + ) + + _, augmented_targets, _ = _apply_augmentations( + systems, + {"target": tensor}, + transformations, + wigner_D_matrices, + ) + augmented = augmented_targets["target"] + + expected_values = values.clone() + rows = _row_indices(tensor.block().samples, len(systems))[0] + expected_values[rows] = torch.einsum( + "Aa,iabp,bB->iABp", + wigner_D_matrices[1][0], + values[rows], + wigner_D_matrices[1][0].T, + ) + expected = TensorMap( + tensor.keys, + [ + TensorBlock( + values=expected_values, + samples=tensor.block().samples, + components=tensor.block().components, + properties=tensor.block().properties, + ) + ], + ) + + assert augmented.block().samples == expected.block().samples + assert torch.allclose(augmented.block().values, expected.block().values, atol=1e-12) + assert not torch.allclose(augmented.block().values, values) + + +def test_system_positions_and_cell_are_rotated(): + # Non-trivial positions and cell so the rotation is observable; verifies that + # `_apply_augmentations` does not silently leave the System unchanged. + positions_a = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], + dtype=torch.float64, + ) + positions_b = torch.tensor( + [[0.5, 0.5, 0.0], [1.0, -1.0, 0.5]], + dtype=torch.float64, + ) + cell_a = torch.eye(3, dtype=torch.float64) * 3.0 + cell_b = torch.eye(3, dtype=torch.float64) * 4.0 + pbc = torch.tensor([True, True, True]) + systems = [ + _make_system([1, 1, 1], positions=positions_a, cell=cell_a, pbc=pbc), + _make_system([8, 8], positions=positions_b, cell=cell_b, pbc=pbc), + ] + transformations, wigner_D_matrices = _rotation_batch([np.pi / 3, np.pi / 4]) + + new_systems, _, _ = _apply_augmentations( + systems, + {}, + transformations, + wigner_D_matrices, + ) + + assert len(new_systems) == 2 + for original, new, R in zip(systems, new_systems, transformations, strict=True): + assert torch.allclose(new.positions, original.positions @ R.T, atol=1e-12) + assert torch.allclose(new.cell, original.cell @ R.T, atol=1e-12) + # types and pbc must pass through unchanged + assert torch.equal(new.types, original.types) + assert torch.equal(new.pbc, original.pbc) + + +def test_extra_data_caller_dict_is_not_mutated(): + # Regression: _apply_augmentations used to pop "_mask" entries from the caller's + # dict, which corrupted the caller's view. + systems = [_make_system([1])] + transformations, wigner_D_matrices = _rotation_batch([np.pi / 2]) + + mask_values = torch.tensor([[1.0]], dtype=torch.float64) + mask_tmap = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int32)), + [ + TensorBlock( + values=mask_values, + samples=Labels(["system"], torch.tensor([[0]], dtype=torch.int32)), + components=[], + properties=Labels(["p"], torch.tensor([[0]], dtype=torch.int32)), + ) + ], + ) + extra_data = {"some_mask": mask_tmap} + original_keys = set(extra_data.keys()) + + _apply_augmentations(systems, {}, transformations, wigner_D_matrices, extra_data) + + assert set(extra_data.keys()) == original_keys diff --git a/python/metatomic_torch/tests/symmetrized_model.py b/python/metatomic_torch/tests/symmetrized_model.py new file mode 100644 index 000000000..d05374b61 --- /dev/null +++ b/python/metatomic_torch/tests/symmetrized_model.py @@ -0,0 +1,967 @@ +"""Tests for symmetrized_model.py standalone functions and SymmetrizedModel class.""" + +from pathlib import Path +from typing import Dict, List, Optional +import warnings + +import metatensor.torch as mts +import numpy as np +import pytest +import torch +from metatensor.torch import Labels, TensorBlock, TensorMap +from scipy.spatial.transform import Rotation + +from metatomic.torch import ModelOutput, System, systems_to_torch +from metatomic.torch._wigner import ( + compute_real_wigner_matrices as _compute_real_wigner_matrices, +) +from metatomic.torch.symmetrized_model import ( + SymmetrizedModel, + _choose_quadrature, + _evaluate_with_gradients, + _l0_components_from_matrices, + _l2_components_from_matrices, + _rotations_from_angles, + _transform_system, + get_euler_angles_quadrature, +) + + +REAL_CHECKPOINT = ( + Path(__file__).resolve().parents[3] + / "SYMMOD_EXAMPLE" + / "pet-mad-xs-v1.5.0.ckpt" +) + + +class TestL0Components: + """Test extraction of L=0 (trace) components from 3x3 matrices.""" + + def test_identity_trace(self): + # Identity matrix has trace 3. The function expects shape (a, 3, 3, b). + A = torch.eye(3, dtype=torch.float64).unsqueeze(0).unsqueeze(-1) + result = _l0_components_from_matrices(A) + assert result.shape == (1, 1, 1) + assert torch.allclose(result, torch.tensor([[[3.0]]], dtype=torch.float64)) + + def test_traceless_matrix(self): + # A traceless matrix should give L=0 = 0 + M = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 0.0]], + dtype=torch.float64, + ) + A = M.unsqueeze(0).unsqueeze(-1) + result = _l0_components_from_matrices(A) + assert torch.allclose( + result, torch.tensor([[[0.0]]], dtype=torch.float64), atol=1e-14 + ) + + def test_batch_dimensions(self): + # Test with batch size > 1 and multiple properties + batch = 5 + n_prop = 3 + A = torch.randn(batch, 3, 3, n_prop, dtype=torch.float64) + result = _l0_components_from_matrices(A) + assert result.shape == (batch, 1, n_prop) + for i in range(batch): + for p in range(n_prop): + expected_trace = A[i, 0, 0, p] + A[i, 1, 1, p] + A[i, 2, 2, p] + assert torch.allclose(result[i, 0, p], expected_trace, atol=1e-14) + + +class TestL2Components: + """Test extraction of L=2 (symmetric traceless) components from 3x3 matrices.""" + + def test_identity_gives_zero(self): + # Identity is proportional to L=0 only; L=2 components should be zero. + A = torch.eye(3, dtype=torch.float64).unsqueeze(0).unsqueeze(-1) + result = _l2_components_from_matrices(A) + assert result.shape == (1, 5, 1) + assert torch.allclose( + result, torch.zeros(1, 5, 1, dtype=torch.float64), atol=1e-14 + ) + + def test_diagonal_traceless(self): + # diag(1, -1, 0) is traceless and has known L=2 components + M = torch.tensor( + [[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, 0.0]], + dtype=torch.float64, + ) + A = M.unsqueeze(0).unsqueeze(-1) + result = _l2_components_from_matrices(A) + assert result.shape == (1, 5, 1) + # m=0: (2*0 - 1 - (-1)) / (2*sqrt(3)) = 0 + assert torch.allclose(result[0, 2, 0], torch.tensor(0.0, dtype=torch.float64)) + # m=2 (last component): (1 - (-1)) / 2 = 1 + assert torch.allclose(result[0, 4, 0], torch.tensor(1.0, dtype=torch.float64)) + + def test_frobenius_norm_relation(self): + """For a symmetric traceless matrix S, the L=2 decomposition should satisfy + a norm relation: sum(c_i^2) relates to (1/2) * sum(S_ij * S_ji). + """ + # Build a symmetric traceless matrix + S = torch.tensor( + [[2.0, 1.0, 0.5], [1.0, -1.0, 0.3], [0.5, 0.3, -1.0]], + dtype=torch.float64, + ) + A = S.unsqueeze(0).unsqueeze(-1) + l2 = _l2_components_from_matrices(A) + l2_norm_sq = (l2**2).sum() + + # The L=2 norm squared should equal half the Frobenius norm of the + # symmetric part (since the decomposition extracts the symmetric part) + sym_S = 0.5 * (S + S.T) + half_frob = 0.5 * (sym_S**2).sum() + # They won't be exactly equal because S has an L=0 part too. + # But for a traceless symmetric matrix, L=0 is zero, so they match. + trace = S[0, 0] + S[1, 1] + S[2, 2] + assert abs(trace) < 1e-14, "Matrix should be traceless for this test" + assert torch.allclose(l2_norm_sq, half_frob, atol=1e-12) + + +class TestDecomposeStressRoundtrip: + """Test that L=0 + L=2 decomposition covers the symmetric part of a 3x3 tensor.""" + + def test_norm_conservation(self): + """The sum of L=0 and L=2 squared norms should equal + the Frobenius norm squared of the symmetrized matrix.""" + M = torch.randn(1, 3, 3, 1, dtype=torch.float64) + sym_M = 0.5 * (M + M.transpose(1, 2)) + + l0 = _l0_components_from_matrices(sym_M) + l2 = _l2_components_from_matrices(sym_M) + + # L=0 norm: trace^2 / 3 (the trace component carries norm trace^2/3 + # in the irrep normalization). Actually, the L=0 extraction returns + # the raw trace, and L=2 the 5 components. Let's check reconstruction. + trace_val = l0[0, 0, 0] + # Reconstruct L=0 part: (trace/3) * I + l0_matrix = (trace_val / 3.0) * torch.eye(3, dtype=torch.float64) + + # Reconstruct L=2 part from components + c = l2[0, :, 0] # 5 components: (m=-2, m=-1, m=0, m=1, m=2) + l2_matrix = torch.zeros(3, 3, dtype=torch.float64) + # Reverse of the extraction formulas: + l2_matrix[0, 1] = c[0] + l2_matrix[1, 0] = c[0] + l2_matrix[1, 2] = c[1] + l2_matrix[2, 1] = c[1] + l2_matrix[0, 2] = c[3] + l2_matrix[2, 0] = c[3] + l2_matrix[0, 0] = c[4] + c[2] * np.sqrt(3) / 3 * (-1) + l2_matrix[1, 1] = -c[4] + c[2] * np.sqrt(3) / 3 * (-1) + l2_matrix[2, 2] = c[2] * 2.0 * np.sqrt(3) / 3 + + reconstructed = l0_matrix + l2_matrix + original_sym = sym_M[0, :, :, 0] + assert torch.allclose(reconstructed, original_sym, atol=1e-12) + + +class TestWignerD: + """Test properties of real Wigner D matrices.""" + + def test_orthogonality(self): + """D(R)^T D(R) = I for all ell.""" + rng = np.random.default_rng(42) + R = Rotation.random(5, random_state=rng) + angles = ( + np.zeros(5), + np.zeros(5), + np.zeros(5), + ) + # Use actual rotation angles + euler = R.as_euler("ZYZ") + angles = (euler[:, 0], euler[:, 1], euler[:, 2]) + + l_max = 4 + wigner = _compute_real_wigner_matrices(l_max, angles) + for ell in range(l_max + 1): + D = wigner[ell] # shape (5, 2l+1, 2l+1) + for i in range(5): + Di = D[i] + product = Di.T @ Di + identity = torch.eye(2 * ell + 1, dtype=Di.dtype) + assert torch.allclose(product, identity, atol=1e-10), ( + f"D^T D != I for ell={ell}, rotation {i}" + ) + + def test_identity_rotation(self): + """D(identity) = I for all ell.""" + angles = (np.array([0.0]), np.array([0.0]), np.array([0.0])) + l_max = 4 + wigner = _compute_real_wigner_matrices(l_max, angles) + for ell in range(l_max + 1): + D = wigner[ell][0] + identity = torch.eye(2 * ell + 1, dtype=D.dtype) + assert torch.allclose(D, identity, atol=1e-10), ( + f"D(identity) != I for ell={ell}" + ) + + def test_composition(self): + """D(R1) @ D(R2) ≈ D(R1 @ R2) for random rotations.""" + rng = np.random.default_rng(123) + R1 = Rotation.random(random_state=rng) + R2 = Rotation.random(random_state=rng) + R12 = R1 * R2 + + l_max = 3 + e1 = np.atleast_2d(R1.as_euler("ZYZ")) + e2 = np.atleast_2d(R2.as_euler("ZYZ")) + e12 = np.atleast_2d(R12.as_euler("ZYZ")) + + D1 = _compute_real_wigner_matrices(l_max, (e1[:, 0], e1[:, 1], e1[:, 2])) + D2 = _compute_real_wigner_matrices(l_max, (e2[:, 0], e2[:, 1], e2[:, 2])) + D12 = _compute_real_wigner_matrices(l_max, (e12[:, 0], e12[:, 1], e12[:, 2])) + + for ell in range(l_max + 1): + product = D1[ell][0] @ D2[ell][0] + expected = D12[ell][0] + assert torch.allclose(product, expected, atol=1e-10), ( + f"D(R1)D(R2) != D(R1R2) for ell={ell}" + ) + + +class TestQuadrature: + """Test quadrature weights and grid properties.""" + + def test_weights_sum(self): + """Quadrature weights should sum to 1 (normalized Haar measure on SO(3)).""" + for L_max in [3, 5, 7]: + lebedev_order, n_inplane = _choose_quadrature(L_max) + _, _, _, w = get_euler_angles_quadrature(lebedev_order, n_inplane) + # The weights are w_i / (4*pi*K) repeated K times, where w_i sum to 4*pi + # So total sum = sum(w_i)/(4*pi*K) * K = sum(w_i)/(4*pi) = 1 + assert np.allclose(w.sum(), 1.0, atol=1e-12), ( + f"Weights don't sum to 1 for L_max={L_max}: sum={w.sum()}" + ) + + def test_choose_quadrature_monotone(self): + """Higher L_max should give equal or larger quadrature grids.""" + prev_n = 0 + for L_max in [3, 5, 7, 11, 15]: + n, K = _choose_quadrature(L_max) + assert n >= prev_n + assert K == L_max + 1 + prev_n = n + + def test_rotations_are_proper(self): + """All rotation matrices from the quadrature should have det = +1.""" + lebedev_order, n_inplane = _choose_quadrature(5) + alpha, beta, gamma, _ = get_euler_angles_quadrature(lebedev_order, n_inplane) + R = _rotations_from_angles(alpha, beta, gamma) + matrices = R.as_matrix() + dets = np.linalg.det(matrices) + assert np.allclose(dets, 1.0, atol=1e-10) + + +class _QuadraticEnergyModel(torch.nn.Module): + """Minimal model where E = sum(positions^2). Analytical forces = -2*positions.""" + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + n_sys = len(systems) + energies = [] + for sys in systems: + energies.append(torch.sum(sys.positions**2)) + + key = Labels( + names=["_"], + values=torch.tensor([[0]], dtype=torch.int64), + ) + energy_block = TensorBlock( + values=torch.stack(energies).unsqueeze(-1), + samples=Labels( + names=["system"], + values=torch.arange(n_sys, dtype=torch.int64).unsqueeze(1), + ), + components=[], + properties=Labels( + names=["energy"], + values=torch.tensor([[0]], dtype=torch.int64), + ), + ) + return {"energy": TensorMap(key, [energy_block])} + + def requested_neighbor_lists(self): + return [] + + +class _EnergyAndVectorModel(torch.nn.Module): + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + n_sys = len(systems) + key = Labels(names=["_"], values=torch.tensor([[0]], dtype=torch.int64)) + result = {} + + if "energy" in outputs: + energy_block = TensorBlock( + values=torch.stack( + [torch.sum(sys.positions**2) for sys in systems] + ).unsqueeze(-1), + samples=Labels( + names=["system"], + values=torch.arange(n_sys, dtype=torch.int64).unsqueeze(1), + ), + components=[], + properties=Labels( + names=["energy"], + values=torch.tensor([[0]], dtype=torch.int64), + ), + ) + result["energy"] = TensorMap(key, [energy_block]) + + if "non_conservative_forces" in outputs: + values = torch.cat([sys.positions.unsqueeze(-1) for sys in systems], dim=0) + samples = [] + for i_sys, sys in enumerate(systems): + for atom in range(len(sys)): + samples.append([i_sys, atom]) + force_block = TensorBlock( + values=values, + samples=Labels( + names=["system", "atom"], + values=torch.tensor(samples, dtype=torch.int64), + ), + components=[ + Labels( + names=["xyz"], + values=torch.arange(3, dtype=torch.int64).reshape(-1, 1), + ) + ], + properties=Labels( + names=["p"], + values=torch.tensor([[0]], dtype=torch.int64), + ), + ) + force_tmap = TensorMap(key, [force_block]) + if selected_atoms is not None: + force_tmap = mts.slice( + force_tmap, + axis="samples", + selection=selected_atoms, + ) + result["non_conservative_forces"] = force_tmap + + return result + + def requested_neighbor_lists(self): + return [] + + +class TestGradientForces: + """Test conservative forces from autograd via _evaluate_with_gradients.""" + + def test_forces_identity_rotation(self): + """With identity rotation, forces should be -2*positions for E=sum(pos^2).""" + model = _QuadraticEnergyModel() + positions = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float64 + ) + system = System( + types=torch.tensor([1, 1]), + positions=positions, + cell=torch.zeros(3, 3, dtype=torch.float64), + pbc=torch.tensor([False, False, False]), + ) + rotation = torch.eye(3, dtype=torch.float64).unsqueeze(0) # (1, 3, 3) + outputs = {"energy": ModelOutput(sample_kind="system")} + + out = _evaluate_with_gradients( + model, + system, + rotation, + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + assert "forces" in out + forces = out["forces"].block().values.squeeze(-1) # (n_atoms, 3) for N=1 + expected = -2.0 * positions + assert torch.allclose(forces, expected, atol=1e-12) + + def test_forces_with_rotation(self): + """Forces in rotated frame should equal R @ (forces in lab frame). + For E=sum(pos^2), forces_lab = -2*pos_lab. + In rotated frame: forces_rot = -dE/d(pos_rot) where pos_rot = pos_lab @ R.T. + Since E = sum((pos_rot @ R)^2) = sum(pos_rot^2) (R is orthogonal), + forces_rot = -2*pos_rot = -2*(pos_lab @ R.T). + """ + model = _QuadraticEnergyModel() + positions = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float64 + ) + system = System( + types=torch.tensor([1, 1]), + positions=positions, + cell=torch.zeros(3, 3, dtype=torch.float64), + pbc=torch.tensor([False, False, False]), + ) + # Random rotation + rng = np.random.default_rng(42) + R_scipy = Rotation.random(random_state=rng) + R = torch.tensor(R_scipy.as_matrix(), dtype=torch.float64) + outputs = {"energy": ModelOutput(sample_kind="system")} + + out = _evaluate_with_gradients( + model, + system, + R.unsqueeze(0), + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + forces_rot = out["forces"].block().values.squeeze(-1) + expected_rot = -2.0 * (positions @ R.T) + assert torch.allclose(forces_rot, expected_rot, atol=1e-12) + + def test_stress_periodic_system(self): + """For a periodic system with E=sum(pos^2), check stress via strain trick. + + With strain trick: pos_final = pos_rot @ strain, so + E = sum((pos_rot @ strain)^2) = sum_i sum_a (sum_b pos_rot_ib * strain_ba)^2 + dE/d(strain_cd) = 2 * sum_i sum_a (pos_rot @ strain)_ia * pos_rot_ic * delta_da + = 2 * (pos_rot.T @ (pos_rot @ strain))_{ca} (at strain=I) + = 2 * pos_rot.T @ pos_rot + stress = (1/V) * dE/d(strain) = (2/V) * pos_rot.T @ pos_rot + """ + model = _QuadraticEnergyModel() + positions = torch.tensor( + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float64 + ) + cell = torch.eye(3, dtype=torch.float64) * 5.0 + system = System( + types=torch.tensor([1, 1]), + positions=positions, + cell=cell, + pbc=torch.tensor([True, True, True]), + ) + R = torch.eye(3, dtype=torch.float64).unsqueeze(0) + outputs = {"energy": ModelOutput(sample_kind="system")} + + out = _evaluate_with_gradients( + model, + system, + R, + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + assert "stress" in out + stress = out["stress"].block().values.squeeze(0).squeeze(-1) # (3, 3) + volume = torch.abs(torch.linalg.det(cell)) + expected_stress = 2.0 * positions.T @ positions / volume + assert torch.allclose(stress, expected_stress, atol=1e-12) + + def test_no_stress_for_nonperiodic(self): + """Non-periodic systems should not produce stress output.""" + model = _QuadraticEnergyModel() + system = System( + types=torch.tensor([1]), + positions=torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float64), + cell=torch.zeros(3, 3, dtype=torch.float64), + pbc=torch.tensor([False, False, False]), + ) + R = torch.eye(3, dtype=torch.float64).unsqueeze(0) + outputs = {"energy": ModelOutput(sample_kind="system")} + + out = _evaluate_with_gradients( + model, + system, + R, + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + assert "forces" in out + assert "stress" not in out + + def test_forces_batched_rotations(self): + """N>1 rotations should produce per-system forces that match what the same + rotations would produce one at a time. Validates the batched-gradient refactor. + """ + model = _QuadraticEnergyModel() + positions = torch.tensor( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=torch.float64 + ) + system = System( + types=torch.tensor([1, 1]), + positions=positions, + cell=torch.zeros(3, 3, dtype=torch.float64), + pbc=torch.tensor([False, False, False]), + ) + + rng = np.random.default_rng(7) + rotations = torch.stack( + [ + torch.tensor( + Rotation.random(random_state=rng).as_matrix(), + dtype=torch.float64, + ) + for _ in range(3) + ], + dim=0, + ) + outputs = {"energy": ModelOutput(sample_kind="system")} + + out = _evaluate_with_gradients( + model, + system, + rotations, + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + forces = out["forces"].block().values.squeeze(-1) # (3 * n_atoms, 3) + n_atoms = positions.shape[0] + for i in range(rotations.shape[0]): + R = rotations[i] + expected = -2.0 * (positions @ R.T) + actual = forces[i * n_atoms : (i + 1) * n_atoms] + assert torch.allclose(actual, expected, atol=1e-12) + + # also verify sample labels are (system=i, atom=j) in row-major order + samples = out["forces"].block().samples + sys_col = samples.column("system") + atom_col = samples.column("atom") + for i in range(rotations.shape[0]): + for j in range(n_atoms): + row = i * n_atoms + j + assert int(sys_col[row]) == i + assert int(atom_col[row]) == j + + def test_stress_batched_rotations(self): + """Stress under N>1 rotations: rotation-invariant magnitude per system, with + rotated principal axes.""" + model = _QuadraticEnergyModel() + positions = torch.tensor( + [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=torch.float64 + ) + cell = torch.eye(3, dtype=torch.float64) * 5.0 + system = System( + types=torch.tensor([1, 1]), + positions=positions, + cell=cell, + pbc=torch.tensor([True, True, True]), + ) + + rng = np.random.default_rng(11) + rotations = torch.stack( + [ + torch.tensor( + Rotation.random(random_state=rng).as_matrix(), + dtype=torch.float64, + ) + for _ in range(2) + ], + dim=0, + ) + outputs = {"energy": ModelOutput(sample_kind="system")} + + out = _evaluate_with_gradients( + model, + system, + rotations, + outputs, + None, + device=torch.device("cpu"), + dtype=torch.float64, + ) + + assert "stress" in out + stress_values = out["stress"].block().values.squeeze(-1) # (N, 3, 3) + volume = torch.abs(torch.linalg.det(cell)) + for i in range(rotations.shape[0]): + R = rotations[i] + rotated_pos = positions @ R.T + expected = 2.0 * rotated_pos.T @ rotated_pos / volume + assert torch.allclose(stress_values[i], expected, atol=1e-12) + + +class TestSymmetrizedModelForward: + def _make_system(self, dtype=torch.float64): + return System( + types=torch.tensor([1, 1], dtype=torch.int32), + positions=torch.tensor([[1.0, 0.0, 0.0], [0.0, 2.0, 0.0]], dtype=dtype), + cell=torch.zeros((3, 3), dtype=dtype), + pbc=torch.tensor([False, False, False]), + ) + + def _make_second_system(self, dtype=torch.float64): + return System( + types=torch.tensor([1, 1], dtype=torch.int32), + positions=torch.tensor([[0.0, 0.0, 3.0], [4.0, 0.0, 0.0]], dtype=dtype), + cell=torch.zeros((3, 3), dtype=dtype), + pbc=torch.tensor([False, False, False]), + ) + + def test_scalar_forward_outputs(self): + model = SymmetrizedModel( + _QuadraticEnergyModel(), + max_o3_lambda_character=1, + max_o3_lambda_target=0, + batch_size=2, + ).to(dtype=torch.float64) + + outputs = {"energy": ModelOutput(sample_kind="system")} + result = model([self._make_system()], outputs) + + assert "energy_l0_mean" in result + assert "energy_l0_var" in result + assert "energy_l0_norm_squared" in result + assert torch.allclose( + result["energy_l0_mean"].block().values, + torch.tensor([[[5.0]]], dtype=torch.float64), + atol=1e-10, + ) + + def test_forward_project_tokens(self): + model = SymmetrizedModel( + _QuadraticEnergyModel(), + max_o3_lambda_character=1, + max_o3_lambda_target=0, + batch_size=1, + ).to(dtype=torch.float64) + + outputs = {"energy": ModelOutput(sample_kind="system")} + result = model( + [self._make_system()], + outputs, + project_tokens=True, + ) + + expected_keys = { + "energy_l0", + "energy_l0_mean", + "energy_l0_var", + "energy_l0_norm_squared", + "energy_l0_character_projection", + "energy_l0_character_projection_plus", + "energy_l0_character_projection_minus", + } + assert expected_keys.issubset(result.keys()) + + def test_project_tokens_without_gradients_runs_under_no_grad(self): + class _GradStateModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.recorded_grad_states = [] + + def forward(self, systems, outputs, selected_atoms): + self.recorded_grad_states.append(torch.is_grad_enabled()) + # shape (n_systems=1, n_properties=1) — no components, so 2D + values = torch.tensor( + [[1.0]], + dtype=systems[0].positions.dtype, + device=systems[0].positions.device, + ) + tensor = TensorMap( + Labels(["_"], torch.tensor([[0]], dtype=torch.int64, device=values.device)), + [ + TensorBlock( + values=values, + samples=Labels( + ["system"], + torch.tensor([[0]], dtype=torch.int64, device=values.device), + ), + components=[], + properties=Labels( + ["energy"], + torch.tensor([[0]], dtype=torch.int64, device=values.device), + ), + ) + ], + ) + return {"energy": tensor} + + base_model = _GradStateModel() + model = SymmetrizedModel( + base_model, + max_o3_lambda_character=1, + max_o3_lambda_target=0, + batch_size=1, + offload_to_cpu=False, + ).to(dtype=torch.float64) + + outputs = {"energy": ModelOutput(sample_kind="system")} + result = model( + [self._make_system()], + outputs, + project_tokens=True, + compute_gradients=False, + ) + + assert "energy_l0_character_projection" in result + assert len(base_model.recorded_grad_states) > 0 + assert all(state is False for state in base_model.recorded_grad_states) + + def test_offload_to_cpu_does_not_change_outputs(self): + # Regression: with `compute_gradients=False`, switching offload_to_cpu + # must only move tensors, not change numerical results. + outputs = {"energy": ModelOutput(sample_kind="system")} + systems = [self._make_system()] + + results = {} + for offload in (False, True): + base_model = _QuadraticEnergyModel() + model = SymmetrizedModel( + base_model, + max_o3_lambda_character=1, + max_o3_lambda_target=0, + batch_size=2, + offload_to_cpu=offload, + ).to(dtype=torch.float64) + results[offload] = model(systems, outputs, project_tokens=True) + + shared_keys = set(results[False].keys()) & set(results[True].keys()) + assert shared_keys, "no shared output keys between offload modes" + for name in shared_keys: + tensor_false = results[False][name] + tensor_true = results[True][name] + assert tensor_false.keys == tensor_true.keys + for key in tensor_false.keys: + block_false = tensor_false.block(key) + block_true = tensor_true.block(key) + assert torch.allclose( + block_false.values.cpu(), + block_true.values.cpu(), + atol=1e-12, + ), f"offload mode changed values for '{name}' / key {key}" + + def test_compute_gradients_produces_forces(self): + # Regression: with `compute_gradients=True`, _evaluate_with_gradients must + # run autograd through the base model and inject "forces" into the output + # stream so the symmetrization pipeline produces forces_l1_*. The previous + # bug coupled autograd enablement to offload_to_cpu and could silently + # break this path. + outputs = {"energy": ModelOutput(sample_kind="system")} + model = SymmetrizedModel( + _QuadraticEnergyModel(), + max_o3_lambda_character=1, + max_o3_lambda_target=1, + batch_size=1, + offload_to_cpu=False, + ).to(dtype=torch.float64) + + result = model([self._make_system()], outputs, compute_gradients=True) + + assert "forces_l1_mean" in result + forces = result["forces_l1_mean"].block().values + # Forces for E=sum(pos^2) at positions [[1,0,0],[0,2,0]] have magnitudes + # [2, 4]; back-rotated/averaged forces retain non-trivial values for at + # least one atom (averaging is over O(3), not over atoms). + assert torch.any(forces.abs() > 0.5) + + def test_vector_like_forward_outputs(self): + model = SymmetrizedModel( + _EnergyAndVectorModel(), + max_o3_lambda_character=1, + max_o3_lambda_target=1, + batch_size=2, + ).to(dtype=torch.float64) + + outputs = { + "energy": ModelOutput(sample_kind="system"), + "non_conservative_forces": ModelOutput(sample_kind="atom"), + } + result = model([self._make_system()], outputs) + + assert "non_conservative_forces_l1" in result + assert "non_conservative_forces_l1_mean" in result + assert "non_conservative_forces_l1_var" in result + assert "non_conservative_forces_l1_norm_squared" in result + + def test_selected_atoms_are_mapped_per_outer_system(self): + systems = [self._make_system(), self._make_second_system()] + model = SymmetrizedModel( + _EnergyAndVectorModel(), + max_o3_lambda_character=1, + max_o3_lambda_target=1, + batch_size=5, + ).to(dtype=torch.float64) + + outputs = { + "energy": ModelOutput(sample_kind="system"), + "non_conservative_forces": ModelOutput(sample_kind="atom"), + } + selected_atoms = Labels( + names=["system", "atom"], + values=torch.tensor([[0, 0], [1, 1]], dtype=torch.int64), + ) + + result = model(systems, outputs, selected_atoms=selected_atoms) + + energy_block = result["energy_l0_mean"].block() + assert energy_block.samples.values.tolist() == [[0], [1]] + assert torch.allclose( + energy_block.values[:, 0, 0], + torch.tensor([5.0, 25.0], dtype=torch.float64), + atol=1e-10, + ) + + force_block = result["non_conservative_forces_l1_mean"].block() + assert force_block.samples.values.tolist() == [[0, 0], [1, 1]] + assert torch.allclose( + force_block.values.roll(1, 1).squeeze(-1), + torch.stack([systems[0].positions[0], systems[1].positions[1]]), + atol=1e-10, + ) + + def test_selected_atoms_can_be_empty_for_some_systems(self): + systems = [self._make_system(), self._make_second_system()] + model = SymmetrizedModel( + _EnergyAndVectorModel(), + max_o3_lambda_character=1, + max_o3_lambda_target=1, + batch_size=5, + ).to(dtype=torch.float64) + + outputs = { + "non_conservative_forces": ModelOutput(sample_kind="atom"), + } + selected_atoms = Labels( + names=["system", "atom"], + values=torch.tensor([[1, 1]], dtype=torch.int64), + ) + + result = model(systems, outputs, selected_atoms=selected_atoms) + + force_block = result["non_conservative_forces_l1_mean"].block() + assert force_block.samples.values.tolist() == [[1, 1]] + assert torch.allclose( + force_block.values.roll(1, 1).squeeze(-1), + systems[1].positions[1].unsqueeze(0), + atol=1e-10, + ) + + +@pytest.mark.skipif( + not REAL_CHECKPOINT.is_file(), + reason="requires local SYMMOD_EXAMPLE checkpoint", +) +def test_real_checkpoint_energy_variance_matches_explicit_o3_reference(capfd): + pytest.importorskip("ase") + load_model = pytest.importorskip("metatrain.utils.io").load_model + get_system_with_neighbor_lists = pytest.importorskip( + "metatrain.utils.neighbor_lists" + ).get_system_with_neighbor_lists + from ase.build import bulk + + dtype = torch.float64 + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="the 'features' output name is deprecated, please update the model to use 'feature' instead", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message="the 'non_conservative_forces' output name is deprecated, please update the model to use 'non_conservative_force' instead", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"Lebedev order may be insufficient for character projections\.", + category=UserWarning, + ) + warnings.filterwarnings( + "ignore", + message=r"`per_atom` is deprecated, please use `sample_kind` instead.*", + category=DeprecationWarning, + ) + + model = load_model(REAL_CHECKPOINT) + model.eval() + model = model.to(dtype=dtype, device="cpu") + + atoms = bulk("Si", cubic=True) + atoms.rattle(0.1, seed=0) + system = systems_to_torch([atoms], device="cpu", dtype=dtype)[0] + system = get_system_with_neighbor_lists( + system.to(dtype=dtype, device="cpu"), + model.model.requested_neighbor_lists(), + ) + + outputs = {"energy": ModelOutput(sample_kind="system")} + symm_model = SymmetrizedModel( + model, + max_o3_lambda_grid=3, + max_o3_lambda_target=2, + max_o3_lambda_character=2, + batch_size=1, + ).to(device="cpu", dtype=dtype) + + with torch.no_grad(): + result = symm_model([system], outputs) + + weights = [] + energies = [] + for inversion in [1, -1]: + for weight, rotation in zip( + symm_model.so3_weights, symm_model.so3_rotations + ): + transformed = _transform_system( + system, + (inversion * rotation).to( + dtype=system.positions.dtype, + device=system.positions.device, + ), + ) + energy = model([transformed], outputs, None)["energy"].block().values.squeeze() + weights.append(0.5 * weight.to(dtype=dtype, device="cpu")) + energies.append(energy.to(dtype=dtype, device="cpu")) + + captured = capfd.readouterr() + assert captured.out == "" + allowed_stderr_fragments = [ + "`per_atom` is deprecated, please use `sample_kind` instead", + "output 'energy' has an empty unit. Consider adding a unit to ensure correct unit conversion.", + "ModelOutput.quantity is deprecated and will be removed in a future version", + "the 'features' quantity is deprecated, please update this code to use 'feature' instead.", + ] + unexpected_stderr = [ + line + for line in captured.err.splitlines() + if line != "" + and not any(fragment in line for fragment in allowed_stderr_fragments) + ] + assert unexpected_stderr == [] + + weights_tensor = torch.stack(weights) + energies_tensor = torch.stack(energies) + mean_reference = torch.sum(weights_tensor * energies_tensor) + norm_squared_reference = torch.sum(weights_tensor * energies_tensor.square()) + variance_reference = norm_squared_reference - mean_reference.square() + + mean_value = result["energy_l0_mean"].block().values.squeeze() + norm_squared_value = result["energy_l0_norm_squared"].block().values.squeeze() + variance_value = result["energy_l0_var"].block().values.squeeze() + + # The current checkpoint is not exactly O(3)-equivariant, so validate against + # the explicit quadrature reference instead of assuming near-zero variance. + assert torch.allclose(mean_value, mean_reference, atol=1e-12, rtol=0.0) + assert torch.allclose( + norm_squared_value, + norm_squared_reference, + atol=1e-12, + rtol=0.0, + ) + assert torch.allclose(variance_value, variance_reference, atol=1e-12, rtol=0.0)