Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion earth2studio/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .gfs import GFS, GFS_FX
from .ghcn import GHCNDaily
from .goes import GOES
from .goes_glm import GOESGLM
from .goes_glm import GOESGLM, GOESGLMGrid
from .himawari_ahi import HimawariAHI
from .hrrr import HRRR, HRRR_FX
from .ibtracs import IBTrACS
Expand Down
190 changes: 181 additions & 9 deletions earth2studio/data/goes_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pandas as pd
import pyarrow as pa
import s3fs
import xarray as xr
from loguru import logger

from earth2studio.data.utils import (
Expand Down Expand Up @@ -317,8 +318,12 @@ async def fetch(
pd.DataFrame
Event-level lightning observations.
"""
if self.fs is None:
await self._async_init()
# Always build a fresh asynchronous filesystem for this fetch. The
Comment thread
pzharrington marked this conversation as resolved.
# instance is created with ``skip_instance_cache=True`` and its aiohttp
# session is closed by ``managed_session`` below; reusing it across
# repeated calls (e.g. one per 5-min bin from ``GOESGLMGrid``) would
# hand later calls a torn-down aiobotocore client.
await self._async_init()

time_list, variable_list = prep_data_inputs(time, variable)
self._validate_time(time_list)
Expand All @@ -332,14 +337,17 @@ async def fetch(
schema = self.resolve_fields(fields)
pathlib.Path(self.cache).mkdir(parents=True, exist_ok=True)

files = await self._discover_files(time_list)
unique_uris = sorted({f.s3_uri for f in files})
logger.info(
f"[{self.SOURCE_ID}] discovered {len(unique_uris)} unique GLM "
f"files across {len(time_list)} requested times"
)

# Listing and fetching share a single managed session so prefix
# discovery does not leak an unclosed s3fs session and both use the
# same refreshed client.
async with managed_session(self.fs):
files = await self._discover_files(time_list)
unique_uris = sorted({f.s3_uri for f in files})
logger.info(
f"[{self.SOURCE_ID}] discovered {len(unique_uris)} unique GLM "
f"files across {len(time_list)} requested times"
)

coros = [
async_retry(
self._fetch_remote_file,
Expand Down Expand Up @@ -721,3 +729,167 @@ def _normalize_lat_lon_bbox(
f"into two boxes. Got {lat_lon_bbox}."
)
return (lat_min, lon_min, lat_max, lon_max)


@check_optional_dependencies()
class GOESGLMGrid:
Comment thread
NickGeneva marked this conversation as resolved.
"""Gridded GOES GLM lightning product for StormScope.

Wraps :py:class:`GOESGLM` (a per-event LCFA source) and turns the event
point cloud into a regular **0.1-degree** lat/lon grid by 5-minute temporal
binning and 2D histogramming, matching the GLM product the StormScope
MRMS+GLM nowcast model was trained on. Unlike :py:class:`GOESGLM` (which
returns a :py:class:`pandas.DataFrame` of events), this source returns a
gridded :py:class:`xarray.DataArray` consumable by
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
:py:func:`earth2studio.data.fetch_data`.

For each requested (5-minute-aligned) time ``t`` the events whose timestamps
fall in ``[t, t + 5 min)`` are accumulated and histogrammed:

- ``glm_density`` : raw **event count** per cell (the "density" name is
historical; it is an unweighted count, matching training).
- ``glm_energy_density`` : summed **event energy** (J) per cell.

The field is **not** mean/std normalized; downstream the StormScope model
applies ``log1p`` (and ``expm1`` on output). This source emits raw counts/sums
on the 0.1-degree grid; the model bilinearly regrids to its own grid.

Parameters
----------
satellite : str, optional
GOES platform selector passed to :py:class:`GOESGLM` (``"east"`` default).
cache : bool, optional
Cache downloaded NetCDFs, by default True.
verbose : bool, optional
Show progress, by default True.
**goes_glm_kwargs : Any
Additional keyword arguments forwarded to the underlying
:py:class:`GOESGLM` (e.g. ``async_workers``, ``retries``).

Note
----
Grid geometry (must match training): regular 0.1-degree grid over
lat ``[20, 55]`` / lon ``[-130, -60]`` (350 x 700 cells), with cell centres at
``edge + 0.5 * resolution``. Output longitudes are returned in the Earth2Studio
``[0, 360)`` convention. The accumulation window is fixed at 5 minutes,
bin-start labeled (the training cadence); do not substitute a 10-minute window.

Badges
------
region:na dataclass:observation product:sat
"""

# Accumulation window (minutes), bin-start labeled. Fixed to match training.
BIN_MINUTES = 5
# Regular 0.1-degree CONUS grid (degrees, [-180, 180) longitude internally).
_RES = 0.1
_LAT_MIN, _LAT_MAX = 20.0, 55.0
_LON_MIN, _LON_MAX = -130.0, -60.0
# CONUS parse-time bounding box (lat_min, lon_min, lat_max, lon_max).
_CONUS_BBOX = (24.5, -125.0, 49.5, -66.0)
# E2S variable -> underlying GOESGLM event variable.
_VARIABLE_MAP = {"glm_density": "flashc", "glm_energy_density": "flashe"}

def __init__(
self,
satellite: str = "east",
cache: bool = True,
verbose: bool = True,
**goes_glm_kwargs: object,
Comment thread
NickGeneva marked this conversation as resolved.
Outdated
) -> None:
self._events = GOESGLM(
satellite=satellite,
lat_lon_bbox=self._CONUS_BBOX,
time_tolerance=(
np.timedelta64(0, "m"),
np.timedelta64(self.BIN_MINUTES, "m"),
),
cache=cache,
verbose=verbose,
**goes_glm_kwargs, # type: ignore[arg-type]
)

# Bin edges and centres. arange end padded by a small epsilon so the final
# edge is included; centres sit at edge + 0.5 * resolution.
self._lat_edges = np.arange(self._LAT_MIN, self._LAT_MAX + 1e-9, self._RES)
self._lon_edges = np.arange(self._LON_MIN, self._LON_MAX + 1e-9, self._RES)
self._lat_centres = 0.5 * (self._lat_edges[:-1] + self._lat_edges[1:])
lon_centres = 0.5 * (self._lon_edges[:-1] + self._lon_edges[1:])
# Return longitudes in the Earth2Studio [0, 360) convention.
self._lon_centres = (lon_centres + 360.0) % 360.0

def __call__(
Comment thread
pzharrington marked this conversation as resolved.
self,
time: datetime | list[datetime] | TimeArray,
variable: str | list[str] | VariableArray,
) -> xr.DataArray:
"""Fetch the gridded GLM product for the requested times and variables.

Parameters
----------
time : datetime | list[datetime] | TimeArray
5-minute-aligned timestamps (UTC). Each labels a ``[t, t+5min)`` bin.
variable : str | list[str] | VariableArray
One or more of ``"glm_density"`` / ``"glm_energy_density"``.

Returns
-------
xr.DataArray
Array with dims ``[time, variable, lat, lon]`` on the 0.1-degree grid.
"""
time_list, variable_list = prep_data_inputs(time, variable)
for v in variable_list:
if v not in self._VARIABLE_MAP:
raise KeyError(
f"Variable id {v!r} not supported by GOESGLMGrid. "
f"Available: {list(self._VARIABLE_MAP)}"
)

ny, nx = self._lat_centres.size, self._lon_centres.size
out = np.zeros((len(time_list), len(variable_list), ny, nx), dtype=np.float32)

underlying = sorted({self._VARIABLE_MAP[v] for v in variable_list})
# Fetch all requested times in a single GOESGLM call so that S3 session
# setup, prefix listing, and file downloads are amortised across the
# entire sliding window rather than paying per-timestep.
bin_delta = np.timedelta64(self.BIN_MINUTES, "m")
df_all = self._events(time_list, underlying)
for ti, t in enumerate(time_list):
t_ts = pd.Timestamp(t)
t_end = pd.Timestamp(t + bin_delta)
df = df_all[(df_all["time"] >= t_ts) & (df_all["time"] < t_end)]
for vi, v in enumerate(variable_list):
uvar = self._VARIABLE_MAP[v]
sub = df[df["variable"] == uvar]
if len(sub) == 0:
continue
# Events use [0, 360) longitude; convert to the grid's [-180, 180).
ev_lon = ((sub["lon"].to_numpy() + 180.0) % 360.0) - 180.0
hist, _, _ = np.histogram2d(
sub["lat"].to_numpy(),
ev_lon,
bins=[self._lat_edges, self._lon_edges],
weights=sub["observation"].to_numpy(),
)
out[ti, vi] = hist.astype(np.float32)

return xr.DataArray(
data=out,
dims=["time", "variable", "lat", "lon"],
coords={
"time": np.asarray(time_list, dtype="datetime64[ns]"),
"variable": np.asarray(variable_list),
"lat": self._lat_centres,
"lon": self._lon_centres,
},
)

@property
def lat(self) -> np.ndarray:
"""1D array of grid-cell-centre latitudes."""
return self._lat_centres

@property
def lon(self) -> np.ndarray:
"""1D array of grid-cell-centre longitudes ([0, 360) convention)."""
return self._lon_centres
Loading