Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2cfe7c1
Add checkpoint utilities and workflow support
NickGeneva Jun 15, 2026
a14e263
Move restart example support into Persistence
NickGeneva Jun 15, 2026
2e06c13
Merge remote-tracking branch 'upstream/main' into codex/checkpoint-ca…
NickGeneva Jun 15, 2026
af5f39a
Add UCast checkpoint restart support
NickGeneva Jun 15, 2026
55c9ee8
Document model checkpoint opt-in warning
NickGeneva Jun 15, 2026
8e6acf6
Remove legacy checkpoint state policy aliases
NickGeneva Jun 15, 2026
31cda94
Rename checkpoint history limit parameter
NickGeneva Jun 15, 2026
7900ae5
Avoid rank folders for serial checkpoints
NickGeneva Jun 15, 2026
0aa55fc
Group checkpoint serialization helpers in codec
NickGeneva Jun 15, 2026
dcff29b
Use FCN for checkpoint restart support
NickGeneva Jun 15, 2026
3492e2c
Inline checkpoint helper functions
NickGeneva Jun 16, 2026
2864940
Use checkpoint context in docs examples
NickGeneva Jun 16, 2026
4977267
Simplify workflow checkpoint selection
NickGeneva Jun 16, 2026
2c0187f
Tighten checkpointing documentation
NickGeneva Jun 16, 2026
01ebee3
Feedback 1
NickGeneva Jun 24, 2026
b868a8b
Feedback 2
NickGeneva Jun 24, 2026
e464cd7
Feedback 2
NickGeneva Jun 24, 2026
622a22b
Feedback 2
NickGeneva Jun 24, 2026
af8635c
Feedback 2
NickGeneva Jun 24, 2026
c21fb44
Updates
NickGeneva Jun 24, 2026
a09ee39
Merge branch 'main' into codex/checkpoint-catalog
NickGeneva Jun 24, 2026
3ae26cc
Few fixes
NickGeneva Jun 24, 2026
444e4d5
Few fixes
NickGeneva Jun 24, 2026
be2197f
Few fixes
NickGeneva Jun 24, 2026
5dc41b7
Few fixes
NickGeneva Jun 24, 2026
1d55b1b
Convience
NickGeneva Jun 24, 2026
12f0257
Convience
NickGeneva Jun 24, 2026
6d68551
test fixes
NickGeneva Jun 24, 2026
5332852
Use numeric checkpoint levels
NickGeneva Jun 24, 2026
2d36f5b
Guard workflow checkpoint resumes by level
NickGeneva Jun 24, 2026
a4bdde3
Select ensemble checkpoints by member
NickGeneva Jun 24, 2026
7674409
Merge branch 'main' into codex/checkpoint-catalog
NickGeneva Jun 24, 2026
a63b416
Lint
NickGeneva Jun 24, 2026
5c44475
Fix checkpoint typing
NickGeneva Jun 24, 2026
2cbda59
Report checkpoint metadata serialization keys
NickGeneva Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
through 2011-03-27 with cycles every 5 days, served from the NCEI HTTPS
archive
- Added IBTrACS tropical cyclone track DataFrame source (`IBTrACS`)
- Added checkpoint/session utilities and restart support for deterministic,
diagnostic, and ensemble inference workflows
- Added EUMETNET OPERA European weather radar composite DataSource for DBZH
reflectivity, rain rate, and 1-hour accumulation (`OPERA`)
- Added support for cumulative variables in ARCO data source
Expand Down
25 changes: 25 additions & 0 deletions docs/modules/utils_all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,31 @@ The following functions can be used to convert to and from these numpy arrays.
utils.time.timearray_to_datetime
utils.time.to_time_array

.. _earth2studio.utils.checkpoint:

:mod:`earth2studio.utils`: Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Checkpoint utilities for restartable inference workflows. Checkpoints track
workflow progress, restart rows, optional artifacts, and component state needed
to resume long-running workflows while forecast fields remain in the selected
IO backend.

