Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion earth2studio/perturbation/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from typing import Any

import numpy as np
import torch
from typing_extensions import Self

from earth2studio.utils import handshake_dim
from earth2studio.utils.checkpoint import bind_checkpoint_state
from earth2studio.utils.imports import (
OptionalDependencyFailure,
check_optional_dependencies,
Expand All @@ -34,6 +36,11 @@
InverseRealSHT = None


@dataclass
class _GaussianCheckpointState:
generator_state: torch.Tensor | None = None


class Gaussian:
"""Standard Gaussian peturbation

Expand All @@ -50,6 +57,8 @@ def __init__(self, noise_amplitude: float | torch.Tensor = 0.05):
if isinstance(noise_amplitude, torch.Tensor)
else torch.Tensor([noise_amplitude])
)
self.generator: torch.Generator | None = None
self.checkpoint = bind_checkpoint_state(_GaussianCheckpointState())

@torch.inference_mode()
def __call__(
Expand All @@ -71,8 +80,45 @@ def __call__(
tuple[torch.Tensor, CoordSystem]:
Output tensor and respective coordinate system dictionary
"""
generator = self._get_generator(x.device)
pre_state = generator.get_state()
noise_amplitude = self.noise_amplitude.to(x.device)
return x + noise_amplitude * torch.randn_like(x), coords
y = x + noise_amplitude * torch.randn(
x.shape, dtype=x.dtype, device=x.device, generator=generator
)
self._save_generator_state(pre_state, generator.get_state(), generator)
return y, coords

def _get_generator(self, device: torch.device) -> torch.Generator:
if self.generator is None or self.generator.device != device:
self.generator = torch.Generator(device=device)
if (
self.checkpoint.checkpoint_state_loaded
and self.checkpoint.generator_state is not None
):
self.generator.set_state(self.checkpoint.generator_state.cpu())
else:
self.generator.seed()
return self.generator
Comment thread
NickGeneva marked this conversation as resolved.

def _save_generator_state(
self,
pre_state: torch.Tensor,
post_state: torch.Tensor,
generator: torch.Generator,
) -> None:
if not self.checkpoint.checkpoint_enabled:
return

level = self.checkpoint.checkpoint_level
if level < 1:
self.checkpoint.generator_state = None
return

generator_state = pre_state if level == 1 else post_state
self.checkpoint.generator_state = generator_state.to(
self.checkpoint.device
).clone()


@check_optional_dependencies()
Expand Down
47 changes: 47 additions & 0 deletions test/perturbation/test_gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

from collections import OrderedDict

import numpy as np
import pytest
import torch

from earth2studio.perturbation import CorrelatedSphericalGaussian, Gaussian
from earth2studio.utils.checkpoint import Checkpoint


@pytest.mark.parametrize(
Expand Down Expand Up @@ -75,6 +77,51 @@ def test_gaussian(x, coords, amplitude, device):
assert dx.device == x.device


def test_gaussian_checkpoint(tmp_path):
x = torch.zeros(2, 3)
coords = OrderedDict([("batch", []), ("variable", [])])

level_one_checkpoint = Checkpoint(
"gaussian-level-1", path=tmp_path / "level-1", level=1
)
with level_one_checkpoint as ckpt:
perturbation = Gaussian(1.0)
expected_level_one, _ = perturbation(x, coords)
assert perturbation.checkpoint.generator_state is not None
ckpt.write(lead_time=np.timedelta64(0, "h"))

with level_one_checkpoint.select(-1):
perturbation = Gaussian(1.0)
level_one_output, _ = perturbation(x, coords)
assert perturbation.checkpoint.checkpoint_state_loaded
assert torch.allclose(level_one_output, expected_level_one)

level_two_checkpoint = Checkpoint(
"gaussian-level-2",
path=tmp_path / "level-2",
flush_interval=2,
level=2,
)
with level_two_checkpoint as ckpt:
perturbation = Gaussian(1.0)
perturbation(x, coords)
assert perturbation.checkpoint.generator_state is not None
ckpt.write(lead_time=np.timedelta64(0, "h"))
perturbation(x, coords)
assert perturbation.checkpoint.generator_state is not None
ckpt.write(lead_time=np.timedelta64(6, "h"))
expected_level_two_next, _ = perturbation(x, coords)
expected_level_two_third, _ = perturbation(x, coords)

with level_two_checkpoint.select(-1):
perturbation = Gaussian(1.0)
resumed, _ = perturbation(x, coords)
assert perturbation.checkpoint.checkpoint_state_loaded
assert torch.allclose(resumed, expected_level_two_next)
next_perturbed, _ = perturbation(x, coords)
assert torch.allclose(next_perturbed, expected_level_two_third)


def test_correlated_spherical_gaussian_no_amplitude():
"""Test that CorrelatedSphericalGaussian raises error without amplitude"""
with pytest.raises(ValueError):
Expand Down