Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions xrspatial/geotiff/_vrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,53 @@ def _codec_decode_exceptions() -> tuple[type[BaseException], ...]:
_ALLOWED_ROOTS_ENV = 'XRSPATIAL_VRT_ALLOWED_ROOTS'


# Cap on the VRT XML file read. VRTs are XML metadata only (pixel data lives
# in source TIFFs); a 50k-source VRT is around 25 MB, so 64 MiB is a safe
# default that still rejects pathological inputs without scanning the whole
# file into memory.
_MAX_XML_BYTES_ENV = 'XRSPATIAL_VRT_MAX_XML_BYTES'
_DEFAULT_MAX_XML_BYTES = 64 * 1024 * 1024


def _get_vrt_max_xml_bytes() -> int:
"""Return the cap on VRT XML file size, in bytes."""
raw = os.environ.get(_MAX_XML_BYTES_ENV)
if raw is None or raw == '':
return _DEFAULT_MAX_XML_BYTES
try:
value = int(raw)
except (TypeError, ValueError):
raise ValueError(
f"{_MAX_XML_BYTES_ENV} must be a positive integer, got "
f"{raw!r}")
if value <= 0:
raise ValueError(
f"{_MAX_XML_BYTES_ENV} must be a positive integer, got "
f"{value}")
return value


def _read_vrt_xml(vrt_path: str) -> str:
"""Read a VRT XML file with a bounded total size."""
cap = _get_vrt_max_xml_bytes()
chunks = []
total = 0
with open(vrt_path, 'rb') as f:
while True:
remaining = cap + 1 - total
chunk = f.read(min(65536, remaining))
if not chunk:
break
total += len(chunk)
if total > cap:
raise ValueError(
f"VRT XML at {vrt_path!r} exceeds the "
f"{cap:,}-byte cap. Raise the cap by setting "
f"{_MAX_XML_BYTES_ENV} if this file is legitimate.")
chunks.append(chunk)
return b''.join(chunks).decode('utf-8')


def _allowed_source_roots() -> list[str]:
"""Return the operator-supplied allowlist of trusted source roots.

Expand Down Expand Up @@ -637,8 +684,7 @@ def read_vrt(vrt_path: str, *, window=None,
"""
from ._reader import PixelSafetyLimitError, read_to_array

with open(vrt_path, 'r') as f:
xml_str = f.read()
xml_str = _read_vrt_xml(vrt_path)

vrt_dir = os.path.dirname(os.path.abspath(vrt_path))
vrt = parse_vrt(xml_str, vrt_dir)
Expand Down
94 changes: 94 additions & 0 deletions xrspatial/geotiff/tests/test_vrt_xml_size_cap_1815.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
"""VRT XML reads must be bounded to avoid unbounded memory allocation.

Regression test for issue #1815: ``read_vrt`` previously called
``f.read()`` on the VRT XML path with no size limit, so a multi-gigabyte
XML file would consume all that memory before parsing. The fix adds a
64 MiB default cap, overridable via ``XRSPATIAL_VRT_MAX_XML_BYTES``.
"""
from __future__ import annotations

import os

import numpy as np
import pytest

from xrspatial.geotiff import to_geotiff
from xrspatial.geotiff._vrt import read_vrt


def _write_source(td: str) -> str:
src_path = os.path.join(td, 'tmp_1815_src.tif')
to_geotiff(np.zeros((10, 10), dtype=np.uint8), src_path,
compression='none')
return src_path


def _write_vrt(td: str, *, pad_bytes: int = 0) -> str:
"""Write a VRT, optionally padded with a large XML comment."""
vrt_path = os.path.join(td, 'tmp_1815_mosaic.vrt')
comment = ''
if pad_bytes > 0:
comment = '<!-- ' + ('x' * pad_bytes) + ' -->\n'
vrt_xml = (
'<VRTDataset rasterXSize="10" rasterYSize="10">\n'
+ comment +
' <VRTRasterBand dataType="Byte" band="1">\n'
' <SimpleSource>\n'
' <SourceFilename relativeToVRT="1">'
'tmp_1815_src.tif</SourceFilename>\n'
' <SourceBand>1</SourceBand>\n'
' <SrcRect xOff="0" yOff="0" xSize="10" ySize="10"/>\n'
' <DstRect xOff="0" yOff="0" xSize="10" ySize="10"/>\n'
' </SimpleSource>\n'
' </VRTRasterBand>\n'
'</VRTDataset>\n'
)
with open(vrt_path, 'w') as f:
f.write(vrt_xml)
return vrt_path


def test_small_vrt_parses_under_default_cap(tmp_path):
"""A normal-sized VRT parses successfully with the default cap."""
td = str(tmp_path)
_write_source(td)
vrt_path = _write_vrt(td)
arr, _ = read_vrt(vrt_path)
assert arr.shape == (10, 10)


def test_oversized_vrt_raises_value_error(tmp_path, monkeypatch):
"""A VRT padded past the cap raises ValueError naming the cap and env var."""
td = str(tmp_path)
_write_source(td)
# Set a small cap (1 KiB) and pad the comment past it.
monkeypatch.setenv('XRSPATIAL_VRT_MAX_XML_BYTES', '1024')
vrt_path = _write_vrt(td, pad_bytes=4096)
with pytest.raises(ValueError) as exc_info:
read_vrt(vrt_path)
msg = str(exc_info.value)
assert 'XRSPATIAL_VRT_MAX_XML_BYTES' in msg
assert '1,024' in msg


def test_raising_cap_lets_padded_vrt_parse(tmp_path, monkeypatch):
"""Setting the env var higher allows a padded VRT to parse."""
td = str(tmp_path)
_write_source(td)
vrt_path = _write_vrt(td, pad_bytes=4096)
# Default cap of 64 MiB is more than enough; verify with an explicit
# higher cap too.
monkeypatch.setenv('XRSPATIAL_VRT_MAX_XML_BYTES', str(1024 * 1024))
arr, _ = read_vrt(vrt_path)
assert arr.shape == (10, 10)


@pytest.mark.parametrize('bad_value', ['not_a_number', '0', '-1', '-1024'])
def test_invalid_cap_raises_value_error(tmp_path, monkeypatch, bad_value):
"""Non-numeric, zero, or negative cap values produce a clear error."""
td = str(tmp_path)
_write_source(td)
vrt_path = _write_vrt(td)
monkeypatch.setenv('XRSPATIAL_VRT_MAX_XML_BYTES', bad_value)
with pytest.raises(ValueError, match='XRSPATIAL_VRT_MAX_XML_BYTES'):
read_vrt(vrt_path)
Loading