.. autosummary::
:toctree: generated/utils/
:template: class.rst

utils.checkpoint.Checkpoint
utils.checkpoint.CheckpointSession
utils.checkpoint.CheckpointState
utils.checkpoint.NullCheckpoint

.. autosummary::
:toctree: generated/utils/
:template: function.rst

utils.checkpoint.bind_checkpoint_state

.. _earth2studio.data.functions:

:mod:`earth2studio.data`: Data
Expand Down
3 changes: 1 addition & 2 deletions docs/userguide/about/install.md
Original file line number Diff line number Diff line change
Expand Up @@ -620,8 +620,7 @@ entries in the pipeline configuration must reference the full path to the
Docker container, the binaries are copied to `/usr/local/bin` and are therefore
available on the `PATH`; in that case only the executable names are needed
(e.g. `DetectNodes ...`). Examples for both commands are provided in the
docstring of the `TempestExtremes` class and in the
[TC tracking recipe](../../recipes/tc_tracking/README.md).
docstring of the `TempestExtremes` class and in the TC tracking recipe.

:::
::::
Expand Down
212 changes: 212 additions & 0 deletions docs/userguide/advanced/checkpointing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Checkpointing

Earth2Studio checkpoints enable users to restart inference workflows. A
checkpoint is a catalog of restart points for one named run. Each row in that
catalog is a checkpoint session: it records the completed lead time, optional
metadata, and opt-in component state(s) needed to restart the workflow.

Checkpoint storage is independent of the user-facing IO backend and independent
of any particular model implementation. Forecast fields remain in the IO backend
you choose, model weights are not copied into checkpoints, and the checkpoint
catalog stores only the restart metadata and component state requested by the
configured checkpoint level. Exactly what gets logged is user-configurable
through the checkpoint options and through the checkpoint support implemented
by each component.

```{warning}
Checkpointing support is opt-in for every component. Not all models,
perturbations, or custom components support checkpointed restarts. Always verify
that the model you plan to use has checkpoint support before relying on it
for restartable inference. If checkpointing is missing for a model you need, open
a [feature request](https://github.com/NVIDIA/earth2studio/issues).
```

## Basic Use

```python
from earth2studio.run import deterministic
from earth2studio.utils.checkpoint import Checkpoint

checkpoint = Checkpoint("my-forecast", flush_interval=6, level=2)

with checkpoint as ckpt:
deterministic(
time=["2024-01-01"],
nsteps=24,
prognostic=model,
data=data,
io=io,
checkpoint=ckpt,
)
```

Use checkpointed workflows inside a checkpoint context. This makes the active
session clear and lets restart-aware models or perturbations bind their state
before the run starts. Built-in workflows automatically interact with the
checkpoint session: they call `write` after successful IO writes and call
`flush` before returning. If no checkpoint is supplied, they use the
`NullCheckpoint` no-op checkpoint.

`Checkpoint` options:

- `flush_interval=1`: commit every workflow write.
- `flush_interval=None`: keep writes pending until `flush()`.
- `mode="overwrite"`: keep only the latest row.
- `mode="append"`: keep a row history; cap it with `history_size`.
- `device=torch.device("cpu")`: device used by components for staged tensor state.

Components can opt into checkpoint state when they need restart information
(RNG state, counters, tensors, etc.). The user chooses the requested level,
and each component decides what it supports:

- `0`: no component logging; workflows still record catalog progress and explicit metadata.
- `1`: enough component state to restart a workflow item such as an ensemble member.
- `2`: full component state for restarting inside a rollout when the component supports it.

Comment thread
pzharrington marked this conversation as resolved.
## Selecting Rows

Print a checkpoint to inspect its catalog, then select a row by integer index.
Negative indexing is supported.

```python
checkpoint = Checkpoint("my-forecast")

print(checkpoint)

latest = checkpoint.select(-1)
Comment thread
NickGeneva marked this conversation as resolved.
first = checkpoint.select(0)
```

