Skip to content

Streaming variational inference: Trainer for minibatch ADVI#8333

Draft
YichengYang-Ethan wants to merge 28 commits into
pymc-devs:mainfrom
YichengYang-Ethan:streaming-trainer
Draft

Streaming variational inference: Trainer for minibatch ADVI#8333
YichengYang-Ethan wants to merge 28 commits into
pymc-devs:mainfrom
YichengYang-Ethan:streaming-trainer

Conversation

@YichengYang-Ethan

Copy link
Copy Markdown

Follow-up to #8325.

pm.Minibatch random-indexes a fully-resident array (peak memory O(N)).
StreamingDataset feeds minibatches from an arbitrary source into a small
pytensor.shared buffer (peak memory O(batch_size)), reusing the existing
total_size / create_minibatch_rv rescaling unchanged. Adds a shuffle_buffer
helper and an equivalence test (streaming ADVI == in-RAM pm.Minibatch ADVI).
Close three silent-corruption holes found in a 5-lens review:
- reject total_size <= 0 in __init__: 0 is falsy and skips the N/batch_size
  rescaling entirely (posterior collapses to prior); negative flips the data
  log-likelihood's sign via get_scaling.
- shuffle_buffer now accumulates max(buffer_size, batch_size) rows before
  emitting, so buffer_size < batch_size no longer silently discards the whole
  stream; also validate buffer_size/batch_size as positive ints.
- positive-int checks use numbers.Integral (accept numpy ints, reject bool).

+5 regression tests; existing 10 unchanged and passing.
A seeded shuffle_buffer rebuilt its RNG from the same seed on every factory
call, so under cycle=True every epoch replayed one fixed permutation -- which
weakens the very mixing the buffer exists to provide and compounds the
block-shuffle bias on ordered data. Derive a fresh per-epoch sub-stream from a
SeedSequence so the order differs across epochs while staying reproducible for
a given seed. +2 tests.
Cuts the "user must pass total_size" burden (open question pymc-devs#1 for the design review):

- total_size="auto" resolves N from a source's .n_rows (cheap -- e.g. Parquet
  footer metadata via the new parquet_source) else one counting pass over a
  finite, re-readable source. One-shot / infinite sources still pass total_size
  explicitly (and are rejected with a clear error under "auto").
- a free sanity check using the existing rows_streamed counter: at the first
  epoch boundary, warn if total_size grossly disagrees with the rows actually
  streamed in one pass (catches a wrong-but-positive total_size).
- parquet_source(directory): a finite, re-readable source carrying .n_rows read
  from Parquet metadata (no data scan).

+7 tests; the existing 17 are unchanged and still pass.
…tion

An adversarial re-review surfaced edge cases the first hardening pass missed:

- total_size / batch_size: numpy integers were accepted but stored unchanged,
  so a stored np.int64 reached create_minibatch_rv and raised "Invalid type
  for total_size". Normalize to Python int at construction.
- _make_factory: a zero-arg factory returning a non-iterator iterable (e.g. a
  list) crashed in __next__ ("'list' object is not an iterator"); wrap in iter().
- total_size="auto": a factory that returns the same one-shot iterator each call
  now raises, instead of leaving the first advance() empty.
- fit_callback: seeds the buffer by default. PyMC runs callbacks after each
  step, so an unseeded first step trained on the zero-initialized placeholder.
- _validate: a 0-D batch now raises a clear ValueError instead of IndexError.

Adds 7 regression tests (31 total).
…ze="auto"

shuffle_buffer now propagates a known .n_rows (e.g. parquet_source's, read from
Parquet metadata) to its wrapped factory, so the common composition

    StreamingDataset(shuffle_buffer(parquet_source(dir)), total_size="auto")

resolves N for free instead of doing a full counting pass over the data. The
only discrepancy is the single dropped trailing partial batch (< batch_size
rows), which is within the auto-size sanity tolerance.

Adds 2 regression tests (33 total).
…iner

