-
Notifications
You must be signed in to change notification settings - Fork 222
HealDA IO Improvement: Parallel decode + Obstore #914
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
f6c4bc6
8ba2c4e
07b457c
d904d2c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,10 @@ | |
|
|
||
| from __future__ import annotations | ||
|
|
||
| import concurrent.futures | ||
| import hashlib | ||
| import math | ||
| import multiprocessing | ||
| import os | ||
| import pathlib | ||
| import shutil | ||
|
|
@@ -27,10 +30,11 @@ | |
|
|
||
| import h5netcdf | ||
| import numpy as np | ||
| import obstore as obs | ||
| import pandas as pd | ||
| import pyarrow as pa | ||
| import s3fs | ||
| from loguru import logger | ||
| from obstore.store import S3Store | ||
| from tqdm.asyncio import tqdm | ||
|
|
||
| from earth2studio.data.utils import _sync_async, datasource_cache_root, prep_data_inputs | ||
|
|
@@ -53,6 +57,37 @@ class _GSIAsyncTask: | |
| satellite: str | None = None | ||
|
|
||
|
|
||
| # Transient context for the parallel-decode worker pool. Set just before the | ||
| # ProcessPoolExecutor runs and cleared after; forked workers inherit it via | ||
| # copy-on-write (which also carries the otherwise-unpicklable lexicon closures). | ||
| _DECODE_CTX: dict = {} | ||
|
|
||
| # Cap for automatic decode-worker selection. Decode throughput plateaus once it | ||
| # drops below the network floor (~16 workers in practice), and more workers than | ||
| # files is wasted; this bounds the "auto" setting. | ||
| _DECODE_WORKERS_CAP = 16 | ||
|
|
||
|
|
||
| def _cuda_initialized() -> bool: | ||
| """True if a CUDA context already exists in this process. | ||
|
|
||
| Forking after CUDA init is unsafe, so parallel (fork-based) decode falls | ||
| back to serial when this is True. | ||
| """ | ||
| try: | ||
| import torch | ||
|
|
||
| return torch.cuda.is_initialized() | ||
| except Exception: | ||
| return False | ||
|
|
||
|
|
||
| def _decode_chunk_idx(i: int) -> pd.DataFrame | None: | ||
| """Worker entry point: decode the i-th task chunk in a forked process.""" | ||
| self, chunks, variables, schema = _DECODE_CTX["args"] | ||
| return self._compile_chunk(chunks[i], variables, schema) | ||
|
|
||
|
|
||
| class _UFSObsBase: | ||
| """Base class for GSI data sources. | ||
|
|
||
|
|
@@ -78,17 +113,26 @@ def __init__( | |
| self._max_workers = max_workers | ||
| self.async_timeout = async_timeout | ||
| self._tmp_cache_hash: str | None = None | ||
| self.fs: s3fs.S3FileSystem | None = None | ||
| # Anonymous obstore S3 stores, cached per bucket (created lazily). | ||
| self._stores: dict[str, S3Store] = {} | ||
|
|
||
| lower, upper = normalize_time_tolerance(time_tolerance) | ||
| self._tolerance_lower = pd.to_timedelta(lower).to_pytimedelta() | ||
| self._tolerance_upper = pd.to_timedelta(upper).to_pytimedelta() | ||
|
|
||
| async def _async_init(self) -> None: | ||
| """Async initialization of S3 filesystem""" | ||
| self.fs = s3fs.S3FileSystem( | ||
| anon=True, client_kwargs={}, asynchronous=True, skip_instance_cache=True | ||
| ) | ||
| # NOAA UFS GEFSv13 replay archive is a public bucket in us-east-1. | ||
| _region = "us-east-1" | ||
|
|
||
| def _store(self, bucket: str) -> S3Store: | ||
| """Return a cached anonymous obstore S3Store for ``bucket``.""" | ||
| if bucket not in self._stores: | ||
| self._stores[bucket] = S3Store( | ||
| bucket, | ||
| region=self._region, | ||
| skip_signature=True, | ||
| client_options={"pool_max_idle_per_host": str(self._max_workers)}, | ||
| ) | ||
| return self._stores[bucket] | ||
|
|
||
| def __call__( | ||
| self, | ||
|
|
@@ -124,11 +168,7 @@ async def fetch( | |
| fields: str | list[str] | pa.Schema | None = None, | ||
| ) -> pd.DataFrame: | ||
| """Async function to get data.""" | ||
| if self.fs is None: | ||
| await self._async_init() | ||
|
|
||
| session = await self.fs.set_session(refresh=True) # type: ignore[union-attr] | ||
|
|
||
| # obstore S3 stores are created lazily per bucket in _fetch_remote_file. | ||
| time_list, variable_list = prep_data_inputs(time, variable) | ||
| self._validate_time(time_list) | ||
| schema = self.resolve_fields(fields) | ||
|
|
@@ -141,9 +181,6 @@ async def fetch( | |
| *fetch_jobs, desc="Fetching GSI files", disable=(not self._verbose) | ||
| ) | ||
|
|
||
| if session: | ||
| await session.close() | ||
|
|
||
| df = self._compile_dataframe(async_tasks, variable_list, schema) | ||
|
|
||
| return df | ||
|
|
@@ -171,32 +208,114 @@ async def _fetch_remote_file( | |
| byte_length : int | None, optional | ||
| Number of bytes to read, by default None (read all) | ||
| """ | ||
| if self.fs is None: | ||
| raise ValueError("File system is not initialized") | ||
|
|
||
| cache_path = self.cache_path(path, byte_offset, byte_length) | ||
| if not pathlib.Path(cache_path).is_file(): | ||
| if pathlib.Path(cache_path).is_file(): | ||
| return | ||
|
|
||
| # path is an S3 URI ("bucket/key" or "s3://bucket/key"); split into | ||
| # bucket + key for the per-bucket obstore store. | ||
| key = path[5:] if path.startswith("s3://") else path | ||
| bucket, _, object_key = key.partition("/") | ||
| store = self._store(bucket) | ||
| try: | ||
| if byte_length: | ||
| byte_length = int(byte_offset + byte_length) | ||
| try: | ||
| data = await self.fs._cat_file(path, start=byte_offset, end=byte_length) | ||
| with open(cache_path, "wb") as file: | ||
| file.write(data) | ||
| except FileNotFoundError: | ||
| self._handle_missing_file(path) | ||
| result = await obs.get_range_async( | ||
| store, object_key, start=byte_offset, end=byte_offset + byte_length | ||
| ) | ||
| else: | ||
| response = await obs.get_async(store, object_key) | ||
| result = await response.bytes_async() | ||
| data = result.to_bytes() if hasattr(result, "to_bytes") else bytes(result) | ||
| with open(cache_path, "wb") as file: | ||
| file.write(data) | ||
| except (FileNotFoundError, obs.exceptions.NotFoundError): | ||
| self._handle_missing_file(path) | ||
| except Exception as err: | ||
| raise | ||
|
|
||
| def _handle_missing_file(self, path: str) -> None: | ||
| """Handle missing file during fetch. Can be overridden by subclasses.""" | ||
| logger.error(f"File {path} not found") | ||
| raise FileNotFoundError(f"File {path} not found") | ||
|
|
||
| def _resolve_decode_workers(self, n_tasks: int) -> int: | ||
| """Automatically choose the NetCDF->DataFrame decode worker count. | ||
|
|
||
| Picks ``min(available_cpus, cap, n_tasks)`` -- decode throughput plateaus | ||
| once it drops below the network floor (cap ``_DECODE_WORKERS_CAP``) and | ||
| never needs more workers than files. | ||
|
|
||
| Safety guard: parallel decode uses the ``fork`` start method (so workers | ||
| inherit the unpicklable GSI lexicon closures via copy-on-write). Forking | ||
| after a CUDA context exists is unsafe, so if CUDA is already initialized | ||
| this falls back to serial decode. | ||
| """ | ||
| try: | ||
| avail = len(os.sched_getaffinity(0)) # CPUs available to process | ||
| except AttributeError: # not available on this platform | ||
| avail = os.cpu_count() or 1 | ||
| workers = max(1, min(_DECODE_WORKERS_CAP, avail, n_tasks)) | ||
| if workers > 1 and ( | ||
| "fork" not in multiprocessing.get_all_start_methods() | ||
| or _cuda_initialized() | ||
| ): | ||
| # Parallel decode requires the 'fork' start method (not available on | ||
| # Windows; unsafe on macOS). Also unsafe once CUDA is initialized. | ||
| # In either case fall back to serial decode. | ||
| workers = 1 | ||
|
Comment on lines
+258
to
+265
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The condition |
||
| return workers | ||
|
|
||
| def _compile_dataframe( | ||
| self, | ||
| async_tasks: list[_GSIAsyncTask], | ||
| variables: list[str], | ||
| schema: pa.Schema, | ||
| ) -> pd.DataFrame: | ||
| """Compile fetched data into a DataFrame.""" | ||
| """Compile fetched GSI files into a DataFrame. | ||
|
|
||
| Each file's HDF5->pandas decode is CPU- and GIL-bound, so the files are | ||
| decoded across forked worker processes when more than one is selected | ||
| (the count is chosen automatically; see :meth:`_resolve_decode_workers`). | ||
| Falls back to serial when CUDA is already initialized, since the 'fork' | ||
| start method (used to inherit the unpicklable lexicon closures) is unsafe | ||
| after CUDA init. | ||
| """ | ||
| workers = self._resolve_decode_workers(len(async_tasks)) | ||
| if workers <= 1: | ||
| result = self._compile_chunk(async_tasks, variables, schema) | ||
| return result if result is not None else pd.DataFrame() | ||
|
|
||
| size = math.ceil(len(async_tasks) / workers) | ||
| chunks = [ | ||
| c | ||
| for c in ( | ||
| async_tasks[i * size : (i + 1) * size] for i in range(workers) | ||
| ) | ||
| if c | ||
| ] | ||
| _DECODE_CTX["args"] = (self, chunks, variables, schema) | ||
| try: | ||
| with concurrent.futures.ProcessPoolExecutor( | ||
| max_workers=len(chunks), | ||
| mp_context=multiprocessing.get_context("fork"), | ||
| ) as executor: | ||
| parts = list(executor.map(_decode_chunk_idx, range(len(chunks)))) | ||
| finally: | ||
| _DECODE_CTX.clear() | ||
|
Comment on lines
+296
to
+304
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| frames = [p for p in parts if p is not None and len(p)] | ||
| if not frames: # all chunks empty (e.g. all files missing) -> match serial | ||
| return pd.DataFrame() | ||
| result = pd.concat(frames, ignore_index=True) | ||
| return result[[name for name in schema.names if name in result.columns]] | ||
|
|
||
| def _compile_chunk( | ||
| self, | ||
| async_tasks: list[_GSIAsyncTask], | ||
| variables: list[str], | ||
| schema: pa.Schema, | ||
| ) -> pd.DataFrame | None: | ||
| """Decode one set of GSI files into a DataFrame (the per-process unit).""" | ||
| # Identify schema fields that are per-channel (need Channel_Index lookup) | ||
| channel_indexed_fields: dict[str, str] = {} | ||
| for field in schema: | ||
|
|
@@ -273,6 +392,8 @@ def _compile_dataframe( | |
| df = df.loc[mask] | ||
| frames.append(task.gsi_modifier(df)) | ||
|
|
||
| if not frames: | ||
| return None | ||
| result = pd.concat(frames, ignore_index=True) | ||
| return result[[name for name in schema.names if name in result.columns]] | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
except Exception as err: raiseblock does nothing — it catches the exception and immediately re-raises it with no logging, wrapping, or side-effects. Remove it so unhandled exceptions propagate naturally from thetryblock.Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!