Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions xrspatial/geotiff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3753,6 +3753,165 @@ def _gpu_compress_to_part(gpu_arr, w, h, spp):
_write_bytes(file_bytes, path)


def _vrt_effective_dtype(vrt, band):
"""Return the dtype a VRT read is expected to materialize."""
selected = [vrt.bands[band]] if band is not None else vrt.bands
if not selected:
raise ValueError(
"VRT has no <VRTRasterBand> elements; cannot determine "
"output dtype"
)
effective = []
for vrt_band in selected:
dt = vrt_band.dtype
for src in vrt_band.sources:
scaled = src.scale is not None and src.scale != 1.0
offset = src.offset is not None and src.offset != 0.0
if scaled or offset:
dt = np.dtype(np.float64)
break
if dt.kind in ('u', 'i') and vrt_band.nodata is not None:
try:
if isinstance(vrt_band.nodata, (int, np.integer)):
nd = int(vrt_band.nodata)
else:
nf = float(vrt_band.nodata)
nd = int(nf) if np.isfinite(nf) and nf.is_integer() else None
if nd is not None:
info = np.iinfo(dt)
if info.min <= nd <= info.max:
dt = np.dtype(np.float64)
except (TypeError, ValueError):
pass
effective.append(dt)
return np.result_type(*effective)


def _read_vrt_dask(source: str, *, dtype=None, window=None, band=None,
name=None, chunks=None, max_pixels=None):
"""Build a truly lazy dask-backed VRT DataArray from window tasks."""
import os
import dask
import dask.array as da
from ._reader import _check_dimensions, MAX_PIXELS_DEFAULT
from ._vrt import parse_vrt

with open(source, 'r') as f:
xml_str = f.read()
vrt_dir = os.path.dirname(os.path.abspath(source))
vrt = parse_vrt(xml_str, vrt_dir)

if band is not None:
if not isinstance(band, (int, np.integer)) or isinstance(band, bool):
raise ValueError(f"band must be a non-negative int, got {band!r}")
if band < 0 or band >= len(vrt.bands):
raise ValueError(
f"band index {band} out of range for VRT with "
f"{len(vrt.bands)} band(s)")

if window is not None:
win_r0, win_c0, win_r1, win_c1 = window
if (win_r0 < 0 or win_c0 < 0
or win_r1 > vrt.height or win_c1 > vrt.width
or win_r0 >= win_r1 or win_c0 >= win_c1):
raise ValueError(
f"window={window} is outside the VRT extent "
f"({vrt.height}x{vrt.width}) or has non-positive size.")
else:
win_r0, win_c0, win_r1, win_c1 = 0, 0, vrt.height, vrt.width

height = win_r1 - win_r0
width = win_c1 - win_c0
n_bands = len([vrt.bands[band]] if band is not None else vrt.bands)
if max_pixels is None:
max_pixels = MAX_PIXELS_DEFAULT
_check_dimensions(width, height, n_bands, max_pixels)

out_dtype = np.dtype(dtype) if dtype is not None else _vrt_effective_dtype(vrt, band)
if dtype is not None:
_validate_dtype_cast(_vrt_effective_dtype(vrt, band), out_dtype)

if isinstance(chunks, int):
ch_h = ch_w = chunks
else:
ch_h, ch_w = chunks

rows = list(range(0, height, ch_h))
cols = list(range(0, width, ch_w))
out_has_band_axis = band is None and n_bands > 1

Comment on lines +3894 to +3897
@dask.delayed
def _read_chunk(chunk_window):
chunk_da = read_vrt(
source, dtype=dtype, window=chunk_window, band=band,
chunks=None, gpu=False, max_pixels=max_pixels,
)
arr = np.asarray(chunk_da.values)
if arr.dtype != out_dtype:
arr = arr.astype(out_dtype)
Comment on lines +3898 to +3907
return arr

dask_rows = []
for r0 in rows:
r1 = min(r0 + ch_h, height)
dask_cols = []
for c0 in cols:
c1 = min(c0 + ch_w, width)
chunk_window = (r0 + win_r0, c0 + win_c0,
r1 + win_r0, c1 + win_c0)
shape = ((r1 - r0, c1 - c0, n_bands)
if out_has_band_axis else (r1 - r0, c1 - c0))
dask_cols.append(da.from_delayed(
_read_chunk(chunk_window), shape=shape, dtype=out_dtype))
dask_rows.append(da.concatenate(dask_cols, axis=1))
dask_arr = da.concatenate(dask_rows, axis=0)

