Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2cfe7c1
Add checkpoint utilities and workflow support
NickGeneva Jun 15, 2026
a14e263
Move restart example support into Persistence
NickGeneva Jun 15, 2026
2e06c13
Merge remote-tracking branch 'upstream/main' into codex/checkpoint-ca…
NickGeneva Jun 15, 2026
af5f39a
Add UCast checkpoint restart support
NickGeneva Jun 15, 2026
55c9ee8
Document model checkpoint opt-in warning
NickGeneva Jun 15, 2026
8e6acf6
Remove legacy checkpoint state policy aliases
NickGeneva Jun 15, 2026
31cda94
Rename checkpoint history limit parameter
NickGeneva Jun 15, 2026
7900ae5
Avoid rank folders for serial checkpoints
NickGeneva Jun 15, 2026
0aa55fc
Group checkpoint serialization helpers in codec
NickGeneva Jun 15, 2026
dcff29b
Use FCN for checkpoint restart support
NickGeneva Jun 15, 2026
3492e2c
Inline checkpoint helper functions
NickGeneva Jun 16, 2026
2864940
Use checkpoint context in docs examples
NickGeneva Jun 16, 2026
4977267
Simplify workflow checkpoint selection
NickGeneva Jun 16, 2026
2c0187f
Tighten checkpointing documentation
NickGeneva Jun 16, 2026
01ebee3
Feedback 1
NickGeneva Jun 24, 2026
b868a8b
Feedback 2
NickGeneva Jun 24, 2026
e464cd7
Feedback 2
NickGeneva Jun 24, 2026
622a22b
Feedback 2
NickGeneva Jun 24, 2026
af8635c
Feedback 2
NickGeneva Jun 24, 2026
c21fb44
Updates
NickGeneva Jun 24, 2026
a09ee39
Merge branch 'main' into codex/checkpoint-catalog
NickGeneva Jun 24, 2026
3ae26cc
Few fixes
NickGeneva Jun 24, 2026
444e4d5
Few fixes
NickGeneva Jun 24, 2026
be2197f
Few fixes
NickGeneva Jun 24, 2026
5dc41b7
Few fixes
NickGeneva Jun 24, 2026
1d55b1b
Convience
NickGeneva Jun 24, 2026
12f0257
Convience
NickGeneva Jun 24, 2026
6d68551
test fixes
NickGeneva Jun 24, 2026
5332852
Use numeric checkpoint levels
NickGeneva Jun 24, 2026
2d36f5b
Guard workflow checkpoint resumes by level
NickGeneva Jun 24, 2026
a4bdde3
Select ensemble checkpoints by member
NickGeneva Jun 24, 2026
7674409
Merge branch 'main' into codex/checkpoint-catalog
NickGeneva Jun 24, 2026
a63b416
Lint
NickGeneva Jun 24, 2026
5c44475
Fix checkpoint typing
NickGeneva Jun 24, 2026
2cbda59
Report checkpoint metadata serialization keys
NickGeneva Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
through 2011-03-27 with cycles every 5 days, served from the NCEI HTTPS
archive
- Added IBTrACS tropical cyclone track DataFrame source (`IBTrACS`)
- Added checkpoint/session utilities and restart support for deterministic,
diagnostic, and ensemble inference workflows
- Added checkpoint state policy hints exposed through bound state proxies
- Added FCN/FourCastNet checkpoint state support for mid-rollout deterministic forecast restarts
- Added EUMETNET OPERA European weather radar composite DataSource for DBZH
reflectivity, rain rate, and 1-hour accumulation (`OPERA`)
- Added support for cumulative variables in ARCO data source

### Changed

- Simplified workflow checkpoint handling with package-level no-op checkpoint
sessions and idempotent final flushes
- Deterministic checkpoint resume now leaves restart-state restoration to the
prognostic iterator after fetching the normal initial condition
- Renamed checkpoint state policies to `minimal`, `state`, and `full`
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
- UFS GSI observation sources (`UFSObsConv`, `UFSObsSat`) now fetch from S3 via native
`obstore` instead of `s3fs` to avoid the Python-GIL bottleneck that caps fsspec's
concurrent S3 read throughput (~22% faster obs fetch, ~20% HealDA e2e on B200; output unchanged).
Expand Down
24 changes: 24 additions & 0 deletions docs/modules/utils_all.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,30 @@ The following functions can be used to convert to and from these numpy arrays.
utils.time.timearray_to_datetime
utils.time.to_time_array

.. _earth2studio.utils.checkpoint:

:mod:`earth2studio.utils`: Checkpointing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Checkpoint utilities for restartable inference workflows. Checkpoints track
explicit workflow labels, optional artifacts, and dataclass state bound by
components during a selected checkpoint session.
Comment thread
NickGeneva marked this conversation as resolved.
Outdated

.. autosummary::
:toctree: generated/utils/
:template: class.rst

utils.checkpoint.Checkpoint
utils.checkpoint.CheckpointSession
utils.checkpoint.CheckpointState
utils.checkpoint.NullCheckpointSession

.. autosummary::
:toctree: generated/utils/
:template: function.rst

utils.checkpoint.bind_checkpoint_state

.. _earth2studio.data.functions:

:mod:`earth2studio.data`: Data
Expand Down
194 changes: 194 additions & 0 deletions docs/userguide/developer/checkpointing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Checkpointing

Earth2Studio checkpoints are small restart catalogs for inference workflows. A
checkpoint row records progress labels, the latest completed lead time, optional
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
small artifacts, and dataclass state from components that opt in. Forecast fields
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
stay in the selected IO backend; model weights are not copied into checkpoints.

## Basic Use

```python
from earth2studio.run import deterministic
from earth2studio.utils.checkpoint import Checkpoint

checkpoint = Checkpoint("my-forecast", flush_interval=6, state_policy="full")

with checkpoint as ckpt:
deterministic(
time=["2024-01-01"],
nsteps=24,
prognostic=model,
data=data,
io=io,
checkpoint=ckpt,
)
```

Use checkpointed workflows inside a checkpoint context. This makes the active
session clear and lets restart-aware models or perturbations bind their state
before the run starts. Built-in workflows call `write` after successful IO writes
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
and call `flush` before returning. If no checkpoint is supplied, they use the
package no-op checkpoint session.

`Checkpoint` options:

- `flush_interval=1`: commit every workflow write.
- `flush_interval=None`: keep writes pending until `flush()`.
- `mode="overwrite"`: keep the latest row for a label set.
- `mode="append"`: keep a row history; cap it with `history_size`.
- `device=torch.device("cpu")`: device used by components for staged tensor state.

`state_policy` is a hint for opt-in components:
Comment thread
NickGeneva marked this conversation as resolved.
Outdated

- `minimal`: catalog progress and explicit artifacts only.
- `state`: lightweight restart state such as RNG or counters.
- `full`: all supported restart state, including tensors needed to resume inside a rollout.

## Selecting Rows

Print a checkpoint to inspect its catalog, then select a row by index or labels.
Negative indexing is supported. Label value `-1` means the latest saved value for
that label after applying any other label filters.

```python
checkpoint = Checkpoint("my-forecast")

print(checkpoint)

latest = checkpoint.select(-1)
latest_time = checkpoint.select(time=-1)
member = checkpoint.select(time="2024-01-01T00:00:00", ensemble=0)
```

`with checkpoint` selects the latest catalog row when one exists, or opens a new
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
session when the catalog is empty. Use `with checkpoint.select(...):` when a
specific row or label set is required.

## Custom Loops

Call `write` after a safe restart boundary, usually after forecast fields have
been written to IO. Call `flush()` to force the latest pending write to disk.

```python
checkpoint = Checkpoint("ensemble-forecast", mode="append", history_size=8)

with checkpoint.select(time=time, ensemble=member) as ckpt:
for lead_time in lead_times:
x, coords = step_model(...)
io.write(*split_coords(x, coords))
ckpt.write(
lead_time=lead_time,
artifacts={"last_complete_lead_time": lead_time},
)

ckpt.flush()
```

