diff --git a/earth2studio/perturbation/gaussian.py b/earth2studio/perturbation/gaussian.py index 5809db6d9..723660e21 100644 --- a/earth2studio/perturbation/gaussian.py +++ b/earth2studio/perturbation/gaussian.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Any import numpy as np @@ -21,6 +22,7 @@ from typing_extensions import Self from earth2studio.utils import handshake_dim +from earth2studio.utils.checkpoint import bind_checkpoint_state from earth2studio.utils.imports import ( OptionalDependencyFailure, check_optional_dependencies, @@ -34,6 +36,11 @@ InverseRealSHT = None +@dataclass +class _GaussianCheckpointState: + generator_state: torch.Tensor | None = None + + class Gaussian: """Standard Gaussian peturbation @@ -50,6 +57,8 @@ def __init__(self, noise_amplitude: float | torch.Tensor = 0.05): if isinstance(noise_amplitude, torch.Tensor) else torch.Tensor([noise_amplitude]) ) + self.generator: torch.Generator | None = None + self.checkpoint = bind_checkpoint_state(_GaussianCheckpointState()) @torch.inference_mode() def __call__( @@ -71,8 +80,45 @@ def __call__( tuple[torch.Tensor, CoordSystem]: Output tensor and respective coordinate system dictionary """ + generator = self._get_generator(x.device) + pre_state = generator.get_state() noise_amplitude = self.noise_amplitude.to(x.device) - return x + noise_amplitude * torch.randn_like(x), coords + y = x + noise_amplitude * torch.randn( + x.shape, dtype=x.dtype, device=x.device, generator=generator + ) + self._save_generator_state(pre_state, generator.get_state(), generator) + return y, coords + + def _get_generator(self, device: torch.device) -> torch.Generator: + if self.generator is None or self.generator.device != device: + self.generator = torch.Generator(device=device) + if ( + self.checkpoint.checkpoint_state_loaded + and self.checkpoint.generator_state is not None + ): + self.generator.set_state(self.checkpoint.generator_state.cpu()) + else: + self.generator.seed() + return self.generator + + def _save_generator_state( + self, + pre_state: torch.Tensor, + post_state: torch.Tensor, + generator: torch.Generator, + ) -> None: + if not self.checkpoint.checkpoint_enabled: + return + + level = self.checkpoint.checkpoint_level + if level < 1: + self.checkpoint.generator_state = None + return + + generator_state = pre_state if level == 1 else post_state + self.checkpoint.generator_state = generator_state.to( + self.checkpoint.device + ).clone() @check_optional_dependencies() diff --git a/test/perturbation/test_gaussian.py b/test/perturbation/test_gaussian.py index 04516d8c3..2c9d04002 100644 --- a/test/perturbation/test_gaussian.py +++ b/test/perturbation/test_gaussian.py @@ -16,10 +16,12 @@ from collections import OrderedDict +import numpy as np import pytest import torch from earth2studio.perturbation import CorrelatedSphericalGaussian, Gaussian +from earth2studio.utils.checkpoint import Checkpoint @pytest.mark.parametrize( @@ -75,6 +77,51 @@ def test_gaussian(x, coords, amplitude, device): assert dx.device == x.device +def test_gaussian_checkpoint(tmp_path): + x = torch.zeros(2, 3) + coords = OrderedDict([("batch", []), ("variable", [])]) + + level_one_checkpoint = Checkpoint( + "gaussian-level-1", path=tmp_path / "level-1", level=1 + ) + with level_one_checkpoint as ckpt: + perturbation = Gaussian(1.0) + expected_level_one, _ = perturbation(x, coords) + assert perturbation.checkpoint.generator_state is not None + ckpt.write(lead_time=np.timedelta64(0, "h")) + + with level_one_checkpoint.select(-1): + perturbation = Gaussian(1.0) + level_one_output, _ = perturbation(x, coords) + assert perturbation.checkpoint.checkpoint_state_loaded + assert torch.allclose(level_one_output, expected_level_one) + + level_two_checkpoint = Checkpoint( + "gaussian-level-2", + path=tmp_path / "level-2", + flush_interval=2, + level=2, + ) + with level_two_checkpoint as ckpt: + perturbation = Gaussian(1.0) + perturbation(x, coords) + assert perturbation.checkpoint.generator_state is not None + ckpt.write(lead_time=np.timedelta64(0, "h")) + perturbation(x, coords) + assert perturbation.checkpoint.generator_state is not None + ckpt.write(lead_time=np.timedelta64(6, "h")) + expected_level_two_next, _ = perturbation(x, coords) + expected_level_two_third, _ = perturbation(x, coords) + + with level_two_checkpoint.select(-1): + perturbation = Gaussian(1.0) + resumed, _ = perturbation(x, coords) + assert perturbation.checkpoint.checkpoint_state_loaded + assert torch.allclose(resumed, expected_level_two_next) + next_perturbed, _ = perturbation(x, coords) + assert torch.allclose(next_perturbed, expected_level_two_third) + + def test_correlated_spherical_gaussian_no_amplitude(): """Test that CorrelatedSphericalGaussian raises error without amplitude""" with pytest.raises(ValueError):