diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1d9278dd34..f84e8a5110 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -190,7 +190,7 @@ jobs: linker: [cvm, numba] python-version: ["3.12"] test-subset: - - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py + - tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/variational/test_streaming.py tests/variational/test_streaming_autosize.py tests/variational/test_streaming_trainer.py tests/test_initial_point.py - tests/model/test_core.py tests/sampling/test_mcmc.py - tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py - tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py diff --git a/docs/source/api/vi.rst b/docs/source/api/vi.rst index cca88dfde2..3e59294f67 100644 --- a/docs/source/api/vi.rst +++ b/docs/source/api/vi.rst @@ -68,6 +68,21 @@ Special Stein +Streaming +--------- +Out-of-core minibatching for variational inference on datasets that do not fit in +memory (see :mod:`pymc.variational.streaming`). + +.. currentmodule:: pymc.variational +.. autosummary:: + :toctree: generated/ + + DataLoader + IterableDataset + Trainer + shuffle_buffer + parquet_source + .. currentmodule:: pymc .. autosummary:: :toctree: generated/ diff --git a/pymc/variational/__init__.py b/pymc/variational/__init__.py index 17b3cf3f7f..61896fb068 100644 --- a/pymc/variational/__init__.py +++ b/pymc/variational/__init__.py @@ -44,6 +44,13 @@ # special from pymc.variational.stein import Stein +from pymc.variational.streaming import ( + DataLoader, + IterableDataset, + Trainer, + parquet_source, + shuffle_buffer, +) from pymc.variational.updates import ( adadelta, adagrad, @@ -64,11 +71,14 @@ "ADVI", "ASVGD", "SVGD", + "DataLoader", "Empirical", "FullRank", "FullRankADVI", "Group", + "IterableDataset", "MeanField", + "Trainer", "adadelta", "adagrad", "adagrad_window", @@ -80,8 +90,10 @@ "momentum", "nesterov_momentum", "norm_constraint", + "parquet_source", "rmsprop", "sample_approx", "sgd", + "shuffle_buffer", "total_norm_constraint", ) diff --git a/pymc/variational/streaming.py b/pymc/variational/streaming.py new file mode 100644 index 0000000000..9317e61627 --- /dev/null +++ b/pymc/variational/streaming.py @@ -0,0 +1,889 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Out-of-core minibatching for variational inference. + +``pm.Minibatch`` random-indexes an array that is fully resident in memory; its +peak memory is therefore O(N) in the dataset size. This module instead streams +minibatches from an out-of-core source into a ``pm.Data`` placeholder, so peak +memory is set by the batch, the source chunk, and the optional shuffle buffer, +independent of N. + +The API follows PyTorch's ``torch.utils.data``: + +* :class:`IterableDataset`: a re-iterable, out-of-core source of rows + (e.g. :func:`parquet_source` over a directory of shards). It never loads the + whole dataset; it yields it a chunk at a time. +* :class:`DataLoader`: turns a dataset into fixed-size (optionally shuffled) + minibatches; it is iterable (the minibatch stream) and sized. Note ``len(loader)`` + is the row count ``N`` (what the observed distribution needs for ``total_size``), + not the batch count ``torch.utils.data.DataLoader.__len__`` returns. +* :class:`Trainer`: drives variational inference (ADVI, ...) over a + ``DataLoader`` with no user-facing callbacks; + ``Trainer(method=..., dataloader=...).fit(n)`` streams each minibatch into the + model's ``pm.Data`` placeholder with ``set_data``. + +With bounded source chunks the full data never sits in RAM at once. The model +graph observes only a ``(batch_size, *sample_shape)`` ``pm.Data`` placeholder +that the ``Trainer`` overwrites with the next minibatch every step. Passing a +directory of Parquet shards far larger than RAM still gives a model whose +resident footprint is one batch (:func:`parquet_source` reads one row group at +a time). + +The unbiased-gradient rescaling is the same as for ``pm.Minibatch``: the +observed log-likelihood must be scaled by ``N / batch_size`` through the existing +:func:`~pymc.variational.minibatch_rv.create_minibatch_rv`. ``N`` is exactly +``len(loader)`` (the loader is sized; ``len`` returns the row count ``N``), so the +model passes ``total_size=len(loader)``. (Folding that scaling into the inference +step, so it drops out of the model body, is the next step in PyMC's VI rework.) + +Batches have exactly ``batch_size`` rows, so each pass drops the final +``N mod batch_size`` rows (torch's ``drop_last``). With ``shuffle=True`` that +remainder is re-drawn every epoch, so all rows participate across epochs; with +a source that replays a fixed order, the same rows are dropped every pass (after +a one-time on-disk pre-shuffle that fixed remainder is a random subset). + +One difference from ``pm.Minibatch`` is shuffling. +``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its +minibatches are i.i.d. by construction. A streaming source is only as well +mixed as the order it yields rows in: reading time/row-ordered data through a +bounded buffer is merely a block-shuffle, and the resulting non-representative +minibatches can bias the variational posterior. +Pre-shuffle the data once on disk (or interleave shards) and/or pass +``shuffle=True``. + +Examples +-------- +.. code-block:: python + + import numpy as np + import pymc as pm + from pymc.variational.streaming import DataLoader, Trainer, parquet_source + + # The data was pre-shuffled on disk once (see the module note on shuffling), + # so the loader streams it sequentially. The full table stays on disk. + loader = DataLoader( + parquet_source("shuffled/"), # an IterableDataset over the shards + batch_size=4096, + sample_shape=(4,), # 3 features + 1 observed column + total_size="auto", # infer N from Parquet metadata; N == len(loader) + ) + + with pm.Model() as model: + b = pm.Normal("b", 0.0, 3.0, shape=4) + batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder for one minibatch + logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] + pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) + + # No callbacks: the Trainer streams each minibatch into "batch" with set_data. + with model: + approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000) +""" + +from __future__ import annotations + +import glob +import numbers +import os +import warnings + +from collections.abc import Callable, Iterable, Iterator + +import numpy as np + +from pymc.model import modelcontext +from pymc.variational.inference import Inference +from pymc.variational.inference import fit as _fit + +__all__ = ["DataLoader", "IterableDataset", "Trainer", "parquet_source", "shuffle_buffer"] + + +def _is_positive_int(value: object) -> bool: + """True for a strictly positive integer (incl. numpy integer types), excluding bool.""" + return isinstance(value, numbers.Integral) and not isinstance(value, bool) and int(value) > 0 + + +class IterableDataset: + """A re-iterable, out-of-core source of rows, like ``torch.utils.data.IterableDataset``. + + Subclass and implement :meth:`__iter__` to yield ``np.ndarray`` blocks of rows + (shape ``(rows, *sample_shape)``); :class:`DataLoader` re-batches those blocks + into fixed-size minibatches. ``__iter__`` must return a fresh iterator each + call so the dataset can be replayed across epochs. + + Optionally set :attr:`n_rows` (the total row count, if known cheaply, e.g. + from file metadata) so a :class:`DataLoader` with ``total_size="auto"`` can + resolve ``N`` without a counting pass. + + A plain zero-arg factory (``Callable[[], Iterator[np.ndarray]]``) or any + re-iterable is also accepted directly by :class:`DataLoader`; this base class + is only needed when you want to attach behavior or ``n_rows`` to a custom + source. + """ + + n_rows: int | None = None + + def __iter__(self) -> Iterator[np.ndarray]: + raise NotImplementedError("IterableDataset subclasses must implement __iter__") + + +class DataLoader: + """Turn an out-of-core dataset into fixed-size minibatches for variational inference. + + Like ``torch.utils.data.DataLoader``, it batches (and optionally + shuffles) an :class:`IterableDataset` into the minibatch stream that + :class:`Trainer` feeds to the model. It is iterable and sized (``len(loader)`` + is the dataset size ``N``). With bounded source chunks the full dataset is + never resident at once. + + Parameters + ---------- + dataset : IterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]] + The source of rows. An :class:`IterableDataset`, a re-iterable (including a + plain ``np.ndarray``), or a zero-arg factory returning a fresh iterator + (preferred, so the stream can be restarted each epoch). It may yield single + samples (e.g. the rows of a raw array) or blocks of any size; the loader + re-batches them, in order, to exactly ``batch_size`` rows. Trailing rows + that do not fill a final batch are dropped at the end of a pass, like + ``drop_last=True`` in PyTorch (required here because the model observes a + fixed-shape placeholder). With ``shuffle=True`` the dropped remainder + differs per epoch; with a fixed replay order it is the same rows every + pass. + batch_size : int + Leading dimension of every yielded minibatch. + shuffle : bool, default False + If ``True``, wrap the source in a bounded :func:`shuffle_buffer` of + ``buffer_size`` rows. This only approximates i.i.d. batches for an + already unordered stream; a bounded buffer cannot fix strongly + time/row-ordered data (pre-shuffle on disk for that; see the module + docstring). + buffer_size : int, optional + Shuffle-buffer size in rows when ``shuffle=True``. Defaults to + ``50 * batch_size``. Ignored when ``shuffle=False``. A buffer at least + as large as the dataset holds all of it in memory (a full shuffle). + seed : int, optional + Seed for the shuffle buffer (ignored when ``shuffle=False``). + sample_shape : tuple of int, optional + Trailing shape of a single observation. ``()`` for scalar observations, + ``(k,)`` to stream ``k`` columns (e.g. features + the observed column). + Defaults to ``dataset.shape[1:]`` for a raw ``np.ndarray`` source (its + rows are the samples, like torch's ``TensorDataset``), else ``()``. + dtype : str, default "float64" + Dtype each prepared batch is cast to; match the dtype of the ``pm.Data`` + placeholder the batches are streamed into. + total_size : int or "auto", optional + The true dataset size ``N`` (a positive integer), or ``"auto"`` to infer + it (from the source's ``n_rows`` if available, else a single counting + pass). Pass it on to the observed distribution as + ``total_size=len(loader)`` so the minibatch log-likelihood is rescaled by + ``N / batch_size`` (the same mechanism as ``pm.Minibatch``). Unlike + ``pm.Minibatch`` it cannot be inferred from a resident array; ``None`` + warns at construction and a non-positive value raises (it would otherwise + silently disable or invert the rescaling). + preprocess_fn : callable, optional + Pure transform applied to each batch before validation (e.g. + normalization). It must preserve the row count and ``sample_shape``; + to select columns, do it at the source instead + (``parquet_source(columns=...)``). + """ + + def __init__( + self, + dataset: IterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]], + *, + batch_size: int, + shuffle: bool = False, + buffer_size: int | None = None, + seed: int | None = None, + sample_shape: tuple[int, ...] | None = None, + dtype: str = "float64", + total_size: int | str | None = None, + preprocess_fn: Callable[[np.ndarray], np.ndarray] | None = None, + ): + if not _is_positive_int(batch_size): + raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") + if sample_shape is None: + # A raw array is rows-of-samples; without this default a 2-D array + # would be read as blocks of scalars and silently flattened. + sample_shape = dataset.shape[1:] if isinstance(dataset, np.ndarray) else () + sample_shape = tuple(sample_shape) + + raw_factory = _make_factory(dataset) + source_factory = raw_factory + if shuffle: + if buffer_size is None: + buffer_size = 50 * int(batch_size) + # shuffle_buffer concatenates yields along the leading axis, so single + # samples must be promoted to one-row blocks before shuffling. + source_factory = shuffle_buffer( + _block_factory(raw_factory, sample_shape), + buffer_size=buffer_size, + batch_size=batch_size, + seed=seed, + ) + self._source_factory = source_factory + + if isinstance(total_size, str): + if total_size != "auto": + raise ValueError(f"total_size string must be 'auto', got {total_size!r}") + # Count the unshuffled source: the shuffle wrapper drops the trailing + # partial batch, so counting through it would undercount N. + total_size = _auto_total_size(raw_factory, dataset, sample_shape) + elif total_size is None: + warnings.warn( + "DataLoader created with total_size=None: the minibatch " + "log-likelihood will not be rescaled and the posterior will be " + "biased. Pass total_size=N (the true dataset size) or total_size='auto'.", + UserWarning, + stacklevel=2, + ) + elif not _is_positive_int(total_size): + # 0 is falsy (the rescaling would be silently skipped) and a negative + # value flips the sign of the data log-likelihood; raise on both. + raise ValueError( + "total_size must be a positive integer (the true dataset size N) so " + "the minibatch log-likelihood is rescaled by N / batch_size; got " + f"{total_size!r}." + ) + + # Plain Python ints: create_minibatch_rv rejects np.int64 for total_size. + self._batch_size = int(batch_size) + self._sample_shape = sample_shape + self._dtype = dtype + self._total_size = None if total_size is None else int(total_size) + self._preprocess_fn = preprocess_fn + + self._batches_seen = 0 + self._rows_streamed = 0 + self._warned_size = False + + @property + def batch_size(self) -> int: + return self._batch_size + + @property + def total_size(self) -> int | None: + """The dataset size ``N`` (pass to the distribution's ``total_size``).""" + return self._total_size + + @property + def batches_seen(self) -> int: + return self._batches_seen + + @property + def rows_streamed(self) -> int: + """Total rows streamed into the model (grows past ``N`` across epochs).""" + return self._rows_streamed + + def _rebatched(self) -> Iterator[np.ndarray]: + """A fresh pass of exactly ``batch_size``-row batches from the source.""" + return _rebatch(self._source_factory(), self._batch_size, self._sample_shape) + + def __iter__(self) -> Iterator[np.ndarray]: + """Yield one epoch of validated ``(batch_size, *sample_shape)`` minibatches. + + The same batches the :class:`Trainer` streams into the model's ``pm.Data`` + placeholder (it consumes them through an accounting wrapper, so plain + iteration leaves the counters untouched). Re-iterate the loader for + another epoch. + """ + for batch in self._rebatched(): + yield self._prepare(batch) + + def __len__(self) -> int: + """The dataset size ``N`` (row count); pass it to the distribution's ``total_size``. + + ``total_size=len(loader)`` is how the model gets the ``N / batch_size`` + rescaling. Note this returns the row count ``N``, not the batch count + that ``torch.utils.data.DataLoader.__len__`` returns; ``total_size`` + needs ``N``. :attr:`total_size` is the same value. + """ + if self._total_size is None: + raise TypeError( + "len(DataLoader) is the dataset size N, but this loader was built with " + "total_size=None; construct it with total_size=N or total_size='auto'." + ) + return self._total_size + + def _stream_batches(self) -> Iterator[np.ndarray]: + """One epoch of prepared minibatches, with accounting (the Trainer's path). + + Like :meth:`__iter__` but it updates :attr:`batches_seen` / + :attr:`rows_streamed` and runs the one-shot ``total_size`` sanity check on + the pass's final batch. The rebatcher is kept one batch ahead so the check + still fires when a fit stops exactly at the pass boundary; without the + lookahead the generator would be abandoned right before its epilogue. + :meth:`__iter__` stays side-effect-free so plain iteration does not mutate + counters. + """ + seen_this_pass = 0 + it = self._rebatched() + batch = next(it, None) + while batch is not None: + following = next(it, None) + prepared = self._prepare(batch) + self._batches_seen += 1 + self._rows_streamed += int(prepared.shape[0]) + seen_this_pass += int(prepared.shape[0]) + if following is None: + self._maybe_warn_total_size(seen_this_pass) + yield prepared + batch = following + + def _prepare(self, batch: np.ndarray) -> np.ndarray: + """Preprocess, validate, and return an owned copy of one batch. + + A source may legitimately yield views into a reused array; the copy + prevents the consumer from aliasing it. + """ + if self._preprocess_fn is not None: + batch = self._preprocess_fn(batch) + self._validate(batch) + return np.array(batch, dtype=self._dtype) + + def _maybe_warn_total_size(self, seen: int) -> None: + """Warn once if ``total_size`` is inconsistent with the rows of one full pass. + + ``seen`` is the row count of the pass that just completed (not the + cumulative :attr:`rows_streamed`, which keeps growing across partial + streams and earlier fits). A correct ``N`` satisfies + ``seen <= N < seen + batch_size`` after a full pass (the trailing partial + batch is dropped), so that window never warns; outside it a 10% slack + absorbs sources that are only approximately sized. + """ + if self._warned_size or self._total_size is None: + return + self._warned_size = True + if not seen or seen <= self._total_size < seen + self._batch_size: + return + if abs(self._total_size - seen) > 0.1 * seen: + warnings.warn( + f"total_size={self._total_size} disagrees with the {seen} rows streamed " + f"in one full pass; the N/batch_size rescaling, and therefore the " + f"posterior width, is likely wrong. Pass the true dataset size (or, if " + f"'auto' resolved it from the source's n_rows, fix that attribute).", + UserWarning, + stacklevel=3, + ) + + def _validate(self, batch: np.ndarray) -> None: + if not isinstance(batch, np.ndarray): + raise TypeError(f"expected np.ndarray batch, got {type(batch).__name__}") + if batch.ndim < 1: + raise ValueError( + "batch needs a leading batch dimension; got a scalar array with " + f"shape {batch.shape}." + ) + if batch.shape[0] != self._batch_size: + raise ValueError( + f"batch shape[0] = {batch.shape[0]} does not match batch_size = {self._batch_size}." + ) + if batch.shape[1:] != self._sample_shape: + raise ValueError( + f"batch sample-shape {batch.shape[1:]} does not match declared " + f"sample_shape={self._sample_shape}" + ) + + +class Trainer: + """Drive variational inference over a :class:`DataLoader` without user callbacks. + + Follows the design in PyMC's variational-inference rework and PyTorch + Lightning: the ``Trainer`` owns the training loop, the + :class:`DataLoader` owns batching (and ``len(dataloader)`` is the dataset size + ``N``), and the model owns the math. The model exposes a ``pm.Data`` placeholder; + the ``Trainer`` streams minibatches into it with ``model.set_data`` once per + step; no user callbacks are needed. + + Parameters + ---------- + method : str or Inference, default "advi" + Variational method, forwarded to :func:`pymc.fit`: a name (``"advi"``, + ``"fullrank_advi"``, ...) or an :class:`~pymc.variational.inference.Inference` + instance. ``pm.fit`` applies ``model`` and ``random_seed`` only to a name; + an instance is already bound to a model, so configure it at construction + (e.g. ``ADVI(random_seed=...)``). + dataloader : DataLoader + The minibatch source. ``len(dataloader)`` is ``N``; the model should pass + it to the observed distribution's ``total_size``. + model : pymc.Model, optional + Defaults to the model on the context stack. + data_name : str, default "batch" + Name of the ``pm.Data`` placeholder minibatches are streamed into. Must + match the name used for ``pm.Data(name, ...)`` in the model. + **fit_kwargs + Default keyword arguments forwarded to :func:`pymc.fit` (e.g. + ``obj_optimizer``); per-call kwargs to :meth:`fit` override them. + + Notes + ----- + The per-step ``set_data`` currently lives in the ``Trainer``. Once the VI + rework's ``Inference.step(batch)`` lands it moves there, at which point the + ``total_size`` rescaling can be derived from ``len(dataloader)`` and dropped + from the model body entirely. + + Examples + -------- + .. code-block:: python + + loader = DataLoader( + parquet_source("shuffled/"), batch_size=4096, sample_shape=(4,), total_size="auto" + ) + with pm.Model() as model: + b = pm.Normal("b", 0.0, 3.0, shape=4) + batch = pm.Data("batch", np.zeros((4096, 4))) # placeholder + logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2] + pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader)) + approx = Trainer(method="advi", dataloader=loader, data_name="batch").fit(20_000) + """ + + def __init__( + self, + *, + method: str | Inference = "advi", + dataloader: DataLoader, + model=None, + data_name: str = "batch", + **fit_kwargs, + ): + self.method = method + self.dataloader = dataloader + self.model = model + self.data_name = data_name + self._fit_kwargs = fit_kwargs + + def fit(self, n: int = 10_000, **kwargs): + """Fit for ``n`` steps, streaming minibatches into the model's placeholder. + + Exactly ``n`` minibatches are fed to the model: the first seeds the + placeholder before step 0, and the advance after the final step is skipped. + The accounting stream reads one batch ahead so the pass-size check can fire + at a pass boundary, so a re-readable source (the only kind the loader + accepts) may be read one batch past the ``n`` the model uses. Keyword + arguments are forwarded to :func:`pymc.fit` on top of the constructor's + ``fit_kwargs`` (per-call wins); ``progressbar`` defaults to ``False`` + unless either sets it. + + Returns + ------- + :class:`Approximation` + The fitted approximation, as returned by :func:`pymc.fit`. + """ + if not _is_positive_int(n): + raise ValueError(f"n must be a positive integer (the number of fit steps), got {n!r}") + loader = self.dataloader + if not isinstance(loader, DataLoader): + raise TypeError( + f"Trainer needs a DataLoader for `dataloader`, got {type(loader).__name__}." + ) + model = modelcontext(self.model) + if self.data_name not in model: + # Checked before the stream starts so no batch is consumed (and no + # counter advances) on a typo. + raise KeyError( + f"data_name {self.data_name!r} is not a variable in the model; it " + f"must name the pm.Data placeholder the minibatches are streamed into." + ) + + def _stream() -> Iterator[np.ndarray]: + while True: + empty = True + for batch in loader._stream_batches(): + empty = False + yield batch + if empty: + raise RuntimeError("dataloader yielded no batches") + + batches = _stream() + # Seed the placeholder before step 0: pm.fit runs callbacks after each step, + # so without this the first step would train on the placeholder's contents. + model.set_data(self.data_name, next(batches)) + + steps_done = 0 + + def _advance(*_): + # pm.fit fires callbacks after every step including the last; skip the + # advance on this fit's final step so exactly n batches are consumed. + # Only that one call is skipped (not every call past n): Inference.refine + # replays the saved callbacks and must keep streaming fresh batches. + nonlocal steps_done + steps_done += 1 + if steps_done != n: + model.set_data(self.data_name, next(batches)) + + merged = {**self._fit_kwargs, **kwargs} + merged.setdefault("progressbar", False) + # User callbacks (e.g. convergence trackers) are appended after the + # internal advance instead of colliding with it on the keyword. + user_callbacks = merged.pop("callbacks", None) or [] + return _fit( + n, + method=self.method, + model=model, + callbacks=[_advance, *user_callbacks], + **merged, + ) + + +def shuffle_buffer( + chunk_source: Callable[[], Iterator[np.ndarray]], + *, + buffer_size: int, + batch_size: int, + seed: int | None = None, +) -> Callable[[], Iterator[np.ndarray]]: + """Wrap a chunk source into a shuffled, fixed-size batch source. + + Accumulates rows from ``chunk_source`` into a buffer of at least + ``buffer_size`` rows, shuffles it, and yields ``batch_size`` slices; rows that + do not fill a final batch are carried over into the next buffer (never + dropped) until the source is exhausted, at which point a single trailing + partial batch (< ``batch_size`` rows) is dropped. This approximates i.i.d. + minibatches from an unordered or pre-shuffled stream. + + :class:`DataLoader` calls this for you when ``shuffle=True``; use it directly + when you want explicit control over ``buffer_size`` independently of the + loader. + + It does not by itself fix a strongly time/row-ordered stream (a bounded + buffer only block-shuffles such data); pre-shuffle on disk, or interleave + shards into ``chunk_source``, for that. ``buffer_size`` is a lower bound: + each fill accumulates at least ``max(buffer_size, batch_size)`` rows before + shuffling (so a ``buffer_size`` smaller than ``batch_size`` still yields full + batches; the final fill stops at whatever the source has left), and the chunk + that crosses the threshold is kept whole, so the buffer holds fewer than + ``max(buffer_size, batch_size)`` plus one chunk's rows. Concatenating a fill + into one shuffleable array transiently allocates a second copy of those + rows, so peak allocation is about twice that bound. + + Each epoch (each call of the returned factory) draws a fresh permutation from + a sub-stream of ``seed``, so the shuffle order differs across epochs while + staying reproducible for a given ``seed``. + """ + if not _is_positive_int(batch_size): + raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") + if not _is_positive_int(buffer_size): + raise ValueError(f"buffer_size must be a positive integer, got {buffer_size!r}") + seed_seq = np.random.SeedSequence(seed) + + def factory() -> Iterator[np.ndarray]: + # A fresh sub-stream per epoch: re-iterating reshuffles instead of + # replaying one fixed permutation, yet stays reproducible per seed. + rng = np.random.default_rng(seed_seq.spawn(1)[0]) + # A factory may return a re-iterable (a list of chunks, ...); normalize so + # each buffer fill continues one stream instead of restarting it forever. + it = iter(chunk_source()) + carry: np.ndarray | None = None + exhausted = False + # Accumulate at least one batch even when buffer_size < batch_size, + # otherwise the guard below would silently discard the whole stream. + target = max(buffer_size, batch_size) + while not exhausted: + bufs: list[np.ndarray] = [] + have = 0 + if carry is not None: + bufs.append(carry) + have += carry.shape[0] + carry = None + for arr in it: + a = np.asarray(arr) + bufs.append(a) + have += a.shape[0] + if have >= target: + break + else: + exhausted = True + if have < batch_size: + # Only reachable once the source is exhausted: drop the final + # partial batch. + return + buf = np.concatenate(bufs, axis=0) + rng.shuffle(buf) + n_full = buf.shape[0] // batch_size + for i in range(n_full): + yield buf[i * batch_size : (i + 1) * batch_size] + rem = buf.shape[0] - n_full * batch_size + carry = buf[n_full * batch_size :].copy() if rem else None + + # Forward a known row count so total_size="auto" stays metadata-cheap + # through the shuffle wrapper. + source_n_rows = getattr(chunk_source, "n_rows", None) + if source_n_rows is not None: + factory.n_rows = source_n_rows # type: ignore[attr-defined] + + return factory + + +def _promote_to_block(a: np.ndarray, sample_shape: tuple[int, ...]) -> np.ndarray: + """Return ``a`` as a ``(rows, *sample_shape)`` block; a single sample becomes one row.""" + if a.shape == sample_shape: + return a[None, ...] + if a.ndim != len(sample_shape) + 1 or a.shape[1:] != sample_shape: + raise ValueError( + f"source yielded shape {a.shape}; expected one sample of shape " + f"{sample_shape} or a (rows, *sample_shape) block; if the source is " + f"right, declare its trailing shape with DataLoader(sample_shape=...)" + ) + return a + + +def _block_factory( + factory: Callable[[], Iterator[np.ndarray]], + sample_shape: tuple[int, ...], +) -> Callable[[], Iterator[np.ndarray]]: + """Wrap ``factory`` so every yield is a block, promoting single samples. + + :func:`shuffle_buffer` counts and concatenates yields along the leading axis, + so single-sample yields (e.g. the rows of a raw array) must be promoted to + one-row blocks before shuffling. A known ``.n_rows`` is forwarded. + """ + + def f() -> Iterator[np.ndarray]: + for arr in factory(): + yield _promote_to_block(np.asarray(arr), sample_shape) + + n_rows = getattr(factory, "n_rows", None) + if n_rows is not None: + f.n_rows = n_rows # type: ignore[attr-defined] + return f + + +def _rebatch( + blocks: Iterable[np.ndarray], + batch_size: int, + sample_shape: tuple[int, ...], +) -> Iterator[np.ndarray]: + """Slice a stream of samples/blocks into exact ``batch_size``-row batches, in order. + + Accepts single samples (shape ``sample_shape``, e.g. the rows of a raw array) + and blocks of any size (shape ``(rows, *sample_shape)``), carrying remainders + across blocks so no row is lost mid-stream. Trailing rows that do not fill a + final batch are dropped when the stream ends (``drop_last=True`` behavior; the + model observes a fixed-shape placeholder, so a partial batch cannot be fed). + Sources that already yield exact ``batch_size`` blocks (e.g. + :func:`shuffle_buffer`) pass through without copying. + """ + buf: list[np.ndarray] = [] + have = 0 + for arr in blocks: + a = _promote_to_block(np.asarray(arr), sample_shape) + buf.append(a) + have += a.shape[0] + if have < batch_size: + continue + merged = np.concatenate(buf, axis=0) if len(buf) > 1 else buf[0] + n_full = merged.shape[0] // batch_size + for i in range(n_full): + yield merged[i * batch_size : (i + 1) * batch_size] + rem = merged.shape[0] - n_full * batch_size + buf = [merged[n_full * batch_size :].copy()] if rem else [] + have = rem + + +def _make_factory( + source: Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]], +) -> Callable[[], Iterator[np.ndarray]]: + """Coerce ``source`` into a zero-arg callable returning a fresh iterator. + + A callable that is not itself an iterator is treated as the factory; a bare + iterator is wrapped (and refuses a second epoch); any other iterable (incl. an + :class:`IterableDataset`) is re-``iter``-ed each epoch. A known ``.n_rows`` is + forwarded onto the returned factory so ``total_size="auto"`` stays cheap. + """ + if callable(source) and not isinstance(source, Iterator): + # A factory may return any iterable (a list of batches, a generator, ...); + # normalize so the loader always pulls from a true iterator. + def _factory() -> Iterator[np.ndarray]: + return iter(source()) # type: ignore[operator] + + elif isinstance(source, Iterator): + consumed = {"done": False} + + def _factory() -> Iterator[np.ndarray]: + if consumed["done"]: + raise RuntimeError( + "source is a bare iterator and was already consumed; the loader " + "restarts the stream each epoch, so pass a zero-arg factory or a " + "re-iterable instead" + ) + consumed["done"] = True + return source + + else: + + def _factory() -> Iterator[np.ndarray]: + return iter(source) + + n_rows = getattr(source, "n_rows", None) + if n_rows is not None: + _factory.n_rows = n_rows # type: ignore[attr-defined] + return _factory + + +def _auto_total_size( + factory: Callable[[], Iterator[np.ndarray]], + source: object, + sample_shape: tuple[int, ...] = (), +) -> int: + """Resolve ``total_size="auto"``: a source ``.n_rows`` (cheap) else a counting pass. + + Fast path: if ``source`` advertises ``.n_rows`` (e.g. :func:`parquet_source`, which + reads it from Parquet metadata without scanning the data) use it directly. Otherwise + do a single counting pass over a finite, re-readable source. A bare one-shot iterator + cannot be auto-counted (counting consumes it) and an infinite stream would make the + pass hang; both must pass ``total_size`` explicitly. + """ + n = getattr(source, "n_rows", None) + if n is None: + n = getattr(factory, "n_rows", None) + if n is not None: + if not _is_positive_int(n): + raise ValueError(f"source.n_rows must be a positive integer, got {n!r}") + return int(n) + if isinstance(source, Iterator): + raise ValueError( + "total_size='auto' needs a re-readable source (a zero-arg factory or an " + "iterable), not a one-shot iterator; pass total_size=N explicitly instead." + ) + warnings.warn( + "total_size='auto' is doing a full counting pass over the source; for a cheap " + "path use a source exposing .n_rows (e.g. parquet_source, from Parquet metadata).", + UserWarning, + stacklevel=3, + ) + first_iter = factory() + count = 0 + for chunk in first_iter: + a = np.asarray(chunk) + # A yield of shape exactly `sample_shape` is one sample, not a block. + count += 1 if a.shape == sample_shape else int(a.shape[0]) + if count <= 0: + raise ValueError("total_size='auto' counted 0 rows (empty or non-re-readable source).") + # A genuine factory yields a fresh, non-empty stream each call; one that + # returns the same exhausted iterator (or a new generator over consumed + # state) would leave the loader with nothing to stream. The probe costs one + # chunk, which the counting pass has already dwarfed. + second_iter = factory() + if second_iter is first_iter or next(second_iter, None) is None: + raise ValueError( + "total_size='auto' counted rows but the factory's next stream was empty " + "(it returns the same one-shot iterator, or closes over an already-" + "consumed one); pass a factory that creates a fresh iterator each call, " + "or total_size=N explicitly." + ) + return count + + +class _ParquetDataset(IterableDataset): + """An :class:`IterableDataset` over a directory of Parquet shards. + + Yields one ``(rows, n_columns)`` array per row group (so peak read memory is + one row group, not one file), in the fixed column order chosen at + construction, and exposes :attr:`n_rows` read from Parquet metadata (no data + scan). + """ + + def __init__(self, paths: list[str], columns: list[str], n_rows: int): + self._paths = paths + self._columns = columns + self.n_rows = n_rows + + def __iter__(self) -> Iterator[np.ndarray]: + import pyarrow as pa + import pyarrow.parquet as pq + + for path in self._paths: + file = pq.ParquetFile(path) + schema = file.schema_arrow + missing = [c for c in self._columns if c not in schema.names] + if missing: + # read_row_group(columns=...) silently drops unknown names, so a + # malformed shard must be named here, not surface as a bare + # KeyError with no path. + raise ValueError(f"columns {missing} not found in {path!r}") + non_numeric = [ + c + for c in self._columns + if not ( + pa.types.is_integer(schema.field(c).type) + or pa.types.is_floating(schema.field(c).type) + or pa.types.is_boolean(schema.field(c).type) + ) + ] + if non_numeric: + # parquet_source validates types against the first shard only; a + # later shard whose column turned non-numeric would otherwise + # become an object array and fail at the batch cast with no path. + raise ValueError( + f"columns {non_numeric} in {path!r} are not numeric and cannot be " + f"streamed into a float batch; select numeric columns with columns=." + ) + for i in range(file.metadata.num_row_groups): + table = file.read_row_group(i, columns=self._columns) + # Stack by the frozen column names, not the file's own order, so + # a shard with a permuted schema cannot silently swap features. + yield np.column_stack([table.column(c).to_numpy() for c in self._columns]) + + +def parquet_source( + directory: str, + *, + columns: list[str] | None = None, + pattern: str = "*.parquet", +) -> _ParquetDataset: + """An :class:`IterableDataset` over a directory of Parquet files. + + Yields one ``(rows, n_columns)`` array per row group (one or more per file), + so peak read memory is one row group, not one file. The column order is + frozen at construction — ``columns`` if given, else the first file's schema + order — and every shard is read in that order, so a shard with a permuted + schema cannot silently reorder features mid-stream. Carries an ``n_rows`` + attribute read from Parquet metadata (no data scan) so that + ``DataLoader(parquet_source(dir), ..., total_size="auto")`` resolves the + dataset size for free. Pass ``shuffle=True`` to the :class:`DataLoader` (or + wrap in :func:`shuffle_buffer`) to get shuffled batches. + """ + # pyarrow is an optional dependency, so it is imported on use. + import pyarrow as pa + import pyarrow.parquet as pq + + paths = sorted(glob.glob(os.path.join(directory, pattern))) + if not paths: + raise ValueError(f"no Parquet files match {os.path.join(directory, pattern)!r}") + schema = pq.read_schema(paths[0]) + if columns is None: + columns = list(schema.names) + else: + missing = sorted(set(columns) - set(schema.names)) + if missing: + raise ValueError( + f"columns {missing} not found in {paths[0]!r}; available: {sorted(schema.names)}" + ) + non_numeric = [ + c + for c in columns + if not ( + pa.types.is_integer(schema.field(c).type) + or pa.types.is_floating(schema.field(c).type) + or pa.types.is_boolean(schema.field(c).type) + ) + ] + if non_numeric: + # A string/dictionary column would turn whole chunks object-dtype and only + # fail later at the batch cast, without naming the column. + raise ValueError( + f"columns {non_numeric} in {paths[0]!r} are not numeric and cannot be " + f"streamed into a float batch; select numeric columns with columns=." + ) + n_rows = sum(pq.read_metadata(p).num_rows for p in paths) + return _ParquetDataset(paths, columns, n_rows) diff --git a/tests/variational/streaming_helpers.py b/tests/variational/streaming_helpers.py new file mode 100644 index 0000000000..4d7fab5a1a --- /dev/null +++ b/tests/variational/streaming_helpers.py @@ -0,0 +1,34 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared helpers for the streaming-dataset tests.""" + + +def chunked_factory(data, size): + """Return a zero-arg factory that replays ``data`` in ``size``-row chunks. + + A ``DataLoader`` restarts its source once per epoch, so the source has to be + re-readable. This returns a *factory* (a zero-arg callable) that produces a + fresh generator each call, the way an out-of-core source like + ``parquet_source`` does; a bare generator would be one-shot and could not be + replayed. The + final chunk may hold fewer than ``size`` rows -- the loader re-batches the + stream to ``batch_size`` regardless -- so this also exercises the loader's + re-batching across uneven source blocks. + """ + + def factory(): + for i in range(0, len(data), size): + yield data[i : i + size] + + return factory diff --git a/tests/variational/test_streaming.py b/tests/variational/test_streaming.py new file mode 100644 index 0000000000..175d145093 --- /dev/null +++ b/tests/variational/test_streaming.py @@ -0,0 +1,357 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +import pymc as pm + +from pymc.variational.streaming import ( + DataLoader, + IterableDataset, + shuffle_buffer, +) +from tests.variational.streaming_helpers import chunked_factory + + +def test_plain_loader_rebatches_arbitrary_blocks(): + """Blocks of 3 with batch_size=4 are re-batched in order; the trailing rows + that cannot fill a final batch are dropped (drop_last semantics).""" + data = np.arange(20, dtype="float64").reshape(10, 2) + ds = DataLoader(chunked_factory(data, 3), batch_size=4, sample_shape=(2,), total_size=10) + batches = list(ds) + assert [b.shape for b in batches] == [(4, 2), (4, 2)] + np.testing.assert_array_equal(np.concatenate(batches), data[:8]) + + +def test_raw_array_source_like_vi_rework_sketch(): + """A raw array works directly, as in the VI-rework sketch + ``Dataloader(np.random.normal(...), batch_size=...)``: rows are yielded one + sample at a time, re-batched, and counted as rows by total_size='auto'.""" + data = np.arange(40, dtype="float64").reshape(20, 2) + with pytest.warns(UserWarning, match="counting pass"): + ds = DataLoader(data, batch_size=8, sample_shape=(2,), total_size="auto") + assert ds.total_size == 20 + batches = list(ds) + assert [b.shape for b in batches] == [(8, 2), (8, 2)] + np.testing.assert_array_equal(np.concatenate(batches), data[:16]) + + +def test_wrong_sample_shape_rejected(): + """A source whose trailing shape does not match sample_shape raises.""" + data = np.zeros((12, 3)) + ds = DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(2,), total_size=12) + with pytest.raises(ValueError, match="source yielded shape"): + next(iter(ds)) + + +def test_total_size_none_warns_at_construction(): + """total_size=None disables the N/batch_size rescaling, so it warns.""" + data = np.zeros((8, 1)) + with pytest.warns(UserWarning, match="total_size=None"): + DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,)) + + +def test_preprocess_fn_applied(): + """preprocess_fn transforms each batch before it is yielded.""" + data = np.ones((8, 1)) + ds = DataLoader( + chunked_factory(data, 4), + batch_size=4, + sample_shape=(1,), + total_size=8, + preprocess_fn=lambda b: b * 3.0, + ) + np.testing.assert_array_equal(next(iter(ds)), np.full((4, 1), 3.0)) + + +def test_shuffle_buffer_conserves_rows_with_non_dividing_chunks(): + """Chunk and buffer sizes that do not divide batch_size must not lose or + duplicate rows; the remainder is carried into the next buffer fill.""" + data = np.arange(140, dtype="float64").reshape(140, 1) + src = shuffle_buffer(chunked_factory(data, 7), buffer_size=55, batch_size=10, seed=0) + batches = list(src()) + assert all(b.shape == (10, 1) for b in batches) + seen = np.sort(np.concatenate([b.ravel() for b in batches])) + np.testing.assert_array_equal(seen, data.ravel()) + + +def test_shuffle_buffer_does_not_mutate_source(): + """Shuffling happens on an owned copy, never in place on the source arrays.""" + data = np.arange(100, dtype="float64").reshape(100, 1) + original = data.copy() + src = shuffle_buffer(chunked_factory(data, 25), buffer_size=40, batch_size=10, seed=1) + list(src()) + np.testing.assert_array_equal(data, original) + + +def test_dataloader_shuffle_true_yields_full_batches(): + """shuffle=True wraps the source in a bounded shuffle_buffer; one epoch yields + full batches and conserves every row when N divides batch_size.""" + data = np.arange(120, dtype="float64").reshape(120, 1) + ds = DataLoader( + chunked_factory(data, 8), + batch_size=10, + shuffle=True, + buffer_size=40, + seed=0, + sample_shape=(1,), + total_size=120, + ) + batches = list(ds) + assert all(b.shape == (10, 1) for b in batches) + np.testing.assert_array_equal( + np.sort(np.concatenate([b.ravel() for b in batches])), data.ravel() + ) + + +def test_total_size_rescales_logp_like_minibatch(): + """total_size=len(loader) scales the observed minibatch log-likelihood by + exactly N / batch_size, through the same create_minibatch_rv mechanism as + pm.Minibatch: logp(scaled) == logp(plain) * N / batch_size.""" + rng = np.random.default_rng(0) + N, bs = 1000, 20 + data = rng.normal(size=(bs, 1)) + loader = DataLoader(lambda: iter([data]), batch_size=bs, sample_shape=(1,), total_size=N) + + with pm.Model() as scaled: + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", data) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + with pm.Model() as plain: + mu = pm.Normal("mu", 0, 1) + pm.Normal("y", mu, 1, observed=data[:, 0]) + + point = {"mu": np.array(0.3)} + obs_scaled = scaled.compile_logp(scaled.observed_RVs)(point) + obs_plain = plain.compile_logp(plain.observed_RVs)(point) + np.testing.assert_allclose(obs_scaled, obs_plain * (N / bs), rtol=1e-6) + + +def test_len_returns_total_size(): + """len(loader) is the dataset row count N, the value total_size needs.""" + data = np.zeros((40, 1)) + loader = DataLoader(chunked_factory(data, 8), batch_size=8, sample_shape=(1,), total_size=40) + assert len(loader) == 40 + + +def test_len_raises_when_total_size_none(): + """With total_size=None there is no N to hand the model, so len() raises + rather than silently skipping the N/batch_size rescaling.""" + data = np.ones((4, 1)) + with pytest.warns(UserWarning, match="total_size=None"): + loader = DataLoader(lambda: iter([data] * 5), batch_size=4, sample_shape=(1,)) + with pytest.raises(TypeError, match="total_size=None"): + len(loader) + + +def test_iter_yields_clean_batches_and_reiterates(): + """__iter__ yields validated (batch_size, *sample_shape) batches and can be + re-iterated for another epoch.""" + data = np.arange(40, dtype="float64").reshape(40, 1) + loader = DataLoader(chunked_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=40) + e1 = list(loader) + e2 = list(loader) + assert len(e1) == 4 and all(b.shape == (10, 1) for b in e1) + np.testing.assert_array_equal(np.sort(np.concatenate([b.ravel() for b in e1])), data.ravel()) + np.testing.assert_array_equal(np.sort(np.concatenate([b.ravel() for b in e2])), data.ravel()) + + +def test_total_size_zero_raises(): + """total_size=0 is falsy and would silently skip the rescaling, so it raises.""" + data = np.zeros((8, 1)) + with pytest.raises(ValueError, match="positive integer"): + DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=0) + + +def test_total_size_negative_raises(): + """A negative total_size would flip the sign of the data log-likelihood.""" + data = np.zeros((8, 1)) + with pytest.raises(ValueError, match="positive integer"): + DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=-100) + + +def test_shuffle_buffer_small_buffer_conserves_rows(): + """buffer_size < batch_size must not silently discard the dataset: the buffer + accumulates to at least batch_size before emitting.""" + data = np.arange(120, dtype="float64").reshape(120, 1) + src = shuffle_buffer(chunked_factory(data, 7), buffer_size=3, batch_size=10, seed=0) + batches = list(src()) + assert batches + assert all(b.shape == (10, 1) for b in batches) + seen = np.sort(np.concatenate([b.ravel() for b in batches])) + np.testing.assert_array_equal(seen, data.ravel()) + + +def test_shuffle_buffer_rejects_nonpositive_sizes(): + """Zero or negative buffer/batch sizes raise at construction.""" + data = np.zeros((10, 1)) + with pytest.raises(ValueError, match="buffer_size"): + shuffle_buffer(chunked_factory(data, 5), buffer_size=0, batch_size=4) + with pytest.raises(ValueError, match="batch_size"): + shuffle_buffer(chunked_factory(data, 5), buffer_size=10, batch_size=0) + + +def test_accepts_numpy_integer_sizes_rejects_bool(): + """The positive-int check uses numbers.Integral: numpy ints pass, bool does not.""" + data = np.zeros((8, 1)) + ds = DataLoader( + chunked_factory(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) + ) + assert next(iter(ds)).shape == (4, 1) + assert ds.batch_size == 4 + with pytest.raises(ValueError): + DataLoader(chunked_factory(data, 4), batch_size=True, sample_shape=(1,), total_size=8) + + +def test_shuffle_buffer_draws_fresh_permutation_each_epoch(): + """A seeded buffer must not replay one fixed permutation every epoch; each + epoch reshuffles while conserving rows.""" + data = np.arange(60, dtype="float64").reshape(60, 1) + factory = shuffle_buffer(chunked_factory(data, 10), buffer_size=60, batch_size=10, seed=0) + epoch1 = np.concatenate([b.ravel() for b in factory()]) + epoch2 = np.concatenate([b.ravel() for b in factory()]) + assert not np.array_equal(epoch1, epoch2) + np.testing.assert_array_equal(np.sort(epoch1), data.ravel()) + np.testing.assert_array_equal(np.sort(epoch2), data.ravel()) + + +def test_shuffle_buffer_seed_reproducible_across_runs(): + """The same seed gives an identical first-epoch order across constructions.""" + data = np.arange(60, dtype="float64").reshape(60, 1) + a = np.concatenate( + [ + b.ravel() + for b in shuffle_buffer( + chunked_factory(data, 10), buffer_size=60, batch_size=10, seed=7 + )() + ] + ) + b = np.concatenate( + [ + b.ravel() + for b in shuffle_buffer( + chunked_factory(data, 10), buffer_size=60, batch_size=10, seed=7 + )() + ] + ) + np.testing.assert_array_equal(a, b) + + +def test_sizes_normalized_to_python_int(): + """Numpy integer sizes are stored as plain Python ints so total_size is + accepted downstream by create_minibatch_rv.""" + data = np.zeros((8, 1)) + ds = DataLoader( + chunked_factory(data, 4), batch_size=np.int64(4), sample_shape=(1,), total_size=np.int64(8) + ) + assert type(ds.batch_size) is int + assert type(ds.total_size) is int + + +def test_numpy_total_size_accepted_by_observed_rv(): + """A numpy-integer total_size used to reach create_minibatch_rv and raise; the + normalized value must build and compile a valid observed RV.""" + data = np.zeros((4, 1), dtype="float64") + loader = DataLoader( + lambda: iter([data]), batch_size=4, sample_shape=(1,), total_size=np.int64(4) + ) + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", data) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=loader.total_size) + model.compile_logp(model.observed_RVs)({"mu": np.array(0.0)}) + + +def test_factory_returning_reiterable_is_accepted(): + """A zero-arg factory may return any iterable (e.g. a list), not just an + iterator.""" + data = [np.zeros((4, 1), dtype="float64")] + ds = DataLoader(lambda: data, batch_size=4, sample_shape=(1,), total_size=4) + assert next(iter(ds)).shape == (4, 1) + + +def test_raw_array_with_shuffle_true(): + """A raw array source composes with shuffle=True: rows are promoted to + one-row blocks before the shuffle buffer instead of being flattened by it.""" + data = np.arange(40, dtype="float64").reshape(20, 2) + ds = DataLoader( + data, batch_size=8, shuffle=True, buffer_size=16, seed=0, sample_shape=(2,), total_size=20 + ) + batches = list(ds) + assert [b.shape for b in batches] == [(8, 2), (8, 2)] + rows = {tuple(r) for b in batches for r in b} + assert len(rows) == 16 and rows <= {tuple(r) for r in data} + + +def test_scalar_raw_array_with_shuffle_true(): + """Scalar samples from a raw 1-D array compose with shuffle=True.""" + data = np.arange(12, dtype="float64") + ds = DataLoader( + data, batch_size=4, shuffle=True, buffer_size=6, seed=0, sample_shape=(), total_size=12 + ) + batches = list(ds) + assert [b.shape for b in batches] == [(4,), (4,), (4,)] + np.testing.assert_array_equal(np.sort(np.concatenate(batches)), data) + + +def test_scalar_samples_are_batched(): + """With sample_shape=() a 0-D yield is one scalar sample, exactly what + iterating a raw 1-D array produces; the loader batches scalars.""" + data = np.arange(6, dtype="float64") + ds = DataLoader(data, batch_size=3, sample_shape=(), total_size=6) + batches = list(ds) + assert [b.shape for b in batches] == [(3,), (3,)] + np.testing.assert_array_equal(np.concatenate(batches), data) + + +def test_iterable_dataset_base_is_abstract(): + """The base class is a contract: __iter__ must be overridden.""" + with pytest.raises(NotImplementedError): + iter(IterableDataset()) + + +def test_raw_2d_array_infers_sample_shape(): + """A raw 2-D array defaults sample_shape to its trailing shape, so the + VI-rework sketch ``DataLoader(arr, batch_size=...)`` batches rows instead of + flattening them into scalars.""" + data = np.arange(40, dtype="float64").reshape(20, 2) + with pytest.warns(UserWarning, match="counting pass"): + ds = DataLoader(data, batch_size=8, total_size="auto") + assert ds.total_size == 20 + batches = list(ds) + assert [b.shape for b in batches] == [(8, 2), (8, 2)] + np.testing.assert_array_equal(np.concatenate(batches), data[:16]) + + +def test_explicit_sample_shape_overrides_inference(): + """An explicit sample_shape=() reads each row of a 2-D array as a block of + scalar samples, the pre-inference behavior.""" + data = np.arange(40, dtype="float64").reshape(20, 2) + ds = DataLoader(data, batch_size=8, sample_shape=(), total_size=40) + batches = list(ds) + assert [b.shape for b in batches] == [(8,)] * 5 + + +def test_shuffle_buffer_accepts_factory_returning_reiterable(): + """A factory returning a re-iterable (which _make_factory tolerates for the + loader) must not restart per buffer fill and loop forever; the stream is + normalized to a single iterator.""" + data = np.arange(120, dtype="float64").reshape(120, 1) + chunks = [data[i : i + 20] for i in range(0, 120, 20)] + src = shuffle_buffer(lambda: chunks, buffer_size=50, batch_size=10, seed=0) + batches = list(src()) + assert len(batches) == 12 + np.testing.assert_array_equal( + np.sort(np.concatenate([b.ravel() for b in batches])), data.ravel() + ) diff --git a/tests/variational/test_streaming_autosize.py b/tests/variational/test_streaming_autosize.py new file mode 100644 index 0000000000..c0a71de0ea --- /dev/null +++ b/tests/variational/test_streaming_autosize.py @@ -0,0 +1,358 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""total_size='auto' resolution + the rows_streamed sanity warning.""" + +import warnings + +import numpy as np +import pytest + +from pymc.variational.streaming import ( + DataLoader, + IterableDataset, + parquet_source, + shuffle_buffer, +) +from tests.variational.streaming_helpers import chunked_factory + + +def test_auto_counts_finite_source(): + """Without .n_rows, 'auto' does one counting pass and resolves the true N.""" + data = np.arange(60, dtype="float64").reshape(60, 1) + with pytest.warns(UserWarning, match="counting pass"): + ds = DataLoader( + chunked_factory(data, 7), batch_size=10, sample_shape=(1,), total_size="auto" + ) + assert ds.total_size == 60 + + +def test_auto_uses_n_rows_fast_path(): + """A source-advertised .n_rows is trusted without a counting pass.""" + data = np.zeros((8, 1)) + f = chunked_factory(data, 4) + f.n_rows = 1000 + ds = DataLoader(f, batch_size=4, sample_shape=(1,), total_size="auto") + assert ds.total_size == 1000 + + +def test_auto_rejects_one_shot_iterator(): + """A bare generator would be consumed by the counting pass, so 'auto' refuses it.""" + data = np.zeros((20, 1)) + one_shot = (data[i : i + 4] for i in range(0, 20, 4)) + with pytest.raises(ValueError, match="re-readable"): + DataLoader(one_shot, batch_size=4, sample_shape=(1,), total_size="auto") + + +def test_shuffle_buffer_forwards_n_rows_for_auto(): + """shuffle_buffer forwards a known .n_rows so total_size='auto' works through + an explicit shuffle_buffer(parquet_source(...)) composition without counting.""" + data = np.arange(40, dtype="float64").reshape(40, 1) + src = chunked_factory(data, 8) + src.n_rows = 40 + wrapped = shuffle_buffer(src, buffer_size=20, batch_size=10, seed=0) + assert wrapped.n_rows == 40 + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + ds = DataLoader(wrapped, batch_size=10, sample_shape=(1,), total_size="auto") + assert ds.total_size == 40 + + +def test_dataloader_shuffle_auto_resolves_via_n_rows(): + """DataLoader(shuffle=True, total_size='auto') resolves N from the source's + .n_rows without a counting pass, even though shuffle wraps the source.""" + data = np.arange(40, dtype="float64").reshape(40, 1) + src = chunked_factory(data, 8) + src.n_rows = 40 + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + ds = DataLoader( + src, + batch_size=10, + shuffle=True, + buffer_size=20, + seed=0, + sample_shape=(1,), + total_size="auto", + ) + assert ds.total_size == 40 + + +def test_shuffle_buffer_without_n_rows_has_no_attribute(): + """A source without .n_rows must not gain a bogus one through the wrapper.""" + data = np.arange(40, dtype="float64").reshape(40, 1) + wrapped = shuffle_buffer(chunked_factory(data, 8), buffer_size=20, batch_size=10, seed=0) + assert not hasattr(wrapped, "n_rows") + + +def test_auto_rejects_factory_returning_same_one_shot_iterator(): + """A factory that returns the same already-consumed iterator each call is not + re-readable; the counting pass detects and refuses it.""" + data = np.zeros((20, 1)) + one_shot = (data[i : i + 4] for i in range(0, 20, 4)) + with ( + pytest.warns(UserWarning, match="counting pass"), + pytest.raises(ValueError, match="fresh iterator"), + ): + DataLoader(lambda: one_shot, batch_size=4, sample_shape=(1,), total_size="auto") + + +def test_auto_rejects_bad_n_rows(): + """A non-positive source .n_rows is rejected instead of trusted.""" + f = chunked_factory(np.zeros((8, 1)), 4) + f.n_rows = 0 + with pytest.raises(ValueError, match="n_rows must be a positive integer"): + DataLoader(f, batch_size=4, sample_shape=(1,), total_size="auto") + + +def test_sanity_warns_on_grossly_wrong_total_size(): + """A hand-passed total_size that grossly disagrees with the rows actually + streamed in one pass triggers the one-shot warning at the epoch boundary.""" + data = np.arange(20, dtype="float64").reshape(20, 1) + ds = DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=100) + with pytest.warns(UserWarning, match="disagrees with"): + list(ds._stream_batches()) + + +def test_sanity_silent_when_total_size_matches(): + """No warning when total_size matches the rows streamed in one pass.""" + data = np.arange(20, dtype="float64").reshape(20, 1) + ds = DataLoader(chunked_factory(data, 4), batch_size=4, sample_shape=(1,), total_size=20) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + list(ds._stream_batches()) + + +def test_parquet_source_n_rows_from_metadata(tmp_path): + """parquet_source reads n_rows from file metadata (no data scan) and + total_size='auto' picks it up without a counting pass.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + rng = np.random.default_rng(0) + total = 0 + for i in range(3): + n = 100 + 50 * i + total += n + block = rng.normal(size=(n, 2)) + pq.write_table( + pa.table({"a": block[:, 0], "b": block[:, 1]}), + f"{tmp_path}/part_{i:02d}.parquet", + ) + src = parquet_source(str(tmp_path)) + assert isinstance(src, IterableDataset) + assert src.n_rows == total + + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + ds = DataLoader(src, batch_size=10, sample_shape=(2,), total_size="auto") + assert ds.total_size == total + + +def test_parquet_source_columns_and_shard_order(tmp_path): + """columns= selects a column subset and shards are read in sorted path order.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + for i in range(2): + pq.write_table( + pa.table( + {"a": [float(i)] * 2, "b": [9.0] * 2, "c": [float(10 + i)] * 2}, + ), + f"{tmp_path}/part_{i}.parquet", + ) + src = parquet_source(str(tmp_path), columns=["a", "c"]) + blocks = list(src) + assert [b.shape for b in blocks] == [(2, 2), (2, 2)] + np.testing.assert_array_equal(blocks[0][:, 0], [0.0, 0.0]) + np.testing.assert_array_equal(blocks[1][:, 1], [11.0, 11.0]) + + +def test_parquet_source_empty_dir_raises(tmp_path): + """A directory with no matching Parquet files raises a clear error.""" + pytest.importorskip("pyarrow") + with pytest.raises(ValueError, match="no Parquet files match"): + parquet_source(str(tmp_path)) + + +def test_auto_counts_unshuffled_source_when_shuffling_non_divisible(): + """total_size='auto' with shuffle=True counts the unshuffled source: the + shuffle buffer drops the trailing partial batch, so counting through it would + undercount N by up to batch_size - 1 (here 125 vs 120).""" + data = np.arange(125, dtype="float64").reshape(125, 1) + with pytest.warns(UserWarning, match="counting pass"): + ds = DataLoader( + chunked_factory(data, 125), + batch_size=10, + shuffle=True, + buffer_size=30, + seed=0, + sample_shape=(1,), + total_size="auto", + ) + assert ds.total_size == 125 + + +def test_stream_batches_updates_counters_and_warns_on_wrong_total_size(): + """The accounting stream (``DataLoader._stream_batches``) updates the public + counters and fires the one-shot total_size sanity check at the epoch boundary, + while plain __iter__ stays side-effect-free.""" + data = np.arange(40, dtype="float64").reshape(20, 2) + ds = DataLoader( + chunked_factory(data, 5), + batch_size=5, + sample_shape=(2,), + total_size=10_000, + ) + assert ds.batches_seen == 0 and ds.rows_streamed == 0 + list(ds) + assert ds.batches_seen == 0 and ds.rows_streamed == 0 + with pytest.warns(UserWarning, match="disagrees with"): + batches = list(ds._stream_batches()) + assert len(batches) == 4 + assert ds.batches_seen == 4 + assert ds.rows_streamed == 20 + + +def test_sanity_silent_when_drop_last_truncates(): + """An exactly-correct total_size does not warn when batch_size does not + divide N: the trailing partial batch is dropped by design.""" + data = np.arange(25, dtype="float64").reshape(25, 1) + ds = DataLoader(chunked_factory(data, 5), batch_size=10, sample_shape=(1,), total_size=25) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + list(ds._stream_batches()) + + +def test_sanity_silent_for_auto_resolved_non_divisible_n(): + """total_size='auto' must not warn against the N it just resolved.""" + data = np.arange(25, dtype="float64").reshape(25, 1) + with pytest.warns(UserWarning, match="counting pass"): + ds = DataLoader( + chunked_factory(data, 5), batch_size=10, sample_shape=(1,), total_size="auto" + ) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + list(ds._stream_batches()) + + +def test_sanity_check_counts_the_completed_pass_not_cumulative_rows(): + """A partially consumed stray stream must not inflate the epoch-boundary + check: with a correct total_size, the next full pass stays silent.""" + data = np.arange(100, dtype="float64").reshape(100, 1) + ds = DataLoader(chunked_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=100) + stray = ds._stream_batches() + for _ in range(3): + next(stray) + with warnings.catch_warnings(): + warnings.simplefilter("error", UserWarning) + list(ds._stream_batches()) + + +def test_sanity_check_not_fooled_by_cumulative_rows_matching_total_size(): + """The converse: a wrong total_size that happens to equal the cumulative + row counter must still warn after a true full pass.""" + data = np.arange(100, dtype="float64").reshape(100, 1) + ds = DataLoader(chunked_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=130) + stray = ds._stream_batches() + for _ in range(3): + next(stray) + with pytest.warns(UserWarning, match="disagrees with"): + list(ds._stream_batches()) + + +def test_auto_rejects_factory_closing_over_consumed_iterator(): + """A generator function over a one-shot iterator returns a new (so not + identical) but empty stream after the counting pass; the re-read probe + catches it at construction.""" + data = np.zeros((20, 1)) + underlying = iter([data[i : i + 4] for i in range(0, 20, 4)]) + + def gen(): + yield from underlying + + with ( + pytest.warns(UserWarning, match="counting pass"), + pytest.raises(ValueError, match="fresh iterator"), + ): + DataLoader(gen, batch_size=4, sample_shape=(1,), total_size="auto") + + +def test_parquet_source_freezes_column_order_across_permuted_shards(tmp_path): + """A shard whose schema permutes the columns is read back in the first + shard's order instead of silently swapping features.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": [1.0, 1.0], "b": [10.0, 10.0]}), f"{tmp_path}/p0.parquet") + pq.write_table(pa.table({"b": [20.0, 20.0], "a": [2.0, 2.0]}), f"{tmp_path}/p1.parquet") + blocks = list(parquet_source(str(tmp_path))) + np.testing.assert_array_equal(blocks[0], [[1.0, 10.0], [1.0, 10.0]]) + np.testing.assert_array_equal(blocks[1], [[2.0, 20.0], [2.0, 20.0]]) + + +def test_parquet_source_streams_row_groups_not_whole_files(tmp_path): + """A multi-row-group file is yielded one row group at a time, so peak read + memory is a row group rather than the whole file.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": np.arange(30.0)}), f"{tmp_path}/p.parquet", row_group_size=10) + blocks = list(parquet_source(str(tmp_path))) + assert [b.shape for b in blocks] == [(10, 1), (10, 1), (10, 1)] + np.testing.assert_array_equal(np.concatenate(blocks).ravel(), np.arange(30.0)) + + +def test_parquet_source_names_a_later_shard_with_a_non_numeric_column(tmp_path): + """parquet_source type-checks only the first shard at construction; a later + shard whose same-named column turned non-numeric is caught at iteration with + that shard's path, not as an opaque float-cast error downstream.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": [1.0, 2.0]}), f"{tmp_path}/p0.parquet") + pq.write_table(pa.table({"a": ["bad", "worse"]}), f"{tmp_path}/p1.parquet") + src = parquet_source(str(tmp_path)) # construction sees only the numeric p0 + with pytest.raises(ValueError, match=r"p1\.parquet.*not numeric"): + list(src) + + +def test_parquet_source_rejects_non_numeric_columns(tmp_path): + """A string column cannot be streamed into a float batch; the default + all-columns freeze rejects it at construction, naming the column and the + columns= remedy, instead of failing later at the batch cast.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"x": [1.0, 2.0], "id": ["a", "b"]}), f"{tmp_path}/p.parquet") + with pytest.raises(ValueError, match="not numeric"): + parquet_source(str(tmp_path)) + src = parquet_source(str(tmp_path), columns=["x"]) + np.testing.assert_array_equal(next(iter(src)), [[1.0], [2.0]]) + + +def test_parquet_source_names_the_shard_missing_a_column(tmp_path): + """read_row_group silently drops unknown column names, so a later shard + missing a frozen column must raise an error that names that shard.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": [1.0], "b": [2.0]}), f"{tmp_path}/p0.parquet") + pq.write_table(pa.table({"a": [3.0]}), f"{tmp_path}/p1.parquet") + src = parquet_source(str(tmp_path)) + with pytest.raises(ValueError, match="p1.parquet"): + list(src) + + +def test_parquet_source_rejects_unknown_columns(tmp_path): + """A typo in columns= raises a clear ValueError at construction instead of a + pyarrow error at first iteration.""" + pa = pytest.importorskip("pyarrow") + pq = pytest.importorskip("pyarrow.parquet") + pq.write_table(pa.table({"a": [1.0], "b": [2.0]}), f"{tmp_path}/p.parquet") + with pytest.raises(ValueError, match="not found"): + parquet_source(str(tmp_path), columns=["a", "nope"]) diff --git a/tests/variational/test_streaming_trainer.py b/tests/variational/test_streaming_trainer.py new file mode 100644 index 0000000000..2e03ba78d5 --- /dev/null +++ b/tests/variational/test_streaming_trainer.py @@ -0,0 +1,280 @@ +# Copyright 2024 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trainer: drive variational inference over a DataLoader with no user callbacks.""" + +import numpy as np +import pytest + +import pymc as pm + +from pymc.variational.streaming import DataLoader, Trainer +from tests.variational.streaming_helpers import chunked_factory + + +def test_trainer_end_to_end_matches_in_ram_minibatch(): + """End-to-end: Trainer-driven streaming ADVI reproduces in-RAM pm.Minibatch ADVI. + + Exercises the whole API: a pm.Data placeholder, total_size=len(loader), and a + Trainer that streams minibatches into the placeholder with set_data while the + user writes no callbacks. Runs long enough to cycle the loader across epochs. + """ + seed = 0 + rng = np.random.default_rng(seed) + N, bs = 60_000, 2048 + X = rng.normal(size=(N, 2)) + b_true = np.array([0.3, -1.1, 0.7]) + y = (rng.random(N) < 1 / (1 + np.exp(-(b_true[0] + X @ b_true[1:])))).astype("float64") + data = np.column_stack([X, y]) + + with pm.Model(): + b = pm.Normal("b", 0, 3, shape=3) + xb, zb, yb = pm.Minibatch(X[:, 0].copy(), X[:, 1].copy(), y, batch_size=bs) + pm.Bernoulli("o", logit_p=b[0] + b[1] * xb + b[2] * zb, observed=yb, total_size=N) + ap = pm.fit( + 6000, + method="advi", + obj_optimizer=pm.adam(learning_rate=0.02), + progressbar=False, + random_seed=seed, + ) + in_ram = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) + + loader = DataLoader( + chunked_factory(data, 20_000), + batch_size=bs, + shuffle=True, + buffer_size=40_000, + seed=seed, + sample_shape=(3,), + total_size=N, + ) + with pm.Model() as model: + b = pm.Normal("b", 0, 3, shape=3) + batch = pm.Data("batch", np.zeros((bs, 3))) + pm.Bernoulli( + "o", + logit_p=b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1], + observed=batch[:, 2], + total_size=len(loader), + ) + ap = Trainer( + method="advi", + dataloader=loader, + data_name="batch", + obj_optimizer=pm.adam(learning_rate=0.02), + ).fit(6000, random_seed=seed) + stream = ap.sample(400).posterior["b"].values.reshape(-1, 3).mean(0) + + np.testing.assert_allclose(in_ram, stream, atol=0.1) + + +def test_trainer_streams_into_placeholder(): + """The Trainer seeds the pm.Data placeholder before step 0 (pm.fit runs + callbacks after each step) and overwrites it each step; after fitting it holds + a real batch, not the zero seed.""" + data = np.ones((4, 1)) + loader = DataLoader(lambda: iter([data] * 100), batch_size=4, sample_shape=(1,), total_size=4) + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + Trainer(method="advi", dataloader=loader, data_name="batch").fit( + 5, progressbar=False, random_seed=0 + ) + np.testing.assert_array_equal(model["batch"].get_value(), data) + + +def test_trainer_raises_when_loader_cannot_restart(): + """A source that streams one epoch and then comes back empty cannot be cycled; + the Trainer surfaces a clear error instead of training on stale data.""" + calls = {"n": 0} + + def factory(): + calls["n"] += 1 + if calls["n"] == 1: + yield np.zeros((4, 1)) + + loader = DataLoader(factory, batch_size=4, sample_shape=(1,), total_size=4) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + with pytest.raises(RuntimeError, match="yielded no batches"): + Trainer(method="advi", dataloader=loader, data_name="batch").fit( + 5, progressbar=False, random_seed=0 + ) + + +def test_trainer_rejects_non_dataloader(): + """The isinstance guard fires before any model lookup.""" + with pytest.raises(TypeError, match="DataLoader"): + Trainer(method="advi", dataloader=object()).fit(10) + + +def test_trainer_appends_user_callbacks_and_streams_distinct_batches(): + """User callbacks (e.g. convergence trackers) compose with the internal + advance callback instead of colliding on the keyword, and the placeholder + holds a different batch on successive steps. Also exercises the default + data_name ("batch").""" + blocks = [np.full((4, 1), float(i)) for i in range(60)] + loader = DataLoader(lambda: iter(blocks), batch_size=4, sample_shape=(1,), total_size=240) + seen = [] + with pm.Model() as model: + x = pm.Normal("x", 0.0, 1.0) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", x, 1.0, observed=batch[:, 0], total_size=len(loader)) + Trainer(method="advi", dataloader=loader).fit( + 5, callbacks=[lambda *_: seen.append(float(model["batch"].get_value()[0, 0]))] + ) + assert len(seen) == 5 + assert len(set(seen)) > 1 + + +def test_trainer_accepts_inference_instance(): + """An Inference instance is forwarded to pm.fit unchanged; it is bound to + the model it was built under, so the Trainer only streams the batches.""" + data = np.ones((4, 1)) + loader = DataLoader(lambda: iter([data] * 50), batch_size=4, sample_shape=(1,), total_size=4) + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + approx = Trainer(method=pm.ADVI(random_seed=0), dataloader=loader).fit(5) + assert len(approx.hist) == 5 + np.testing.assert_array_equal(model["batch"].get_value(), data) + + +def test_constructor_fit_kwargs_take_random_seed(): + """random_seed works as a constructor default, as the docstring promises, + and a per-call value overrides the constructor's.""" + data = np.ones((4, 1)) + + def fit_with(ctor_kwargs, fit_kwargs): + loader = DataLoader( + lambda: iter([data] * 50), batch_size=4, sample_shape=(1,), total_size=4 + ) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + return Trainer(method="advi", dataloader=loader, data_name="batch", **ctor_kwargs).fit( + 5, **fit_kwargs + ) + + a = fit_with({"random_seed": 7}, {}) + b = fit_with({"random_seed": 0}, {"random_seed": 7}) + np.testing.assert_array_equal(a.hist, b.hist) + + +def test_fit_consumes_exactly_n_batches(): + """fit(n) consumes exactly n minibatches: one seeds the placeholder before + step 0 and the advance after the final step is skipped, so an (n+1)-th + batch is never fetched.""" + blocks = [np.full((2, 1), float(i)) for i in range(2)] + loader = DataLoader(lambda: iter(blocks), batch_size=2, sample_shape=(1,), total_size=4) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((2, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + Trainer(method="advi", dataloader=loader).fit(3, random_seed=0) + assert loader.batches_seen == 3 + assert loader.rows_streamed == 6 + + +def test_fit_one_step_on_single_batch_one_shot_source(): + """A finite stream with exactly the batches needed must not be over-consumed: + fit(1) on a one-batch, one-shot source trains and returns instead of failing + on a post-final restart.""" + loader = DataLoader(iter([np.ones((2, 1))]), batch_size=2, sample_shape=(1,), total_size=2) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((2, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + approx = Trainer(method="advi", dataloader=loader).fit(1, random_seed=0) + assert len(approx.hist) == 1 + assert loader.batches_seen == 1 + + +def test_refine_after_fit_resumes_the_stream(): + """Inference.refine replays pm.fit's saved callbacks. Because the advance + skips only fit's own final step (and not every step past n), refine resumes + advancing the stream instead of going permanently dead on the last batch. + + refine does not re-seed, so its first step still trains on the batch fit left + in the placeholder; this pins that resume-not-reseed behavior with distinct + batch markers rather than claiming every refine step is fresh. + """ + blocks = [np.full((4, 1), float(i)) for i in range(50)] + loader = DataLoader(lambda: iter(blocks), batch_size=4, sample_shape=(1,), total_size=4) + sets = [] + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((4, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + original = model.set_data + model.set_data = lambda name, values, *a, **k: ( # type: ignore[method-assign] + sets.append(float(np.asarray(values)[0, 0])), + original(name, values, *a, **k), + )[1] + inference = pm.ADVI(random_seed=0) + Trainer(method=inference, dataloader=loader).fit(3) + assert sets == [0.0, 1.0, 2.0] # fit seeds 0, advances to 1 and 2, skips its last + sets.clear() + inference.refine(4, progressbar=False) + # refine resumes from where the stream stopped (3, 4, 5, ...), not stuck on 2 + assert sets == [3.0, 4.0, 5.0, 6.0] + assert loader.batches_seen == 7 + + +def test_total_size_check_fires_when_fit_ends_at_pass_boundary(): + """fit(n) with n exactly the batches in one pass still runs the total_size + sanity check: the stream is kept one batch ahead, so stopping at the + boundary does not abandon the check right before it would fire.""" + data = np.zeros((40, 1)) + loader = DataLoader(chunked_factory(data, 10), batch_size=10, sample_shape=(1,), total_size=400) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((10, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + with pytest.warns(UserWarning, match="disagrees with"): + Trainer(method="advi", dataloader=loader).fit(4, random_seed=0) + + +def test_fit_rejects_nonpositive_n(): + """fit consumes the seed batch before pm.fit could reject n itself, so a + non-positive n is refused up front, before touching the stream.""" + loader = DataLoader( + lambda: iter([np.zeros((2, 1))]), batch_size=2, sample_shape=(1,), total_size=2 + ) + with pm.Model(): + mu = pm.Normal("mu", 0, 1) + batch = pm.Data("batch", np.zeros((2, 1))) + pm.Normal("y", mu, 1, observed=batch[:, 0], total_size=len(loader)) + with pytest.raises(ValueError, match="positive integer"): + Trainer(method="advi", dataloader=loader).fit(0) + assert loader.batches_seen == 0 + + +def test_unknown_data_name_raises_before_consuming(): + """A data_name that is not in the model raises a guided KeyError before any + batch is pulled from the loader.""" + loader = DataLoader( + lambda: iter([np.zeros((4, 1))] * 3), batch_size=4, sample_shape=(1,), total_size=4 + ) + with pm.Model(): + pm.Normal("mu", 0, 1) + with pytest.raises(KeyError, match="pm.Data placeholder"): + Trainer(method="advi", dataloader=loader, data_name="nope").fit(2) + assert loader.batches_seen == 0 + assert loader.rows_streamed == 0