Design-review feedback from Rob (mentor): the streaming API should mirror
torch.utils.data so the mental model transfers, and the user-facing callback
should go away in favour of a Lightning-style trainer.

- IterableDataset: re-iterable out-of-core source base (parquet_source now
  returns one); carries an optional .n_rows for total_size="auto".
- DataLoader: the former StreamingDataset, renamed; gains PyTorch-style
  shuffle=/buffer_size=/seed= (wraps shuffle_buffer internally). Still owns the
  fixed pytensor.shared buffer the model observes; advance()/as_tensor() kept.
- Trainer: Trainer(method="advi").fit(model, loader, n) drives VI with NO
  user-facing callbacks -- it seeds the buffer and advances it each step
  internally. The per-step advance is wired into pm.fit privately.

All hardening preserved (int normalization, total_size guards + "auto",
shuffle row-conservation + per-epoch reshuffle, copy-before-borrow, validation).
shuffle_buffer/parquet_source stay public. 36 tests pass (1 skipped: pyarrow).

total_size still appears in the model (total_size=loader.total_size); removing
it is an open design question for Rob -- see notes. It is compiled into the logp
graph at register_rv time (MinibatchRandomVariable Op), so fit-time injection
needs either Trainer graph surgery or a dims-based rule in core.
Follows jessegrabowski/pymc VI_Overview.ipynb (the VI rework Rob/Jesse are
building) instead of my ad-hoc shapes:

- DataLoader.__len__ == total_size N (sized like a PyTorch DataLoader), and
  __iter__ yields the validated minibatch stream. This is the answer to Rob's
  open question: total_size leaves the model and becomes len(loader).
- Trainer takes (method=, dataloader=, model=, data_name=) and fit(n); it streams
  each minibatch into the model's pm.Data placeholder via model.set_data, so the
  model is fully decoupled from the loader and the user writes no callbacks.
- Model idiom is now pm.Data("batch", placeholder) + total_size=len(loader),
  matching the blueprint; verified end-to-end (recovers in-RAM pm.Minibatch ADVI).
- Kept the as_tensor()/advance() shared-buffer path as a documented advanced
  escape hatch; dropped the now-unused _seed_buffer/_advance_callback.

