Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 148 additions & 27 deletions earth2studio/data/ufs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@

from __future__ import annotations

import concurrent.futures
import hashlib
import math
import multiprocessing
import os
import pathlib
import shutil
Expand All @@ -27,10 +30,11 @@

import h5netcdf
import numpy as np
import obstore as obs
import pandas as pd
import pyarrow as pa
import s3fs
from loguru import logger
from obstore.store import S3Store
from tqdm.asyncio import tqdm

from earth2studio.data.utils import _sync_async, datasource_cache_root, prep_data_inputs
Expand All @@ -53,6 +57,37 @@ class _GSIAsyncTask:
satellite: str | None = None


# Transient context for the parallel-decode worker pool. Set just before the
# ProcessPoolExecutor runs and cleared after; forked workers inherit it via
# copy-on-write (which also carries the otherwise-unpicklable lexicon closures).
_DECODE_CTX: dict = {}

# Cap for automatic decode-worker selection. Decode throughput plateaus once it
# drops below the network floor (~16 workers in practice), and more workers than
# files is wasted; this bounds the "auto" setting.
_DECODE_WORKERS_CAP = 16


def _cuda_initialized() -> bool:
"""True if a CUDA context already exists in this process.

Forking after CUDA init is unsafe, so parallel (fork-based) decode falls
back to serial when this is True.
"""
try:
import torch

return torch.cuda.is_initialized()
except Exception:
return False


def _decode_chunk_idx(i: int) -> pd.DataFrame | None:
"""Worker entry point: decode the i-th task chunk in a forked process."""
self, chunks, variables, schema = _DECODE_CTX["args"]
return self._compile_chunk(chunks[i], variables, schema)


class _UFSObsBase:
"""Base class for GSI data sources.

Expand All @@ -78,17 +113,26 @@ def __init__(
self._max_workers = max_workers
self.async_timeout = async_timeout
self._tmp_cache_hash: str | None = None
self.fs: s3fs.S3FileSystem | None = None
# Anonymous obstore S3 stores, cached per bucket (created lazily).
self._stores: dict[str, S3Store] = {}

lower, upper = normalize_time_tolerance(time_tolerance)
self._tolerance_lower = pd.to_timedelta(lower).to_pytimedelta()
self._tolerance_upper = pd.to_timedelta(upper).to_pytimedelta()

async def _async_init(self) -> None:
"""Async initialization of S3 filesystem"""
self.fs = s3fs.S3FileSystem(
anon=True, client_kwargs={}, asynchronous=True, skip_instance_cache=True
)
# NOAA UFS GEFSv13 replay archive is a public bucket in us-east-1.
_region = "us-east-1"

def _store(self, bucket: str) -> S3Store:
"""Return a cached anonymous obstore S3Store for ``bucket``."""
if bucket not in self._stores:
self._stores[bucket] = S3Store(
bucket,
region=self._region,
skip_signature=True,
client_options={"pool_max_idle_per_host": str(self._max_workers)},
)
return self._stores[bucket]

def __call__(
self,
Expand Down Expand Up @@ -124,11 +168,7 @@ async def fetch(
fields: str | list[str] | pa.Schema | None = None,
) -> pd.DataFrame:
"""Async function to get data."""
if self.fs is None:
await self._async_init()

session = await self.fs.set_session(refresh=True) # type: ignore[union-attr]

# obstore S3 stores are created lazily per bucket in _fetch_remote_file.
time_list, variable_list = prep_data_inputs(time, variable)
self._validate_time(time_list)
schema = self.resolve_fields(fields)
Expand All @@ -141,9 +181,6 @@ async def fetch(
*fetch_jobs, desc="Fetching GSI files", disable=(not self._verbose)
)

if session:
await session.close()

df = self._compile_dataframe(async_tasks, variable_list, schema)

return df
Expand Down Expand Up @@ -171,32 +208,114 @@ async def _fetch_remote_file(
byte_length : int | None, optional
Number of bytes to read, by default None (read all)
"""
if self.fs is None:
raise ValueError("File system is not initialized")

cache_path = self.cache_path(path, byte_offset, byte_length)
if not pathlib.Path(cache_path).is_file():
if pathlib.Path(cache_path).is_file():
return

# path is an S3 URI ("bucket/key" or "s3://bucket/key"); split into
# bucket + key for the per-bucket obstore store.
key = path[5:] if path.startswith("s3://") else path
bucket, _, object_key = key.partition("/")
store = self._store(bucket)
try:
if byte_length:
byte_length = int(byte_offset + byte_length)
try:
data = await self.fs._cat_file(path, start=byte_offset, end=byte_length)
with open(cache_path, "wb") as file:
file.write(data)
except FileNotFoundError:
self._handle_missing_file(path)
result = await obs.get_range_async(
store, object_key, start=byte_offset, end=byte_offset + byte_length
)
else:
response = await obs.get_async(store, object_key)
result = await response.bytes_async()
data = result.to_bytes() if hasattr(result, "to_bytes") else bytes(result)
with open(cache_path, "wb") as file:
file.write(data)
except (FileNotFoundError, obs.exceptions.NotFoundError):
self._handle_missing_file(path)
except Exception as err:
raise
Comment on lines +231 to +234

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The except Exception as err: raise block does nothing — it catches the exception and immediately re-raises it with no logging, wrapping, or side-effects. Remove it so unhandled exceptions propagate naturally from the try block.