By default, the context manager `with checkpoint` selects the latest catalog row
when one exists, or opens a new session when the catalog is empty. Use
`with checkpoint.select(-1):` when a specific saved row should be restored.

## Custom Loops

Call `write` after a safe restart boundary, usually after forecast fields have
been written to IO. Call `flush()` to force the latest pending write to disk.

```python
checkpoint = Checkpoint("forecast", mode="append", history_size=8)

with checkpoint as ckpt:
for lead_time in lead_times:
x, coords = step_model(...)
io.write(*split_coords(x, coords))
ckpt.write(
lead_time=lead_time,
last_complete_lead_time=lead_time,
)

ckpt.flush()
```

`write` and `flush` accept generic keyword metadata. Use metadata for small
restart properties; large forecast arrays should remain in the IO backend.

## Workflow Resume

Built-in workflows, like `run.deterministic`, always fetch the normal initial
condition and pass it to the prognostic iterator.
The checkpoint session tells the workflow what has been already completed.
For ensemble runs, the workflow records `completed_ensembles` metadata
and resumes from the first incomplete member.
Checkpoint-aware models decide whether to restore their own state and, if restored,
should yield the next forecast state rather than the existing saved boundary.

This decouples user-facing IO from what might be required to resume a run via a
checkpoint: users often save only a subset of generated forecast variables, but
continuing a rollout may require all fields and possibly internal model state.

## Component State

Models, perturbations, and custom components opt in by binding a dataclass.
Existing components do not need to change.

```python
from dataclasses import dataclass

import torch

from earth2studio.utils.checkpoint import bind_checkpoint_state


@dataclass
class NoiseState:
rng_state: torch.Tensor | None = None


class NoisePerturbation:
def __init__(self, generator: torch.Generator):
self.generator = generator
self.checkpoint = bind_checkpoint_state(NoiseState())

if self.checkpoint.rng_state is not None:
self.generator.set_state(self.checkpoint.rng_state)

def __call__(self, x):
y = add_noise(x, generator=self.generator)
if self.checkpoint.checkpoint_level >= 1:
self.checkpoint.rng_state = self.generator.get_state()
return y
```

`bind_checkpoint_state` returns a proxy around the dataclass. Normal dataclass
fields are accessed directly. Checkpoint metadata is available through read-only
properties such as `checkpoint_enabled`, `checkpoint_level`,
`checkpoint_state_loaded`, `checkpoint_metadata`, `checkpoint_lead_time`, and
`device`. Use `device` for staging tensor state, for example
`x.detach().clone().to(self.checkpoint.device)`.

Construct restart-aware components inside the checkpoint context when saved
state must be restored during initialization:

```python
checkpoint = Checkpoint("my-forecast")

with checkpoint.select(-1) as ckpt:
model = MyRestartableModel(...)
deterministic(..., checkpoint=ckpt)
```

If a component binds before an existing session is active, Earth2Studio restores
it when the session opens and warns because constructor side effects may already
Comment thread
NickGeneva marked this conversation as resolved.
have used default state.

State identity is the dataclass type's fully qualified module and class name.
Binding the same dataclass type twice in one session raises
`CheckpointStateCollision`; use distinct dataclass types for distinct components.

## Serialization

Checkpoint state is pickle-free. Supported values include JSON-like scalars and
containers, dataclasses, `datetime`, `date`, `timedelta`, `numpy.datetime64`,
`numpy.timedelta64`, NumPy scalars and dtypes, `torch.device`, `torch.dtype`,
`torch.Tensor`, and non-object `numpy.ndarray`. Tensors and arrays are stored as
separate `.npy` files with pickle disabled.

Unsupported values raise `CheckpointSerializationError`. Incompatible dataclass
schema changes raise `CheckpointStateSchemaError` during restore.

## Storage

Default location:

```text
$EARTH2STUDIO_CACHE/checkpoints/<name>
```

or, if `EARTH2STUDIO_CACHE` is unset:

```text
~/.cache/earth2studio/checkpoints/<name>
```

Pass `path=` to choose another location. Serial runs write directly into that
folder. Distributed runs write per-rank folders such as `rank_000000`; rank
detection checks PhysicsNeMo's distributed manager first, then common distributed
environment variables.

Checkpoint writes are staged in temporary directories and atomically moved into
place when complete. Catalog JSON writes use the same temporary-file pattern.
Incomplete temporary writes are ignored and cleaned up by later writes.
1 change: 1 addition & 0 deletions docs/userguide/advanced/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
```{toctree}
:maxdepth: 1

checkpointing
batch
auto
lexicon
Expand Down
1 change: 1 addition & 0 deletions docs/userguide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ run(["2024-01-01"], 10, model, ds, io)

## Advanced Usage

- [Checkpointing](advanced/checkpointing)
- [Batch Dimension](advanced/batch)
- [AutoModels](advanced/auto)
- [Lexicon](advanced/lexicon)
Expand Down
4 changes: 3 additions & 1 deletion earth2studio/io/kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import torch
import xarray
from loguru import logger

from earth2studio.utils.coords import convert_multidim_to_singledim
from earth2studio.utils.type import CoordSystem
Expand Down Expand Up @@ -123,7 +124,8 @@ def add_array(

for name, di in zip(array_name, data):
if name in self.root:
raise AssertionError(f"Warning! {name} is already in KV Store.")
logger.warning("{} is already in KV Store. Skipping add_array.", name)
continue

self.dims[name] = list(adjusted_coords)

Expand Down
8 changes: 4 additions & 4 deletions earth2studio/io/netcdf4.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import torch
from cftime import date2num, num2date
from loguru import logger
from netCDF4 import Dataset, Variable

from earth2studio.utils.coords import convert_multidim_to_singledim
Expand Down Expand Up @@ -208,11 +209,10 @@ def add_array(

for name, di in zip(array_name, data):
if name in self.root.variables:
raise RuntimeError(
f"{name} is already in NetCDF Store. "
+ "NetCDF does not allow variables to be redefined. "
+ r"To overwrite entire NetCDF, create object with backend_kwargs=\{'mode': 'w'\}"
logger.warning(
"{} is already in NetCDF Store. Skipping add_array.", name
)
continue

di = di.cpu().numpy() if di is not None else None
dtype = di.dtype if di is not None else "float32"
Expand Down
6 changes: 5 additions & 1 deletion earth2studio/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import torch
import xarray as xr
from loguru import logger

from earth2studio.utils.coords import convert_multidim_to_singledim
from earth2studio.utils.type import CoordSystem
Expand Down Expand Up @@ -141,7 +142,10 @@ def add_array(

for name, di in zip(array_name, data):
if name in self.root:
raise AssertionError(f"Warning! {name} is already in xarray Dataset.")
logger.warning(
"{} is already in xarray Dataset. Skipping add_array.", name
)
continue

if di is not None:
self.root[name] = xr.DataArray(
Expand Down
8 changes: 3 additions & 5 deletions earth2studio/io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import torch
import zarr
from loguru import logger
from zarr.core.array import Array as ZarrArray
from zarr.core.array import CompressorsLike

Expand Down Expand Up @@ -202,11 +203,8 @@ def add_array(

for name, di in zip(array_name, data):
if name in self.root and not kwargs.get("overwrite", False):
raise RuntimeError(
f"{name} is already in Zarr Store. "
+ "To overwrite Zarr array pass overwrite=True to this function or"
+ " backend_kwargs = {'overwrite': True} to the ZarrBackend constructor"
)
logger.warning("{} is already in Zarr Store. Skipping add_array.", name)
continue

di = di.cpu().numpy() if di is not None else None
dtype = di.dtype if di is not None else "float32"
Expand Down
Loading