38 tests pass (1 skipped: pyarrow). Open for Rob: spelling DataLoader (PyTorch,
per his "match PyTorch") vs Dataloader (Jesse's draft); method-as-string until
the ADVI(Inference).step rework lands.
- Trainer's stream now updates batches_seen/rows_streamed and runs the
  one-shot total_size sanity check at each epoch boundary (previously dead
  on the Trainer path; __iter__ stays side-effect-free).
- total_size="auto" with shuffle=True counts the unshuffled source, fixing
  an undercount of up to batch_size-1 rows.
- Trainer default data_name "data" -> "batch" to match the examples/tests.
- Clarify len(loader)==N (rows, not batches) in docstrings; raise a clear
  error when a cycled source restarts empty.
- Register the streaming API in docs/source/api/vi.rst.
- Add regression tests for the auto-size shuffle count and Trainer counters.
The non-shuffle path previously required the source to yield exact
batch_size blocks and raised on anything else, while the docstrings
promised re-batching. Now both paths re-batch: blocks of any size are
sliced in order with remainders carried across blocks, and a raw array
(or any single-sample stream) is accepted directly, so the VI-rework
sketch usage Dataloader(<array>, batch_size=...) works as written.
Trailing rows that do not fill a final batch are dropped, like
drop_last=True in torch, since the model observes a fixed-shape
placeholder.

Also: total_size="auto" counts a single-sample stream as rows rather
than flattened elements; Trainer.fit(callbacks=...) appends user
callbacks after the internal advance instead of raising a duplicate
keyword error.
- Drop the shared-buffer path (as_tensor/advance and the cycle/name
  parameters): neither exists in torch.utils.data and the Trainer never
  used it. Manual stepping stays available through plain iteration plus
  set_data.
- Move modelcontext/fit imports to module level.
- Replace test comments with docstrings, drop redundant comments and
  section banners, and rename the reshuffle test descriptively.
shuffle_buffer concatenates yields along the leading axis, so a raw
array source under shuffle=True had its rows flattened (2-D) or crashed
on shape[0] (scalars). Promote single samples to one-row blocks before
the shuffle wrap, with the same helper the re-batcher uses.

Also tighten a few docstring claims: the parquet dtype follows the file
columns, and the shuffle buffer bound is stated as rows held.
DataLoader infers sample_shape from a raw array source, so
DataLoader(arr, batch_size=...) batches rows instead of silently
flattening them to scalars. The total_size check no longer warns on an
exact N when drop-last truncates the final batch, and its advice covers
a wrong source n_rows. Trainer.fit routes all kwargs through one merge
so constructor defaults like random_seed work as documented, accepts an
Inference instance, and rejects an unknown data_name before consuming a
batch. parquet_source validates columns against the schema up front.
The shuffle_buffer docstring states the true buffer bound.
- Trainer.fit(n) consumes exactly n batches: the advance after the final
  step is skipped, so a finite source is not over-consumed
- the total_size sanity check counts the pass that completed instead of
  the cumulative row counter, which inflated across partial streams
- parquet_source freezes the column order at construction and reads one
  row group at a time, so a permuted shard schema cannot silently swap
  features and peak read memory is a row group, not a file
- warn at construction when a fixed-order loader would drop the same
  non-divisible tail every pass
- total_size='auto' probes that the factory can actually be re-read,
  catching factories that close over an already-consumed iterator
- document the shuffle-buffer transient concatenation copy and the
  full-buffer case
…diagnostics

- the internal advance skips only fit's own final step, so
  Inference.refine on a method instance keeps streaming instead of
  silently retraining on the last batch
- keep the rebatcher one batch ahead in the accounting stream, so the
  total_size sanity check still fires when fit(n) stops exactly at the
  pass boundary
- drop the fixed-order divisibility warning: it false-alarmed on the
  module's own pre-shuffled-on-disk example and on manual shuffle_buffer
  wrapping; the drop-last caveat lives in the docs instead
- validate n in Trainer.fit (fit(0) consumed the seed batch; fit(-1)
  failed deep inside PyTensor)
- normalize shuffle_buffer's factory output with iter(), which a
  re-iterable-returning factory would otherwise restart every fill
- parquet_source rejects non-numeric columns at construction and names
  the shard when a later file is missing a frozen column
- name the sample_shape remedy in the block-shape error; spell behavior
  consistently
The class summary still claimed the full dataset never enters memory in
the absolute; match the module docstring's bounded-source-chunks framing
and fix a double space.
- _ParquetDataset checks each shard's column types, so a later shard
  whose column turned non-numeric is named instead of failing as an
  opaque float cast downstream (parquet_source only saw the first shard)
- the fit docstring no longer says 'exactly n consumed'; it feeds exactly
  n batches to the model, but the one-batch lookahead can read a
  re-readable source one batch further
- the refine test now uses distinct batch markers and pins the honest
  resume-not-reseed behavior (its first step reuses fit's last batch)
  instead of only checking counters on all-ones data
Keep this PR to the dataset/loader layer (IterableDataset, DataLoader,
shuffle_buffer, parquet_source); the Trainer and its tests move to a stacked
follow-up PR. Fold the re-readable chunked-source factory the loader tests
share into tests/variational/streaming_helpers.py, which doubles as a place to
explain why a re-readable factory (not a one-shot generator) is needed.
Drive variational inference over a DataLoader with no user-facing callbacks:
Trainer(method=..., dataloader=...).fit(n) streams each minibatch into the
model's pm.Data placeholder once per step. Re-adds the Trainer class, its docs
entry, and the tests the DataLoader PR split out; the tests reuse the shared
chunked-source helper, and the CI subset gains test_streaming_trainer.py.
@welcome

welcome Bot commented Jun 16, 2026

Copy link
Copy Markdown

Thank You Banner]
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant