-
Notifications
You must be signed in to change notification settings - Fork 223
Add FCN3 checkpoint state #923
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: codex/checkpoint-gaussian
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| import json | ||
| from collections import OrderedDict | ||
| from collections.abc import Generator, Iterator | ||
| from dataclasses import dataclass | ||
| from datetime import datetime | ||
|
|
||
| import numpy as np | ||
|
|
@@ -27,6 +28,7 @@ | |
| from earth2studio.models.px.base import PrognosticModel | ||
| from earth2studio.models.px.utils import PrognosticMixin | ||
| from earth2studio.utils import handshake_coords, handshake_dim | ||
| from earth2studio.utils.checkpoint import bind_checkpoint_state | ||
| from earth2studio.utils.imports import ( | ||
| OptionalDependencyFailure, | ||
| check_optional_dependencies, | ||
|
|
@@ -129,6 +131,14 @@ | |
| ] | ||
|
|
||
|
|
||
| @dataclass | ||
| class _FCN3CheckpointState: | ||
| x: torch.Tensor | None = None | ||
| coord_keys: tuple[str, ...] = () | ||
| coord_values: tuple[np.ndarray, ...] = () | ||
| internal_noise_states: tuple[tuple[torch.Tensor | None, ...], ...] = () | ||
|
|
||
|
|
||
| @check_optional_dependencies() | ||
| class FCN3(torch.nn.Module, AutoModelMixin, PrognosticMixin): | ||
| """FourCastNet 3 advances global weather modeling by implementing a scalable, | ||
|
|
@@ -177,6 +187,7 @@ def __init__( | |
| self.variables[self.variables == "2d"] = "d2m" | ||
|
|
||
| self.set_rng(reset=True, seed=seed) | ||
| self.checkpoint = bind_checkpoint_state(_FCN3CheckpointState()) | ||
|
|
||
| def __str__(self) -> str: | ||
| return "fcn3" | ||
|
|
@@ -307,6 +318,71 @@ def load_model( | |
|
|
||
| return cls(model, variables=variables) | ||
|
|
||
| def _restore_checkpoint_state( | ||
| self, x: torch.Tensor, coords: CoordSystem | ||
| ) -> tuple[torch.Tensor, CoordSystem, bool]: | ||
| if not self.checkpoint.checkpoint_state_loaded: | ||
| return x, coords, False | ||
| if self.checkpoint.checkpoint_level < 2: | ||
| return x, coords, False | ||
| if self.checkpoint.x is None or not self.checkpoint.coord_keys: | ||
| return x, coords, False | ||
| if len(self.checkpoint.coord_keys) != len(self.checkpoint.coord_values): | ||
| raise RuntimeError("FCN3 checkpoint coordinate state is incomplete.") | ||
| if not self.checkpoint.internal_noise_states: | ||
| raise RuntimeError( | ||
| "FCN3 checkpoint is missing internal noise state required for full restart." | ||
| ) | ||
|
|
||
| restored_x = self.checkpoint.x.to(x.device) | ||
| restored_coords = OrderedDict( | ||
| (key, np.asarray(value).copy()) | ||
| for key, value in zip( | ||
| self.checkpoint.coord_keys, self.checkpoint.coord_values | ||
| ) | ||
| ) | ||
| restored_states = [ | ||
| [None if state is None else state.to(restored_x.device) for state in states] | ||
| for states in self.checkpoint.internal_noise_states | ||
| ] | ||
| if len(restored_states) != len(restored_coords["batch"]) or any( | ||
| len(states) != len(restored_coords["time"]) for states in restored_states | ||
| ): | ||
| raise RuntimeError( | ||
| "FCN3 checkpoint internal noise state does not match saved coordinates." | ||
| ) | ||
|
|
||
| self._internal_noise_states = restored_states | ||
| return restored_x, restored_coords, True | ||
|
|
||
| def _save_checkpoint_state(self, x: torch.Tensor, coords: CoordSystem) -> None: | ||
| if not self.checkpoint.checkpoint_enabled: | ||
| return | ||
| if self.checkpoint.checkpoint_level < 2: | ||
| self.checkpoint.x = None | ||
| self.checkpoint.coord_keys = () | ||
| self.checkpoint.coord_values = () | ||
| self.checkpoint.internal_noise_states = () | ||
| return | ||
|
|
||
| def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor: | ||
| tensor = tensor.detach() | ||
| if tensor.device == self.checkpoint.device: | ||
| return tensor.clone() | ||
| return tensor.to(self.checkpoint.device) | ||
|
|
||
| self.checkpoint.x = checkpoint_tensor(x) | ||
| self.checkpoint.coord_keys = tuple(coords.keys()) | ||
| self.checkpoint.coord_values = tuple( | ||
| np.asarray(value).copy() for value in coords.values() | ||
| ) | ||
| self.checkpoint.internal_noise_states = tuple( | ||
| tuple( | ||
| None if state is None else checkpoint_tensor(state) for state in states | ||
| ) | ||
| for states in getattr(self, "_internal_noise_states", ()) | ||
| ) | ||
|
|
||
| def _get_internal_state(self, ensemble_index: int, time_index: int) -> torch.Tensor: | ||
| """Get the internal RNG state for the given ensemble and time index | ||
|
|
||
|
|
@@ -391,6 +467,7 @@ def _forward( | |
| self._set_internal_state(j, i) | ||
|
|
||
| x = x.unsqueeze(2) | ||
| self._save_checkpoint_state(x, output_coords) | ||
| return x, output_coords | ||
|
|
||
| @batch_func() | ||
|
|
@@ -413,11 +490,12 @@ def __call__( | |
| tuple[torch.Tensor, CoordSystem] | ||
| Output tensor and coordinate system | ||
| """ | ||
| # Initialize the internal noise states | ||
| # for each batch index, we will have a list of noise states for each separate time | ||
| self._reset_internal_state(len(coords["batch"]), len(coords["time"])) | ||
| output, coords = self._forward(x, coords) | ||
| return output, coords | ||
| x, coords, restored = self._restore_checkpoint_state(x, coords) | ||
| if not restored: | ||
| # Initialize the internal noise states for each batch and time. | ||
| self._reset_internal_state(len(coords["batch"]), len(coords["time"])) | ||
|
|
||
| return self._forward(x, coords) | ||
|
Comment on lines
473
to
+498
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
When |
||
|
|
||
| @batch_func() | ||
| def _default_generator( | ||
|
|
@@ -426,15 +504,17 @@ def _default_generator( | |
| coords = coords.copy() | ||
| self.output_coords(coords) | ||
|
|
||
| # Initialize the internal noise states | ||
| self._reset_internal_state(len(coords["batch"]), len(coords["time"])) | ||
| x, coords, restored = self._restore_checkpoint_state(x, coords) | ||
| if not restored: | ||
| # Initialize the internal noise states for each batch and time. | ||
| self._reset_internal_state(len(coords["batch"]), len(coords["time"])) | ||
| self._save_checkpoint_state(x, coords) | ||
| yield x, coords | ||
|
|
||
| # Yield the initial condition | ||
| yield x, coords | ||
| while True: | ||
| # Front hook | ||
| x, coords = self.front_hook(x, coords) | ||
| # Forward is identity operator | ||
| # Advance FCN3 one forecast step. Restored checkpoints resume here. | ||
| x, coords = self._forward(x, coords) | ||
| # Rear hook | ||
| x, coords = self.rear_hook(x, coords) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checkpoint_tensorcan skipclone()on same-device CUDA tensorsWhen
Checkpoint(device="cuda")(without an explicit index) is used and the model tensor lives oncuda:0,torch.device("cuda") != torch.device("cuda:0")evaluates toTrue, so the branch falls through to.to(self.checkpoint.device). Calling.to()on a tensor that is already on the target device returnsself(shared storage), not a copy. Any subsequent in-place operation on the livexwould then silently corrupt the saved checkpoint field. Normalising both sides to an indexed device before comparing eliminates the ambiguity.