coords = {}
gt = vrt.geo_transform
if gt is not None:
origin_x, res_x, _, origin_y, _, res_y = gt
if vrt.raster_type == 'point':
x_shift = win_c0 * res_x
y_shift = win_r0 * res_y
else:
x_shift = (win_c0 + 0.5) * res_x
y_shift = (win_r0 + 0.5) * res_y
coords = {
'x': np.arange(width, dtype=np.float64) * res_x + origin_x + x_shift,
'y': np.arange(height, dtype=np.float64) * res_y + origin_y + y_shift,
}

attrs = {}
if vrt.crs_wkt:
epsg = _wkt_to_epsg(vrt.crs_wkt)
if epsg is not None:
attrs['crs'] = epsg
attrs['crs_wkt'] = vrt.crs_wkt
if vrt.raster_type == 'point':
attrs['raster_type'] = 'point'
if vrt.bands:
band_idx_for_nodata = band if band is not None else 0
nodata = vrt.bands[band_idx_for_nodata].nodata
if nodata is not None:
attrs['nodata'] = nodata
if gt is not None:
origin_x, res_x, _, origin_y, _, res_y = gt
attrs['transform'] = (
float(res_x), 0.0, float(origin_x) + win_c0 * float(res_x),
0.0, float(res_y), float(origin_y) + win_r0 * float(res_y),
)

if name is None:
name = os.path.splitext(os.path.basename(source))[0]
if out_has_band_axis:
dims = ['y', 'x', 'band']
coords['band'] = np.arange(n_bands)
else:
dims = ['y', 'x']
return xr.DataArray(dask_arr, dims=dims, coords=coords,
name=name, attrs=attrs)


def read_vrt(source: str, *,
dtype: str | np.dtype | None = None,
window: tuple | None = None,
Expand Down Expand Up @@ -3828,6 +3987,12 @@ def read_vrt(source: str, *,
# default (eager read), so allow it through here.
chunks = _validate_chunks_arg(chunks, allow_none=True)

if chunks is not None and not gpu:
return _read_vrt_dask(
source, dtype=dtype, window=window, band=band, name=name,
chunks=chunks, max_pixels=max_pixels,
)

arr, vrt = _read_vrt_internal(source, window=window, band=band,
max_pixels=max_pixels)

Expand Down
51 changes: 51 additions & 0 deletions xrspatial/geotiff/tests/test_read_vrt_lazy_chunks_1798.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""read_vrt(chunks=...) should build lazy window tasks (#1798)."""
from __future__ import annotations

import os
import warnings

import numpy as np

from xrspatial.geotiff import to_geotiff, read_vrt


def _write_vrt(vrt_path, source_name):
vrt_path.write_text(
'<VRTDataset rasterXSize="6" rasterYSize="4">\n'
' <VRTRasterBand dataType="Float32" band="1">\n'
' <SimpleSource>\n'
f' <SourceFilename relativeToVRT="1">{source_name}'
'</SourceFilename>\n'
' <SourceBand>1</SourceBand>\n'
' <SrcRect xOff="0" yOff="0" xSize="6" ySize="4"/>\n'
' <DstRect xOff="0" yOff="0" xSize="6" ySize="4"/>\n'
' </SimpleSource>\n'
' </VRTRasterBand>\n'
'</VRTDataset>\n'
)


def test_read_vrt_chunks_matches_eager_values(tmp_path):
arr = np.arange(24, dtype=np.float32).reshape(4, 6)
src = tmp_path / "tmp_1798_source.tif"
to_geotiff(arr, str(src), compression='none')
vrt = tmp_path / "tmp_1798_source.vrt"
_write_vrt(vrt, os.path.basename(src))

eager = read_vrt(str(vrt))
lazy = read_vrt(str(vrt), chunks=2)

assert lazy.data.chunks == ((2, 2), (2, 2, 2))
np.testing.assert_array_equal(lazy.compute().values, eager.values)


def test_read_vrt_chunks_does_not_read_sources_during_construction(tmp_path):
vrt = tmp_path / "tmp_1798_missing_source.vrt"
_write_vrt(vrt, "missing.tif")

with warnings.catch_warnings(record=True) as caught:
lazy = read_vrt(str(vrt), chunks=2)

assert caught == []
assert hasattr(lazy.data, 'compute')

Loading