diff --git a/earth2studio/models/px/fcn3.py b/earth2studio/models/px/fcn3.py index 9f6bd4d0c..a6c10e18c 100644 --- a/earth2studio/models/px/fcn3.py +++ b/earth2studio/models/px/fcn3.py @@ -16,6 +16,7 @@ import json from collections import OrderedDict from collections.abc import Generator, Iterator +from dataclasses import dataclass from datetime import datetime import numpy as np @@ -27,6 +28,7 @@ from earth2studio.models.px.base import PrognosticModel from earth2studio.models.px.utils import PrognosticMixin from earth2studio.utils import handshake_coords, handshake_dim +from earth2studio.utils.checkpoint import bind_checkpoint_state from earth2studio.utils.imports import ( OptionalDependencyFailure, check_optional_dependencies, @@ -129,6 +131,14 @@ ] +@dataclass +class _FCN3CheckpointState: + x: torch.Tensor | None = None + coord_keys: tuple[str, ...] = () + coord_values: tuple[np.ndarray, ...] = () + internal_noise_states: tuple[tuple[torch.Tensor | None, ...], ...] = () + + @check_optional_dependencies() class FCN3(torch.nn.Module, AutoModelMixin, PrognosticMixin): """FourCastNet 3 advances global weather modeling by implementing a scalable, @@ -177,6 +187,7 @@ def __init__( self.variables[self.variables == "2d"] = "d2m" self.set_rng(reset=True, seed=seed) + self.checkpoint = bind_checkpoint_state(_FCN3CheckpointState()) def __str__(self) -> str: return "fcn3" @@ -307,6 +318,71 @@ def load_model( return cls(model, variables=variables) + def _restore_checkpoint_state( + self, x: torch.Tensor, coords: CoordSystem + ) -> tuple[torch.Tensor, CoordSystem, bool]: + if not self.checkpoint.checkpoint_state_loaded: + return x, coords, False + if self.checkpoint.checkpoint_level < 2: + return x, coords, False + if self.checkpoint.x is None or not self.checkpoint.coord_keys: + return x, coords, False + if len(self.checkpoint.coord_keys) != len(self.checkpoint.coord_values): + raise RuntimeError("FCN3 checkpoint coordinate state is incomplete.") + if not self.checkpoint.internal_noise_states: + raise RuntimeError( + "FCN3 checkpoint is missing internal noise state required for full restart." + ) + + restored_x = self.checkpoint.x.to(x.device) + restored_coords = OrderedDict( + (key, np.asarray(value).copy()) + for key, value in zip( + self.checkpoint.coord_keys, self.checkpoint.coord_values + ) + ) + restored_states = [ + [None if state is None else state.to(restored_x.device) for state in states] + for states in self.checkpoint.internal_noise_states + ] + if len(restored_states) != len(restored_coords["batch"]) or any( + len(states) != len(restored_coords["time"]) for states in restored_states + ): + raise RuntimeError( + "FCN3 checkpoint internal noise state does not match saved coordinates." + ) + + self._internal_noise_states = restored_states + return restored_x, restored_coords, True + + def _save_checkpoint_state(self, x: torch.Tensor, coords: CoordSystem) -> None: + if not self.checkpoint.checkpoint_enabled: + return + if self.checkpoint.checkpoint_level < 2: + self.checkpoint.x = None + self.checkpoint.coord_keys = () + self.checkpoint.coord_values = () + self.checkpoint.internal_noise_states = () + return + + def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.detach() + if tensor.device == self.checkpoint.device: + return tensor.clone() + return tensor.to(self.checkpoint.device) + + self.checkpoint.x = checkpoint_tensor(x) + self.checkpoint.coord_keys = tuple(coords.keys()) + self.checkpoint.coord_values = tuple( + np.asarray(value).copy() for value in coords.values() + ) + self.checkpoint.internal_noise_states = tuple( + tuple( + None if state is None else checkpoint_tensor(state) for state in states + ) + for states in getattr(self, "_internal_noise_states", ()) + ) + def _get_internal_state(self, ensemble_index: int, time_index: int) -> torch.Tensor: """Get the internal RNG state for the given ensemble and time index @@ -391,6 +467,7 @@ def _forward( self._set_internal_state(j, i) x = x.unsqueeze(2) + self._save_checkpoint_state(x, output_coords) return x, output_coords @batch_func() @@ -413,11 +490,12 @@ def __call__( tuple[torch.Tensor, CoordSystem] Output tensor and coordinate system """ - # Initialize the internal noise states - # for each batch index, we will have a list of noise states for each separate time - self._reset_internal_state(len(coords["batch"]), len(coords["time"])) - output, coords = self._forward(x, coords) - return output, coords + x, coords, restored = self._restore_checkpoint_state(x, coords) + if not restored: + # Initialize the internal noise states for each batch and time. + self._reset_internal_state(len(coords["batch"]), len(coords["time"])) + + return self._forward(x, coords) @batch_func() def _default_generator( @@ -426,15 +504,17 @@ def _default_generator( coords = coords.copy() self.output_coords(coords) - # Initialize the internal noise states - self._reset_internal_state(len(coords["batch"]), len(coords["time"])) + x, coords, restored = self._restore_checkpoint_state(x, coords) + if not restored: + # Initialize the internal noise states for each batch and time. + self._reset_internal_state(len(coords["batch"]), len(coords["time"])) + self._save_checkpoint_state(x, coords) + yield x, coords - # Yield the initial condition - yield x, coords while True: # Front hook x, coords = self.front_hook(x, coords) - # Forward is identity operator + # Advance FCN3 one forecast step. Restored checkpoints resume here. x, coords = self._forward(x, coords) # Rear hook x, coords = self.rear_hook(x, coords) diff --git a/test/models/px/test_fcn3.py b/test/models/px/test_fcn3.py index b3c38b636..51d17254e 100644 --- a/test/models/px/test_fcn3.py +++ b/test/models/px/test_fcn3.py @@ -24,10 +24,10 @@ from earth2studio.data import Random, fetch_data from earth2studio.models.px import FCN3 from earth2studio.utils import handshake_dim +from earth2studio.utils.checkpoint import Checkpoint class PhooFCN3Preprocessor(torch.nn.Module): - def __init__( self, ): @@ -67,6 +67,17 @@ def set_rng(self, reset: bool = True, seed: int = 333): return +class PhooRestartFCN3ModelWrapper(PhooFCN3ModelWrapper): + def forward(self, x, t, normalized_data: bool = False, replace_state: bool = False): + return x + self.model.preprocessor.state[0].to(x.device) + + +def _phoo_fcn3(model: torch.nn.Module, variables: np.ndarray) -> FCN3: + fcn3 = FCN3.__new__(FCN3) + FCN3.__init__.__wrapped__(fcn3, model, variables=variables) + return fcn3 + + @pytest.fixture(scope="function") def dummy_model(): preprocessor = PhooFCN3Preprocessor() @@ -167,6 +178,59 @@ def test_fcn3_iter(ensemble, device, dummy_model): break +def test_fcn3_checkpoint_state_round_trip_with_phoo_model(tmp_path): + variables = np.array(["u10m"]) + time = np.array([np.datetime64("1993-04-05T00:00")]) + checkpoint = Checkpoint("fcn3", path=tmp_path / "fcn3", flush_interval=2, level=2) + + torch.manual_seed(123) + model = PhooRestartFCN3ModelWrapper(PhooFCN3Model(PhooFCN3Preprocessor())) + p = _phoo_fcn3(model, variables=variables).to("cpu") + input_coords = p.input_coords() + coords = OrderedDict( + { + "time": time, + "lead_time": input_coords["lead_time"], + "variable": input_coords["variable"], + "lat": input_coords["lat"], + "lon": input_coords["lon"], + } + ) + x = torch.zeros( + len(coords["time"]), + len(coords["lead_time"]), + len(coords["variable"]), + len(coords["lat"]), + len(coords["lon"]), + ) + + with checkpoint.select(time=time) as ckpt: + iterator = p.create_iterator(x, coords) + _, initial_coords = next(iterator) + ckpt.write(lead_time=initial_coords["lead_time"]) + step1, step1_coords = next(iterator) + ckpt.write(lead_time=step1_coords["lead_time"]) + step2, step2_coords = next(iterator) + step3, step3_coords = next(iterator) + + torch.manual_seed(999) + with checkpoint.select(-1): + resumed_model = PhooRestartFCN3ModelWrapper( + PhooFCN3Model(PhooFCN3Preprocessor()) + ) + resumed = _phoo_fcn3(resumed_model, variables=variables).to("cpu") + assert resumed.checkpoint.checkpoint_state_loaded + + resumed_iterator = resumed.create_iterator(torch.full_like(x, -100.0), coords) + resumed_step2, resumed_step2_coords = next(resumed_iterator) + resumed_step3, resumed_step3_coords = next(resumed_iterator) + + assert torch.allclose(resumed_step2, step2) + assert torch.allclose(resumed_step3, step3) + assert resumed_step2_coords["lead_time"][0] == step2_coords["lead_time"][0] + assert resumed_step3_coords["lead_time"][0] == step3_coords["lead_time"][0] + + @pytest.mark.parametrize( "dc", [