Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
30d0ce2
feat(variational): StreamingDataset for out-of-core minibatch VI
YichengYang-Ethan Jun 5, 2026
cc3658d
harden StreamingDataset validation (deep-review fixes)
YichengYang-Ethan Jun 5, 2026
2c47255
strengthen shuffle_buffer: reshuffle each epoch
YichengYang-Ethan Jun 5, 2026
90b5b83
prototype: total_size="auto" + rows_streamed sanity warning
YichengYang-Ethan Jun 6, 2026
a75e7cf
fix(streaming): normalize int sizes; harden factory, callback, valida…
YichengYang-Ethan Jun 6, 2026
7b9fe85
feat(streaming): forward shuffle_buffer's source .n_rows for total_si…
YichengYang-Ethan Jun 6, 2026
20d7870
refactor(streaming): adopt PyTorch-style Dataset/DataLoader + add Tra…
YichengYang-Ethan Jun 6, 2026
77229a8
refactor(streaming): align Trainer/DataLoader with VI-rework blueprint
YichengYang-Ethan Jun 7, 2026
03f3a76
Refine streaming VI after mentor design review
YichengYang-Ethan Jun 9, 2026
36f0955
Register streaming tests in CI matrix (check_all_tests_are_covered)
YichengYang-Ethan Jun 9, 2026
9b8a914
Re-batch arbitrary block sizes in the plain DataLoader path
YichengYang-Ethan Jun 10, 2026
32ec88f
Address review comments
YichengYang-Ethan Jun 10, 2026
c59cbb1
Fix mypy error comparing Integral with int
YichengYang-Ethan Jun 10, 2026
c5a029e
Match docstring punctuation to the rest of the codebase
YichengYang-Ethan Jun 10, 2026
6f9eed2
Follow numpydoc section conventions
YichengYang-Ethan Jun 10, 2026
45cb513
Declare __all__ like the neighboring modules
YichengYang-Ethan Jun 10, 2026
14568b1
Add docstrings to the last three tests
YichengYang-Ethan Jun 11, 2026
51ee22c
Plainer docstring wording
YichengYang-Ethan Jun 11, 2026
54f13e4
State the memory bound precisely
YichengYang-Ethan Jun 11, 2026
fc5eb11
Promote single samples before the shuffle buffer
YichengYang-Ethan Jun 11, 2026
5769fbf
Apply formatter
YichengYang-Ethan Jun 11, 2026
13a6a05
Fix loader and trainer edge cases; tighten docstring claims
YichengYang-Ethan Jun 11, 2026
4bbb470
Fix batch accounting and parquet edge cases from review
YichengYang-Ethan Jun 11, 2026
4b97774
Address fourth review round: refine support, boundary check, parquet …
YichengYang-Ethan Jun 12, 2026
ab178a0
Align the DataLoader summary with the module's bounded-chunks wording
YichengYang-Ethan Jun 13, 2026
cb85e8c
Tighten parquet type checks, fit/refine wording from review
YichengYang-Ethan Jun 13, 2026
796306d
Split the Trainer into a follow-up PR; share the loader test helper
YichengYang-Ethan Jun 16, 2026
d831969
Add the streaming Trainer (stacked on the DataLoader PR)
YichengYang-Ethan Jun 16, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ jobs:
linker: [cvm, numba]
python-version: ["3.12"]
test-subset:
- tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/test_initial_point.py
- tests/variational/test_approximations.py tests/variational/test_callbacks.py tests/variational/test_inference.py tests/variational/test_opvi.py tests/variational/test_streaming.py tests/variational/test_streaming_autosize.py tests/variational/test_streaming_trainer.py tests/test_initial_point.py
- tests/model/test_core.py tests/sampling/test_mcmc.py
- tests/gp/test_cov.py tests/gp/test_gp.py tests/gp/test_mean.py tests/gp/test_util.py tests/ode/test_ode.py tests/ode/test_utils.py tests/smc/test_smc.py tests/sampling/test_parallel.py
- tests/step_methods/test_metropolis.py tests/step_methods/test_slicer.py tests/step_methods/hmc/test_nuts.py tests/step_methods/test_compound.py tests/step_methods/hmc/test_hmc.py tests/step_methods/test_state.py
Expand Down
15 changes: 15 additions & 0 deletions docs/source/api/vi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,21 @@ Special

Stein

Streaming
---------
Out-of-core minibatching for variational inference on datasets that do not fit in
memory (see :mod:`pymc.variational.streaming`).

.. currentmodule:: pymc.variational
.. autosummary::
:toctree: generated/

DataLoader
IterableDataset
Trainer
shuffle_buffer
parquet_source

.. currentmodule:: pymc
.. autosummary::
:toctree: generated/
Expand Down
12 changes: 12 additions & 0 deletions pymc/variational/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@

# special
from pymc.variational.stein import Stein
from pymc.variational.streaming import (
DataLoader,
IterableDataset,
Trainer,
parquet_source,
shuffle_buffer,
)
from pymc.variational.updates import (
adadelta,
adagrad,
Expand All @@ -64,11 +71,14 @@
"ADVI",
"ASVGD",
"SVGD",
"DataLoader",
"Empirical",
"FullRank",
"FullRankADVI",
"Group",
"IterableDataset",
"MeanField",
"Trainer",
"adadelta",
"adagrad",
"adagrad_window",
Expand All @@ -80,8 +90,10 @@
"momentum",
"nesterov_momentum",
"norm_constraint",
"parquet_source",
"rmsprop",
"sample_approx",
"sgd",
"shuffle_buffer",
"total_norm_constraint",
)
889 changes: 889 additions & 0 deletions pymc/variational/streaming.py

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions tests/variational/streaming_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright 2024 - present The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared helpers for the streaming-dataset tests."""


def chunked_factory(data, size):
"""Return a zero-arg factory that replays ``data`` in ``size``-row chunks.

A ``DataLoader`` restarts its source once per epoch, so the source has to be
re-readable. This returns a *factory* (a zero-arg callable) that produces a
fresh generator each call, the way an out-of-core source like
``parquet_source`` does; a bare generator would be one-shot and could not be
replayed. The
final chunk may hold fewer than ``size`` rows -- the loader re-batches the
stream to ``batch_size`` regardless -- so this also exercises the loader's
re-batching across uneven source blocks.
"""

def factory():
for i in range(0, len(data), size):
yield data[i : i + size]

return factory
Loading