Suggested change
except (FileNotFoundError, obs.exceptions.NotFoundError):
self._handle_missing_file(path)
except Exception as err:
raise
except (FileNotFoundError, obs.exceptions.NotFoundError):
self._handle_missing_file(path)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


def _handle_missing_file(self, path: str) -> None:
"""Handle missing file during fetch. Can be overridden by subclasses."""
logger.error(f"File {path} not found")
raise FileNotFoundError(f"File {path} not found")

def _resolve_decode_workers(self, n_tasks: int) -> int:
"""Automatically choose the NetCDF->DataFrame decode worker count.

Picks ``min(available_cpus, cap, n_tasks)`` -- decode throughput plateaus
once it drops below the network floor (cap ``_DECODE_WORKERS_CAP``) and
never needs more workers than files.

Safety guard: parallel decode uses the ``fork`` start method (so workers
inherit the unpicklable GSI lexicon closures via copy-on-write). Forking
after a CUDA context exists is unsafe, so if CUDA is already initialized
this falls back to serial decode.
"""
try:
avail = len(os.sched_getaffinity(0)) # CPUs available to process
except AttributeError: # not available on this platform
avail = os.cpu_count() or 1
workers = max(1, min(_DECODE_WORKERS_CAP, avail, n_tasks))
if workers > 1 and (
"fork" not in multiprocessing.get_all_start_methods()
or _cuda_initialized()
):
# Parallel decode requires the 'fork' start method (not available on
# Windows; unsafe on macOS). Also unsafe once CUDA is initialized.
# In either case fall back to serial decode.
workers = 1
Comment on lines +258 to +265

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 macOS safety guard does not cover macOS

The condition "fork" not in multiprocessing.get_all_start_methods() evaluates to False on macOS because fork is technically in the list of available methods — it just isn't the safe/default one. The code comment explicitly states "unsafe on macOS … fall back to serial decode," but the guard never triggers on macOS, and line 300 unconditionally creates the executor with get_context("fork"). On macOS this will proceed with fork-after-threads, risking deadlocks or corrupted state from Objective-C/Core Foundation globals inherited by workers. The fix is to also check sys.platform == "darwin" (or compare against multiprocessing.get_start_method(), which defaults to "spawn" on macOS) before allowing parallel decode.

return workers

def _compile_dataframe(
self,
async_tasks: list[_GSIAsyncTask],
variables: list[str],
schema: pa.Schema,
) -> pd.DataFrame:
"""Compile fetched data into a DataFrame."""
"""Compile fetched GSI files into a DataFrame.

Each file's HDF5->pandas decode is CPU- and GIL-bound, so the files are
decoded across forked worker processes when more than one is selected
(the count is chosen automatically; see :meth:`_resolve_decode_workers`).
Falls back to serial when CUDA is already initialized, since the 'fork'
start method (used to inherit the unpicklable lexicon closures) is unsafe
after CUDA init.
"""
workers = self._resolve_decode_workers(len(async_tasks))
if workers <= 1:
result = self._compile_chunk(async_tasks, variables, schema)
return result if result is not None else pd.DataFrame()

size = math.ceil(len(async_tasks) / workers)
chunks = [
c
for c in (
async_tasks[i * size : (i + 1) * size] for i in range(workers)
)
if c
]
_DECODE_CTX["args"] = (self, chunks, variables, schema)
try:
with concurrent.futures.ProcessPoolExecutor(
max_workers=len(chunks),
mp_context=multiprocessing.get_context("fork"),
) as executor:
parts = list(executor.map(_decode_chunk_idx, range(len(chunks))))
finally:
_DECODE_CTX.clear()
Comment on lines +296 to +304

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Thread-unsafe module-level global _DECODE_CTX

_DECODE_CTX["args"] is written on line 296 and the forked workers inherit it via copy-on-write. If two calls to _compile_dataframe are in-flight simultaneously — e.g., when _sync_async dispatches work across threads or when two separate instances are driven concurrently — Thread A's write can be overwritten by Thread B before A's workers are forked. Workers from Thread A would then process Thread B's tasks, silently producing wrong results or raising index/shape errors. Even with the GIL protecting individual dict operations, the three-step sequence (write global → fork workers → workers read global) is not atomic. Consider replacing the global with a multiprocessing.Manager shared value, threading locks around the critical section, or — simplest — passing the context as pickled init args to the pool initializer instead of relying on global inheritance.


frames = [p for p in parts if p is not None and len(p)]
if not frames: # all chunks empty (e.g. all files missing) -> match serial
return pd.DataFrame()
result = pd.concat(frames, ignore_index=True)
return result[[name for name in schema.names if name in result.columns]]

def _compile_chunk(
self,
async_tasks: list[_GSIAsyncTask],
variables: list[str],
schema: pa.Schema,
) -> pd.DataFrame | None:
"""Decode one set of GSI files into a DataFrame (the per-process unit)."""
# Identify schema fields that are per-channel (need Channel_Index lookup)
channel_indexed_fields: dict[str, str] = {}
for field in schema:
Expand Down Expand Up @@ -273,6 +392,8 @@ def _compile_dataframe(
df = df.loc[mask]
frames.append(task.gsi_modifier(df))

if not frames:
return None
result = pd.concat(frames, ignore_index=True)
return result[[name for name in schema.names if name in result.columns]]

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies = [
"huggingface-hub>=0.27.0",
"loguru",
"netCDF4>=1.6.4,<1.7.3", # https://github.com/Unidata/netcdf4-python/issues/1438
"obstore>=0.8",
"pygrib",
"python-dotenv",
"pandas",
Expand Down