Add pmd.CustomDist — dims-aware custom distribution for pymc.dims#8311
Add pmd.CustomDist — dims-aware custom distribution for pymc.dims#8311williambdean wants to merge 7 commits into
Conversation
These are orthogonal to having a dist argument. You can have dist with logp (or without, maybe it derives it). The only incompatible case is dist AND random, since they both represent the random path |
| return func | ||
|
|
||
|
|
||
| def _default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False): |
There was a problem hiding this comment.
I think this is just rv.zeros_like() ?
There was a problem hiding this comment.
no ones_like or zeros_like on xtensor FYI
import pytensor.xtensor as px
x = px.xtensor("x", dims=("covariate", ))
x.ones_like()---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[35], line 1
----> 1 px.xtensor("x", dims=("covariate", )).ones_like()
AttributeError: 'XTensorVariable' object has no attribute 'ones_like'
but can use px.zeros_like
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #8311 +/- ##
===========================================
- Coverage 91.72% 79.85% -11.88%
===========================================
Files 125 126 +1
Lines 20526 20716 +190
===========================================
- Hits 18828 16542 -2286
- Misses 1698 4174 +2476
🚀 New features to boost your workflow:
|
|
Good start, I think we should drop the random argument and a lot of complexity falls out of the way |
dae3580 to
68ae365
Compare
Supports both symbolic (dist=) and black-box (logp=) paths, enabling user-defined distributions with named dims. The symbolic path auto-derives logprob from inner XRV nodes; the black-box path creates a dynamic RandomVariable subclass and registers _logprob dispatches that reconstruct XTensorVariables for the value and dims-bearing params.
Covers both symbolic (dist=) and black-box (logp=/random=) paths: graph comparison against regular distributions, dim propagation, observed data, custom support points, and model variables as params.
…signature inference, fix compound dists - Replace compiled-function + graph-walking hybrid path with DimSymbolicRandomVariable(SymbolicRandomVariable) + OpFromGraph - Deduplicate _infer_dims_signature / _infer_final_signature - Add XElemwise support to expand_dist_dims for compound dists - Drop _forward_dim_lengths, enforce strict XTensorVariable output - Add tests: compound non-XRV output, hybrid support_point
15684f8 to
c570454
Compare
Adds
CustomDisttopymc.dims.distributions, a sibling topm.CustomDistthat operates onXTensorVariablewith named dims.Two construction paths:
Symbolic (
dist=kwarg): receives XTensorVariable params, returns an XTensorVariable RV (e.g., composingpmd.Normal.dist). Auto-derives logp from inner XRV nodes.Black-box (
logp=kwarg): dynamically creates aRandomVariablesubclass; dispatches_logprob,_logcdf,_support_point. Thevaluearrives asXTensorVariable; use.valuesforpt.*ops orptx.*for dim-aware ops.Key design points:
DimDistribution._as_xtensorpath aspmd.Normaletc. — identical behavior (scalars auto-convert, non-scalars require dims).logp,logcdf,support_point) captured in closures to avoid Python descriptor protocol issues.RandomVariablesubclass sets onlysignature(notndim_supp/ndims_params) to avoid deprecation warnings.