Streaming variational inference: Trainer for minibatch ADVI#8333
Draft
YichengYang-Ethan wants to merge 28 commits into
Draft
Streaming variational inference: Trainer for minibatch ADVI#8333YichengYang-Ethan wants to merge 28 commits into
YichengYang-Ethan wants to merge 28 commits into
Conversation
pm.Minibatch random-indexes a fully-resident array (peak memory O(N)). StreamingDataset feeds minibatches from an arbitrary source into a small pytensor.shared buffer (peak memory O(batch_size)), reusing the existing total_size / create_minibatch_rv rescaling unchanged. Adds a shuffle_buffer helper and an equivalence test (streaming ADVI == in-RAM pm.Minibatch ADVI).
Close three silent-corruption holes found in a 5-lens review: - reject total_size <= 0 in __init__: 0 is falsy and skips the N/batch_size rescaling entirely (posterior collapses to prior); negative flips the data log-likelihood's sign via get_scaling. - shuffle_buffer now accumulates max(buffer_size, batch_size) rows before emitting, so buffer_size < batch_size no longer silently discards the whole stream; also validate buffer_size/batch_size as positive ints. - positive-int checks use numbers.Integral (accept numpy ints, reject bool). +5 regression tests; existing 10 unchanged and passing.
A seeded shuffle_buffer rebuilt its RNG from the same seed on every factory call, so under cycle=True every epoch replayed one fixed permutation -- which weakens the very mixing the buffer exists to provide and compounds the block-shuffle bias on ordered data. Derive a fresh per-epoch sub-stream from a SeedSequence so the order differs across epochs while staying reproducible for a given seed. +2 tests.
Cuts the "user must pass total_size" burden (open question pymc-devs#1 for the design review): - total_size="auto" resolves N from a source's .n_rows (cheap -- e.g. Parquet footer metadata via the new parquet_source) else one counting pass over a finite, re-readable source. One-shot / infinite sources still pass total_size explicitly (and are rejected with a clear error under "auto"). - a free sanity check using the existing rows_streamed counter: at the first epoch boundary, warn if total_size grossly disagrees with the rows actually streamed in one pass (catches a wrong-but-positive total_size). - parquet_source(directory): a finite, re-readable source carrying .n_rows read from Parquet metadata (no data scan). +7 tests; the existing 17 are unchanged and still pass.
…tion
An adversarial re-review surfaced edge cases the first hardening pass missed:
- total_size / batch_size: numpy integers were accepted but stored unchanged,
so a stored np.int64 reached create_minibatch_rv and raised "Invalid type
for total_size". Normalize to Python int at construction.
- _make_factory: a zero-arg factory returning a non-iterator iterable (e.g. a
list) crashed in __next__ ("'list' object is not an iterator"); wrap in iter().
- total_size="auto": a factory that returns the same one-shot iterator each call
now raises, instead of leaving the first advance() empty.
- fit_callback: seeds the buffer by default. PyMC runs callbacks after each
step, so an unseeded first step trained on the zero-initialized placeholder.
- _validate: a 0-D batch now raises a clear ValueError instead of IndexError.
Adds 7 regression tests (31 total).
…ze="auto"
shuffle_buffer now propagates a known .n_rows (e.g. parquet_source's, read from
Parquet metadata) to its wrapped factory, so the common composition
StreamingDataset(shuffle_buffer(parquet_source(dir)), total_size="auto")
resolves N for free instead of doing a full counting pass over the data. The
only discrepancy is the single dropped trailing partial batch (< batch_size
rows), which is within the auto-size sanity tolerance.
Adds 2 regression tests (33 total).
…iner Design-review feedback from Rob (mentor): the streaming API should mirror torch.utils.data so the mental model transfers, and the user-facing callback should go away in favour of a Lightning-style trainer. - IterableDataset: re-iterable out-of-core source base (parquet_source now returns one); carries an optional .n_rows for total_size="auto". - DataLoader: the former StreamingDataset, renamed; gains PyTorch-style shuffle=/buffer_size=/seed= (wraps shuffle_buffer internally). Still owns the fixed pytensor.shared buffer the model observes; advance()/as_tensor() kept. - Trainer: Trainer(method="advi").fit(model, loader, n) drives VI with NO user-facing callbacks -- it seeds the buffer and advances it each step internally. The per-step advance is wired into pm.fit privately. All hardening preserved (int normalization, total_size guards + "auto", shuffle row-conservation + per-epoch reshuffle, copy-before-borrow, validation). shuffle_buffer/parquet_source stay public. 36 tests pass (1 skipped: pyarrow). total_size still appears in the model (total_size=loader.total_size); removing it is an open design question for Rob -- see notes. It is compiled into the logp graph at register_rv time (MinibatchRandomVariable Op), so fit-time injection needs either Trainer graph surgery or a dims-based rule in core.
Follows jessegrabowski/pymc VI_Overview.ipynb (the VI rework Rob/Jesse are
building) instead of my ad-hoc shapes:
- DataLoader.__len__ == total_size N (sized like a PyTorch DataLoader), and
__iter__ yields the validated minibatch stream. This is the answer to Rob's
open question: total_size leaves the model and becomes len(loader).
- Trainer takes (method=, dataloader=, model=, data_name=) and fit(n); it streams
each minibatch into the model's pm.Data placeholder via model.set_data, so the
model is fully decoupled from the loader and the user writes no callbacks.
- Model idiom is now pm.Data("batch", placeholder) + total_size=len(loader),
matching the blueprint; verified end-to-end (recovers in-RAM pm.Minibatch ADVI).
- Kept the as_tensor()/advance() shared-buffer path as a documented advanced
escape hatch; dropped the now-unused _seed_buffer/_advance_callback.
38 tests pass (1 skipped: pyarrow). Open for Rob: spelling DataLoader (PyTorch,
per his "match PyTorch") vs Dataloader (Jesse's draft); method-as-string until
the ADVI(Inference).step rework lands.
- Trainer's stream now updates batches_seen/rows_streamed and runs the one-shot total_size sanity check at each epoch boundary (previously dead on the Trainer path; __iter__ stays side-effect-free). - total_size="auto" with shuffle=True counts the unshuffled source, fixing an undercount of up to batch_size-1 rows. - Trainer default data_name "data" -> "batch" to match the examples/tests. - Clarify len(loader)==N (rows, not batches) in docstrings; raise a clear error when a cycled source restarts empty. - Register the streaming API in docs/source/api/vi.rst. - Add regression tests for the auto-size shuffle count and Trainer counters.
The non-shuffle path previously required the source to yield exact batch_size blocks and raised on anything else, while the docstrings promised re-batching. Now both paths re-batch: blocks of any size are sliced in order with remainders carried across blocks, and a raw array (or any single-sample stream) is accepted directly, so the VI-rework sketch usage Dataloader(<array>, batch_size=...) works as written. Trailing rows that do not fill a final batch are dropped, like drop_last=True in torch, since the model observes a fixed-shape placeholder. Also: total_size="auto" counts a single-sample stream as rows rather than flattened elements; Trainer.fit(callbacks=...) appends user callbacks after the internal advance instead of raising a duplicate keyword error.
- Drop the shared-buffer path (as_tensor/advance and the cycle/name parameters): neither exists in torch.utils.data and the Trainer never used it. Manual stepping stays available through plain iteration plus set_data. - Move modelcontext/fit imports to module level. - Replace test comments with docstrings, drop redundant comments and section banners, and rename the reshuffle test descriptively.
shuffle_buffer concatenates yields along the leading axis, so a raw array source under shuffle=True had its rows flattened (2-D) or crashed on shape[0] (scalars). Promote single samples to one-row blocks before the shuffle wrap, with the same helper the re-batcher uses. Also tighten a few docstring claims: the parquet dtype follows the file columns, and the shuffle buffer bound is stated as rows held.
DataLoader infers sample_shape from a raw array source, so DataLoader(arr, batch_size=...) batches rows instead of silently flattening them to scalars. The total_size check no longer warns on an exact N when drop-last truncates the final batch, and its advice covers a wrong source n_rows. Trainer.fit routes all kwargs through one merge so constructor defaults like random_seed work as documented, accepts an Inference instance, and rejects an unknown data_name before consuming a batch. parquet_source validates columns against the schema up front. The shuffle_buffer docstring states the true buffer bound.
- Trainer.fit(n) consumes exactly n batches: the advance after the final step is skipped, so a finite source is not over-consumed - the total_size sanity check counts the pass that completed instead of the cumulative row counter, which inflated across partial streams - parquet_source freezes the column order at construction and reads one row group at a time, so a permuted shard schema cannot silently swap features and peak read memory is a row group, not a file - warn at construction when a fixed-order loader would drop the same non-divisible tail every pass - total_size='auto' probes that the factory can actually be re-read, catching factories that close over an already-consumed iterator - document the shuffle-buffer transient concatenation copy and the full-buffer case
…diagnostics - the internal advance skips only fit's own final step, so Inference.refine on a method instance keeps streaming instead of silently retraining on the last batch - keep the rebatcher one batch ahead in the accounting stream, so the total_size sanity check still fires when fit(n) stops exactly at the pass boundary - drop the fixed-order divisibility warning: it false-alarmed on the module's own pre-shuffled-on-disk example and on manual shuffle_buffer wrapping; the drop-last caveat lives in the docs instead - validate n in Trainer.fit (fit(0) consumed the seed batch; fit(-1) failed deep inside PyTensor) - normalize shuffle_buffer's factory output with iter(), which a re-iterable-returning factory would otherwise restart every fill - parquet_source rejects non-numeric columns at construction and names the shard when a later file is missing a frozen column - name the sample_shape remedy in the block-shape error; spell behavior consistently
The class summary still claimed the full dataset never enters memory in the absolute; match the module docstring's bounded-source-chunks framing and fix a double space.
- _ParquetDataset checks each shard's column types, so a later shard whose column turned non-numeric is named instead of failing as an opaque float cast downstream (parquet_source only saw the first shard) - the fit docstring no longer says 'exactly n consumed'; it feeds exactly n batches to the model, but the one-batch lookahead can read a re-readable source one batch further - the refine test now uses distinct batch markers and pins the honest resume-not-reseed behavior (its first step reuses fit's last batch) instead of only checking counters on all-ones data
Keep this PR to the dataset/loader layer (IterableDataset, DataLoader, shuffle_buffer, parquet_source); the Trainer and its tests move to a stacked follow-up PR. Fold the re-readable chunked-source factory the loader tests share into tests/variational/streaming_helpers.py, which doubles as a place to explain why a re-readable factory (not a one-shot generator) is needed.
Drive variational inference over a DataLoader with no user-facing callbacks: Trainer(method=..., dataloader=...).fit(n) streams each minibatch into the model's pm.Data placeholder once per step. Re-adds the Trainer class, its docs entry, and the tests the DataLoader PR split out; the tests reuse the shared chunked-source helper, and the CI subset gains test_streaming_trainer.py.
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Follow-up to #8325.