Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 90 additions & 10 deletions earth2studio/models/px/fcn3.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import json
from collections import OrderedDict
from collections.abc import Generator, Iterator
from dataclasses import dataclass
from datetime import datetime

import numpy as np
Expand All @@ -27,6 +28,7 @@
from earth2studio.models.px.base import PrognosticModel
from earth2studio.models.px.utils import PrognosticMixin
from earth2studio.utils import handshake_coords, handshake_dim
from earth2studio.utils.checkpoint import bind_checkpoint_state
from earth2studio.utils.imports import (
OptionalDependencyFailure,
check_optional_dependencies,
Expand Down Expand Up @@ -129,6 +131,14 @@
]


@dataclass
class _FCN3CheckpointState:
x: torch.Tensor | None = None
coord_keys: tuple[str, ...] = ()
coord_values: tuple[np.ndarray, ...] = ()
internal_noise_states: tuple[tuple[torch.Tensor | None, ...], ...] = ()


@check_optional_dependencies()
class FCN3(torch.nn.Module, AutoModelMixin, PrognosticMixin):
"""FourCastNet 3 advances global weather modeling by implementing a scalable,
Expand Down Expand Up @@ -177,6 +187,7 @@ def __init__(
self.variables[self.variables == "2d"] = "d2m"

self.set_rng(reset=True, seed=seed)
self.checkpoint = bind_checkpoint_state(_FCN3CheckpointState())

def __str__(self) -> str:
return "fcn3"
Expand Down Expand Up @@ -307,6 +318,71 @@ def load_model(

return cls(model, variables=variables)

def _restore_checkpoint_state(
self, x: torch.Tensor, coords: CoordSystem
) -> tuple[torch.Tensor, CoordSystem, bool]:
if not self.checkpoint.checkpoint_state_loaded:
return x, coords, False
if self.checkpoint.checkpoint_level < 2:
return x, coords, False
if self.checkpoint.x is None or not self.checkpoint.coord_keys:
return x, coords, False
if len(self.checkpoint.coord_keys) != len(self.checkpoint.coord_values):
raise RuntimeError("FCN3 checkpoint coordinate state is incomplete.")
if not self.checkpoint.internal_noise_states:
raise RuntimeError(
"FCN3 checkpoint is missing internal noise state required for full restart."
)

restored_x = self.checkpoint.x.to(x.device)
restored_coords = OrderedDict(
(key, np.asarray(value).copy())
for key, value in zip(
self.checkpoint.coord_keys, self.checkpoint.coord_values
)
)
restored_states = [
[None if state is None else state.to(restored_x.device) for state in states]
for states in self.checkpoint.internal_noise_states
]
if len(restored_states) != len(restored_coords["batch"]) or any(
len(states) != len(restored_coords["time"]) for states in restored_states
):
raise RuntimeError(
"FCN3 checkpoint internal noise state does not match saved coordinates."
)

self._internal_noise_states = restored_states
return restored_x, restored_coords, True

def _save_checkpoint_state(self, x: torch.Tensor, coords: CoordSystem) -> None:
if not self.checkpoint.checkpoint_enabled:
return
if self.checkpoint.checkpoint_level < 2:
self.checkpoint.x = None
self.checkpoint.coord_keys = ()
self.checkpoint.coord_values = ()
self.checkpoint.internal_noise_states = ()
return

def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach()
if tensor.device == self.checkpoint.device:
return tensor.clone()
return tensor.to(self.checkpoint.device)
Comment on lines +368 to +372

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 checkpoint_tensor can skip clone() on same-device CUDA tensors

When Checkpoint(device="cuda") (without an explicit index) is used and the model tensor lives on cuda:0, torch.device("cuda") != torch.device("cuda:0") evaluates to True, so the branch falls through to .to(self.checkpoint.device). Calling .to() on a tensor that is already on the target device returns self (shared storage), not a copy. Any subsequent in-place operation on the live x would then silently corrupt the saved checkpoint field. Normalising both sides to an indexed device before comparing eliminates the ambiguity.

Suggested change
def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach()
if tensor.device == self.checkpoint.device:
return tensor.clone()
return tensor.to(self.checkpoint.device)
def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor:
tensor = tensor.detach()
target = self.checkpoint.device
# Normalise both sides so cuda == cuda:0 comparisons work correctly.
src = torch.device(tensor.device.type, tensor.device.index or 0)
tgt = torch.device(target.type, target.index or 0)
if src == tgt:
return tensor.clone()
return tensor.to(target)


self.checkpoint.x = checkpoint_tensor(x)
self.checkpoint.coord_keys = tuple(coords.keys())
self.checkpoint.coord_values = tuple(
np.asarray(value).copy() for value in coords.values()
)
self.checkpoint.internal_noise_states = tuple(
tuple(
None if state is None else checkpoint_tensor(state) for state in states
)
for states in getattr(self, "_internal_noise_states", ())
)

def _get_internal_state(self, ensemble_index: int, time_index: int) -> torch.Tensor:
"""Get the internal RNG state for the given ensemble and time index