`write` accepts only `lead_time` and `artifacts`. Artifacts are for small explicit
restart metadata; large forecast arrays should remain in the IO backend.

## Workflow Resume

Built-in workflows always fetch the normal initial condition and pass it to the
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
prognostic iterator. The checkpoint row tells the workflow which writes already
completed. Checkpoint-aware models decide whether to restore their own dataclass
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
state and, if restored, should yield the next forecast state rather than the
saved boundary.

This avoids assuming that user-facing IO output is restart-complete. It matters
when IO stores only selected variables but a model needs more internal state to
continue a rollout. Ensemble workflows use the `ensemble_batch` label so each
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
mini-batch can resume independently.

## Component State

Models, perturbations, and custom components opt in by binding a dataclass.
Existing components do not need to change.

```python
from dataclasses import dataclass

import torch

from earth2studio.utils.checkpoint import bind_checkpoint_state


@dataclass
class NoiseState:
rng_state: torch.Tensor | None = None


class NoisePerturbation:
def __init__(self, generator: torch.Generator):
self.generator = generator
self.checkpoint = bind_checkpoint_state(NoiseState())

if self.checkpoint.rng_state is not None:
self.generator.set_state(self.checkpoint.rng_state)

def __call__(self, x):
y = add_noise(x, generator=self.generator)
if self.checkpoint.checkpoint_enabled:
self.checkpoint.rng_state = self.generator.get_state()
return y
```

`bind_checkpoint_state` returns a proxy around the dataclass. Normal dataclass
fields are accessed directly. Checkpoint metadata is available through read-only
properties such as `checkpoint_enabled`, `checkpoint_state_policy`,
`checkpoint_state_loaded`, `checkpoint_lead_time`, `checkpoint_labels`, and
`device`. Use `device` for staging tensor state, for example
`x.detach().clone().to(self.checkpoint.device)`.

Construct restart-aware components inside the checkpoint context when hydration
must happen during initialization:
Comment thread
NickGeneva marked this conversation as resolved.
Outdated

```python
checkpoint = Checkpoint("my-forecast")

with checkpoint.select(-1) as ckpt:
model = MyRestartableModel(...)
deterministic(..., checkpoint=ckpt)
```

If a component binds before an existing session is active, Earth2Studio hydrates
it when the session opens and warns because constructor side effects may already
have used default state.

State identity is the dataclass type's fully qualified module and class name.
Binding the same dataclass type twice in one session raises
`CheckpointStateCollision`; use distinct dataclass types for distinct components.

## Serialization

Checkpoint state is pickle-free. Supported values include JSON-like scalars and
containers, dataclasses, `datetime`, `date`, `timedelta`, `numpy.datetime64`,
`numpy.timedelta64`, NumPy scalars and dtypes, `torch.device`, `torch.dtype`,
`torch.Tensor`, and non-object `numpy.ndarray`. Tensors and arrays are stored as
separate `.npy` files with pickle disabled.

Unsupported values raise `CheckpointSerializationError`. Incompatible dataclass
schema changes raise `CheckpointStateSchemaError` during hydration.
Comment thread
NickGeneva marked this conversation as resolved.
Outdated

## Storage

Default location:

```text
$EARTH2STUDIO_CACHE/checkpoints/<name>
```

or, if `EARTH2STUDIO_CACHE` is unset:

```text
~/.cache/earth2studio/checkpoints/<name>
```

Pass `path=` to choose another location. Serial runs write directly into that
folder. Distributed runs write per-rank folders such as `rank_000000`; rank
detection checks PhysicsNeMo's distributed manager first, then common distributed
environment variables.

