diff --git a/.claude/skills-dev/developer-model-checkpoint-update/SKILL.md b/.claude/skills-dev/developer-model-checkpoint-update/SKILL.md new file mode 100644 index 000000000..d6a6c5f99 --- /dev/null +++ b/.claude/skills-dev/developer-model-checkpoint-update/SKILL.md @@ -0,0 +1,319 @@ +--- +name: developer-model-checkpoint-update +version: 0.16.0 +license: Apache-2.0 +metadata: + author: NVIDIA Earth-2 Team + tags: + - earth2studio + - earth2 + - python + - checkpoint + - restart + - models + - inference + - testing +description: > + Add checkpoint restart support to an existing Earth2Studio model or component + by inspecting its call and iterator semantics, identifying continuation state + and random state, binding a pickle-free dataclass with bind_checkpoint_state, + honoring minimal/state/full checkpoint policies, handling checkpoint device + staging, and adding focused restart unit tests. +--- + +# Model Checkpoint Update - Add Restart Support + +Add checkpoint restart support to an existing Earth2Studio model, perturbation, +or component without expanding its public model API. Use this skill when a +component needs to opt into `earth2studio.utils.checkpoint` restart state. + +The goal is a small, explicit implementation: one internal dataclass, one +`bind_checkpoint_state(...)` call during construction, minimal changes to the +existing call/iterator path, and one focused unit test proving restart behavior. + +--- + +## Design Rules + +Follow these rules throughout the change: + +1. Do not add required user-facing model arguments for checkpointing. +2. Do not add a new model base class, wrapper class, inheritance layer, or global + registry. +3. Do not use pickle, `torch.save`, or arbitrary object serialization for + checkpoint state. +4. Keep checkpoint state in a dataclass local to the component module unless + there is already a better local pattern. +5. Use `bind_checkpoint_state` exactly once per restart state dataclass instance, + normally in `__init__` after normal model fields are initialized. +6. Treat `minimal`, `state`, and `full` as user intent hints exposed through the + bound checkpoint proxy. +7. Use `self..device` or `self..checkpoint_device` when staging live + tensors so `Checkpoint(device=...)` controls where restart tensors live. +8. Keep IO backend output separate from model restart state. Do not assume saved + user-facing output variables are sufficient to rehydrate model internals. +9. Add or update one focused unit test using existing test utility models when + possible. Avoid broad mocks and avoid large integration fixtures. + +--- + +## Step 1 - Read the Component Before Editing + +Inspect the target component and its tests before deciding what to save. + +Look for: + +- Constructor inputs and persistent attributes. +- `__call__`, `forward`, `_forward`, `create_iterator`, or generator methods. +- Autoregressive state that is advanced between yields. +- Random number generation, including `torch.Generator`, NumPy generators, + global Torch/NumPy RNG usage, perturbations, dropout-like behavior, or sampling. +- Pre/post hooks that mutate tensors or coordinates. +- Device moves, dtype casts, autocast blocks, and CPU/GPU assumptions. +- Existing test utility models such as Phoo, Random, Persistence, Identity, or + local dummy modules. + +Identify the restart boundary. For prognostic iterators, the saved boundary is +usually the latest completed forecast state. On restore, the iterator should +consume that saved boundary internally and yield the next forecast state. + +--- + +## Step 2 - Choose the State Policy Behavior + +Implement behavior for all three checkpoint policies: + +- `minimal`: do not store component state. The workflow may still record catalog + progress and explicit artifacts. +- `state`: store lightweight state needed to resume at workflow or forecast-run + boundaries, such as RNG state, counters, selected member IDs, or replayable + metadata. Do not store large continuation tensors here. +- `full`: store everything supported by the component to resume mid-rollout, + including continuation tensors and coordinate metadata when user-facing IO is + not restart-complete. + +When a model cannot support a policy, make the fallback explicit in code. For +example, `state` may save RNG state while `full` saves RNG plus continuation +state. `minimal` should clear or avoid updating component state. + +Do not expose these policies as new model constructor parameters. Read them from +the bound state proxy: + +```python +if self.restart.checkpoint_state_policy == "full": + ... +elif self.restart.checkpoint_state_policy == "state": + ... +else: + ... +``` + +--- + +## Step 3 - Add a Dataclass State + +Create a dataclass containing only pickle-free serializable fields. Good field +examples include: + +- `torch.Tensor | None` +- `np.ndarray` +- JSON-like scalars, tuples, lists, and dicts +- `np.datetime64`, `np.timedelta64`, `torch.dtype`, `torch.device` +- coordinate keys and coordinate values needed to reconstruct an `OrderedDict` + +Avoid fields that hold live model objects, modules, hooks, data sources, IO +backends, callables, open files, generators, or arbitrary Python objects. + +Example: + +```python +@dataclass +class MyModelCheckpointState: + x: torch.Tensor | None = None + coord_keys: tuple[str, ...] = () + coord_values: tuple[np.ndarray, ...] = () + rng_state: torch.Tensor | None = None +``` + +Bind it in the constructor: + +```python +self.restart = bind_checkpoint_state(MyModelCheckpointState()) +``` + +Use a component-specific field name such as `restart`, `checkpoint`, or +`_checkpoint_state` that matches local style. Do not name dataclass fields +`device` or `checkpoint_*`; those names are reserved for checkpoint metadata on +the proxy. + +--- + +## Step 4 - Restore State at the Right Boundary + +Restore state where the component first has enough runtime context to do so. + +For a simple callable component, this may be in `__call__` before generating the +next stochastic value. For a prognostic model, prefer the iterator construction +path so workflow code can still fetch and pass the normal initial condition. + +Pattern for an iterator: + +```python +restored = False +if self.restart.checkpoint_state_loaded and self.restart.x is not None: + x = self.restart.x.to(x.device) + coords = OrderedDict( + (key, np.asarray(value).copy()) + for key, value in zip(self.restart.coord_keys, self.restart.coord_values) + ) + restored = True + +iterator = super().create_iterator(x, coords) +if restored: + next(iterator) # consume the saved boundary internally + +for x_out, coords_out in iterator: + self._save_checkpoint_state(x_out, coords_out) + yield x_out, coords_out +``` + +If the component has both `__call__` and `create_iterator`, put shared save logic +in a small private helper so both paths update the dataclass consistently. + +If hooks mutate the returned tensor or coordinates, save the post-hook state. +Checkpoint state should match the boundary that future computation will continue +from, not an earlier internal intermediate unless that is intentional and tested. + +--- + +## Step 5 - Handle Random State Explicitly + +Prefer a component-owned `torch.Generator` when possible. This avoids saving or +restoring global Torch or NumPy RNG state. + +For stochastic components: + +1. Create the generator internally from existing seed/input parameters. +2. If checkpoint state is loaded, restore the generator from the saved state. +3. For `state`, save the generator state needed to reproduce the next component + call or forecast instance. +4. For `full`, save the generator state that correctly continues after the saved + mid-rollout boundary. + +Be precise about pre-state versus post-state: + +- Save pre-call state when restart should replay the just-started stochastic + operation. +- Save post-call state when restart should continue after the completed + stochastic operation. + +Do not gate RNG dataclass updates on `checkpoint_is_flush_due` unless stale state +cannot affect correctness. `flush_interval` should usually control disk writes, +not whether the live dataclass has the latest restart state. + +--- + +## Step 6 - Stage Tensors on the Checkpoint Device + +When saving live tensor state, detach and clone before staging it: + +```python +self.restart.x = x.detach().clone().to(self.restart.device) +``` + +Use `self.restart.device` or `self.restart.checkpoint_device`; both are provided +by the bound checkpoint proxy. The default is CPU, but users may set +`Checkpoint(device=torch.device("cuda:0"))` to keep full checkpoint tensors on +the active inference device and reduce device transfers. + +When restoring, move staged tensors to the runtime input/model device: + +```python +x = self.restart.x.to(x.device) +``` + +Only store large tensors for `full`. For `minimal` and usually `state`, clear +large fields or leave them unset: + +```python +self.restart.x = None +self.restart.coord_keys = () +self.restart.coord_values = () +``` + +--- + +## Step 7 - Write One Focused Unit Test + +Add a restart test close to the component's existing tests. Prefer an existing +small model fixture over a broad mock. + +The test should cover: + +1. Construct a `Checkpoint(..., state_policy="full")` or `state` as appropriate. +2. Run the component long enough to write a checkpoint row. +3. Re-open or re-select the checkpoint with `checkpoint.select(-1)`. +4. Construct the component inside the selected checkpoint context. +5. Continue the run and assert the restarted result matches the uninterrupted + reference or expected continuation. +6. Assert the component actually used hydrated state when that is observable. + +Example skeleton: + +```python +checkpoint = Checkpoint("model", path=tmp_path, mode="append", state_policy="full") + +with checkpoint as ckpt: + model = MyRestartableModel(...) + x1, coords1 = next(model.create_iterator(x0, coords0)) + ckpt.write(lead_time=coords1["lead_time"][-1]) + ckpt.flush() + +with checkpoint.select(-1) as ckpt: + restarted = MyRestartableModel(...) + out, coords = next(restarted.create_iterator(x0, coords0)) + +assert ... +``` + +For random components, compare the restarted sample sequence against a reference +sequence produced by the same seed. For prognostic models, compare both tensor +values and lead-time progression. + +Run the smallest relevant test first, then the local test file if optional +dependencies allow it. + +--- + +## Step 8 - Validate the Change + +Run targeted checks before committing: + +```bash +uv run ruff check +uv run pytest :: -q +git diff --check +``` + +If the full test file cannot run because an optional dependency group is missing, +run the targeted new test and clearly report the optional dependency limitation. +Do not skip the focused restart test. + +--- + +## Review Checklist + +Before opening the PR, verify: + +- Existing model construction still works without a checkpoint. +- Existing public model APIs are unchanged unless the user explicitly requested + otherwise. +- `bind_checkpoint_state` is called during construction. +- `minimal`, `state`, and `full` behavior is explicit. +- Tensor state is staged with `.detach().clone().to(self.restart.device)` when + relevant. +- Restore logic moves tensors back to the runtime device. +- Iterator restart consumes the saved boundary internally and yields the next + forecast state. +- Random state uses a component-owned generator where practical. +- The test proves restart behavior rather than only checking serialization. diff --git a/CHANGELOG.md b/CHANGELOG.md index 5105ce3ac..32196c139 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,12 +24,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 through 2011-03-27 with cycles every 5 days, served from the NCEI HTTPS archive - Added IBTrACS tropical cyclone track DataFrame source (`IBTrACS`) +- Added checkpoint/session utilities and restart support for deterministic, + diagnostic, and ensemble inference workflows +- Added checkpoint state policy hints exposed through bound state proxies +- Added U-CAST checkpoint state support for mid-rollout deterministic forecast restarts - Added EUMETNET OPERA European weather radar composite DataSource for DBZH reflectivity, rain rate, and 1-hour accumulation (`OPERA`) - Added support for cumulative variables in ARCO data source ### Changed +- Simplified workflow checkpoint handling with package-level no-op checkpoint + sessions and idempotent final flushes +- Deterministic checkpoint resume now leaves restart-state restoration to the + prognostic iterator after fetching the normal initial condition +- Renamed checkpoint state policies to `minimal`, `state`, and `full` - UFS GSI observation sources (`UFSObsConv`, `UFSObsSat`) now fetch from S3 via native `obstore` instead of `s3fs` to avoid the Python-GIL bottleneck that caps fsspec's concurrent S3 read throughput (~22% faster obs fetch, ~20% HealDA e2e on B200; output unchanged). diff --git a/docs/modules/utils_all.rst b/docs/modules/utils_all.rst index 0aa55f12b..f373e9a32 100644 --- a/docs/modules/utils_all.rst +++ b/docs/modules/utils_all.rst @@ -48,6 +48,30 @@ The following functions can be used to convert to and from these numpy arrays. utils.time.timearray_to_datetime utils.time.to_time_array +.. _earth2studio.utils.checkpoint: + +:mod:`earth2studio.utils`: Checkpointing +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Checkpoint utilities for restartable inference workflows. Checkpoints track +explicit workflow labels, optional artifacts, and dataclass state bound by +components during a selected checkpoint session. + +.. autosummary:: + :toctree: generated/utils/ + :template: class.rst + + utils.checkpoint.Checkpoint + utils.checkpoint.CheckpointSession + utils.checkpoint.CheckpointState + utils.checkpoint.NullCheckpointSession + +.. autosummary:: + :toctree: generated/utils/ + :template: function.rst + + utils.checkpoint.bind_checkpoint_state + .. _earth2studio.data.functions: :mod:`earth2studio.data`: Data diff --git a/docs/userguide/developer/checkpointing.md b/docs/userguide/developer/checkpointing.md new file mode 100644 index 000000000..a00555756 --- /dev/null +++ b/docs/userguide/developer/checkpointing.md @@ -0,0 +1,265 @@ +# Checkpointing + +Earth2Studio checkpointing is designed for long-running inference jobs that need +to restart without asking every model, perturbation, or custom loop to adopt a +new component API. A `Checkpoint` manages a small set of saved restart points. +Each row is selected by labels such as `time`, `ensemble`, or a user-defined +scenario name, and each row can contain small workflow artifacts plus dataclass +state from components that opt in. + +The checkpoint does not store model weights or duplicate data already written by +an IO backend. Forecast fields should continue to be written to the selected IO +backend. The checkpoint stores the position and small state needed to decide +where to restart. + +## Basic Use + +```python +from earth2studio.run import deterministic +from earth2studio.utils.checkpoint import Checkpoint + +checkpoint = Checkpoint("my-forecast", flush_interval=6, state_policy="full") + +deterministic( + time=["2024-01-01"], + nsteps=24, + prognostic=model, + data=data, + io=io, + checkpoint=checkpoint, +) +``` + +The built-in deterministic, diagnostic, and ensemble workflows write checkpoint +rows after successful IO writes. `flush_interval` controls durable writes to +disk. The workflow can call `write` on the active checkpoint session every +iteration while the checkpoint decides whether a real disk commit is due. +Ensemble workflow rows are tracked independently by `ensemble_batch` when a +checkpoint is provided. + +Use `flush_interval=1` for every write call to commit immediately, or +`flush_interval=None` to keep updates pending until `flush` is called. The +workflow flushes the final pending state before returning. When checkpointing is +omitted, built-in workflows use a no-op checkpoint session with the same +`write` and `flush` methods. + +`state_policy` is a user intent hint exposed to bound component state. It does +not force a model to checkpoint a particular payload. Supported values are: + +- `minimal`: save only catalog progress and explicit workflow artifacts. +- `state`: save lightweight component restart state such as RNG, generator, or + counter state needed to resume at workflow or forecast-instance boundaries. +- `full`: save all supported restart state, including heavy tensors needed to + resume inside a rollout. This is the default. + +## Selecting Restart Points + +`Checkpoint.select` returns a `CheckpointSession`, which chooses a saved row or a +future label set. The selected labels also become the labels for future writes in +that session. + +```python +checkpoint = Checkpoint("my-forecast") + +print(checkpoint) + +latest = checkpoint.select(-1) +latest_time = checkpoint.select(time=-1) +first_member = checkpoint.select(time="2024-01-01T00:00:00", ensemble=0) +``` + +The integer positional argument selects catalog rows, with negative indexing +supported. Keyword selections choose the latest matching row. A keyword value of +`-1` means the latest saved value for that label after all other labels are +applied. + +## Custom Loops + +Custom workflows can call `write` after a safe iteration boundary, usually after +the forecast fields for that iteration have been written to the IO backend. + +```python +checkpoint = Checkpoint("ensemble-forecast", mode="append", history_size=8) + +with checkpoint.select(time=time, ensemble=member) as ckpt: + for lead_time in lead_times: + coords, array = step_model(...) + io.write(coords, array) + + ckpt.write( + lead_time=lead_time, + artifacts={"last_complete_lead_time": lead_time}, + ) + + ckpt.flush() +``` + +`write` accepts explicit `lead_time` and `artifacts` keyword arguments. +`lead_time` records the latest completed forecast position. `artifacts` is for +small user-provided restart metadata. `flush()` with no arguments commits the +latest pending `write`; if there is nothing pending, it is a no-op. Arbitrary +keyword arguments are intentionally not accepted so checkpoint payloads remain +explicit. + +`mode="overwrite"` keeps only the latest row for a label set. `mode="append"` +keeps a history, and `history_size` can cap that history. + +## Workflow Resume + +Built-in workflows always fetch the normal initial condition and feed it to the +prognostic model iterator. When a checkpoint row is selected, the workflow uses +the row's lead time only as the completed workflow position. A checkpoint-aware +model is responsible for using its bound dataclass state inside `create_iterator` +to restore the selected restart point. If state is restored, the iterator should +consume that saved boundary internally and yield the next forecast state. If no +state is restored, the iterator should keep the normal convention of yielding +the initial condition first. + +This keeps restart independent from model internals and avoids assuming that +user-facing IO output is restart-complete. Diagnostic workflows still track the +prognostic lead time before diagnostic output is written. Ensemble workflows store +progress per mini-batch using an `ensemble_batch` label, allowing completed +batches to be skipped and partially completed batches to continue from their +latest saved lead time. + +For custom loops, print the checkpoint and select the desired row by index, for +example `checkpoint.select(-1)` for the latest row. `Checkpoint` and +`CheckpointSession` both support context-manager use. Built-in workflows accept +either the checkpoint manager or a selected `CheckpointSession`. Passing the +manager while a session is active uses that active session; passing a manager +with no active session chooses the latest matching workflow row, or starts a new +row when no matching checkpoint exists. + +## Component State + +Models, perturbations, and user-defined components can opt in by binding a +dataclass instance. Existing components do not need to change. + +```python +from dataclasses import dataclass + +import torch + +from earth2studio.utils.checkpoint import bind_checkpoint_state + + +@dataclass +class NoiseState: + calls: int = 0 + rng_state: torch.Tensor | None = None + + +class NoisePerturbation: + def __init__(self, generator: torch.Generator): + self.generator = generator + self.checkpoint = bind_checkpoint_state(NoiseState()) + + if self.checkpoint.rng_state is not None: + self.generator.set_state(self.checkpoint.rng_state) + + def __call__(self, x): + y = add_noise(x, generator=self.generator) + self.checkpoint.calls += 1 + if self.checkpoint.checkpoint_enabled: + self.checkpoint.rng_state = self.generator.get_state() + return y +``` + +`bind_checkpoint_state` returns a proxy around the original dataclass. Normal +dataclass fields are accessed directly, while checkpoint metadata is exposed +through `checkpoint_*` properties such as `checkpoint_enabled`, +`checkpoint_state_policy`, `checkpoint_device`, `checkpoint_flush_interval`, +`checkpoint_write_count`, `checkpoint_is_flush_due`, `checkpoint_selected`, +`checkpoint_state_loaded`, and `checkpoint_lead_time`. The shorter `device` +alias is provided for tensor staging, for example `x.to(self.restart.device)`. +Because `device` is reserved checkpoint metadata, component state dataclasses +should use a different field name for their own device values. Live tensor state +is staged on CPU by default; pass `device=` to `Checkpoint` when components +should stage it elsewhere. + +When a checkpoint session is active, `bind_checkpoint_state` loads the matching +saved state if one exists and registers the live dataclass instance for future +writes. Model state can be lightweight `state` policy metadata, such as lead +time and RNG state, or heavier `full` policy tensors when the model cannot be +restarted from user-facing IO output alone. When no session is active but a +`Checkpoint` has been instantiated, `bind_checkpoint_state` buffers the live +dataclass for that checkpoint. The buffered state is registered when a session +for that checkpoint is entered. + + +A model can use the checkpoint policy hint without adding another API call: + +```python +if self.restart.checkpoint_state_loaded: + x = self.restart.x.to(x.device) + +if self.restart.checkpoint_state_policy == "full": + self.restart.x = x.detach().clone().to(self.restart.device) +elif self.restart.checkpoint_state_policy == "state": + self.restart.step = step + self.restart.rng_state = generator.get_state() +else: + self.restart.x = None +``` + +This makes new runs simple: + +```python +checkpoint = Checkpoint("my-forecast") +model = MyRestartableModel(...) +deterministic(..., checkpoint=checkpoint) +``` + +For strict restart hydration from an existing row, construct restartable +components inside the selected session: + +```python +checkpoint = Checkpoint("my-forecast") + +with checkpoint.select(-1): + model = MyRestartableModel(...) + perturbation = NoisePerturbation(generator) + deterministic(..., checkpoint=checkpoint) +``` + +If a component binds state before an existing checkpoint session is entered, the +session will still hydrate that dataclass when it opens, but Earth2Studio emits a +warning because constructor side effects that already used the default state will +not be replayed. + +State identity is based on the dataclass type, using its fully qualified module +and class name. Binding the same dataclass type more than once in one checkpoint +session raises an error, because the checkpoint would otherwise not know which +saved payload belongs to which component. Use separate dataclass types for +distinct restartable components. + +## Serialization Rules + +Checkpoint state is intentionally pickle-free. Supported values include +JSON-like scalars and containers, dataclasses, `datetime`, `date`, `timedelta`, +`numpy.datetime64`, `numpy.timedelta64`, `numpy.dtype`, NumPy scalars, +`torch.device`, `torch.dtype`, `torch.Tensor`, and `numpy.ndarray` with +non-object dtype. Tensors and arrays are stored as separate `.npy` files with +pickle disabled. + +Unsupported objects raise `CheckpointSerializationError` during checkpoint +writes. If a dataclass definition changes incompatibly between save and restore, +checkpoint binding raises `CheckpointStateSchemaError` instead of guessing. + +## Distributed Runs + +Serial runs write directly into the checkpoint directory. Distributed runs write +each process to its own rank directory, such as `rank_000000` or `rank_000001`. +Rank detection first checks PhysicsNeMo's distributed manager when available, +then common distributed environment variables. The checkpoint does not use file +locks across ranks. + +## Storage Layout + +By default, checkpoints are stored under +`$EARTH2STUDIO_CACHE/checkpoints/` or `~/.cache/earth2studio/checkpoints/`. +Pass `path=` to store them elsewhere. + +Each durable write is staged in a temporary directory and then atomically moved +into the commit directory. The catalog file is also written atomically. Incomplete +temporary commits are ignored and cleaned up by later writes. diff --git a/docs/userguide/developer/index.md b/docs/userguide/developer/index.md index 2f5bf59c1..efa6d651b 100644 --- a/docs/userguide/developer/index.md +++ b/docs/userguide/developer/index.md @@ -11,4 +11,5 @@ testing build recipes skills +checkpointing ``` diff --git a/docs/userguide/index.md b/docs/userguide/index.md index 6dbf74309..186807543 100644 --- a/docs/userguide/index.md +++ b/docs/userguide/index.md @@ -68,6 +68,7 @@ run(["2024-01-01"], 10, model, ds, io) - [Testing](developer/testing) - [Build](developer/build) - [Recipes](developer/recipes) +- [Checkpointing](developer/checkpointing) ## Support diff --git a/earth2studio/models/px/fcn3.py b/earth2studio/models/px/fcn3.py index 9f6bd4d0c..83931788c 100644 --- a/earth2studio/models/px/fcn3.py +++ b/earth2studio/models/px/fcn3.py @@ -16,6 +16,7 @@ import json from collections import OrderedDict from collections.abc import Generator, Iterator +from dataclasses import dataclass from datetime import datetime import numpy as np @@ -27,6 +28,7 @@ from earth2studio.models.px.base import PrognosticModel from earth2studio.models.px.utils import PrognosticMixin from earth2studio.utils import handshake_coords, handshake_dim +from earth2studio.utils.checkpoint import bind_checkpoint_state from earth2studio.utils.imports import ( OptionalDependencyFailure, check_optional_dependencies, @@ -129,6 +131,14 @@ ] +@dataclass +class _FCN3CheckpointState: + x: torch.Tensor | None = None + coord_keys: tuple[str, ...] = () + coord_values: tuple[np.ndarray, ...] = () + internal_noise_states: tuple[tuple[torch.Tensor | None, ...], ...] = () + + @check_optional_dependencies() class FCN3(torch.nn.Module, AutoModelMixin, PrognosticMixin): """FourCastNet 3 advances global weather modeling by implementing a scalable, @@ -177,6 +187,7 @@ def __init__( self.variables[self.variables == "2d"] = "d2m" self.set_rng(reset=True, seed=seed) + self.checkpoint = bind_checkpoint_state(_FCN3CheckpointState()) def __str__(self) -> str: return "fcn3" @@ -307,6 +318,71 @@ def load_model( return cls(model, variables=variables) + def _restore_checkpoint_state( + self, x: torch.Tensor, coords: CoordSystem + ) -> tuple[torch.Tensor, CoordSystem, bool]: + if not self.checkpoint.checkpoint_state_loaded: + return x, coords, False + if self.checkpoint.checkpoint_state_policy != "full": + return x, coords, False + if self.checkpoint.x is None or not self.checkpoint.coord_keys: + return x, coords, False + if len(self.checkpoint.coord_keys) != len(self.checkpoint.coord_values): + raise RuntimeError("FCN3 checkpoint coordinate state is incomplete.") + if not self.checkpoint.internal_noise_states: + raise RuntimeError( + "FCN3 checkpoint is missing internal noise state required for full restart." + ) + + restored_x = self.checkpoint.x.to(x.device) + restored_coords = OrderedDict( + (key, np.asarray(value).copy()) + for key, value in zip( + self.checkpoint.coord_keys, self.checkpoint.coord_values + ) + ) + restored_states = [ + [None if state is None else state.to(restored_x.device) for state in states] + for states in self.checkpoint.internal_noise_states + ] + if len(restored_states) != len(restored_coords["batch"]) or any( + len(states) != len(restored_coords["time"]) for states in restored_states + ): + raise RuntimeError( + "FCN3 checkpoint internal noise state does not match saved coordinates." + ) + + self._internal_noise_states = restored_states + return restored_x, restored_coords, True + + def _save_checkpoint_state(self, x: torch.Tensor, coords: CoordSystem) -> None: + if not self.checkpoint.checkpoint_enabled: + return + if self.checkpoint.checkpoint_state_policy != "full": + self.checkpoint.x = None + self.checkpoint.coord_keys = () + self.checkpoint.coord_values = () + self.checkpoint.internal_noise_states = () + return + + def checkpoint_tensor(tensor: torch.Tensor) -> torch.Tensor: + tensor = tensor.detach() + if tensor.device == self.checkpoint.device: + return tensor.clone() + return tensor.to(self.checkpoint.device) + + self.checkpoint.x = checkpoint_tensor(x) + self.checkpoint.coord_keys = tuple(coords.keys()) + self.checkpoint.coord_values = tuple( + np.asarray(value).copy() for value in coords.values() + ) + self.checkpoint.internal_noise_states = tuple( + tuple( + None if state is None else checkpoint_tensor(state) for state in states + ) + for states in getattr(self, "_internal_noise_states", ()) + ) + def _get_internal_state(self, ensemble_index: int, time_index: int) -> torch.Tensor: """Get the internal RNG state for the given ensemble and time index @@ -391,6 +467,7 @@ def _forward( self._set_internal_state(j, i) x = x.unsqueeze(2) + self._save_checkpoint_state(x, output_coords) return x, output_coords @batch_func() @@ -413,11 +490,12 @@ def __call__( tuple[torch.Tensor, CoordSystem] Output tensor and coordinate system """ - # Initialize the internal noise states - # for each batch index, we will have a list of noise states for each separate time - self._reset_internal_state(len(coords["batch"]), len(coords["time"])) - output, coords = self._forward(x, coords) - return output, coords + x, coords, restored = self._restore_checkpoint_state(x, coords) + if not restored: + # Initialize the internal noise states for each batch and time. + self._reset_internal_state(len(coords["batch"]), len(coords["time"])) + + return self._forward(x, coords) @batch_func() def _default_generator( @@ -426,15 +504,17 @@ def _default_generator( coords = coords.copy() self.output_coords(coords) - # Initialize the internal noise states - self._reset_internal_state(len(coords["batch"]), len(coords["time"])) + x, coords, restored = self._restore_checkpoint_state(x, coords) + if not restored: + # Initialize the internal noise states for each batch and time. + self._reset_internal_state(len(coords["batch"]), len(coords["time"])) + self._save_checkpoint_state(x, coords) + yield x, coords - # Yield the initial condition - yield x, coords while True: # Front hook x, coords = self.front_hook(x, coords) - # Forward is identity operator + # Advance FCN3 one forecast step. Restored checkpoints resume here. x, coords = self._forward(x, coords) # Rear hook x, coords = self.rear_hook(x, coords) diff --git a/earth2studio/models/px/persistence.py b/earth2studio/models/px/persistence.py index 7779d553a..d6d72194e 100644 --- a/earth2studio/models/px/persistence.py +++ b/earth2studio/models/px/persistence.py @@ -16,6 +16,7 @@ from collections import OrderedDict from collections.abc import Generator, Iterator +from dataclasses import dataclass import numpy as np import torch @@ -23,9 +24,17 @@ from earth2studio.models.batch import batch_coords, batch_func from earth2studio.models.px.utils import PrognosticMixin from earth2studio.utils import handshake_coords, handshake_dim +from earth2studio.utils.checkpoint import bind_checkpoint_state from earth2studio.utils.type import CoordSystem +@dataclass +class _PersistenceCheckpointState: + x: torch.Tensor | None = None + coord_keys: tuple[str, ...] = () + coord_values: tuple[np.ndarray, ...] = () + + class Persistence(torch.nn.Module, PrognosticMixin): """Persistence model that generates a forecast by applying the identity operator on the initial condition and indexing the lead time by 6 hours. Primarily used in @@ -78,6 +87,7 @@ def __init__( self._history = history self._dt = dt + self.checkpoint = bind_checkpoint_state(_PersistenceCheckpointState()) def __str__( self, @@ -130,6 +140,40 @@ def output_coords(self, input_coords: CoordSystem) -> CoordSystem: ) return output_coords + def _restore_checkpoint_state( + self, x: torch.Tensor, coords: CoordSystem + ) -> tuple[torch.Tensor, CoordSystem, bool]: + if ( + self.checkpoint.checkpoint_state_policy == "full" + and self.checkpoint.checkpoint_state_loaded + and self.checkpoint.x is not None + and self.checkpoint.coord_keys + ): + x = self.checkpoint.x.to(x.device) + coords = OrderedDict( + (key, np.asarray(value).copy()) + for key, value in zip( + self.checkpoint.coord_keys, self.checkpoint.coord_values + ) + ) + return x, coords, True + return x, coords, False + + def _save_checkpoint_state(self, x: torch.Tensor, coords: CoordSystem) -> None: + if ( + self.checkpoint.checkpoint_enabled + and self.checkpoint.checkpoint_state_policy == "full" + ): + self.checkpoint.x = x.detach().clone().to(self.checkpoint.device) + self.checkpoint.coord_keys = tuple(coords.keys()) + self.checkpoint.coord_values = tuple( + np.asarray(value).copy() for value in coords.values() + ) + else: + self.checkpoint.x = None + self.checkpoint.coord_keys = () + self.checkpoint.coord_values = () + @torch.inference_mode() def _forward( self, @@ -162,17 +206,22 @@ def __call__( x : torch.Tensor coords : CoordSystem """ - return self._forward(x, coords) + x_out, coords_out = self._forward(x, coords) + self._save_checkpoint_state(x_out, coords_out) + return x_out, coords_out @batch_func() def _default_generator( self, x: torch.Tensor, coords: CoordSystem ) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]: + x, coords, restored = self._restore_checkpoint_state(x, coords) self.output_coords(coords.copy()) - coords_out = coords.copy() - coords_out["lead_time"] = coords["lead_time"][-1:] - yield x[:, -1:], coords_out + if not restored: + coords_out = coords.copy() + coords_out["lead_time"] = coords["lead_time"][-1:] + self._save_checkpoint_state(x, coords) + yield x[:, -1:], coords_out while True: # Front hook @@ -188,6 +237,7 @@ def _default_generator( [coords["lead_time"][1:], coords_out["lead_time"]] ) x = torch.cat([x[:, 1:], x_out], dim=1) + self._save_checkpoint_state(x, coords) yield x_out, coords_out diff --git a/earth2studio/models/px/ucast.py b/earth2studio/models/px/ucast.py index c9cbbf02b..979ee3c5a 100644 --- a/earth2studio/models/px/ucast.py +++ b/earth2studio/models/px/ucast.py @@ -19,6 +19,7 @@ import math from collections import OrderedDict from collections.abc import Generator, Iterator +from dataclasses import dataclass from pathlib import Path from typing import TypedDict @@ -33,6 +34,7 @@ from earth2studio.models.px.base import PrognosticModel from earth2studio.models.px.utils import PrognosticMixin from earth2studio.utils import handshake_coords, handshake_dim, handshake_size +from earth2studio.utils.checkpoint import bind_checkpoint_state from earth2studio.utils.type import CoordSystem LEVELS = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000] @@ -611,6 +613,17 @@ def _compute_forcings( return forcing +@dataclass +class _UCastCheckpointState: + x: torch.Tensor | None = None + x_norm: torch.Tensor | None = None + sst_mask: torch.Tensor | None = None + coord_keys: tuple[str, ...] = () + coord_values: tuple[np.ndarray, ...] = () + rng_state: torch.Tensor | None = None + rng_device_type: str | None = None + + class UCast(torch.nn.Module, AutoModelMixin, PrognosticMixin): """U-CAST 1.5 degree global probabilistic weather model. @@ -711,6 +724,7 @@ def __init__( "lon": np.linspace(0, 360, 240, endpoint=False), } ) + self.checkpoint = bind_checkpoint_state(_UCastCheckpointState()) def input_coords(self) -> CoordSystem: """Input coordinate system of the prognostic model.""" @@ -849,6 +863,106 @@ def _normalize(self, x: torch.Tensor) -> torch.Tensor: def _denormalize(self, x: torch.Tensor) -> torch.Tensor: return x * self.scale.to(dtype=x.dtype) + self.center.to(dtype=x.dtype) + def _rng_state( + self, device: torch.device + ) -> tuple[torch.Tensor, str] | tuple[None, None]: + if not self.stochastic: + return None, None + if device.type == "cuda": + return torch.cuda.get_rng_state(device).cpu().clone(), "cuda" + return torch.get_rng_state().cpu().clone(), "cpu" + + def _restore_rng_state(self, device: torch.device) -> None: + if ( + not self.stochastic + or not self.checkpoint.checkpoint_state_loaded + or self.checkpoint.rng_state is None + ): + return + if self.checkpoint.rng_device_type == "cuda" and device.type == "cuda": + torch.cuda.set_rng_state(self.checkpoint.rng_state.cpu(), device) + elif self.checkpoint.rng_device_type == "cpu" and device.type == "cpu": + torch.set_rng_state(self.checkpoint.rng_state.cpu()) + + def _restore_checkpoint_state( + self, x: torch.Tensor, coords: CoordSystem + ) -> tuple[ + torch.Tensor, CoordSystem, torch.Tensor | None, torch.Tensor | None, bool + ]: + policy = self.checkpoint.checkpoint_state_policy + if policy in ("state", "full"): + self._restore_rng_state(x.device) + if ( + policy == "full" + and self.checkpoint.checkpoint_state_loaded + and self.checkpoint.x is not None + and self.checkpoint.coord_keys + ): + x = self.checkpoint.x.to(x.device) + coords = OrderedDict( + (key, np.asarray(value).copy()) + for key, value in zip( + self.checkpoint.coord_keys, self.checkpoint.coord_values + ) + ) + x_norm = ( + None + if self.checkpoint.x_norm is None + else self.checkpoint.x_norm.to(x.device) + ) + sst_mask = ( + None + if self.checkpoint.sst_mask is None + else self.checkpoint.sst_mask.to(x.device) + ) + if x_norm is not None and sst_mask is None: + x_norm = None + return x, coords, x_norm, sst_mask, True + return x, coords, None, None, False + + def _save_checkpoint_state( + self, + x: torch.Tensor, + coords: CoordSystem, + x_norm: torch.Tensor | None, + sst_mask: torch.Tensor | None, + device: torch.device, + ) -> None: + if not self.checkpoint.checkpoint_enabled: + return + + rng_state, rng_device_type = self._rng_state(device) + policy = self.checkpoint.checkpoint_state_policy + if policy in ("state", "full"): + self.checkpoint.rng_state = rng_state + self.checkpoint.rng_device_type = rng_device_type + else: + self.checkpoint.rng_state = None + self.checkpoint.rng_device_type = None + + if policy == "full": + self.checkpoint.x = x.detach().clone().to(self.checkpoint.device) + self.checkpoint.x_norm = ( + None + if x_norm is None + else x_norm.detach().clone().to(self.checkpoint.device) + ) + self.checkpoint.sst_mask = ( + None + if sst_mask is None + else sst_mask.detach().clone().to(self.checkpoint.device) + ) + self.checkpoint.coord_keys = tuple(coords.keys()) + self.checkpoint.coord_values = tuple( + np.asarray(value).copy() for value in coords.values() + ) + else: + self.checkpoint.x = None + self.checkpoint.x_norm = None + self.checkpoint.sst_mask = None + self.checkpoint.coord_keys = () + self.checkpoint.coord_values = () + @torch.inference_mode() def _forward( self, @@ -971,7 +1085,27 @@ def __call__( x = x[:, :, :, : len(VARIABLES)] coords = coords.copy() coords["variable"] = self._output_coords["variable"].copy() - return self._forward(x, coords, static_condition), out_coords + + x, coords, x_norm, sst_mask, restored = self._restore_checkpoint_state( + x, coords + ) + if restored: + out_coords = self.output_coords(coords) + out, x_norm, sst_mask = self._forward( + x, + coords, + x_norm=x_norm, + sst_mask=sst_mask, + static_condition=static_condition, + return_state=True, + ) + next_x = torch.cat([x[:, :, 1:], out], dim=2) + next_coords = coords.copy() + next_coords["lead_time"] = np.array( + [coords["lead_time"][-1], out_coords["lead_time"][-1]] + ) + self._save_checkpoint_state(next_x, next_coords, x_norm, sst_mask, x.device) + return out, out_coords @batch_func() def _default_generator( @@ -998,14 +1132,16 @@ def _default_generator( x = x[:, :, :, : len(VARIABLES)] coords["variable"] = self._output_coords["variable"].copy() - out = x[:, :, 1:] - out_coords = coords.copy() - out_coords["lead_time"] = out_coords["lead_time"][1:] - out_coords["variable"] = self._output_coords["variable"].copy() - yield out, out_coords - - x_norm = None - sst_mask = None + x, coords, x_norm, sst_mask, restored = self._restore_checkpoint_state( + x, coords + ) + if not restored: + out = x[:, :, 1:] + out_coords = coords.copy() + out_coords["lead_time"] = out_coords["lead_time"][1:] + out_coords["variable"] = self._output_coords["variable"].copy() + self._save_checkpoint_state(x, coords, None, None, x.device) + yield out, out_coords while True: x, coords = self.front_hook(x, coords) @@ -1024,6 +1160,7 @@ def _default_generator( coords["lead_time"] = np.array( [coords["lead_time"][-1], out_coords["lead_time"][-1]] ) + self._save_checkpoint_state(x, coords, x_norm, sst_mask, x.device) yield out, out_coords.copy() diff --git a/earth2studio/perturbation/gaussian.py b/earth2studio/perturbation/gaussian.py index 5809db6d9..2d8e74cac 100644 --- a/earth2studio/perturbation/gaussian.py +++ b/earth2studio/perturbation/gaussian.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Any import numpy as np @@ -21,6 +22,7 @@ from typing_extensions import Self from earth2studio.utils import handshake_dim +from earth2studio.utils.checkpoint import bind_checkpoint_state from earth2studio.utils.imports import ( OptionalDependencyFailure, check_optional_dependencies, @@ -34,6 +36,12 @@ InverseRealSHT = None +@dataclass +class _GaussianCheckpointState: + generator_state: torch.Tensor | None = None + generator_device_type: str | None = None + + class Gaussian: """Standard Gaussian peturbation @@ -50,6 +58,9 @@ def __init__(self, noise_amplitude: float | torch.Tensor = 0.05): if isinstance(noise_amplitude, torch.Tensor) else torch.Tensor([noise_amplitude]) ) + self.generator: torch.Generator | None = None + self.checkpoint = bind_checkpoint_state(_GaussianCheckpointState()) + self._restore_generator_state() @torch.inference_mode() def __call__( @@ -71,8 +82,61 @@ def __call__( tuple[torch.Tensor, CoordSystem]: Output tensor and respective coordinate system dictionary """ + generator = self._get_generator(x.device) + pre_state = generator.get_state() noise_amplitude = self.noise_amplitude.to(x.device) - return x + noise_amplitude * torch.randn_like(x), coords + y = x + noise_amplitude * torch.randn( + x.shape, dtype=x.dtype, device=x.device, generator=generator + ) + self._save_generator_state(pre_state, generator.get_state(), generator) + return y, coords + + def _get_generator(self, device: torch.device) -> torch.Generator: + if self.generator is None: + self.generator = torch.Generator(device=device) + self.generator.seed() + return self.generator + + def _restore_generator_state(self) -> None: + if not self.checkpoint.checkpoint_state_loaded: + return + if self.checkpoint.generator_state is None: + return + if self.checkpoint.generator_device_type is None: + raise RuntimeError("Gaussian checkpoint generator device is missing.") + if self.generator is None: + self.generator = torch.Generator( + device=self.checkpoint.generator_device_type + ) + if self.generator.device.type != self.checkpoint.generator_device_type: + raise RuntimeError( + "Gaussian checkpoint generator state was saved for " + f"{self.checkpoint.generator_device_type!r}, but generator is on " + f"{self.generator.device.type!r}." + ) + self.generator.set_state(self.checkpoint.generator_state.cpu()) + + def _save_generator_state( + self, + pre_state: torch.Tensor, + post_state: torch.Tensor, + generator: torch.Generator, + ) -> None: + if not ( + self.checkpoint.checkpoint_enabled + and self.checkpoint.checkpoint_is_flush_due + ): + return + + if self.checkpoint.checkpoint_state_policy == "state": + self.checkpoint.generator_state = pre_state.cpu().clone() + self.checkpoint.generator_device_type = generator.device.type + elif self.checkpoint.checkpoint_state_policy == "full": + self.checkpoint.generator_state = post_state.cpu().clone() + self.checkpoint.generator_device_type = generator.device.type + else: + self.checkpoint.generator_state = None + self.checkpoint.generator_device_type = None @check_optional_dependencies() diff --git a/earth2studio/run.py b/earth2studio/run.py index 706acaa93..0930899ec 100644 --- a/earth2studio/run.py +++ b/earth2studio/run.py @@ -28,6 +28,12 @@ from earth2studio.models.dx import DiagnosticModel from earth2studio.models.px import PrognosticModel from earth2studio.perturbation import Perturbation +from earth2studio.utils.checkpoint import ( + NO_CHECKPOINT, + Checkpoint, + CheckpointSession, + NullCheckpointSession, +) from earth2studio.utils.coords import CoordSystem, map_coords, split_coords from earth2studio.utils.time import to_time_array @@ -45,6 +51,10 @@ def deterministic( output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, verbose: bool = True, + checkpoint: Checkpoint + | CheckpointSession + | NullCheckpointSession + | None = NO_CHECKPOINT, ) -> IOBackend: """Built in deterministic workflow. This workflow creates a determinstic inference pipeline to produce a forecast @@ -68,6 +78,11 @@ def deterministic( Device to run inference on, by default None verbose : bool, optional Print inference progress, by default True + checkpoint : Checkpoint, optional + Checkpoint manager or checkpoint session used to record and resume workflow + progress, by default no checkpoint. If a checkpoint has an active session, the + workflow uses that session; otherwise it selects the latest matching row + or starts a new labeled row. Returns ------- @@ -89,26 +104,6 @@ def deterministic( prognostic_ic = prognostic.input_coords() time = to_time_array(time) - if hasattr(prognostic, "interp_method"): - interp_to = prognostic_ic - interp_method = prognostic.interp_method - else: - interp_to = None - interp_method = "nearest" - - x, coords = fetch_data( - source=data, - time=time, - variable=prognostic_ic["variable"], - lead_time=prognostic_ic["lead_time"], - device=device, - interp_to=interp_to, - interp_method=interp_method, - ) - - logger.success(f"Fetched data from {data.__class__.__name__}") - # sphinx - fetch data end - # Set up IO backend total_coords = prognostic.output_coords(prognostic.input_coords()).copy() for key, value in prognostic.output_coords( @@ -129,24 +124,79 @@ def deterministic( for key, value in total_coords.items(): total_coords[key] = output_coords.get(key, value) var_names = total_coords.pop("variable") - io.add_array(total_coords, var_names) - - # Map lat and lon if needed - x, coords = map_coords(x, coords, prognostic.input_coords()) - # Create prognostic iterator - model = prognostic.create_iterator(x, coords) - - logger.info("Inference starting!") - with tqdm( - total=nsteps + 1, desc="Running inference", position=1, disable=(not verbose) - ) as pbar: - for step, (x, coords) in enumerate(model): - # Subselect domain/variables as indicated in output_coords - x, coords = map_coords(x, coords, output_coords) - io.write(*split_coords(x, coords)) - pbar.update(1) - if step == nsteps: - break + missing = list(var_names) + try: + missing = [name for name in var_names if name not in io] + except TypeError: + pass + if missing: + io.add_array(total_coords, missing) + + if checkpoint is None: + checkpoint = NO_CHECKPOINT + if isinstance(checkpoint, Checkpoint): + active = checkpoint.active + checkpoint = active if active is not None else checkpoint.select(time=time) + + with checkpoint as ckpt: + restart_step = None + if ckpt.exists and ckpt.write_count > 0: + restart_step = ckpt.write_count - 1 + if restart_step >= nsteps: + logger.success("\nInference complete") + return io + + if hasattr(prognostic, "interp_method"): + interp_to = prognostic_ic + interp_method = prognostic.interp_method + else: + interp_to = None + interp_method = "nearest" + + x, coords = fetch_data( + source=data, + time=time, + variable=prognostic_ic["variable"], + lead_time=prognostic_ic["lead_time"], + device=device, + interp_to=interp_to, + interp_method=interp_method, + ) + + logger.success(f"Fetched data from {data.__class__.__name__}") + # sphinx - fetch data end + + # Map lat and lon if needed + x, coords = map_coords(x, coords, prognostic.input_coords()) + # Create prognostic iterator + model = prognostic.create_iterator(x, coords) + + logger.info("Inference starting!") + initial_progress = 0 if restart_step is None else restart_step + 1 + with tqdm( + total=nsteps + 1, + initial=initial_progress, + desc="Running inference", + position=1, + disable=(not verbose), + ) as pbar: + for local_step, (x, coords) in enumerate(model): + step = ( + local_step + if restart_step is None + else restart_step + local_step + 1 + ) + + current_lead_time = coords["lead_time"][-1] + # Subselect domain/variables as indicated in output_coords + x, coords = map_coords(x, coords, output_coords) + io.write(*split_coords(x, coords)) + ckpt.write(lead_time=current_lead_time) + pbar.update(1) + if step == nsteps: + break + + ckpt.flush() logger.success("\nInference complete") return io @@ -163,6 +213,10 @@ def diagnostic( output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, verbose: bool = True, + checkpoint: Checkpoint + | CheckpointSession + | NullCheckpointSession + | None = NO_CHECKPOINT, ) -> IOBackend: """Built in diagnostic workflow. This workflow creates a determinstic inference pipeline that couples a prognostic @@ -188,6 +242,11 @@ def diagnostic( Device to run inference on, by default None verbose : bool, optional Print inference progress, by default True + checkpoint : Checkpoint, optional + Checkpoint manager or checkpoint session used to record and resume workflow + progress, by default no checkpoint. When resuming, the workflow fetches the + normal initial condition and checkpoint-aware models restore from their own + bound checkpoint state. Returns ------- @@ -196,7 +255,6 @@ def diagnostic( """ # sphinx - diagnostic end logger.info("Running diagnostic workflow!") - # Load model onto the device device = ( device if device is not None @@ -205,29 +263,11 @@ def diagnostic( logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) diagnostic = diagnostic.to(device) - # Fetch data from data source and load onto device + prognostic_ic = prognostic.input_coords() diagnostic_ic = diagnostic.input_coords() time = to_time_array(time) - if hasattr(prognostic, "interp_method"): - interp_to = prognostic_ic - interp_method = prognostic.interp_method - else: - interp_to = None - interp_method = "nearest" - - x, coords = fetch_data( - source=data, - time=time, - variable=prognostic_ic["variable"], - lead_time=prognostic_ic["lead_time"], - device=device, - interp_to=interp_to, - interp_method=interp_method, - ) - logger.success(f"Fetched data from {data.__class__.__name__}") - # Set up IO backend total_coords = prognostic.output_coords(prognostic.input_coords()) for key, value in prognostic.output_coords( prognostic.input_coords() @@ -249,29 +289,76 @@ def diagnostic( for key, value in total_coords.items(): total_coords[key] = output_coords.get(key, value) var_names = total_coords.pop("variable") - io.add_array(total_coords, var_names) - - # Map lat and lon if needed - x, coords = map_coords(x, coords, prognostic_ic) - - # Create prognostic iterator - model = prognostic.create_iterator(x, coords) - - logger.info("Inference starting!") - with tqdm( - total=nsteps + 1, desc="Running inference", position=1, disable=(not verbose) - ) as pbar: - for step, (x, coords) in enumerate(model): - - # Run diagnostic - x, coords = map_coords(x, coords, diagnostic_ic) - x, coords = diagnostic(x, coords) - # Subselect domain/variables as indicated in output_coords - x, coords = map_coords(x, coords, output_coords) - io.write(*split_coords(x, coords)) - pbar.update(1) - if step == nsteps: - break + missing = list(var_names) + try: + missing = [name for name in var_names if name not in io] + except TypeError: + pass + if missing: + io.add_array(total_coords, missing) + + if checkpoint is None: + checkpoint = NO_CHECKPOINT + if isinstance(checkpoint, Checkpoint): + active = checkpoint.active + checkpoint = active if active is not None else checkpoint.select(time=time) + + with checkpoint as ckpt: + restart_step = None + if ckpt.exists and ckpt.write_count > 0: + restart_step = ckpt.write_count - 1 + if restart_step >= nsteps: + logger.success("\nInference complete") + return io + + if hasattr(prognostic, "interp_method"): + interp_to = prognostic_ic + interp_method = prognostic.interp_method + else: + interp_to = None + interp_method = "nearest" + + x, coords = fetch_data( + source=data, + time=time, + variable=prognostic_ic["variable"], + lead_time=prognostic_ic["lead_time"], + device=device, + interp_to=interp_to, + interp_method=interp_method, + ) + logger.success(f"Fetched data from {data.__class__.__name__}") + + x, coords = map_coords(x, coords, prognostic_ic) + model = prognostic.create_iterator(x, coords) + + logger.info("Inference starting!") + initial_progress = 0 if restart_step is None else restart_step + 1 + with tqdm( + total=nsteps + 1, + initial=initial_progress, + desc="Running inference", + position=1, + disable=(not verbose), + ) as pbar: + for local_step, (x, coords) in enumerate(model): + step = ( + local_step + if restart_step is None + else restart_step + local_step + 1 + ) + + current_lead_time = coords["lead_time"][-1] + x, coords = map_coords(x, coords, diagnostic_ic) + x, coords = diagnostic(x, coords) + x, coords = map_coords(x, coords, output_coords) + io.write(*split_coords(x, coords)) + ckpt.write(lead_time=current_lead_time) + pbar.update(1) + if step == nsteps: + break + + ckpt.flush() logger.success("\nInference complete") return io @@ -290,6 +377,10 @@ def ensemble( output_coords: CoordSystem = OrderedDict({}), device: torch.device | None = None, verbose: bool = True, + checkpoint: Checkpoint + | CheckpointSession + | NullCheckpointSession + | None = NO_CHECKPOINT, ) -> IOBackend: """Built in ensemble workflow. @@ -318,6 +409,10 @@ def ensemble( Device to run inference on, by default None verbose : bool, optional Print inference progress, by default True + checkpoint : Checkpoint, optional + Checkpoint manager or checkpoint session used to record and resume workflow + progress, by default no checkpoint. When a checkpoint manager is provided, rows are tracked + independently for each ensemble batch. Returns ------- @@ -327,7 +422,6 @@ def ensemble( # sphinx - ensemble end logger.info("Running ensemble inference!") - # Load model onto the device device = ( device if device is not None @@ -336,7 +430,6 @@ def ensemble( logger.info(f"Inference device: {device}") prognostic = prognostic.to(device) - # Fetch data from data source and load onto device prognostic_ic = prognostic.input_coords() time = to_time_array(time) if hasattr(prognostic, "interp_method"): @@ -357,7 +450,6 @@ def ensemble( ) logger.success(f"Fetched data from {data.__class__.__name__}") - # Set up IO backend with information from output_coords (if applicable). total_coords = prognostic.output_coords(prognostic.input_coords()).copy() if "batch" in total_coords: del total_coords["batch"] @@ -375,9 +467,14 @@ def ensemble( for key, value in total_coords.items(): total_coords[key] = output_coords.get(key, value) variables_to_save = total_coords.pop("variable") - io.add_array(total_coords, variables_to_save) + missing = list(variables_to_save) + try: + missing = [name for name in variables_to_save if name not in io] + except TypeError: + pass + if missing: + io.add_array(total_coords, missing) - # Compute batch sizes if batch_size is None: batch_size = nensemble batch_size = min(nensemble, batch_size) @@ -387,7 +484,6 @@ def ensemble( f"Starting {nensemble} Member Ensemble Inference with \ {number_of_batches} number of batches." ) - batch_id = 0 for batch_id in tqdm( range(0, nensemble, batch_size), total=number_of_batches, @@ -395,46 +491,56 @@ def ensemble( position=2, disable=(not verbose), ): - - # Get fresh batch data - x = x0.to(device) - - # Expand x, coords for ensemble mini_batch_size = min(batch_size, nensemble - batch_id) - coords = ( - OrderedDict({"ensemble": np.arange(batch_id, batch_id + mini_batch_size)}) - | coords0.copy() - ) - - # Unsqueeze x for batching ensemble - x = x.unsqueeze(0).repeat(mini_batch_size, *([1] * x.ndim)) - - # Map lat and lon if needed - x, coords = map_coords(x, coords, prognostic_ic) - - # Perturb ensemble - x, coords = perturbation(x, coords) - - # Create prognostic iterator - model = prognostic.create_iterator(x, coords) - - with tqdm( - total=nsteps + 1, - desc=f"Running batch {batch_id} inference", - position=1, - leave=False, - disable=(not verbose), - ) as pbar: - for step, (x, coords) in enumerate(model): - # Subselect domain/variables as indicated in output_coords - x, coords = map_coords(x, coords, output_coords) - - io.write(*split_coords(x, coords)) - pbar.update(1) - if step == nsteps: - break - - batch_id += 1 + ensemble_coords = np.arange(batch_id, batch_id + mini_batch_size) + batch_checkpoint = NO_CHECKPOINT if checkpoint is None else checkpoint + if isinstance(batch_checkpoint, Checkpoint): + active = batch_checkpoint.active + batch_checkpoint = ( + active + if active is not None + else batch_checkpoint.select(time=time, ensemble_batch=batch_id) + ) + + with batch_checkpoint as ckpt: + restart_step = None + if ckpt.exists and ckpt.write_count > 0: + restart_step = ckpt.write_count - 1 + if restart_step >= nsteps: + continue + + x = x0.to(device) + coords = OrderedDict({"ensemble": ensemble_coords}) | coords0.copy() + x = x.unsqueeze(0).repeat(mini_batch_size, *([1] * x.ndim)) + x, coords = map_coords(x, coords, prognostic_ic) + x, coords = perturbation(x, coords) + + model = prognostic.create_iterator(x, coords) + initial_progress = 0 if restart_step is None else restart_step + 1 + with tqdm( + total=nsteps + 1, + initial=initial_progress, + desc=f"Running batch {batch_id} inference", + position=1, + leave=False, + disable=(not verbose), + ) as pbar: + for local_step, (x, coords) in enumerate(model): + step = ( + local_step + if restart_step is None + else restart_step + local_step + 1 + ) + + current_lead_time = coords["lead_time"][-1] + x, coords = map_coords(x, coords, output_coords) + io.write(*split_coords(x, coords)) + ckpt.write(lead_time=current_lead_time) + pbar.update(1) + if step == nsteps: + break + + ckpt.flush() logger.success("\nInference complete") return io diff --git a/earth2studio/utils/checkpoint.py b/earth2studio/utils/checkpoint.py new file mode 100644 index 000000000..8dd5344cf --- /dev/null +++ b/earth2studio/utils/checkpoint.py @@ -0,0 +1,1325 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import json +import os +import shutil +import uuid +import warnings +from collections.abc import Mapping +from contextvars import ContextVar, Token +from dataclasses import MISSING, dataclass, fields, is_dataclass +from datetime import date, datetime, timedelta, timezone +from hashlib import sha256 +from pathlib import Path +from typing import Any, Generic, Literal, TypeVar + +import numpy as np +import torch + +T = TypeVar("T") +CheckpointStatePolicy = Literal["minimal", "state", "full"] + +_CHECKPOINT_VERSION = 1 +_ACTIVE_SESSION: ContextVar[CheckpointSession | None] = ContextVar( + "earth2studio_checkpoint_session", default=None +) +_CURRENT_CHECKPOINT: ContextVar[Checkpoint | None] = ContextVar( + "earth2studio_checkpoint", default=None +) +_PENDING_STATES: ContextVar[tuple[PendingCheckpointState, ...]] = ContextVar( + "earth2studio_checkpoint_pending_states", default=() +) + + +class CheckpointError(RuntimeError): + """Base error for checkpoint failures.""" + + +class CheckpointStateCollision(CheckpointError): + """Raised when a state dataclass type is bound more than once.""" + + +class CheckpointSerializationError(CheckpointError): + """Raised when checkpoint state cannot be serialized without pickle.""" + + +class CheckpointStateSchemaError(CheckpointError): + """Raised when a saved dataclass payload does not match the current schema.""" + + +@dataclass(frozen=True) +class PendingCheckpointState: + """Dataclass state bound before a checkpoint session is active.""" + + checkpoint: Checkpoint + state_id: str + state: Any + reusable: bool = False + + +@dataclass(frozen=True) +class CheckpointEntry: + """A committed checkpoint catalog row.""" + + commit_id: str + labels: dict[str, Any] + lead_time: Any | None + write_count: int + saved_at: str + rank: int + world_size: int + state_ids: tuple[str, ...] + artifacts: tuple[str, ...] + + +class CheckpointState(Generic[T]): + """Bound checkpoint state proxy returned by :func:`bind_checkpoint_state`. + + The proxy forwards normal attribute access to the wrapped dataclass while + exposing checkpoint metadata through ``checkpoint_*`` properties. + """ + + __slots__ = ("_state", "_checkpoint", "_session", "_state_loaded") + + def __init__( + self, + state: T, + checkpoint: Checkpoint | None = None, + session: CheckpointSession | None = None, + state_loaded: bool = False, + ) -> None: + object.__setattr__(self, "_state", state) + object.__setattr__(self, "_checkpoint", checkpoint) + object.__setattr__(self, "_session", session) + object.__setattr__(self, "_state_loaded", state_loaded) + + @property + def checkpoint_dataclass(self) -> T: + """Wrapped dataclass instance serialized by the checkpoint.""" + return self._state + + @property + def checkpoint_enabled(self) -> bool: + """Whether this state is associated with a checkpoint.""" + return self._checkpoint is not None + + @property + def checkpoint_state_policy(self) -> CheckpointStatePolicy: + """Checkpoint state policy requested by the user.""" + if self._checkpoint is None: + return "minimal" + return self._checkpoint.state_policy + + @property + def checkpoint_device(self) -> torch.device: + """Device used for live checkpoint state tensors.""" + if self._checkpoint is None: + return torch.device("cpu") + return self._checkpoint.device + + @property + def device(self) -> torch.device: + """Alias for the checkpoint tensor state device.""" + return self.checkpoint_device + + @property + def checkpoint_flush_interval(self) -> int | None: + """Flush interval configured on the associated checkpoint.""" + if self._checkpoint is None: + return None + return self._checkpoint.flush_interval + + @property + def checkpoint_write_count(self) -> int: + """Number of write boundaries recorded in the active session.""" + if self._session is None: + return 0 + return self._session.write_count + + @property + def checkpoint_is_flush_due(self) -> bool: + """Whether the next checkpoint write is expected to flush to disk.""" + interval = self.checkpoint_flush_interval + if self._session is None or interval is None: + return False + return (self._session.write_count + 1) % interval == 0 + + @property + def checkpoint_selected(self) -> bool: + """Whether this state is bound to an existing checkpoint row.""" + return self._session is not None and self._session.exists + + @property + def checkpoint_state_loaded(self) -> bool: + """Whether this dataclass was hydrated from the selected checkpoint row.""" + return self._state_loaded + + @property + def checkpoint_lead_time(self) -> Any | None: + """Selected checkpoint lead time, if one exists.""" + if self._session is None: + return None + return self._session.lead_time + + @property + def checkpoint_labels(self) -> Mapping[str, Any]: + """Labels for the active checkpoint session or pending checkpoint.""" + if self._session is not None: + return self._session.labels + return {} + + def _bind_checkpoint( + self, + checkpoint: Checkpoint | None, + session: CheckpointSession | None = None, + state_loaded: bool = False, + ) -> None: + object.__setattr__(self, "_checkpoint", checkpoint) + object.__setattr__(self, "_session", session) + object.__setattr__(self, "_state_loaded", state_loaded) + + def __getattr__(self, name: str) -> Any: + return getattr(self._state, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in self.__slots__: + object.__setattr__(self, name, value) + return + if name == "device" or name.startswith("checkpoint_"): + raise AttributeError(f"{name!r} is checkpoint metadata and is read-only.") + setattr(self._state, name, value) + + def __repr__(self) -> str: + return repr(self._state) + + +class NullCheckpointSession: + """No-op checkpoint session used when checkpointing is disabled.""" + + exists = False + lead_time = None + device = torch.device("cpu") + checkpoint_device = torch.device("cpu") + + @property + def labels(self) -> Mapping[str, Any]: + """Labels for the no-op session.""" + return {} + + @property + def write_count(self) -> int: + """Number of checkpoint writes accepted by the no-op session.""" + return 0 + + @property + def is_active(self) -> bool: + """Whether this checkpoint session is active in the current context.""" + return False + + def write( + self, + lead_time: Any | None = None, + artifacts: Mapping[str, Any] | None = None, + ) -> None: + """Accept a checkpoint boundary without committing anything.""" + return None + + def flush( + self, + lead_time: Any | None = None, + artifacts: Mapping[str, Any] | None = None, + ) -> None: + """Accept a flush request without committing anything.""" + return None + + def __enter__(self) -> NullCheckpointSession: + return self + + def __exit__(self, *args: Any) -> None: + return None + + def __bool__(self) -> bool: + return False + + +NO_CHECKPOINT = NullCheckpointSession() + + +def default_checkpoint_path(name: str) -> Path: + """Return the default path for a named checkpoint store.""" + base = Path( + os.environ.get("EARTH2STUDIO_CACHE", Path.home() / ".cache" / "earth2studio") + ) + return base / "checkpoints" / name + + +def bind_checkpoint_state(state: T) -> CheckpointState[T]: + """Bind a dataclass instance to checkpoint state metadata. + + The returned proxy forwards normal dataclass field access and exposes + checkpoint metadata through ``checkpoint_*`` properties. When no checkpoint + session is active, state is buffered for the most recently instantiated + :class:`Checkpoint` in this context. + """ + if isinstance(state, CheckpointState): + bound_state = state + else: + if not is_dataclass(state) or isinstance(state, type): + raise TypeError("bind_checkpoint_state requires a dataclass instance.") + bound_state = CheckpointState(state) + + session = _ACTIVE_SESSION.get() + if session is not None: + return session.bind(bound_state) + + checkpoint = _CURRENT_CHECKPOINT.get() + if checkpoint is not None: + bound_state._bind_checkpoint(checkpoint) + return _buffer_pending_state(checkpoint, bound_state) + return bound_state + + +class Checkpoint: + """Catalog of restart checkpoints for a named inference run. + + Checkpoints store small restart metadata, optional artifacts, and dataclass state + bound by components through :func:`bind_checkpoint_state`. + """ + + def __init__( + self, + name: str, + path: str | Path | None = None, + mode: Literal["overwrite", "append"] = "overwrite", + flush_interval: int | None = 1, + history_size: int | None = None, + state_policy: CheckpointStatePolicy = "full", + rank: int | None = None, + world_size: int | None = None, + device: str | torch.device = torch.device("cpu"), + ) -> None: + if mode not in ("overwrite", "append"): + raise ValueError("mode must be 'overwrite' or 'append'.") + if flush_interval is not None and flush_interval < 1: + raise ValueError("flush_interval must be a positive integer or None.") + if history_size is not None and history_size < 1: + raise ValueError("history_size must be a positive integer or None.") + state_policy = _normalize_state_policy(state_policy) + + detected_rank, detected_world_size = _detect_distributed_rank() + self.name = name + self.path = Path(path) if path is not None else default_checkpoint_path(name) + self.mode = mode + self.flush_interval = flush_interval + self.history_size = 1 if mode == "overwrite" else history_size + self.state_policy = state_policy + self.device = torch.device(device) + self.rank = detected_rank if rank is None else rank + self.world_size = detected_world_size if world_size is None else world_size + self._catalog: list[CheckpointEntry] | None = None + self._context_sessions: list[CheckpointSession] = [] + _CURRENT_CHECKPOINT.set(self) + + @property + def rank_path(self) -> Path: + """Directory for this process checkpoint writes.""" + if self.world_size == 1: + return self.path + return self.path / f"rank_{self.rank:06d}" + + @property + def catalog(self) -> tuple[CheckpointEntry, ...]: + """Committed checkpoint entries for the current rank.""" + self.refresh() + return tuple(self._catalog or []) + + @property + def active(self) -> CheckpointSession | None: + """Active checkpoint selected from this catalog, if one is in scope.""" + selected = _ACTIVE_SESSION.get() + if selected is not None and selected.catalog is self: + return selected + return None + + def refresh(self) -> None: + """Refresh the checkpoint catalog from disk.""" + self._catalog = _read_catalog(self.rank_path) + + def select(self, row: int | None = None, **labels: Any) -> CheckpointSession: + """Select a checkpoint row or label set. + + A positional integer selects a catalog row, with negative indexing supported. + Keyword labels select the latest matching row and also define labels for + future writes. A keyword value of ``-1`` selects the latest saved value for + that label after all other labels are applied. + """ + self.refresh() + entries = list(self._catalog or []) + encoded_labels = _encode_labels(labels, entries) + + if encoded_labels: + entries = [ + entry + for entry in entries + if all( + entry.labels.get(key) == value + for key, value in encoded_labels.items() + ) + ] + + selected_entry: CheckpointEntry | None = None + if row is not None: + if not entries: + raise IndexError("No checkpoint entries match this selection.") + selected_entry = entries[row] + if not encoded_labels: + encoded_labels = selected_entry.labels.copy() + elif entries: + selected_entry = entries[-1] + + return CheckpointSession(self, encoded_labels, selected_entry) + + def __enter__(self) -> CheckpointSession: + active = self.active + session = active if active is not None else self.select() + self._context_sessions.append(session) + return session.__enter__() + + def __exit__(self, *args: Any) -> None: + if self._context_sessions: + self._context_sessions.pop().__exit__(*args) + + def __repr__(self) -> str: + entries = self.catalog + lines = [ + f'Checkpoint("{self.name}")', + f"path: {self.path}", + f"mode: {self.mode}", + f"state_policy: {self.state_policy}", + f"rank: {self.rank}/{self.world_size}", + ] + if not entries: + return "\n".join(lines + ["catalog: empty"]) + + label_names = sorted({key for entry in entries for key in entry.labels}) + columns = ["id", *label_names, "lead_time", "write_count", "saved_at"] + rows = [ + [ + str(index), + *[_display_value(entry.labels.get(name)) for name in label_names], + _display_value(entry.lead_time), + str(entry.write_count), + entry.saved_at, + ] + for index, entry in enumerate(entries) + ] + widths = [ + max(len(columns[index]), *(len(row[index]) for row in rows)) + for index in range(len(columns)) + ] + header = " ".join( + column.ljust(width) for column, width in zip(columns, widths) + ) + body = [ + " ".join(value.ljust(width) for value, width in zip(row, widths)) + for row in rows + ] + return "\n".join(lines + ["", header, *body]) + + def _commit( + self, + session: CheckpointSession, + lead_time: Any | None, + artifacts: Mapping[str, Any] | None, + ) -> CheckpointEntry: + self.rank_path.mkdir(parents=True, exist_ok=True) + commits_path = self.rank_path / "commits" + commits_path.mkdir(parents=True, exist_ok=True) + + commit_id = f"commit_{session.write_count:08d}_{uuid.uuid4().hex[:12]}" + tmp_path = self.rank_path / f".tmp_{commit_id}" + commit_path = commits_path / commit_id + if tmp_path.exists(): + shutil.rmtree(tmp_path) + tmp_path.mkdir(parents=True) + + manifest = { + "version": _CHECKPOINT_VERSION, + "checkpoint": self.name, + "commit_id": commit_id, + "mode": self.mode, + "rank": self.rank, + "world_size": self.world_size, + "labels": session.labels, + "lead_time": CheckpointCodec.encode_json_value( + CheckpointCodec.normalize_lead_time(lead_time) + ), + "write_count": session.write_count, + "saved_at": datetime.now(timezone.utc).isoformat(), + "states": {}, + "artifacts": {}, + } + + states_path = tmp_path / "states" + for state_id, state in session.bound_states.items(): + dataclass_state = _unwrap_checkpoint_state(state) + state_path = states_path / _safe_dir_name(state_id) + _write_dataclass_state(dataclass_state, state_id, state_path) + manifest["states"][state_id] = { + "path": str(state_path.relative_to(tmp_path)), + "schema_hash": _schema_hash(dataclass_state), + } + + if artifacts: + artifacts_path = tmp_path / "artifacts" + artifacts_path.mkdir(parents=True, exist_ok=True) + for name, value in artifacts.items(): + if not isinstance(name, str): + raise CheckpointSerializationError( + "checkpoint artifact names must be strings." + ) + manifest["artifacts"][name] = CheckpointCodec.dump_value( + value, artifacts_path, (_safe_dir_name(name),) + ) + + _write_json(tmp_path / "manifest.json", manifest) + tmp_path.rename(commit_path) + + entry = _entry_from_manifest(manifest) + self._update_catalog(entry) + session._entry = entry + session._loaded_states = _load_state_index(commit_path, manifest) + return entry + + def _update_catalog(self, entry: CheckpointEntry) -> None: + entries = ( + list(self._catalog) + if self._catalog is not None + else _read_catalog(self.rank_path) + ) + entries = [item for item in entries if item.commit_id != entry.commit_id] + if self.mode == "overwrite": + entries = [item for item in entries if item.labels != entry.labels] + entries.append(entry) + entries = _apply_history_size(entries, entry.labels, self.history_size) + _write_catalog(self.rank_path, entries) + self._catalog = entries + _prune_commits(self.rank_path, {item.commit_id for item in entries}) + + +class CheckpointSession: + """Active checkpoint row or future label set.""" + + def __init__( + self, + catalog: Checkpoint, + labels: dict[str, Any], + entry: CheckpointEntry | None, + ) -> None: + self.catalog = catalog + self.labels = labels + self._entry = entry + self.bound_states: dict[str, Any] = {} + self._reusable_state_ids: set[str] = set() + self.write_count = entry.write_count if entry is not None else 0 + self._pending_lead_time: Any | None = None + self._pending_artifacts: Mapping[str, Any] | None = None + self._pending_dirty = False + self._tokens: list[Token[CheckpointSession | None]] = [] + self._pending_adopted = False + self._loaded_states = self._load_selected_states() + + @property + def exists(self) -> bool: + """Whether this session resolves to an existing checkpoint row.""" + return self._entry is not None + + @property + def is_active(self) -> bool: + """Whether this checkpoint session is active in the current context.""" + return _ACTIVE_SESSION.get() is self + + @property + def commit_id(self) -> str | None: + """Selected commit identifier, if one exists.""" + return None if self._entry is None else self._entry.commit_id + + @property + def lead_time(self) -> Any | None: + """Lead time recorded for this session, if present.""" + return None if self._entry is None else self._entry.lead_time + + @property + def device(self) -> torch.device: + """Device used for live checkpoint state tensors.""" + return self.catalog.device + + @property + def artifacts(self) -> dict[str, Any]: + """Load artifacts recorded for the selected checkpoint row.""" + if self._entry is None: + return {} + manifest = self._read_manifest() + artifacts_path = self._commit_path / "artifacts" + return { + name: CheckpointCodec.load_value(payload, artifacts_path) + for name, payload in manifest.get("artifacts", {}).items() + } + + def artifact(self, name: str) -> Any: + """Load one artifact by name.""" + artifacts = self.artifacts + if name not in artifacts: + raise KeyError(f"Artifact {name!r} not found in checkpoint session.") + return artifacts[name] + + def bind(self, state: T | CheckpointState[T]) -> CheckpointState[T]: + """Bind and hydrate a dataclass state object.""" + bound_state = _as_checkpoint_state(state) + dataclass_state = bound_state.checkpoint_dataclass + state_id = _state_id(dataclass_state) + existing = self.bound_states.get(state_id) + if existing is not None: + if _unwrap_checkpoint_state(existing) is dataclass_state: + return existing + if state_id not in self._reusable_state_ids: + raise CheckpointStateCollision( + f"{state_id} was registered more than once in this checkpoint session." + ) + self._reusable_state_ids.remove(state_id) + + loaded_state = self._loaded_states.get(state_id) + if loaded_state is not None: + _populate_dataclass_state(dataclass_state, state_id, loaded_state) + + bound_state._bind_checkpoint(self.catalog, self, loaded_state is not None) + self.bound_states[state_id] = bound_state + return bound_state + + def _adopt_pending_states(self) -> None: + pending = _PENDING_STATES.get() + if not pending: + return + + adopted = [item for item in pending if item.checkpoint is self.catalog] + if not adopted: + return + + if self.exists and any(not item.reusable for item in adopted): + warnings.warn( + "Checkpoint state was bound before an existing checkpoint session " + "was active. Saved dataclass state is being hydrated late; " + "constructor side effects that depended on that state will not be " + "replayed. Construct restartable components inside " + "`with checkpoint.select(...):` when hydration must happen during " + "initialization.", + UserWarning, + stacklevel=3, + ) + + for item in adopted: + self.bind(item.state) + if item.reusable: + self._reusable_state_ids.add(item.state_id) + + _PENDING_STATES.set( + tuple(item for item in pending if item.checkpoint is not self.catalog) + ) + + def write( + self, + lead_time: Any | None = None, + artifacts: Mapping[str, Any] | None = None, + ) -> CheckpointEntry | None: + """Record a safe checkpoint boundary and flush if due.""" + self.write_count += 1 + self._pending_lead_time = CheckpointCodec.normalize_lead_time(lead_time) + self._pending_artifacts = artifacts + self._pending_dirty = True + interval = self.catalog.flush_interval + if interval is not None and self.write_count % interval == 0: + return self.flush() + return None + + def flush( + self, + lead_time: Any | None = None, + artifacts: Mapping[str, Any] | None = None, + ) -> CheckpointEntry | None: + """Force an atomic checkpoint commit for the current session.""" + has_updates = lead_time is not None or artifacts is not None + commit_lead_time = ( + self._pending_lead_time + if lead_time is None + else CheckpointCodec.normalize_lead_time(lead_time) + ) + commit_artifacts = self._pending_artifacts if artifacts is None else artifacts + if not self._pending_dirty and not has_updates: + return None + + entry = self.catalog._commit(self, commit_lead_time, commit_artifacts) + self._pending_lead_time = commit_lead_time + self._pending_artifacts = commit_artifacts + self._pending_dirty = False + return entry + + def __enter__(self) -> CheckpointSession: + if not self._pending_adopted: + self._adopt_pending_states() + self._pending_adopted = True + self._tokens.append(_ACTIVE_SESSION.set(self)) + return self + + def __exit__(self, *args: Any) -> None: + if self._tokens: + _ACTIVE_SESSION.reset(self._tokens.pop()) + for state in self.bound_states.values(): + state._bind_checkpoint(self.catalog) + _buffer_pending_state(self.catalog, state, reusable=True) + + def __bool__(self) -> bool: + return self.exists + + @property + def _commit_path(self) -> Path: + if self._entry is None: + raise CheckpointError("This checkpoint session does not exist.") + return self.catalog.rank_path / "commits" / self._entry.commit_id + + def _read_manifest(self) -> dict[str, Any]: + return _read_json(self._commit_path / "manifest.json") + + def _load_selected_states(self) -> dict[str, LoadedState]: + if self._entry is None: + return {} + manifest = self._read_manifest() + return _load_state_index(self._commit_path, manifest) + + +@dataclass(frozen=True) +class LoadedState: + """Serialized dataclass payload waiting to be bound.""" + + path: Path + manifest: dict[str, Any] + + +def _buffer_pending_state( + checkpoint: Checkpoint, state: CheckpointState[T], reusable: bool = False +) -> CheckpointState[T]: + state_id = _state_id(state.checkpoint_dataclass) + pending = _PENDING_STATES.get() + for index, item in enumerate(pending): + if item.checkpoint is not checkpoint or item.state_id != state_id: + continue + if _unwrap_checkpoint_state(item.state) is state.checkpoint_dataclass: + return item.state + if item.reusable: + updated = list(pending) + updated[index] = PendingCheckpointState( + checkpoint, state_id, state, reusable=reusable + ) + _PENDING_STATES.set(tuple(updated)) + return state + _PENDING_STATES.set( + (*pending, PendingCheckpointState(checkpoint, state_id, state, reusable)) + ) + return state + + +def _as_checkpoint_state(state: T | CheckpointState[T]) -> CheckpointState[T]: + if isinstance(state, CheckpointState): + return state + if not is_dataclass(state) or isinstance(state, type): + raise TypeError("checkpoint state requires a dataclass instance.") + return CheckpointState(state) + + +def _unwrap_checkpoint_state(state: Any) -> Any: + if isinstance(state, CheckpointState): + return state.checkpoint_dataclass + return state + + +def _detect_distributed_rank() -> tuple[int, int]: + try: + from physicsnemo.distributed import DistributedManager + + manager = DistributedManager() + rank = getattr(manager, "rank", None) + world_size = getattr(manager, "world_size", None) + if rank is not None and world_size is not None: + return int(rank), int(world_size) + except ImportError: + pass + except (AttributeError, RuntimeError, ValueError, TypeError): + pass + + for rank_name in ("RANK", "LOCAL_RANK", "SLURM_PROCID"): + rank = os.environ.get(rank_name) + if rank is not None: + world_size = os.environ.get("WORLD_SIZE", "1") + return int(rank), int(world_size) + return 0, 1 + + +def _encode_labels( + labels: Mapping[str, Any], entries: list[CheckpointEntry] +) -> dict[str, Any]: + encoded = { + key: CheckpointCodec.encode_json_value(value) for key, value in labels.items() + } + latest_keys = [key for key, value in labels.items() if _is_latest_selector(value)] + if not latest_keys: + return encoded + + filtered = entries + for key, value in encoded.items(): + if key not in latest_keys: + filtered = [entry for entry in filtered if entry.labels.get(key) == value] + if not filtered: + return encoded + latest = filtered[-1] + for key in latest_keys: + if key in latest.labels: + encoded[key] = latest.labels[key] + return encoded + + +def _read_catalog(rank_path: Path) -> list[CheckpointEntry]: + catalog_path = rank_path / "catalog.json" + if catalog_path.exists(): + try: + payload = _read_json(catalog_path) + return [_entry_from_catalog(item) for item in payload.get("entries", [])] + except (OSError, json.JSONDecodeError, KeyError, TypeError): + pass + return _scan_catalog(rank_path) + + +def _scan_catalog(rank_path: Path) -> list[CheckpointEntry]: + commits_path = rank_path / "commits" + if not commits_path.exists(): + return [] + entries: list[CheckpointEntry] = [] + for manifest_path in sorted(commits_path.glob("*/manifest.json")): + try: + entries.append(_entry_from_manifest(_read_json(manifest_path))) + except (OSError, json.JSONDecodeError, KeyError, TypeError): + continue + return sorted(entries, key=lambda entry: entry.saved_at) + + +def _write_catalog(rank_path: Path, entries: list[CheckpointEntry]) -> None: + payload = { + "version": _CHECKPOINT_VERSION, + "entries": [_entry_to_catalog(entry) for entry in entries], + } + _write_json(rank_path / "catalog.json", payload) + + +def _apply_history_size( + entries: list[CheckpointEntry], labels: dict[str, Any], history_size: int | None +) -> list[CheckpointEntry]: + if history_size is None: + return entries + matching = [entry for entry in entries if entry.labels == labels] + remove = {entry.commit_id for entry in matching[:-history_size]} + return [entry for entry in entries if entry.commit_id not in remove] + + +def _prune_commits(rank_path: Path, keep: set[str]) -> None: + commits_path = rank_path / "commits" + if not commits_path.exists(): + return + for commit_path in commits_path.iterdir(): + if commit_path.is_dir() and commit_path.name not in keep: + shutil.rmtree(commit_path) + for tmp_path in rank_path.glob(".tmp_*"): + if tmp_path.is_dir(): + shutil.rmtree(tmp_path) + + +def _entry_from_manifest(manifest: Mapping[str, Any]) -> CheckpointEntry: + return CheckpointEntry( + commit_id=str(manifest["commit_id"]), + labels=dict(manifest.get("labels", {})), + lead_time=CheckpointCodec.decode_json_value(manifest.get("lead_time")), + write_count=int(manifest.get("write_count", 0)), + saved_at=str(manifest["saved_at"]), + rank=int(manifest.get("rank", 0)), + world_size=int(manifest.get("world_size", 1)), + state_ids=tuple(manifest.get("states", {}).keys()), + artifacts=tuple(manifest.get("artifacts", {}).keys()), + ) + + +def _entry_to_catalog(entry: CheckpointEntry) -> dict[str, Any]: + return { + "commit_id": entry.commit_id, + "labels": entry.labels, + "lead_time": CheckpointCodec.encode_json_value(entry.lead_time), + "write_count": entry.write_count, + "saved_at": entry.saved_at, + "rank": entry.rank, + "world_size": entry.world_size, + "state_ids": list(entry.state_ids), + "artifacts": list(entry.artifacts), + } + + +def _entry_from_catalog(payload: Mapping[str, Any]) -> CheckpointEntry: + return CheckpointEntry( + commit_id=str(payload["commit_id"]), + labels=dict(payload.get("labels", {})), + lead_time=CheckpointCodec.decode_json_value(payload.get("lead_time")), + write_count=int(payload.get("write_count", 0)), + saved_at=str(payload["saved_at"]), + rank=int(payload.get("rank", 0)), + world_size=int(payload.get("world_size", 1)), + state_ids=tuple(payload.get("state_ids", ())), + artifacts=tuple(payload.get("artifacts", ())), + ) + + +def _write_dataclass_state(state: Any, state_id: str, state_path: Path) -> None: + state = _unwrap_checkpoint_state(state) + state_path.mkdir(parents=True, exist_ok=True) + manifest = { + "state_id": state_id, + "schema_hash": _schema_hash(state), + "fields": { + field.name: CheckpointCodec.dump_value( + getattr(state, field.name), state_path, (field.name,) + ) + for field in fields(state) + }, + } + _write_json(state_path / "metadata.json", manifest) + + +def _load_state_index( + commit_path: Path, manifest: Mapping[str, Any] +) -> dict[str, LoadedState]: + loaded: dict[str, LoadedState] = {} + for state_id, payload in manifest.get("states", {}).items(): + state_path = commit_path / payload["path"] + loaded[state_id] = LoadedState( + path=state_path, + manifest=_read_json(state_path / "metadata.json"), + ) + return loaded + + +def _populate_dataclass_state( + state: Any, state_id: str, loaded_state: LoadedState +) -> None: + state = _unwrap_checkpoint_state(state) + if loaded_state.manifest.get("state_id") != state_id: + raise CheckpointStateSchemaError( + f"Saved checkpoint state {loaded_state.manifest.get('state_id')} does not match {state_id}." + ) + expected_hash = _schema_hash(state) + if loaded_state.manifest.get("schema_hash") != expected_hash: + raise CheckpointStateSchemaError( + f"Saved checkpoint state {state_id} does not match the current dataclass schema." + ) + + current_fields = {field.name: field for field in fields(state)} + saved_fields = loaded_state.manifest.get("fields", {}) + if set(saved_fields) != set(current_fields): + raise CheckpointStateSchemaError( + f"Saved checkpoint state {state_id} fields do not match the current dataclass fields." + ) + for name, payload in saved_fields.items(): + current_value = getattr(state, name) + setattr( + state, + name, + CheckpointCodec.load_value(payload, loaded_state.path, current_value), + ) + + +def _normalize_state_policy( + policy: str, +) -> CheckpointStatePolicy: + if policy not in ("minimal", "state", "full"): + raise ValueError("state_policy must be 'minimal', 'state', or 'full'.") + return policy + + +def _state_id(state: Any) -> str: + state = _unwrap_checkpoint_state(state) + if not is_dataclass(state) or isinstance(state, type): + raise TypeError("Checkpoint state must be a dataclass instance.") + cls = type(state) + return f"{cls.__module__}.{cls.__qualname__}" + + +def _schema_hash(state: Any) -> str: + state = _unwrap_checkpoint_state(state) + schema = "|".join( + f"{field.name}:{field.type!r}:{_field_default_id(field)}" + for field in fields(state) + ) + return sha256(schema.encode("utf-8")).hexdigest() + + +def _field_default_id(field: Any) -> str: + if field.default is not MISSING: + default_type = type(field.default) + return f"default:{default_type.__module__}.{default_type.__qualname__}" + if field.default_factory is not MISSING: + factory = field.default_factory + module = getattr(factory, "__module__", type(factory).__module__) + qualname = getattr(factory, "__qualname__", type(factory).__qualname__) + return f"factory:{module}.{qualname}" + return "required" + + +class CheckpointCodec: + """Codec for pickle-free checkpoint metadata and state payloads.""" + + SCALAR_KINDS = frozenset(("bool", "int", "float", "str")) + JSON_SCALAR_TYPES = (bool, int, float, str) + + @classmethod + def dump_value( + cls, value: Any, base_path: Path, rel_parts: tuple[str, ...] + ) -> dict[str, Any]: + if value is None: + return {"kind": "none"} + if isinstance(value, bool): + return {"kind": "bool", "value": value} + if isinstance(value, int) and not isinstance(value, bool): + return {"kind": "int", "value": value} + if isinstance(value, float): + return {"kind": "float", "value": value} + if isinstance(value, str): + return {"kind": "str", "value": value} + if isinstance(value, datetime): + return {"kind": "datetime", "value": value.isoformat()} + if isinstance(value, date): + return {"kind": "date", "value": value.isoformat()} + if isinstance(value, timedelta): + return { + "kind": "timedelta", + "days": value.days, + "seconds": value.seconds, + "microseconds": value.microseconds, + } + if isinstance(value, np.datetime64): + return cls.encode_np_datetime(value) + if isinstance(value, np.timedelta64): + return cls.encode_np_timedelta(value) + if isinstance(value, torch.device): + return {"kind": "torch_device", "value": str(value)} + if isinstance(value, torch.dtype): + return {"kind": "torch_dtype", "value": str(value)} + if isinstance(value, np.dtype): + return {"kind": "np_dtype", "value": str(value)} + if isinstance(value, np.generic): + return cls.dump_value(value.item(), base_path, rel_parts) + if isinstance(value, torch.Tensor): + array = value.detach().cpu().numpy() + return cls.dump_array(array, "tensor", base_path, rel_parts) + if isinstance(value, np.ndarray): + return cls.dump_array(value, "ndarray", base_path, rel_parts) + if isinstance(value, list): + return { + "kind": "list", + "items": [ + cls.dump_value(item, base_path, (*rel_parts, str(index))) + for index, item in enumerate(value) + ], + } + if isinstance(value, tuple): + return { + "kind": "tuple", + "items": [ + cls.dump_value(item, base_path, (*rel_parts, str(index))) + for index, item in enumerate(value) + ], + } + if isinstance(value, dict): + payload = {} + for key, item in value.items(): + if not isinstance(key, str): + raise CheckpointSerializationError( + "checkpoint dictionaries must use string keys." + ) + payload[key] = cls.dump_value( + item, base_path, (*rel_parts, _safe_dir_name(key)) + ) + return {"kind": "dict", "items": payload} + if is_dataclass(value) and not isinstance(value, type): + return { + "kind": "dataclass", + "state_id": _state_id(value), + "schema_hash": _schema_hash(value), + "fields": { + field.name: cls.dump_value( + getattr(value, field.name), + base_path, + (*rel_parts, field.name), + ) + for field in fields(value) + }, + } + raise CheckpointSerializationError( + f"Unsupported checkpoint value {type(value).__module__}.{type(value).__qualname__}." + ) + + @staticmethod + def dump_array( + array: np.ndarray, + kind: Literal["tensor", "ndarray"], + base_path: Path, + rel_parts: tuple[str, ...], + ) -> dict[str, Any]: + if array.dtype == object: + raise CheckpointSerializationError( + "object dtype arrays cannot be checkpointed." + ) + rel_path = Path(*rel_parts).with_suffix(".npy") + full_path = base_path / rel_path + full_path.parent.mkdir(parents=True, exist_ok=True) + np.save(full_path, array, allow_pickle=False) + return { + "kind": kind, + "path": str(rel_path), + "dtype": str(array.dtype), + "shape": list(array.shape), + } + + @classmethod + def load_value( + cls, payload: Mapping[str, Any], base_path: Path, current_value: Any = None + ) -> Any: + kind = payload["kind"] + if kind == "none": + return None + if kind in cls.SCALAR_KINDS: + return payload["value"] + if kind == "datetime": + return datetime.fromisoformat(payload["value"]) + if kind == "date": + return date.fromisoformat(payload["value"]) + if kind == "timedelta": + return timedelta( + days=payload["days"], + seconds=payload["seconds"], + microseconds=payload["microseconds"], + ) + if kind == "np_datetime64": + return np.datetime64(payload["value"], payload["unit"]) + if kind == "np_timedelta64": + return np.timedelta64(payload["value"], payload["unit"]) + if kind == "torch_device": + return torch.device(payload["value"]) + if kind == "torch_dtype": + return getattr(torch, payload["value"].split(".")[-1]) + if kind == "np_dtype": + return np.dtype(payload["value"]) + if kind in ("tensor", "ndarray"): + array = np.load(base_path / payload["path"], allow_pickle=False) + if ( + list(array.shape) != payload["shape"] + or str(array.dtype) != payload["dtype"] + ): + raise CheckpointSerializationError( + "checkpoint array metadata does not match stored data." + ) + if kind == "tensor": + return torch.from_numpy(array) + return array + if kind == "list": + return [cls.load_value(item, base_path) for item in payload["items"]] + if kind == "tuple": + return tuple(cls.load_value(item, base_path) for item in payload["items"]) + if kind == "dict": + return { + key: cls.load_value(item, base_path) + for key, item in payload["items"].items() + } + if kind == "dataclass": + if is_dataclass(current_value) and not isinstance(current_value, type): + expected_hash = _schema_hash(current_value) + if payload.get("schema_hash") != expected_hash: + raise CheckpointStateSchemaError( + f"Saved nested state {payload.get('state_id')} does not match the current dataclass schema." + ) + for field in fields(current_value): + setattr( + current_value, + field.name, + cls.load_value( + payload["fields"][field.name], + base_path, + getattr(current_value, field.name), + ), + ) + return current_value + return { + key: cls.load_value(item, base_path) + for key, item in payload["fields"].items() + } + raise CheckpointSerializationError( + f"Unsupported checkpoint payload kind {kind!r}." + ) + + @classmethod + def encode_json_value(cls, value: Any) -> Any: + if value is None or isinstance(value, cls.JSON_SCALAR_TYPES): + return value + if isinstance(value, datetime): + return {"kind": "datetime", "value": value.isoformat()} + if isinstance(value, date): + return {"kind": "date", "value": value.isoformat()} + if isinstance(value, timedelta): + return { + "kind": "timedelta", + "days": value.days, + "seconds": value.seconds, + "microseconds": value.microseconds, + } + if isinstance(value, np.datetime64): + return cls.encode_np_datetime(value) + if isinstance(value, np.timedelta64): + return cls.encode_np_timedelta(value) + if isinstance(value, torch.device): + return {"kind": "torch_device", "value": str(value)} + if isinstance(value, torch.dtype): + return {"kind": "torch_dtype", "value": str(value)} + if isinstance(value, np.dtype): + return {"kind": "np_dtype", "value": str(value)} + if isinstance(value, np.generic): + return cls.encode_json_value(value.item()) + if isinstance(value, np.ndarray): + if value.dtype == object: + raise CheckpointSerializationError( + "object dtype arrays cannot be used as checkpoint labels." + ) + return { + "kind": "ndarray_label", + "dtype": str(value.dtype), + "values": [cls.encode_json_value(item) for item in value.reshape(-1)], + "shape": list(value.shape), + } + if isinstance(value, list | tuple): + return [cls.encode_json_value(item) for item in value] + if isinstance(value, dict): + encoded = {} + for key, item in value.items(): + if not isinstance(key, str): + raise CheckpointSerializationError( + "checkpoint label dictionaries must use string keys." + ) + encoded[key] = cls.encode_json_value(item) + return encoded + raise CheckpointSerializationError( + f"Unsupported checkpoint metadata value {type(value).__module__}.{type(value).__qualname__}." + ) + + @classmethod + def decode_json_value(cls, value: Any) -> Any: + if isinstance(value, list): + return [cls.decode_json_value(item) for item in value] + if not isinstance(value, dict) or "kind" not in value: + return value + kind = value["kind"] + if kind == "datetime": + return datetime.fromisoformat(value["value"]) + if kind == "date": + return date.fromisoformat(value["value"]) + if kind == "timedelta": + return timedelta( + days=value["days"], + seconds=value["seconds"], + microseconds=value["microseconds"], + ) + if kind == "np_datetime64": + return np.datetime64(value["value"], value["unit"]) + if kind == "np_timedelta64": + return np.timedelta64(value["value"], value["unit"]) + if kind == "torch_device": + return torch.device(value["value"]) + if kind == "torch_dtype": + return getattr(torch, value["value"].split(".")[-1]) + if kind == "np_dtype": + return np.dtype(value["value"]) + if kind == "ndarray_label": + decoded = [cls.decode_json_value(item) for item in value["values"]] + return np.asarray(decoded, dtype=np.dtype(value["dtype"])).reshape( + value["shape"] + ) + return value + + @staticmethod + def encode_np_datetime(value: np.datetime64) -> dict[str, Any]: + unit, _ = np.datetime_data(value.dtype) + return {"kind": "np_datetime64", "value": str(value), "unit": unit} + + @staticmethod + def encode_np_timedelta(value: np.timedelta64) -> dict[str, Any]: + unit, _ = np.datetime_data(value.dtype) + return { + "kind": "np_timedelta64", + "value": int(value.astype(f"timedelta64[{unit}]").astype("int64")), + "unit": unit, + } + + @staticmethod + def normalize_lead_time(value: Any | None) -> Any | None: + if isinstance(value, np.ndarray) and value.size == 1: + return value.reshape(-1)[0] + if isinstance(value, torch.Tensor) and value.numel() == 1: + return value.detach().cpu().reshape(-1)[0].item() + return value + + +def _is_latest_selector(value: Any) -> bool: + return isinstance(value, int) and not isinstance(value, bool) and value == -1 + + +def _safe_dir_name(name: str) -> str: + return sha256(name.encode("utf-8")).hexdigest()[:24] + + +def _write_json(path: Path, payload: Mapping[str, Any]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + tmp_path = path.with_name(f".{path.name}.{uuid.uuid4().hex}.tmp") + with tmp_path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, sort_keys=True) + handle.write("\n") + handle.flush() + os.fsync(handle.fileno()) + tmp_path.replace(path) + + +def _read_json(path: Path) -> dict[str, Any]: + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + +def _display_value(value: Any) -> str: + value = CheckpointCodec.decode_json_value(value) + if isinstance(value, np.ndarray): + return np.array2string(value, separator=", ") + return "" if value is None else str(value) diff --git a/examples/01_getting_started/04_checkpoint_restart.py b/examples/01_getting_started/04_checkpoint_restart.py new file mode 100644 index 000000000..2910c5144 --- /dev/null +++ b/examples/01_getting_started/04_checkpoint_restart.py @@ -0,0 +1,269 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# %% +""" +Restarting a Deterministic Forecast +=================================== + +This example shows how to use :py:class:`earth2studio.utils.checkpoint.Checkpoint` +to restart a deterministic forecast after it stops partway through a run. + +The example uses :py:class:`earth2studio.data.Random` and +:py:class:`earth2studio.models.px.UCast`. To keep the example runnable without +downloading the full U-Cast package, the public U-Cast wrapper is paired with a +small zero-residual PyTorch core. The checkpointing mechanics are identical for +the packaged U-Cast model. + +In this example you will learn: + +- Creating a persistent checkpoint +- Running a forecast that stops before the requested final horizon +- Re-opening the IO backend and checkpoint +- Resuming the deterministic workflow from the latest completed lead time +""" +# /// script +# dependencies = [ +# "earth2studio @ git+https://github.com/NVIDIA/earth2studio.git", +# ] +# /// + +# %% +# Set Up +# ------ +# A restartable forecast needs two persistent locations: one for forecast fields +# and one for the checkpoint. The IO backend owns the forecast arrays. +# The checkpoint owns restart metadata plus any model state required to continue +# the rollout. Model weights and forecast fields are not copied into the +# checkpoint. +# +# .. warning:: +# +# Model checkpoint state is opt-in. Before relying on restartable inference, +# verify that the model you plan to use documents checkpoint support. If a +# model does not support checkpointing yet, open a feature request on the +# `Earth2Studio GitHub `_. + +# %% +import os +import shutil +from collections import OrderedDict +from pathlib import Path + +import numpy as np +import torch + +import earth2studio.run as run +from earth2studio.data import Random +from earth2studio.io import ZarrBackend +from earth2studio.models.px import UCast +from earth2studio.models.px.ucast import VARIABLES as UCAST_VARIABLES +from earth2studio.utils.checkpoint import Checkpoint +from earth2studio.utils.time import to_time_array + +os.makedirs("outputs", exist_ok=True) + +forecast_store = Path("outputs/04_checkpoint_restart.zarr") +checkpoint_store = Path("outputs/04_checkpoint_restart_checkpoint") + +for path in (forecast_store, checkpoint_store): + if path.exists(): + shutil.rmtree(path) + +# %% +# Build a small U-Cast forecast problem. The zero-residual core keeps the example +# fast while preserving U-Cast's normal input/output coordinates and restart +# behavior. +# +# Full checkpoint state can be staged on the same device used for inference. +# Setting ``device`` to the current CUDA device can reduce CPU/GPU transfers for +# restart tensors during a run. Set it to ``torch.device("cpu")`` for +# CPU-only development. + +# %% +compute_device = torch.device( + f"cuda:{torch.cuda.current_device()}" if torch.cuda.is_available() else "cpu" +) + + +class ZeroResidualUCastCore(torch.nn.Module): + def forward( + self, + inputs: torch.Tensor, + dynamical_condition: torch.Tensor | None = None, + static_condition: torch.Tensor | None = None, + ) -> torch.Tensor: + return torch.zeros( + inputs.shape[0], + len(UCAST_VARIABLES), + inputs.shape[-2], + inputs.shape[-1], + device=inputs.device, + dtype=inputs.dtype, + ) + + +def make_ucast_model() -> UCast: + n_variables = len(UCAST_VARIABLES) + return UCast( + model=ZeroResidualUCastCore(), + center=torch.zeros(n_variables), + scale=torch.ones(n_variables), + residual_scale=torch.ones(n_variables), + static_condition=torch.zeros(2, 121, 240), + sst_fill_value=0.0, + stochastic=False, + ) + + +domain_coords = OrderedDict( + { + "lat": np.linspace(90, -90, 121), + "lon": np.linspace(0, 360, 240, endpoint=False), + } +) +output_variables = np.array(["t2m", "u10m"]) +time = ["2024-01-01T00:00:00"] +final_nsteps = 3 +first_attempt_nsteps = 1 + + +# %% +# Preallocate the full output store. A real full-length deterministic run does +# this before the first model step. We do it explicitly here because this example +# simulates a mid-run stop by intentionally running only the first forecast step. +# The IO store writes only two variables, while the checkpoint keeps U-Cast's full +# restart state internally when ``state_policy="full"`` is used. + +# %% + + +def deterministic_output_coords(model, time, nsteps, variables): + input_coords = model.input_coords() + output_coords = model.output_coords(input_coords).copy() + for key, value in model.output_coords(input_coords).items(): + if value.shape == (0,): + del output_coords[key] + + output_coords["time"] = to_time_array(time) + output_coords["lead_time"] = np.asarray( + [model.output_coords(input_coords)["lead_time"] * i for i in range(nsteps + 1)] + ).flatten() + output_coords["variable"] = variables + output_coords.move_to_end("lead_time", last=False) + output_coords.move_to_end("time", last=False) + return output_coords + + +prealloc_model = make_ucast_model() + +io = ZarrBackend(str(forecast_store), backend_kwargs={"overwrite": True}) +coords = deterministic_output_coords( + prealloc_model, time, final_nsteps, output_variables +) +var_names = coords.pop("variable") +io.add_array(coords, var_names) + +# %% +# First Attempt +# ------------- +# Every restartable run should be performed inside a checkpoint context. On an +# empty checkpoint, ``with checkpoint`` opens a new session for future writes. +# Construct restart-aware components inside that context so their dataclass state +# binds to the active checkpoint session. The workflow records a checkpoint row +# after each successful IO write because ``flush_interval=1`` and +# ``mode="append"`` keeps each row in the printed checkpoint table. + +# %% +checkpoint = Checkpoint( + "restart-demo", + path=checkpoint_store, + mode="append", + flush_interval=1, + history_size=4, + state_policy="full", + device=compute_device, +) + +with checkpoint as ckpt: + data = Random(domain_coords=domain_coords) + model = make_ucast_model() + run.deterministic( + time=time, + nsteps=first_attempt_nsteps, + prognostic=model, + data=data, + io=io, + output_coords=OrderedDict({"variable": output_variables}), + device=compute_device, + verbose=False, + checkpoint=ckpt, + ) + +print("Checkpoint after the stopped run:") +print(checkpoint) + +# %% +# Resume +# ------ +# In a new process, re-open the same IO store and checkpoint. The printout above +# shows the available row ids. Select ``-1`` to resume from the latest row. +# +# The selected checkpoint session is used as a context manager so the chosen row +# is the active restart state while components are constructed and while the +# workflow runs. ``UCast`` hydrates its restart dataclass during construction. Its +# iterator consumes the selected checkpoint boundary internally and yields the +# next forecast state, while the workflow still fetches the normal initial +# condition and feeds it to the iterator. + +# %% +io = ZarrBackend(str(forecast_store)) +checkpoint = Checkpoint( + "restart-demo", + path=checkpoint_store, + mode="append", + history_size=4, + state_policy="full", + device=compute_device, +) + +with checkpoint.select(-1) as ckpt: + data = Random(domain_coords=domain_coords) + model = make_ucast_model() + run.deterministic( + time=time, + nsteps=final_nsteps, + prognostic=model, + data=data, + io=io, + output_coords=OrderedDict({"variable": output_variables}), + device=compute_device, + verbose=False, + checkpoint=ckpt, + ) + +print("Checkpoint after resume:") +print(checkpoint) +print(io.root.tree()) + +# %% +# The latest checkpoint row now points at the final completed lead time. If the +# second process stopped too, selecting ``-1`` again would continue from the new +# latest row. + +# %% +latest = checkpoint.select(-1) +print(f"Latest restart lead time: {latest.lead_time}") diff --git a/examples/01_getting_started/README.rst b/examples/01_getting_started/README.rst index 54e499c44..add18eb50 100644 --- a/examples/01_getting_started/README.rst +++ b/examples/01_getting_started/README.rst @@ -4,3 +4,5 @@ Getting Started --------------- Introductory examples demonstrating the core inference workflows in Earth2Studio. +This section includes deterministic, diagnostic, ensemble, and checkpoint restart +examples for basic forecast workflows. diff --git a/test/models/px/test_fcn3.py b/test/models/px/test_fcn3.py index b3c38b636..b25d5df11 100644 --- a/test/models/px/test_fcn3.py +++ b/test/models/px/test_fcn3.py @@ -24,10 +24,10 @@ from earth2studio.data import Random, fetch_data from earth2studio.models.px import FCN3 from earth2studio.utils import handshake_dim +from earth2studio.utils.checkpoint import Checkpoint class PhooFCN3Preprocessor(torch.nn.Module): - def __init__( self, ): @@ -67,6 +67,17 @@ def set_rng(self, reset: bool = True, seed: int = 333): return +class PhooRestartFCN3ModelWrapper(PhooFCN3ModelWrapper): + def forward(self, x, t, normalized_data: bool = False, replace_state: bool = False): + return x + self.model.preprocessor.state[0].to(x.device) + + +def _phoo_fcn3(model: torch.nn.Module, variables: np.ndarray) -> FCN3: + fcn3 = FCN3.__new__(FCN3) + FCN3.__init__.__wrapped__(fcn3, model, variables=variables) + return fcn3 + + @pytest.fixture(scope="function") def dummy_model(): preprocessor = PhooFCN3Preprocessor() @@ -167,6 +178,61 @@ def test_fcn3_iter(ensemble, device, dummy_model): break +def test_fcn3_checkpoint_state_round_trip_with_phoo_model(tmp_path): + variables = np.array(["u10m"]) + time = np.array([np.datetime64("1993-04-05T00:00")]) + checkpoint = Checkpoint( + "fcn3", path=tmp_path / "fcn3", flush_interval=2, state_policy="full" + ) + + torch.manual_seed(123) + model = PhooRestartFCN3ModelWrapper(PhooFCN3Model(PhooFCN3Preprocessor())) + p = _phoo_fcn3(model, variables=variables).to("cpu") + input_coords = p.input_coords() + coords = OrderedDict( + { + "time": time, + "lead_time": input_coords["lead_time"], + "variable": input_coords["variable"], + "lat": input_coords["lat"], + "lon": input_coords["lon"], + } + ) + x = torch.zeros( + len(coords["time"]), + len(coords["lead_time"]), + len(coords["variable"]), + len(coords["lat"]), + len(coords["lon"]), + ) + + with checkpoint.select(time=time) as ckpt: + iterator = p.create_iterator(x, coords) + _, initial_coords = next(iterator) + ckpt.write(lead_time=initial_coords["lead_time"]) + step1, step1_coords = next(iterator) + ckpt.write(lead_time=step1_coords["lead_time"]) + step2, step2_coords = next(iterator) + step3, step3_coords = next(iterator) + + torch.manual_seed(999) + with checkpoint.select(-1): + resumed_model = PhooRestartFCN3ModelWrapper( + PhooFCN3Model(PhooFCN3Preprocessor()) + ) + resumed = _phoo_fcn3(resumed_model, variables=variables).to("cpu") + assert resumed.checkpoint.checkpoint_state_loaded + + resumed_iterator = resumed.create_iterator(torch.full_like(x, -100.0), coords) + resumed_step2, resumed_step2_coords = next(resumed_iterator) + resumed_step3, resumed_step3_coords = next(resumed_iterator) + + assert torch.allclose(resumed_step2, step2) + assert torch.allclose(resumed_step3, step3) + assert resumed_step2_coords["lead_time"][0] == step2_coords["lead_time"][0] + assert resumed_step3_coords["lead_time"][0] == step3_coords["lead_time"][0] + + @pytest.mark.parametrize( "dc", [ diff --git a/test/models/px/test_persistence.py b/test/models/px/test_persistence.py index 67f7191b6..00aa96559 100644 --- a/test/models/px/test_persistence.py +++ b/test/models/px/test_persistence.py @@ -22,6 +22,7 @@ from earth2studio.data import Random, fetch_data from earth2studio.models.px import Persistence +from earth2studio.utils.checkpoint import Checkpoint @pytest.mark.parametrize( @@ -144,6 +145,37 @@ def test_persistence_iter(ensemble, variable, history, device): break +def test_persistence_checkpoint_state_round_trip(tmp_path): + variable = ["t2m", "tcwv"] + time = np.array([np.datetime64("1993-04-05T00:00")]) + domain_coords = OrderedDict({"lat": np.arange(2), "lon": np.arange(3)}) + lead_time = np.asarray([np.timedelta64(-6, "h"), np.timedelta64(0, "h")]) + data = Random(domain_coords) + x, coords = fetch_data(data, time, np.asarray(variable), lead_time, device="cpu") + checkpoint = Checkpoint( + "persistence", path=tmp_path, mode="append", state_policy="full" + ) + + with checkpoint as ckpt: + model = Persistence(variable, domain_coords, history=2) + iterator = model.create_iterator(x, coords) + _, initial_coords = next(iterator) + _, saved_coords = next(iterator) + assert initial_coords["lead_time"][0] == np.timedelta64(0, "h") + assert saved_coords["lead_time"][0] == np.timedelta64(6, "h") + ckpt.write(lead_time=saved_coords["lead_time"][-1]) + ckpt.flush() + + with checkpoint.select(-1): + model = Persistence(variable, domain_coords, history=2) + iterator = model.create_iterator(x, coords) + out, out_coords = next(iterator) + assert model.checkpoint.checkpoint_state_loaded + + assert out_coords["lead_time"][0] == np.timedelta64(12, "h") + assert torch.allclose(out, x[:, -1:]) + + @pytest.mark.parametrize( "dc", [ diff --git a/test/models/px/test_ucast.py b/test/models/px/test_ucast.py index f72e80449..beefb9c0b 100644 --- a/test/models/px/test_ucast.py +++ b/test/models/px/test_ucast.py @@ -25,6 +25,7 @@ from earth2studio.models.px import UCast from earth2studio.models.px.ucast import VARIABLES from earth2studio.utils import handshake_dim +from earth2studio.utils.checkpoint import Checkpoint class PhooUCastModel(torch.nn.Module): @@ -65,8 +66,7 @@ def forward( return out -@pytest.fixture(scope="function") -def ucast_model() -> UCast: +def _make_ucast_model(stochastic: bool = False) -> UCast: n_variables = len(VARIABLES) return UCast( model=PhooUCastModel(), @@ -75,10 +75,15 @@ def ucast_model() -> UCast: residual_scale=torch.ones(n_variables), static_condition=torch.zeros(2, 121, 240), sst_fill_value=0.0, - stochastic=False, + stochastic=stochastic, ) +@pytest.fixture(scope="function") +def ucast_model() -> UCast: + return _make_ucast_model() + + @pytest.fixture(scope="function") def model() -> UCast: package = UCast.load_default_package() @@ -200,6 +205,56 @@ def test_ucast_iter(ucast_model: UCast, ensemble: int, device: str) -> None: break +def test_ucast_checkpoint_state_round_trip(tmp_path) -> None: + time = np.array([np.datetime64("2020-01-01T00:00")]) + source_model = _make_ucast_model() + x, coords = _input(source_model, time) + x = x.unsqueeze(0) + coords.update({"ensemble": np.array([0])}) + coords.move_to_end("ensemble", last=False) + + checkpoint = Checkpoint( + "ucast", + path=tmp_path, + mode="append", + flush_interval=1, + state_policy="full", + ) + with checkpoint as ckpt: + model = _make_ucast_model() + iterator = model.create_iterator(x, coords) + initial, initial_coords = next(iterator) + first, first_coords = next(iterator) + ckpt.write(lead_time=first_coords["lead_time"][-1]) + ckpt.flush() + + np.testing.assert_array_equal( + initial_coords["lead_time"], np.array([np.timedelta64(0, "h")]) + ) + np.testing.assert_array_equal( + first_coords["lead_time"], np.array([np.timedelta64(12, "h")]) + ) + assert torch.allclose(initial, x[:, :, -1:]) + assert torch.allclose(first, x[:, :, -1:]) + + checkpoint = Checkpoint( + "ucast", + path=tmp_path, + mode="append", + flush_interval=1, + state_policy="full", + ) + with checkpoint.select(-1): + model = _make_ucast_model() + resumed, resumed_coords = next(model.create_iterator(x, coords)) + assert model.checkpoint.checkpoint_state_loaded + + np.testing.assert_array_equal( + resumed_coords["lead_time"], np.array([np.timedelta64(24, "h")]) + ) + assert torch.allclose(resumed, x[:, :, -1:]) + + def test_ucast_iter_uses_internal_normalized_state() -> None: n_variables = len(VARIABLES) ucast_model = UCast( diff --git a/test/perturbation/test_gaussian.py b/test/perturbation/test_gaussian.py index 04516d8c3..d8ce9a399 100644 --- a/test/perturbation/test_gaussian.py +++ b/test/perturbation/test_gaussian.py @@ -16,10 +16,12 @@ from collections import OrderedDict +import numpy as np import pytest import torch from earth2studio.perturbation import CorrelatedSphericalGaussian, Gaussian +from earth2studio.utils.checkpoint import Checkpoint @pytest.mark.parametrize( @@ -75,6 +77,51 @@ def test_gaussian(x, coords, amplitude, device): assert dx.device == x.device +def test_gaussian_checkpoint_state_round_trip(tmp_path): + x = torch.zeros(2, 3) + coords = OrderedDict([("batch", []), ("variable", [])]) + + replay_checkpoint = Checkpoint( + "gaussian-state", path=tmp_path / "state", state_policy="state" + ) + with replay_checkpoint.select(time="2024-01-01") as ckpt: + perturbation = Gaussian(1.0) + expected_replay, _ = perturbation(x, coords) + assert perturbation.checkpoint.generator_state is not None + ckpt.write(lead_time=np.timedelta64(0, "h")) + + with replay_checkpoint.select(-1): + perturbation = Gaussian(1.0) + replayed, _ = perturbation(x, coords) + assert perturbation.checkpoint.checkpoint_state_loaded + assert torch.allclose(replayed, expected_replay) + + direct_checkpoint = Checkpoint( + "gaussian-full", + path=tmp_path / "full", + flush_interval=2, + state_policy="full", + ) + with direct_checkpoint.select(time="2024-01-01") as ckpt: + perturbation = Gaussian(1.0) + perturbation(x, coords) + assert perturbation.checkpoint.generator_state is None + ckpt.write(lead_time=np.timedelta64(0, "h")) + perturbation(x, coords) + assert perturbation.checkpoint.generator_state is not None + ckpt.write(lead_time=np.timedelta64(6, "h")) + expected_direct_next, _ = perturbation(x, coords) + expected_direct_third, _ = perturbation(x, coords) + + with direct_checkpoint.select(-1): + perturbation = Gaussian(1.0) + resumed, _ = perturbation(x, coords) + assert perturbation.checkpoint.checkpoint_state_loaded + assert torch.allclose(resumed, expected_direct_next) + next_perturbed, _ = perturbation(x, coords) + assert torch.allclose(next_perturbed, expected_direct_third) + + def test_correlated_spherical_gaussian_no_amplitude(): """Test that CorrelatedSphericalGaussian raises error without amplitude""" with pytest.raises(ValueError): diff --git a/test/utils/test_checkpoint.py b/test/utils/test_checkpoint.py new file mode 100644 index 000000000..64beed479 --- /dev/null +++ b/test/utils/test_checkpoint.py @@ -0,0 +1,794 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from collections import OrderedDict +from dataclasses import dataclass, field +from datetime import date, datetime, timedelta, timezone + +import numpy as np +import pytest +import torch + +import earth2studio.run as run +from earth2studio.data import Random +from earth2studio.io import ZarrBackend +from earth2studio.models.dx import Identity +from earth2studio.models.px import Persistence +from earth2studio.perturbation import Zero +from earth2studio.utils.checkpoint import ( + NO_CHECKPOINT, + Checkpoint, + CheckpointError, + CheckpointSerializationError, + CheckpointState, + CheckpointStateCollision, + CheckpointStateSchemaError, + bind_checkpoint_state, + default_checkpoint_path, +) +from earth2studio.utils.time import to_time_array + + +@dataclass +class NestedState: + label: str = "inner" + count: int = 0 + + +@dataclass +class ToyState: + calls: int = 0 + rng: torch.Tensor | None = None + weights: np.ndarray = field( + default_factory=lambda: np.asarray([1.0, 2.0], dtype=np.float32) + ) + timestamp: np.datetime64 = np.datetime64("2026-06-08T00", "h") + delta: np.timedelta64 = np.timedelta64(0, "h") + tensor_device: torch.device = torch.device("cpu") + dtype: torch.dtype = torch.float32 + np_dtype: np.dtype = np.dtype("float32") + created: datetime = datetime(2024, 1, 1, tzinfo=timezone.utc) + day: date = date(2024, 1, 2) + window: timedelta = timedelta(hours=2) + items: list = field(default_factory=lambda: [True, np.float32(2.0)]) + pair: tuple = ("left", 1) + mapping: dict = field(default_factory=lambda: {"value": np.int64(3)}) + nested: NestedState = field(default_factory=NestedState) + loose: object | None = None + + +@dataclass +class OtherState: + value: int = 0 + + +@dataclass +class BadState: + value: object = object() + + +@dataclass +class RequiredState: + value: int + + +def test_checkpoint_contexts_and_no_checkpoint_session(tmp_path): + with NO_CHECKPOINT as ckpt: + assert not ckpt + assert not ckpt.exists + assert not ckpt.is_active + assert ckpt.write(lead_time=_lead_time(0)) is None + assert ckpt.flush() is None + + checkpoint = Checkpoint("forecast", path=tmp_path, mode="append", flush_interval=1) + with checkpoint.select(time="2024-01-01") as selected: + with checkpoint as active: + assert active is selected + assert active.write(lead_time=_lead_time(0)) is not None + assert active.flush() is None + + assert len(checkpoint.catalog) == 1 + + +def test_checkpoint_state_proxy_metadata_and_rebinding(tmp_path): + proxy = CheckpointState(ToyState()) + assert repr(proxy) == repr(proxy.checkpoint_dataclass) + assert proxy.checkpoint_state_policy == "minimal" + assert proxy.checkpoint_device == torch.device("cpu") + assert proxy.device == torch.device("cpu") + assert proxy.checkpoint_flush_interval is None + assert proxy.checkpoint_write_count == 0 + assert not proxy.checkpoint_is_flush_due + assert proxy.checkpoint_lead_time is None + assert proxy.checkpoint_labels == {} + assert not proxy.checkpoint_state_loaded + with pytest.raises(AttributeError): + proxy.device = torch.device("cpu") + assert NO_CHECKPOINT.labels == {} + assert NO_CHECKPOINT.write_count == 0 + assert not NO_CHECKPOINT.is_active + assert NO_CHECKPOINT.device == torch.device("cpu") + + checkpoint = Checkpoint( + "forecast", path=tmp_path, flush_interval=None, device="cpu" + ) + assert checkpoint.device == torch.device("cpu") + with checkpoint.select(time="2024-01-01") as ckpt: + rebound = bind_checkpoint_state(proxy) + assert rebound is proxy + assert ckpt.device == torch.device("cpu") + assert proxy.device == torch.device("cpu") + assert bind_checkpoint_state(proxy) is proxy + assert ckpt.is_active + assert proxy.checkpoint_labels == {"time": "2024-01-01"} + assert ckpt.artifacts == {} + with pytest.raises(TypeError): + ckpt.bind(object()) + assert ckpt.write(lead_time=torch.tensor([6])) is None + entry = ckpt.flush() + + assert entry is not None + assert entry.lead_time == 6 + + +class DroppingLeadTimeDiagnostic(torch.nn.Module): + def __init__(self, variables: list[str], domain_coords: OrderedDict): + super().__init__() + self.variables = np.asarray(variables) + self.domain_coords = domain_coords + + def input_coords(self): + coords = OrderedDict( + { + "batch": np.empty(0), + "time": np.empty(0), + "lead_time": np.empty(0), + "variable": self.variables, + } + ) + coords.update(self.domain_coords) + return coords + + def output_coords(self, input_coords): + output_coords = input_coords.copy() + output_coords.pop("lead_time", None) + return output_coords + + def __call__(self, x, coords): + output_coords = coords.copy() + lead_time_index = list(output_coords).index("lead_time") + output_coords.pop("lead_time") + return x.squeeze(lead_time_index), output_coords + + +class RecordingIO: + def __init__(self): + self.coords = None + self.array_names = None + self.writes = [] + + def add_array(self, coords, array_name, **kwargs): + self.coords = coords + self.array_names = array_name + + def write(self, x, coords, array_name): + self.writes.append((x, coords, array_name)) + + +def _lead_time(hours: int): + return np.timedelta64(hours, "h") + + +def test_bind_round_trip_hydrates_dataclass_and_catalog(tmp_path): + checkpoint = Checkpoint("forecast", path=tmp_path, mode="overwrite") + time = np.asarray([np.datetime64("2024-01-01T00")]) + + with checkpoint.select(time=time, ensemble=0) as ckpt: + state = bind_checkpoint_state(ToyState()) + state.calls = 7 + state.rng = torch.arange(4, dtype=torch.uint8) + state.weights = np.asarray([3.0, 4.0], dtype=np.float32) + state.delta = np.timedelta64(6, "h") + state.nested.count = 2 + state.loose = NestedState(count=5) + ckpt.write(lead_time=_lead_time(6), artifacts={"sample": 3}) + + checkpoint = Checkpoint("forecast", path=tmp_path) + with checkpoint.select(time=time, ensemble=0) as ckpt: + restored = bind_checkpoint_state(ToyState()) + + assert ckpt.exists + assert ckpt.lead_time == np.timedelta64(6, "h") + assert restored.calls == 7 + assert torch.equal(restored.rng, torch.arange(4, dtype=torch.uint8)) + assert np.array_equal( + restored.weights, np.asarray([3.0, 4.0], dtype=np.float32) + ) + assert restored.delta == np.timedelta64(6, "h") + assert restored.tensor_device == torch.device("cpu") + assert restored.dtype == torch.float32 + assert restored.np_dtype == np.dtype("float32") + assert restored.created == datetime(2024, 1, 1, tzinfo=timezone.utc) + assert restored.day == date(2024, 1, 2) + assert restored.window == timedelta(hours=2) + assert restored.items == [True, 2.0] + assert restored.pair == ("left", 1) + assert restored.mapping == {"value": 3} + assert restored.nested.count == 2 + assert restored.loose == {"label": "inner", "count": 5} + assert ckpt.commit_id is not None + assert ckpt.artifact("sample") == 3 + + text = repr(checkpoint) + assert 'Checkpoint("forecast")' in text + assert "ensemble" in text + assert "6 hours" in text + + +def test_duplicate_state_type_errors_but_different_selections_are_independent(tmp_path): + checkpoint = Checkpoint("forecast", path=tmp_path) + + with checkpoint.select(time="2024-01-01") as ckpt: + bind_checkpoint_state(ToyState()) + with pytest.raises(CheckpointStateCollision): + bind_checkpoint_state(ToyState()) + ckpt.flush(lead_time=_lead_time(0)) + + with checkpoint.select(time="2024-01-02") as ckpt: + state = bind_checkpoint_state(ToyState()) + state.calls = 2 + ckpt.flush(lead_time=_lead_time(6)) + + assert len(checkpoint.catalog) == 2 + + +def test_write_interval_overwrite_and_manual_flush_prune_old_commits(tmp_path): + checkpoint = Checkpoint( + "forecast", path=tmp_path, mode="overwrite", flush_interval=2 + ) + + with checkpoint.select(time="2024-01-01") as ckpt: + state = bind_checkpoint_state(ToyState()) + state.calls = 1 + assert ckpt.write(lead_time=_lead_time(6)) is None + + state.calls = 2 + first = ckpt.write(lead_time=_lead_time(12)) + assert first is not None + + state.calls = 3 + assert ckpt.write(lead_time=_lead_time(18)) is None + final = ckpt.flush() + + assert final.write_count == 3 + assert len(checkpoint.catalog) == 1 + commits = list((checkpoint.rank_path / "commits").iterdir()) + assert [commit.name for commit in commits] == [final.commit_id] + + with checkpoint.select(time="2024-01-01") as ckpt: + state = bind_checkpoint_state(ToyState()) + assert ckpt.lead_time == np.timedelta64(18, "h") + assert state.calls == 3 + + +def test_append_history_size_and_positional_selection(tmp_path): + checkpoint = Checkpoint( + "forecast", path=tmp_path, mode="append", flush_interval=1, history_size=2 + ) + + for hours in (6, 12, 18): + with checkpoint.select(time="2024-01-01", ensemble=0) as ckpt: + bind_checkpoint_state(ToyState()).calls = hours + ckpt.write(lead_time=_lead_time(hours)) + + assert len(checkpoint.catalog) == 2 + assert checkpoint.select(-1).lead_time == np.timedelta64(18, "h") + assert checkpoint.select(-2).lead_time == np.timedelta64(12, "h") + assert checkpoint.select(time=-1, ensemble=0).lead_time == np.timedelta64(18, "h") + + +def test_bind_before_new_session_is_adopted_on_enter(tmp_path): + checkpoint = Checkpoint( + "forecast", path=tmp_path, flush_interval=2, state_policy="state" + ) + + dataclass_state = ToyState() + state = bind_checkpoint_state(dataclass_state) + assert isinstance(state, CheckpointState) + assert state.checkpoint_dataclass is dataclass_state + assert bind_checkpoint_state(dataclass_state) is state + assert state.checkpoint_enabled + assert state.checkpoint_state_policy == "state" + assert state.checkpoint_flush_interval == 2 + with pytest.raises(AttributeError): + state.checkpoint_state_policy = "full" + state.calls = 5 + + with checkpoint.select(time="2024-01-01") as ckpt: + assert list(ckpt.bound_states.values()) == [state] + assert not state.checkpoint_selected + assert not state.checkpoint_state_loaded + assert state.checkpoint_write_count == 0 + assert not state.checkpoint_is_flush_due + ckpt.write(lead_time=_lead_time(3)) + assert state.checkpoint_write_count == 1 + assert state.checkpoint_is_flush_due + ckpt.flush(lead_time=_lead_time(6)) + + with Checkpoint("forecast", path=tmp_path).select(time="2024-01-01"): + restored = bind_checkpoint_state(ToyState()) + assert restored.checkpoint_selected + assert restored.checkpoint_state_loaded + assert restored.checkpoint_lead_time == np.timedelta64(6, "h") + + assert restored.calls == 5 + assert not restored.checkpoint_selected + assert not restored.checkpoint_state_loaded + + +def test_bind_before_existing_session_warns_and_hydrates_late(tmp_path): + checkpoint = Checkpoint("forecast", path=tmp_path) + with checkpoint.select(time="2024-01-01") as ckpt: + state = bind_checkpoint_state(ToyState()) + state.calls = 9 + ckpt.flush(lead_time=_lead_time(6)) + + checkpoint = Checkpoint("forecast", path=tmp_path) + state = bind_checkpoint_state(ToyState()) + assert state.calls == 0 + + with pytest.warns(UserWarning, match="bound before an existing checkpoint session"): + with checkpoint.select(time="2024-01-01") as ckpt: + assert ckpt.exists + assert state.checkpoint_state_loaded + assert state.calls == 9 + + +def test_defensive_paths_and_catalog_rebuild(tmp_path): + plain = ToyState() + plain_state = bind_checkpoint_state(plain) + assert isinstance(plain_state, CheckpointState) + assert plain_state.checkpoint_dataclass is plain + plain_state.calls = 1 + assert plain.calls == 1 + with pytest.raises(TypeError): + bind_checkpoint_state(object()) + + with pytest.raises(ValueError): + Checkpoint("bad", path=tmp_path, mode="bad") + with pytest.raises(ValueError): + Checkpoint("bad", path=tmp_path, flush_interval=0) + with pytest.raises(ValueError): + Checkpoint("bad", path=tmp_path, history_size=0) + with pytest.raises(ValueError): + Checkpoint("bad", path=tmp_path, state_policy="bad") + with pytest.raises(ValueError): + Checkpoint("legacy-replay", path=tmp_path, state_policy="replay") + with pytest.raises(ValueError): + Checkpoint("legacy-direct", path=tmp_path, state_policy="direct") + + checkpoint = Checkpoint("forecast", path=tmp_path / "catalog", mode="append") + assert "catalog: empty" in repr(checkpoint) + with pytest.raises(IndexError): + checkpoint.select(-1) + assert checkpoint.select(time="missing").artifacts == {} + assert not checkpoint.select(time="missing") + with pytest.raises(CheckpointError): + checkpoint.select(time="missing")._commit_path + + with pytest.raises(CheckpointSerializationError): + checkpoint.select(meta={1: 2}) + with pytest.raises(CheckpointSerializationError): + checkpoint.select(meta=object()) + with pytest.raises(CheckpointSerializationError): + checkpoint.select(meta=np.asarray([object()], dtype=object)) + + with checkpoint.select(time="2024-01-01") as ckpt: + with pytest.raises(CheckpointSerializationError): + ckpt.flush(artifacts={1: 2}) + with pytest.raises(CheckpointSerializationError): + ckpt.flush(artifacts={"bad": {1: 2}}) + bind_checkpoint_state(RequiredState(1)) + ckpt.flush(lead_time=torch.tensor([6])) + assert ckpt.flush() is None + ckpt.write(lead_time=torch.tensor([7])) + + meta_checkpoint = Checkpoint("metadata", path=tmp_path / "metadata") + with meta_checkpoint.select( + day=date(2024, 1, 1), + window=timedelta(hours=2), + device=torch.device("cpu"), + dtype=torch.float32, + np_dtype=np.dtype("float32"), + values=[np.int64(1)], + meta={"ok": np.float32(1.0)}, + ) as ckpt: + ckpt.flush(lead_time=np.asarray([1, 2])) + assert Checkpoint("metadata", path=tmp_path / "metadata").select(-1).exists + + assert len(checkpoint.catalog) == 2 + (checkpoint.rank_path / "catalog.json").write_text("{") + assert len(Checkpoint("forecast", path=tmp_path / "catalog").catalog) == 2 + + (checkpoint.rank_path / "catalog.json").unlink() + bad_commit = checkpoint.rank_path / "commits" / "bad" + bad_commit.mkdir() + (bad_commit / "manifest.json").write_text("{") + assert len(Checkpoint("forecast", path=tmp_path / "catalog").catalog) == 2 + + +def test_artifacts_round_trip_and_unsupported_objects_reject(tmp_path): + checkpoint = Checkpoint("forecast", path=tmp_path) + + with checkpoint.select(time="2024-01-01") as ckpt: + ckpt.write( + lead_time=_lead_time(6), + artifacts={ + "mask": torch.tensor([True, False]), + "scores": np.asarray([1.0, 2.0], dtype=np.float32), + "meta": {"name": "sample"}, + }, + ) + + selected = checkpoint.select(time="2024-01-01") + assert torch.equal(selected.artifact("mask"), torch.tensor([True, False])) + assert np.array_equal( + selected.artifact("scores"), np.asarray([1.0, 2.0], dtype=np.float32) + ) + assert selected.artifact("meta") == {"name": "sample"} + with pytest.raises(KeyError): + selected.artifact("missing") + + with checkpoint.select(time="bad") as ckpt: + bind_checkpoint_state(BadState()) + with pytest.raises(CheckpointSerializationError): + ckpt.flush(lead_time=_lead_time(0)) + + with checkpoint.select(time="bad-array") as ckpt: + with pytest.raises(CheckpointSerializationError): + ckpt.flush(artifacts={"bad": np.asarray([object()], dtype=object)}) + + +def test_schema_mismatch_errors_before_hydration(tmp_path): + checkpoint = Checkpoint("forecast", path=tmp_path) + + with checkpoint.select(time="2024-01-01") as ckpt: + state = bind_checkpoint_state(ToyState()) + state.calls = 4 + ckpt.flush(lead_time=_lead_time(6)) + + metadata_path = next( + (checkpoint.rank_path / "commits").glob("*/states/*/metadata.json") + ) + original_metadata = json.loads(metadata_path.read_text()) + + metadata = original_metadata.copy() + metadata["state_id"] = "bad" + metadata_path.write_text(json.dumps(metadata)) + with Checkpoint("forecast", path=tmp_path).select(time="2024-01-01"): + with pytest.raises(CheckpointStateSchemaError): + bind_checkpoint_state(ToyState()) + + metadata = original_metadata.copy() + metadata["fields"] = metadata["fields"].copy() + metadata["fields"].pop("calls") + metadata_path.write_text(json.dumps(metadata)) + with Checkpoint("forecast", path=tmp_path).select(time="2024-01-01"): + with pytest.raises(CheckpointStateSchemaError): + bind_checkpoint_state(ToyState()) + + metadata = original_metadata.copy() + metadata["schema_hash"] = "bad" + metadata_path.write_text(json.dumps(metadata)) + with Checkpoint("forecast", path=tmp_path).select(time="2024-01-01"): + with pytest.raises(CheckpointStateSchemaError): + bind_checkpoint_state(ToyState()) + + +def test_default_path_and_rank_directory_detection(tmp_path, monkeypatch): + monkeypatch.setenv("EARTH2STUDIO_CACHE", str(tmp_path)) + + assert default_checkpoint_path("forecast") == tmp_path / "checkpoints" / "forecast" + + serial = Checkpoint("serial", path=tmp_path / "serial", rank=0, world_size=1) + with serial.select(time="2024-01-01") as ckpt: + ckpt.flush(lead_time=_lead_time(0)) + + assert serial.rank_path == serial.path + assert (serial.path / "catalog.json").exists() + assert not (serial.path / "rank_000000").exists() + + monkeypatch.setenv("RANK", "2") + monkeypatch.setenv("WORLD_SIZE", "4") + checkpoint = Checkpoint("forecast") + with checkpoint.select(time="2024-01-01") as ckpt: + ckpt.flush(lead_time=_lead_time(0)) + + assert checkpoint.rank == 2 + assert checkpoint.world_size == 4 + assert checkpoint.rank_path.name == "rank_000002" + assert checkpoint.catalog[0].rank == 2 + + +def test_deterministic_workflow_records_checkpoint(tmp_path): + coords = OrderedDict([("lat", np.arange(2)), ("lon", np.arange(3))]) + variables = ["u10m", "v10m"] + data = Random(domain_coords=coords) + model = Persistence(variables, coords) + io = ZarrBackend() + checkpoint = Checkpoint( + "deterministic", path=tmp_path, mode="overwrite", flush_interval=2 + ) + + run.deterministic( + ["2024-01-01"], + 3, + model, + data, + io, + device=torch.device("cpu"), + verbose=False, + checkpoint=checkpoint, + ) + + selected = checkpoint.select(-1) + assert selected.exists + assert selected.lead_time == np.timedelta64(18, "h") + assert selected.write_count == 4 + + +def test_deterministic_workflow_resumes_from_checkpoint(tmp_path): + coords = OrderedDict([("lat", np.arange(2)), ("lon", np.arange(3))]) + variables = ["u10m", "v10m"] + io = ZarrBackend() + io.add_array( + OrderedDict( + { + "time": np.asarray(["2024-01-01T00"], dtype="datetime64[ns]"), + "lead_time": np.asarray([np.timedelta64(6 * i, "h") for i in range(4)]), + **coords, + } + ), + variables, + ) + checkpoint = Checkpoint( + "deterministic", path=tmp_path, mode="append", flush_interval=1 + ) + + data = Random(domain_coords=coords) + model = Persistence(variables, coords) + run.deterministic( + ["2024-01-01"], + 1, + model, + data, + io, + device=torch.device("cpu"), + verbose=False, + checkpoint=checkpoint, + ) + assert checkpoint.select(-1).lead_time == np.timedelta64(6, "h") + + with checkpoint.select(-1): + data = Random(domain_coords=coords) + model = Persistence(variables, coords) + run.deterministic( + ["2024-01-01"], + 3, + model, + data, + io, + device=torch.device("cpu"), + verbose=False, + checkpoint=checkpoint, + ) + + selected = checkpoint.select(-1) + assert selected.lead_time == np.timedelta64(18, "h") + assert selected.write_count == 4 + assert io["u10m"].shape[1] == 4 + + +def test_deterministic_workflow_uses_model_checkpoint_state_when_io_is_filtered( + tmp_path, +): + coords = OrderedDict([("lat", np.arange(2)), ("lon", np.arange(3))]) + variables = ["u10m", "v10m"] + output_coords = OrderedDict({"variable": np.asarray(["u10m"])}) + io = ZarrBackend() + io.add_array( + OrderedDict( + { + "time": np.asarray(["2024-01-01T00"], dtype="datetime64[ns]"), + "lead_time": np.asarray([np.timedelta64(6 * i, "h") for i in range(4)]), + **coords, + } + ), + ["u10m"], + ) + checkpoint = Checkpoint( + "deterministic", path=tmp_path, mode="append", flush_interval=1 + ) + + data = Random(domain_coords=coords) + model = Persistence(variables, coords) + run.deterministic( + ["2024-01-01"], + 1, + model, + data, + io, + output_coords=output_coords, + device=torch.device("cpu"), + verbose=False, + checkpoint=checkpoint, + ) + + with checkpoint.select(-1): + data = Random(domain_coords=coords) + model = Persistence(variables, coords) + run.deterministic( + ["2024-01-01"], + 3, + model, + data, + io, + output_coords=output_coords, + device=torch.device("cpu"), + verbose=False, + checkpoint=checkpoint, + ) + + selected = checkpoint.select(-1) + assert selected.lead_time == np.timedelta64(18, "h") + assert selected.write_count == 4 + assert io["u10m"].shape[1] == 4 + assert "v10m" not in io + + +def test_diagnostic_workflow_resumes_from_checkpoint(tmp_path): + coords = OrderedDict([("lat", np.arange(2)), ("lon", np.arange(3))]) + variables = ["u10m", "v10m"] + io = ZarrBackend() + io.add_array( + OrderedDict( + { + "time": np.asarray(["2024-01-01T00"], dtype="datetime64[ns]"), + "lead_time": np.asarray([np.timedelta64(6 * i, "h") for i in range(4)]), + **coords, + } + ), + variables, + ) + checkpoint = Checkpoint( + "diagnostic", path=tmp_path, mode="append", flush_interval=1 + ) + + with checkpoint.select(time=to_time_array(["2024-01-01"])) as ckpt: + data = Random(domain_coords=coords) + model = Persistence(variables, coords) + diagnostic = Identity() + run.diagnostic( + ["2024-01-01"], + 1, + model, + diagnostic, + data, + io, + device=torch.device("cpu"), + verbose=False, + checkpoint=ckpt, + ) + assert checkpoint.select(-1).lead_time == np.timedelta64(6, "h") + + with checkpoint.select(-1) as ckpt: + data = Random(domain_coords=coords) + model = Persistence(variables, coords) + diagnostic = Identity() + run.diagnostic( + ["2024-01-01"], + 3, + model, + diagnostic, + data, + io, + device=torch.device("cpu"), + verbose=False, + checkpoint=ckpt, + ) + + selected = checkpoint.select(-1) + assert selected.lead_time == np.timedelta64(18, "h") + assert selected.write_count == 4 + assert io["u10m"].shape[1] == 4 + + +def test_diagnostic_checkpoint_tracks_prognostic_lead_time(tmp_path): + coords = OrderedDict([("lat", np.arange(2)), ("lon", np.arange(3))]) + variables = ["u10m", "v10m"] + checkpoint = Checkpoint("diagnostic", path=tmp_path, flush_interval=1) + io = RecordingIO() + + run.diagnostic( + ["2024-01-01"], + 1, + Persistence(variables, coords), + DroppingLeadTimeDiagnostic(variables, coords), + Random(domain_coords=coords), + io, + device=torch.device("cpu"), + verbose=False, + checkpoint=checkpoint, + ) + + assert len(io.writes) == 2 + assert checkpoint.select(-1).lead_time == np.timedelta64(6, "h") + + +def test_ensemble_workflow_resumes_each_batch_from_checkpoint(tmp_path): + coords = OrderedDict([("lat", np.arange(2)), ("lon", np.arange(3))]) + variables = ["u10m", "v10m"] + nensemble = 2 + io = ZarrBackend() + io.add_array( + OrderedDict( + { + "ensemble": np.arange(nensemble), + "time": np.asarray(["2024-01-01T00"], dtype="datetime64[ns]"), + "lead_time": np.asarray([np.timedelta64(6 * i, "h") for i in range(4)]), + **coords, + } + ), + variables, + ) + checkpoint = Checkpoint("ensemble", path=tmp_path, mode="append", flush_interval=1) + + run.ensemble( + ["2024-01-01"], + 1, + nensemble, + Persistence(variables, coords), + Random(domain_coords=coords), + io, + Zero(), + batch_size=1, + device=torch.device("cpu"), + verbose=False, + checkpoint=checkpoint, + ) + + run.ensemble( + ["2024-01-01"], + 3, + nensemble, + Persistence(variables, coords), + Random(domain_coords=coords), + io, + Zero(), + batch_size=1, + device=torch.device("cpu"), + verbose=False, + checkpoint=checkpoint, + ) + + time = to_time_array(["2024-01-01"]) + for batch_id in range(nensemble): + selected = checkpoint.select(time=time, ensemble_batch=batch_id) + assert selected.lead_time == np.timedelta64(18, "h") + assert selected.write_count == 4 + assert len(checkpoint.catalog) == 8 + assert io["u10m"].shape[:3] == (nensemble, 1, 4)