Expand Down Expand Up @@ -391,6 +467,7 @@ def _forward(
self._set_internal_state(j, i)

x = x.unsqueeze(2)
self._save_checkpoint_state(x, output_coords)
return x, output_coords

@batch_func()
Expand All @@ -413,11 +490,12 @@ def __call__(
tuple[torch.Tensor, CoordSystem]
Output tensor and coordinate system
"""
# Initialize the internal noise states
# for each batch index, we will have a list of noise states for each separate time
self._reset_internal_state(len(coords["batch"]), len(coords["time"]))
output, coords = self._forward(x, coords)
return output, coords
x, coords, restored = self._restore_checkpoint_state(x, coords)
if not restored:
# Initialize the internal noise states for each batch and time.
self._reset_internal_state(len(coords["batch"]), len(coords["time"]))

return self._forward(x, coords)
Comment on lines 473 to +498

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 __call__ silently ignores user-provided x and coords when a full checkpoint is loaded

When checkpoint_state_loaded=True with policy="full", _restore_checkpoint_state replaces the caller's x and coords with the checkpoint's saved state. The user's inputs are completely discarded with no warning, and the returned tensor/coords come from a different starting point than what was passed in. Direct __call__ invocations inside a with checkpoint.select(-1): block will produce silently wrong outputs. There is no test covering this path, and the method's docstring gives no hint that inputs may be overridden.


@batch_func()
def _default_generator(
Expand All @@ -426,15 +504,17 @@ def _default_generator(
coords = coords.copy()
self.output_coords(coords)

# Initialize the internal noise states
self._reset_internal_state(len(coords["batch"]), len(coords["time"]))
x, coords, restored = self._restore_checkpoint_state(x, coords)
if not restored:
# Initialize the internal noise states for each batch and time.
self._reset_internal_state(len(coords["batch"]), len(coords["time"]))
self._save_checkpoint_state(x, coords)
yield x, coords

# Yield the initial condition
yield x, coords
while True:
# Front hook
x, coords = self.front_hook(x, coords)
# Forward is identity operator
# Advance FCN3 one forecast step. Restored checkpoints resume here.
x, coords = self._forward(x, coords)
# Rear hook
x, coords = self.rear_hook(x, coords)
Expand Down
66 changes: 65 additions & 1 deletion test/models/px/test_fcn3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from earth2studio.data import Random, fetch_data
from earth2studio.models.px import FCN3
from earth2studio.utils import handshake_dim
from earth2studio.utils.checkpoint import Checkpoint


class PhooFCN3Preprocessor(torch.nn.Module):

def __init__(
self,
):
Expand Down Expand Up @@ -67,6 +67,17 @@ def set_rng(self, reset: bool = True, seed: int = 333):
return


class PhooRestartFCN3ModelWrapper(PhooFCN3ModelWrapper):
def forward(self, x, t, normalized_data: bool = False, replace_state: bool = False):
return x + self.model.preprocessor.state[0].to(x.device)


def _phoo_fcn3(model: torch.nn.Module, variables: np.ndarray) -> FCN3:
fcn3 = FCN3.__new__(FCN3)
FCN3.__init__.__wrapped__(fcn3, model, variables=variables)
return fcn3


@pytest.fixture(scope="function")
def dummy_model():
preprocessor = PhooFCN3Preprocessor()
Expand Down Expand Up @@ -167,6 +178,59 @@ def test_fcn3_iter(ensemble, device, dummy_model):
break


def test_fcn3_checkpoint_state_round_trip_with_phoo_model(tmp_path):
variables = np.array(["u10m"])
time = np.array([np.datetime64("1993-04-05T00:00")])
checkpoint = Checkpoint("fcn3", path=tmp_path / "fcn3", flush_interval=2, level=2)

torch.manual_seed(123)
model = PhooRestartFCN3ModelWrapper(PhooFCN3Model(PhooFCN3Preprocessor()))
p = _phoo_fcn3(model, variables=variables).to("cpu")
input_coords = p.input_coords()
coords = OrderedDict(
{
"time": time,
"lead_time": input_coords["lead_time"],
"variable": input_coords["variable"],
"lat": input_coords["lat"],
"lon": input_coords["lon"],
}
)
x = torch.zeros(
len(coords["time"]),
len(coords["lead_time"]),
len(coords["variable"]),
len(coords["lat"]),
len(coords["lon"]),
)

with checkpoint.select(time=time) as ckpt:
iterator = p.create_iterator(x, coords)
_, initial_coords = next(iterator)
ckpt.write(lead_time=initial_coords["lead_time"])
step1, step1_coords = next(iterator)
ckpt.write(lead_time=step1_coords["lead_time"])
step2, step2_coords = next(iterator)
step3, step3_coords = next(iterator)

torch.manual_seed(999)
with checkpoint.select(-1):
resumed_model = PhooRestartFCN3ModelWrapper(
PhooFCN3Model(PhooFCN3Preprocessor())
)
resumed = _phoo_fcn3(resumed_model, variables=variables).to("cpu")
assert resumed.checkpoint.checkpoint_state_loaded

resumed_iterator = resumed.create_iterator(torch.full_like(x, -100.0), coords)
resumed_step2, resumed_step2_coords = next(resumed_iterator)
resumed_step3, resumed_step3_coords = next(resumed_iterator)

assert torch.allclose(resumed_step2, step2)
assert torch.allclose(resumed_step3, step3)
assert resumed_step2_coords["lead_time"][0] == step2_coords["lead_time"][0]
assert resumed_step3_coords["lead_time"][0] == step3_coords["lead_time"][0]


@pytest.mark.parametrize(
"dc",
[
Expand Down