Checkpoint writes are staged in temporary directories and atomically moved into
place when complete. Catalog JSON writes use the same temporary-file pattern.
Incomplete temporary writes are ignored and cleaned up by later writes.
1 change: 1 addition & 0 deletions docs/userguide/developer/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ testing
build
recipes
skills
checkpointing
```
1 change: 1 addition & 0 deletions docs/userguide/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ run(["2024-01-01"], 10, model, ds, io)
- [Testing](developer/testing)
- [Build](developer/build)
- [Recipes](developer/recipes)
- [Checkpointing](developer/checkpointing)

## Support

Expand Down
52 changes: 51 additions & 1 deletion earth2studio/models/px/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections import OrderedDict
from collections.abc import Generator, Iterator
from dataclasses import dataclass

import numpy as np
import torch
Expand All @@ -25,6 +26,7 @@
from earth2studio.models.px.base import PrognosticModel
from earth2studio.models.px.utils import PrognosticMixin
from earth2studio.utils import handshake_coords, handshake_dim
from earth2studio.utils.checkpoint import bind_checkpoint_state
from earth2studio.utils.imports import (
OptionalDependencyFailure,
check_optional_dependencies,
Expand Down Expand Up @@ -67,6 +69,13 @@
]


@dataclass
class _FCNCheckpointState:
x: torch.Tensor | None = None
coord_keys: tuple[str, ...] = ()
coord_values: tuple[np.ndarray, ...] = ()


@check_optional_dependencies()
class FCN(torch.nn.Module, AutoModelMixin, PrognosticMixin):
"""FourCastNet global prognostic model. Consists of a single model with a time-step
Expand Down Expand Up @@ -105,6 +114,7 @@ def __init__(
self.model = core_model
self.register_buffer("center", center)
self.register_buffer("scale", scale)
self.checkpoint = bind_checkpoint_state(_FCNCheckpointState())

# sphinx - coords start
def input_coords(self) -> CoordSystem:
Expand Down Expand Up @@ -175,6 +185,40 @@ def __str__(
) -> str:
return "fcn"

def _restore_checkpoint_state(
self, x: torch.Tensor, coords: CoordSystem
) -> tuple[torch.Tensor, CoordSystem, bool]:
if (
self.checkpoint.checkpoint_state_policy == "full"
and self.checkpoint.checkpoint_state_loaded
and self.checkpoint.x is not None
and self.checkpoint.coord_keys
):
x = self.checkpoint.x.to(x.device)
coords = OrderedDict(
(key, np.asarray(value).copy())
for key, value in zip(
self.checkpoint.coord_keys, self.checkpoint.coord_values
)
)
return x, coords, True
return x, coords, False

def _save_checkpoint_state(self, x: torch.Tensor, coords: CoordSystem) -> None:
if (
self.checkpoint.checkpoint_enabled
and self.checkpoint.checkpoint_state_policy == "full"
):
self.checkpoint.x = x.detach().clone().to(self.checkpoint.device)
self.checkpoint.coord_keys = tuple(coords.keys())
self.checkpoint.coord_values = tuple(
np.asarray(value).copy() for value in coords.values()
)
else:
self.checkpoint.x = None
self.checkpoint.coord_keys = ()
self.checkpoint.coord_values = ()

Comment thread
NickGeneva marked this conversation as resolved.
@classmethod
def load_default_package(cls) -> Package:
"""Load prognostic package"""
Expand Down Expand Up @@ -235,9 +279,11 @@ def __call__(
tuple[torch.Tensor, CoordSystem]
Output tensor and coordinate system 6 hours in the future
"""
x, coords, _ = self._restore_checkpoint_state(x, coords)
output_coords = self.output_coords(coords)

x = self._forward(x)
self._save_checkpoint_state(x, output_coords)

return x, output_coords

Expand All @@ -246,10 +292,13 @@ def _default_generator(
self, x: torch.Tensor, coords: CoordSystem
) -> Generator[tuple[torch.Tensor, CoordSystem], None, None]:
coords = coords.copy()
x, coords, restored = self._restore_checkpoint_state(x, coords)

self.output_coords(coords)

yield x, coords
if not restored:
Comment thread
NickGeneva marked this conversation as resolved.
self._save_checkpoint_state(x, coords)
yield x, coords

while True:
# Front hook
Expand All @@ -261,6 +310,7 @@ def _default_generator(

# Rear hook
x, coords = self.rear_hook(x, coords)
self._save_checkpoint_state(x, coords)

yield x, coords.copy()

Expand Down
Loading