From 506e61b9995b799ec31c154d45f65e4e5ba40324 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 20 May 2026 07:42:32 -0400 Subject: [PATCH 1/7] Add dims-aware CustomDist for pymc.dims Supports both symbolic (dist=) and black-box (logp=) paths, enabling user-defined distributions with named dims. The symbolic path auto-derives logprob from inner XRV nodes; the black-box path creates a dynamic RandomVariable subclass and registers _logprob dispatches that reconstruct XTensorVariables for the value and dims-bearing params. --- pymc/dims/distributions/__init__.py | 1 + pymc/dims/distributions/custom.py | 397 ++++++++++++++++++++++++++++ 2 files changed, 398 insertions(+) create mode 100644 pymc/dims/distributions/custom.py diff --git a/pymc/dims/distributions/__init__.py b/pymc/dims/distributions/__init__.py index 6c49789089..fb87ee3a4c 100644 --- a/pymc/dims/distributions/__init__.py +++ b/pymc/dims/distributions/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from pymc.dims.distributions.censored import Censored +from pymc.dims.distributions.custom import CustomDist from pymc.dims.distributions.scalar import * from pymc.dims.distributions.vector import * diff --git a/pymc/dims/distributions/custom.py b/pymc/dims/distributions/custom.py new file mode 100644 index 0000000000..dffea79b9e --- /dev/null +++ b/pymc/dims/distributions/custom.py @@ -0,0 +1,397 @@ +# Copyright 2026 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools + +from collections.abc import Callable, Sequence + +import pytensor.tensor as pt + +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.utils import safe_signature +from pytensor.xtensor.basic import xtensor_from_tensor +from pytensor.xtensor.random.variable import shared_rng as xtensor_shared_rng +from pytensor.xtensor.type import XTensorVariable + +from pymc.dims.distributions.core import DimDistribution, expand_dist_dims +from pymc.distributions.distribution import _call_rv_op, _support_point +from pymc.distributions.shape_utils import rv_size_is_none +from pymc.logprob.abstract import _logcdf, _logprob +from pymc.model.core import new_or_existing_block_model_access + + +class _DimCustomDistRV(RandomVariable): + """Minimal RandomVariable base for the black-box path. + + Only ``signature`` is set on dynamic subclasses to avoid + ``FutureWarning`` from ``ndim_supp``/``ndims_params`` class attributes. + """ + + name = "DimCustomDistRV" + _print_name = ("DimCustomDist", "\\operatorname{DimCustomDist}") + + @classmethod + def rng_fn(cls, rng, *args): + args = list(args) + size = args.pop(-1) + return cls._random_fn(*args, rng=rng, size=size) + + +def _default_not_implemented(rv_name, method_name): + msg = ( + f"Attempted to run {method_name} on the CustomDist '{rv_name}', " + f"but this method had not been provided when the distribution was " + f"constructed. Please re-build your model and provide a callable " + f"to '{rv_name}'s {method_name} keyword argument.\n" + ) + + def func(*args, **kwargs): + raise NotImplementedError(msg) + + return func + + +def _default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False): + if None not in rv.type.shape: + return pt.zeros(rv.type.shape) + elif rv.owner.op.ndim_supp == 0 and not rv_size_is_none(size): + return pt.zeros(size) + elif has_fallback: + return pt.zeros_like(rv) + else: + raise TypeError( + "Cannot safely infer the size of a multivariate random variable's " + "support_point. Please provide a support_point function when " + f"instantiating the {rv_name} random variable." + ) + + +class CustomDist(DimDistribution): + """Dims-aware CustomDist for pymc.dims. + + Supports the same ``dist=`` (symbolic) and ``logp=`` (black-box) paths as + ``pm.CustomDist``, but operates on ``XTensorVariable`` with named dims. + + Symbolic path (``dist`` function receives XTensorVariable params, + use ``ptx.*`` ops for transforms):: + + import pytensor.xtensor.math as ptxm + + + def logitnormal_dist(mu, sigma): + return ptxm.sigmoid(pmd.Normal.dist(mu=mu, sigma=sigma)) + + + pmd.CustomDist("x", mu, sigma, dist=logitnormal_dist, dims="city") + + Black-box path (``logp`` function receives the ``value`` as an + ``XTensorVariable`` with dims; params are plain tensors unless they + have non-empty dims. Use ``value.values`` to access the underlying + tensor for ``pt.*`` operations, or use ``ptx.*`` for dim-aware ops):: + + import pytensor.tensor as pt + + + def normal_logp(value, mu, sigma): + v = value.values # strip dims for tensor ops + return pt.sum( + pt.pow(v - mu, 2) / (2 * sigma**2) + pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) + ) + + + pmd.CustomDist("y", mu, sigma, logp=normal_logp, observed=y, dims="city") + + For dim-aware operations via ``pytensor.xtensor.math``:: + + import pytensor.xtensor.math as ptx + + + def tweedie_logp(value, mu, phi, p): + ll_core = (value * mu ** (1 - p) / (1 - p) - mu ** (2 - p) / (2 - p)) / phi + j = pt.arange(1, 30, dtype="float64") + alpha = (2 - p) / (p - 1) + log_Wj = j[:, None] * alpha * ptx.log(value)[None, :] - ... + log_a = ptx.logsumexp(log_Wj, dim="policy") - ptx.log(value) + ... + + + pmd.CustomDist("y", mu, phi, p, logp=tweedie_logp, observed=y, dims="policy") + + **Why ``value`` arrives as XTensor but params may not be:** the + dispatch in ``_logprob`` reconstructs ``value`` from a lowered tensor + (``MeasurableXTensorFromTensor``) so the original dims are available; + each param is independently wrapped in XTensor only if the user + provided dims for it. Params without dims (scalars, observed data) + stay as plain tensors to avoid shape mismatches — PyTensor's + ``RandomVariable`` stores broadcast scalars with shape ``(1,)``, but + ``xtensor_from_tensor(x, dims=())`` rejects that. Use + ``value.values`` to access the raw tensor for ``pt.*`` operations, or + ``ptx.*`` for dim-aware code on any XTensorVariable. + """ + + @classmethod + def dist( + cls, + *dist_params, + dist: Callable | None = None, + random: Callable | None = None, + logp: Callable | None = None, + logcdf: Callable | None = None, + support_point: Callable | None = None, + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, + dtype: str = "floatX", + dim_lengths: dict | None = None, + core_dims: str | Sequence[str] | None = None, + **kwargs, + ): + kwargs.update( + dist=dist, + random=random, + logp=logp, + logcdf=logcdf, + support_point=support_point, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + signature=signature, + dtype=dtype, + ) + return super().dist( + list(dist_params), + dim_lengths=dim_lengths, + core_dims=core_dims, + **kwargs, + ) + + @classmethod + def _infer_output_dims(cls, params, extra_dims, core_dims): + """Infer output dims from params, extra_dims and core_dims.""" + param_dims = set() + for p in params: + try: + param_dims |= set(p.dims) + except AttributeError: + pass + if extra_dims: + batch_dims = tuple(d for d in extra_dims if d in param_dims) + tuple( + d for d in extra_dims if d not in param_dims + ) + else: + batch_dims = tuple(param_dims) + return batch_dims + (core_dims if core_dims else ()) + + @classmethod + def xrv_op( + cls, + *dist_params, + dist: Callable | None = None, + random: Callable | None = None, + logp: Callable | None = None, + logcdf: Callable | None = None, + support_point: Callable | None = None, + ndim_supp: int | None = None, + ndims_params: Sequence[int] | None = None, + signature: str | None = None, + dtype: str = "floatX", + class_name: str = "CustomDist", + core_dims: str | Sequence[str] | None = None, + extra_dims: dict[str, int] | None = None, + rng=None, + return_next_rng: bool = False, + **kwargs, + ): + if dist is not None: + return cls._symbolic_xrv_op( + list(dist_params), + dist=dist, + logp=logp, + logcdf=logcdf, + support_point=support_point, + class_name=class_name, + core_dims=core_dims, + extra_dims=extra_dims or {}, + rng=rng, + return_next_rng=return_next_rng, + ) + else: + return cls._blackbox_xrv_op( + list(dist_params), + logp=logp, + logcdf=logcdf, + support_point=support_point, + random=random, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + signature=signature, + dtype=dtype, + class_name=class_name, + core_dims=core_dims, + extra_dims=extra_dims or {}, + rng=rng, + return_next_rng=return_next_rng, + ) + + @classmethod + def _symbolic_xrv_op( + cls, + dist_params: list, + *, + dist: Callable, + logp: Callable | None, + logcdf: Callable | None, + support_point: Callable | None, + class_name: str, + core_dims: str | Sequence[str] | None, + extra_dims: dict[str, int], + rng, + return_next_rng: bool, + ): + xtensor_params = [cls._as_xtensor(p) for p in dist_params] + + with new_or_existing_block_model_access( + error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API" + ): + rv = dist(*xtensor_params) + + if isinstance(rv, XTensorVariable): + missing_extra_dims = {d: s for d, s in extra_dims.items() if d not in rv.dims} + if missing_extra_dims: + rv = expand_dist_dims(rv, missing_extra_dims) + else: + output_dims = cls._infer_output_dims(xtensor_params, extra_dims, core_dims) + rv = xtensor_from_tensor(rv, dims=output_dims) + + if return_next_rng: + return xtensor_shared_rng(seed=None), rv + return rv + + @classmethod + def _blackbox_xrv_op( + cls, + dist_params: list, + *, + logp: Callable | None, + logcdf: Callable | None, + support_point: Callable | None, + random: Callable | None, + ndim_supp: int | None, + ndims_params: Sequence[int] | None, + signature: str | None, + dtype: str, + class_name: str, + core_dims: str | Sequence[str] | None, + extra_dims: dict[str, int], + rng, + return_next_rng: bool, + ): + # Strip dims from XTensor params for the RandomVariable internals + # but store the original dims so _logprob can reconstruct XTensorVariables + tensor_params = [] + param_dims = [] + for p in dist_params: + try: + tensor_params.append(p.values) + param_dims.append(p.dims) + except AttributeError: + tensor_params.append(p) + param_dims.append(None) + + # Build signature if not provided + if signature is None: + if ndim_supp is None: + ndim_supp = 0 + if ndims_params is None: + ndims_params = [0] * len(tensor_params) + signature = safe_signature( + core_inputs_ndim=ndims_params, + core_outputs_ndim=[ndim_supp], + ) + + # Infer output dims for the XTensor wrapping + output_dims = cls._infer_output_dims(dist_params, extra_dims, core_dims) + + # Dynamically create a RandomVariable subclass with ONLY signature + # (no ndim_supp/ndims_params class attributes) to avoid deprecation warnings. + # Store dims info for _logprob/_logcdf/_support_point to reconstruct + # XTensorVariables from tensor params during logp computation. + # NOTE: user callables (logp, logcdf, support_point) are captured in + # closures below, NOT stored as class attributes, to avoid Python's + # descriptor protocol binding them to the op instance. + rv_type = type( + class_name, + (_DimCustomDistRV,), + { + "signature": signature, + "dtype": dtype, + "_print_name": (class_name, f"\\operatorname{{{class_name}}}"), + "_random_fn": random, + "_param_dims": tuple(param_dims), + "_output_dims": output_dims, + }, + ) + + # Dispatch logprob — reconstruct XTensor value and params + _logp_fn = logp if logp is not None else _default_not_implemented(class_name, "logp") + + @_logprob.register(rv_type) + def _custom_dist_logp(op, values, rng, size, *dist_params, **kwargs): + value_xt = xtensor_from_tensor(values[0], dims=op._output_dims) + xtensor_params = [ + xtensor_from_tensor(p, dims=dims) if dims else p + for p, dims in zip(dist_params, op._param_dims) + ] + result = _logp_fn(value_xt, *xtensor_params) + return result.values if isinstance(result, XTensorVariable) else result + + # Dispatch logcdf (only when user provided it) + if logcdf is not None: + + @_logcdf.register(rv_type) + def _custom_dist_logcdf(op, value, rng, size, *dist_params, **kwargs): + value_xt = xtensor_from_tensor(value, dims=op._output_dims) + xtensor_params = [ + xtensor_from_tensor(p, dims=dims) if dims else p + for p, dims in zip(dist_params, op._param_dims) + ] + result = logcdf(value_xt, *xtensor_params) + return result.values if isinstance(result, XTensorVariable) else result + + # Dispatch support_point + _support_point_fn = ( + support_point + if support_point is not None + else functools.partial( + _default_support_point, + rv_name=class_name, + has_fallback=random is not None, + ) + ) + + @_support_point.register(rv_type) + def _custom_dist_support_point(op, rv, rng, size, *dist_params): + return _support_point_fn(rv, size, *dist_params) + + # Convert extra_dims to size for RandomVariable + size = tuple(extra_dims.values()) if extra_dims else None + + # Create the RV — _call_rv_op handles rng default + return_next_rng + rv_op = rv_type() + _, tensor_rv = _call_rv_op(rv_op, *tensor_params, size=size, rng=rng) + + # Wrap as XTensorVariable with inferred dims + rv = xtensor_from_tensor(tensor_rv, dims=output_dims) + + if return_next_rng: + return rng, rv + return rv From 3d816e8003320d52e638869d1bc8f375a138b8a7 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 20 May 2026 07:42:48 -0400 Subject: [PATCH 2/7] Add tests for pmd.CustomDist Covers both symbolic (dist=) and black-box (logp=/random=) paths: graph comparison against regular distributions, dim propagation, observed data, custom support points, and model variables as params. --- tests/dims/distributions/test_custom.py | 221 ++++++++++++++++++++++++ 1 file changed, 221 insertions(+) create mode 100644 tests/dims/distributions/test_custom.py diff --git a/tests/dims/distributions/test_custom.py b/tests/dims/distributions/test_custom.py new file mode 100644 index 0000000000..6a0928da4c --- /dev/null +++ b/tests/dims/distributions/test_custom.py @@ -0,0 +1,221 @@ +# Copyright 2026 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor.tensor as pt +import pytest + +from pytensor.xtensor import as_xtensor + +import pymc.distributions as regular_distributions + +from pymc.dims import CustomDist, Normal +from pymc.model.core import Model +from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph + +pytestmark = pytest.mark.filterwarnings( + "error", + r"ignore:^Numba will use object mode to run.*perform method\.:UserWarning", +) + + +class TestCustomDistSymbolic: + """Tests for the symbolic (dist=) path of pmd.CustomDist.""" + + def test_basic(self): + """Symbolic path: dist function wrapping Normal.dist, compared against regular Normal.""" + + def normal_dist(mu, sigma): + return Normal.dist(mu, sigma) + + coords = {"city": range(5)} + with Model(coords=coords) as model: + CustomDist("x", 0, 1, dist=normal_dist, dims="city") + + with Model(coords=coords) as reference_model: + regular_distributions.Normal("x", 0, 1, dims="city") + + assert_equivalent_random_graph(model, reference_model) + assert_equivalent_logp_graph(model, reference_model) + + def test_param_dims_propagate(self): + """Params with dims propagate to the output.""" + + def normal_dist(mu, sigma): + return Normal.dist(mu, sigma) + + coords = {"city": range(5)} + mu = as_xtensor(np.array([0, 1, 2, 3, 4]), dims=("city",)) + sigma = as_xtensor(np.array([1, 2, 3, 4, 5]), dims=("city",)) + + with Model(coords=coords) as model: + x = CustomDist("x", mu, sigma, dist=normal_dist) + + assert set(x.dims) == {"city"} + assert x.type.shape == (5,) + + +class TestCustomDistBlackbox: + """Tests for the black-box (logp=/random=) path of pmd.CustomDist.""" + + def test_logp_basic(self): + """Black-box path with logp function and dims on output.""" + + def normal_logp(value, mu, sigma): + v = value.values + return pt.sum( + -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) + ) + + coords = {"city": range(5)} + rng = np.random.default_rng(42) + observed = as_xtensor(rng.normal(0, 1, size=5), dims=("city",)) + + with Model(coords=coords) as model: + CustomDist( + "x", + 0, + 1, + logp=normal_logp, + observed=observed, + dims="city", + ) + + # Test that logp evaluates without error and returns finite values + ip = model.initial_point() + logp_val = model.compile_logp()(ip) + assert np.isfinite(logp_val) + + def test_random_logp(self): + """Black-box path with both random and logp.""" + + def normal_logp(value, mu, sigma): + v = value.values + return pt.sum( + -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) + ) + + def normal_random(mu, sigma, rng=None, size=None): + return rng.normal(loc=mu, scale=sigma, size=size) + + coords = {"city": range(5)} + with Model(coords=coords) as model: + CustomDist( + "x", + 0, + 1, + logp=normal_logp, + random=normal_random, + dims="city", + ) + + # Verify shape via draw + from pymc import draw as pm_draw + + draws = pm_draw(model["x"], draws=3) + assert draws.shape == (3, 5) + + # Verify logp + ip = model.initial_point() + logp_val = model.compile_logp()(ip) + assert np.isfinite(logp_val) + + def test_logcdf(self): + """Black-box path with logcdf function.""" + + def normal_logp(value, mu, sigma): + v = value.values + return pt.sum( + -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) + ) + + def normal_logcdf(value, mu, sigma): + v = value.values + return pt.sum( + pt.log(pt.erf((v - mu) / (sigma * pt.sqrt(2.0))) + 1.0) - pt.log(pt.constant(2.0)) + ) + + coords = {"city": range(5)} + rng = np.random.default_rng(42) + observed = as_xtensor(rng.normal(0, 1, size=5), dims=("city",)) + + with Model(coords=coords) as model: + CustomDist( + "x", + 0, + 1, + logp=normal_logp, + logcdf=normal_logcdf, + observed=observed, + dims="city", + ) + + ip = model.initial_point() + logp_val = model.compile_logp()(ip) + assert np.isfinite(logp_val) + + def test_mu_as_model_var(self): + """Black-box path with mu as a model variable (no dims on mu).""" + + def normal_logp(value, mu, sigma): + v = value.values + return pt.sum( + -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) + ) + + coords = {"city": range(5)} + rng = np.random.default_rng(42) + observed = as_xtensor(rng.normal(0, 1, size=5), dims=("city",)) + + with Model(coords=coords) as model: + mu = Normal("mu", 0, 1) + CustomDist( + "x", + mu, + 1, + logp=normal_logp, + observed=observed, + dims="city", + ) + + ip = model.initial_point() + logp_val = model.compile_logp()(ip) + assert np.isfinite(logp_val) + + def test_support_point(self): + """Black-box path with custom support_point.""" + + def normal_logp(value, mu, sigma): + v = value.values + return pt.sum( + -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) + ) + + def custom_support_point(rv, size, mu, sigma): + return pt.full_like(rv, mu) + + coords = {"city": range(5)} + with Model(coords=coords) as model: + CustomDist( + "x", + 0, + 1, + logp=normal_logp, + support_point=custom_support_point, + dims="city", + ) + + from pymc.distributions.distribution import support_point + + sp = support_point(model["x"]) + np.testing.assert_allclose(sp.eval(), np.zeros(5)) From f850ccd1ed9ae598730edc1f22bd1d58e2202f9b Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 20 May 2026 07:54:53 -0400 Subject: [PATCH 3/7] dont wrap np.pi --- pymc/dims/distributions/custom.py | 4 +--- tests/dims/distributions/test_custom.py | 20 +++++--------------- 2 files changed, 6 insertions(+), 18 deletions(-) diff --git a/pymc/dims/distributions/custom.py b/pymc/dims/distributions/custom.py index dffea79b9e..e3dff11435 100644 --- a/pymc/dims/distributions/custom.py +++ b/pymc/dims/distributions/custom.py @@ -104,9 +104,7 @@ def logitnormal_dist(mu, sigma): def normal_logp(value, mu, sigma): v = value.values # strip dims for tensor ops - return pt.sum( - pt.pow(v - mu, 2) / (2 * sigma**2) + pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) - ) + return pt.sum(pt.pow(v - mu, 2) / (2 * sigma**2) + pt.log(sigma * pt.sqrt(2 * np.pi))) pmd.CustomDist("y", mu, sigma, logp=normal_logp, observed=y, dims="city") diff --git a/tests/dims/distributions/test_custom.py b/tests/dims/distributions/test_custom.py index 6a0928da4c..d0f21b2d10 100644 --- a/tests/dims/distributions/test_custom.py +++ b/tests/dims/distributions/test_custom.py @@ -73,9 +73,7 @@ def test_logp_basic(self): def normal_logp(value, mu, sigma): v = value.values - return pt.sum( - -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) - ) + return pt.sum(-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi))) coords = {"city": range(5)} rng = np.random.default_rng(42) @@ -101,9 +99,7 @@ def test_random_logp(self): def normal_logp(value, mu, sigma): v = value.values - return pt.sum( - -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) - ) + return pt.sum(-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi))) def normal_random(mu, sigma, rng=None, size=None): return rng.normal(loc=mu, scale=sigma, size=size) @@ -135,9 +131,7 @@ def test_logcdf(self): def normal_logp(value, mu, sigma): v = value.values - return pt.sum( - -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) - ) + return pt.sum(-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi))) def normal_logcdf(value, mu, sigma): v = value.values @@ -169,9 +163,7 @@ def test_mu_as_model_var(self): def normal_logp(value, mu, sigma): v = value.values - return pt.sum( - -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) - ) + return pt.sum(-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi))) coords = {"city": range(5)} rng = np.random.default_rng(42) @@ -197,9 +189,7 @@ def test_support_point(self): def normal_logp(value, mu, sigma): v = value.values - return pt.sum( - -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * pt.constant(np.pi))) - ) + return pt.sum(-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi))) def custom_support_point(rv, size, mu, sigma): return pt.full_like(rv, mu) From a6a0fc8d46e5e3361a025fc332febbc0095c4065 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 20 May 2026 08:22:15 -0400 Subject: [PATCH 4/7] Add test_custom.py to CI test matrix --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1d9278dd34..6ada4ec369 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -142,6 +142,7 @@ jobs: - | tests/dims/distributions/test_core.py tests/dims/distributions/test_censored.py + tests/dims/distributions/test_custom.py tests/dims/distributions/test_scalar.py tests/dims/distributions/test_vector.py tests/dims/test_model.py From 8cf047068ef19148001b9f2f1e28d19a84a552bd Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 20 May 2026 09:25:37 -0400 Subject: [PATCH 5/7] feat: add dist+logp hybrid path to pmd.CustomDist, remove random arg --- pymc/dims/distributions/core.py | 3 + pymc/dims/distributions/custom.py | 252 +++++++++++++++++------- tests/dims/distributions/test_custom.py | 128 +++++++++++- 3 files changed, 299 insertions(+), 84 deletions(-) diff --git a/pymc/dims/distributions/core.py b/pymc/dims/distributions/core.py index 42a642b826..fa614dbbbf 100644 --- a/pymc/dims/distributions/core.py +++ b/pymc/dims/distributions/core.py @@ -191,6 +191,7 @@ class DimDistribution: xrv_op: Callable default_transform: DimTransform | None = None + _forward_dim_lengths: bool = False @staticmethod def _as_xtensor(x): @@ -325,6 +326,8 @@ def dist( } if kwargs.get("rng") is None: kwargs["rng"] = pt.random.shared_rng(seed=None) + if cls._forward_dim_lengths and dim_lengths is not None: + kwargs["dim_lengths"] = dim_lengths _, rv = cls.xrv_op( *dist_params, extra_dims=extra_dims, diff --git a/pymc/dims/distributions/custom.py b/pymc/dims/distributions/custom.py index e3dff11435..6700b6f0be 100644 --- a/pymc/dims/distributions/custom.py +++ b/pymc/dims/distributions/custom.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools - from collections.abc import Callable, Sequence +import numpy as np +import pytensor import pytensor.tensor as pt from pytensor.tensor.random.op import RandomVariable @@ -25,7 +25,6 @@ from pymc.dims.distributions.core import DimDistribution, expand_dist_dims from pymc.distributions.distribution import _call_rv_op, _support_point -from pymc.distributions.shape_utils import rv_size_is_none from pymc.logprob.abstract import _logcdf, _logprob from pymc.model.core import new_or_existing_block_model_access @@ -61,29 +60,34 @@ def func(*args, **kwargs): return func -def _default_support_point(rv, size, *rv_inputs, rv_name=None, has_fallback=False): - if None not in rv.type.shape: - return pt.zeros(rv.type.shape) - elif rv.owner.op.ndim_supp == 0 and not rv_size_is_none(size): - return pt.zeros(size) - elif has_fallback: - return pt.zeros_like(rv) - else: - raise TypeError( - "Cannot safely infer the size of a multivariate random variable's " - "support_point. Please provide a support_point function when " - f"instantiating the {rv_name} random variable." - ) +def _default_support_point(rv, size=None, *rv_inputs): + return pt.zeros_like(rv) + + +def _prep_logp_params(dist_params, param_dims, size): + """Prepare params for logp dispatch. + + Params with non-empty dims are wrapped as XTensorVariable for dim-aware ops. + Params without dims (None) or with empty dims (scalar) stay as plain tensors. + """ + del size # Unused; params may have extra batch dims from explicit_expand_dims, + # but broadcasting handles that at the tensor level. + result = [] + for p, dims in zip(dist_params, param_dims): + if dims: + result.append(xtensor_from_tensor(p, dims=dims)) + else: + result.append(p) + return result class CustomDist(DimDistribution): """Dims-aware CustomDist for pymc.dims. - Supports the same ``dist=`` (symbolic) and ``logp=`` (black-box) paths as - ``pm.CustomDist``, but operates on ``XTensorVariable`` with named dims. + Provides ``dist`` (symbolic) and/or ``logp`` construction paths, + operating on ``XTensorVariable`` with named dims. - Symbolic path (``dist`` function receives XTensorVariable params, - use ``ptx.*`` ops for transforms):: + Symbolic path (``dist`` function receives XTensorVariable params):: import pytensor.xtensor.math as ptxm @@ -94,55 +98,49 @@ def logitnormal_dist(mu, sigma): pmd.CustomDist("x", mu, sigma, dist=logitnormal_dist, dims="city") - Black-box path (``logp`` function receives the ``value`` as an - ``XTensorVariable`` with dims; params are plain tensors unless they - have non-empty dims. Use ``value.values`` to access the underlying - tensor for ``pt.*`` operations, or use ``ptx.*`` for dim-aware ops):: + When ``dist`` is provided without ``logp``, PyMC auto-derives the logp + from the inner graph. When ``logp`` is also given, it overrides the + auto-derived logp while ``dist`` still drives the random path. - import pytensor.tensor as pt + Logp path (``logp`` function receives the ``value`` and all params as + ``XTensorVariable`` — use ``ptx.*`` for dim-aware operations):: + + import pytensor.xtensor.math as ptx def normal_logp(value, mu, sigma): - v = value.values # strip dims for tensor ops - return pt.sum(pt.pow(v - mu, 2) / (2 * sigma**2) + pt.log(sigma * pt.sqrt(2 * np.pi))) + return ptx.sum( + -0.5 * ((value - mu) / sigma) ** 2 - ptx.log(sigma * ptx.sqrt(2 * np.pi)) + ) pmd.CustomDist("y", mu, sigma, logp=normal_logp, observed=y, dims="city") - For dim-aware operations via ``pytensor.xtensor.math``:: + When ``logp`` is provided without ``dist``, prior/posterior predictive + sampling is not available — only MCMC (logp evaluation). For both, + provide ``dist`` + ``logp`` together. - import pytensor.xtensor.math as ptx + For tensor-level operations use ``value.values`` to access the + underlying tensor:: + import pytensor.tensor as pt - def tweedie_logp(value, mu, phi, p): - ll_core = (value * mu ** (1 - p) / (1 - p) - mu ** (2 - p) / (2 - p)) / phi - j = pt.arange(1, 30, dtype="float64") - alpha = (2 - p) / (p - 1) - log_Wj = j[:, None] * alpha * ptx.log(value)[None, :] - ... - log_a = ptx.logsumexp(log_Wj, dim="policy") - ptx.log(value) - ... + def normal_logp(value, mu, sigma): + v = value.values + return pt.sum(-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi))) - pmd.CustomDist("y", mu, phi, p, logp=tweedie_logp, observed=y, dims="policy") - **Why ``value`` arrives as XTensor but params may not be:** the - dispatch in ``_logprob`` reconstructs ``value`` from a lowered tensor - (``MeasurableXTensorFromTensor``) so the original dims are available; - each param is independently wrapped in XTensor only if the user - provided dims for it. Params without dims (scalars, observed data) - stay as plain tensors to avoid shape mismatches — PyTensor's - ``RandomVariable`` stores broadcast scalars with shape ``(1,)``, but - ``xtensor_from_tensor(x, dims=())`` rejects that. Use - ``value.values`` to access the raw tensor for ``pt.*`` operations, or - ``ptx.*`` for dim-aware code on any XTensorVariable. + pmd.CustomDist("y", mu, sigma, logp=normal_logp, observed=y, dims="city") """ + _forward_dim_lengths = True + @classmethod def dist( cls, *dist_params, dist: Callable | None = None, - random: Callable | None = None, logp: Callable | None = None, logcdf: Callable | None = None, support_point: Callable | None = None, @@ -156,7 +154,6 @@ def dist( ): kwargs.update( dist=dist, - random=random, logp=logp, logcdf=logcdf, support_point=support_point, @@ -173,7 +170,7 @@ def dist( ) @classmethod - def _infer_output_dims(cls, params, extra_dims, core_dims): + def _infer_output_dims(cls, params, extra_dims, core_dims, dim_lengths=None): """Infer output dims from params, extra_dims and core_dims.""" param_dims = set() for p in params: @@ -181,7 +178,9 @@ def _infer_output_dims(cls, params, extra_dims, core_dims): param_dims |= set(p.dims) except AttributeError: pass - if extra_dims: + if dim_lengths: + batch_dims = tuple(dim_lengths.keys()) + elif extra_dims: batch_dims = tuple(d for d in extra_dims if d in param_dims) + tuple( d for d in extra_dims if d not in param_dims ) @@ -194,7 +193,6 @@ def xrv_op( cls, *dist_params, dist: Callable | None = None, - random: Callable | None = None, logp: Callable | None = None, logcdf: Callable | None = None, support_point: Callable | None = None, @@ -209,6 +207,7 @@ def xrv_op( return_next_rng: bool = False, **kwargs, ): + dim_lengths = kwargs.pop("dim_lengths", None) if dist is not None: return cls._symbolic_xrv_op( list(dist_params), @@ -219,6 +218,7 @@ def xrv_op( class_name=class_name, core_dims=core_dims, extra_dims=extra_dims or {}, + dim_lengths=dim_lengths, rng=rng, return_next_rng=return_next_rng, ) @@ -228,7 +228,6 @@ def xrv_op( logp=logp, logcdf=logcdf, support_point=support_point, - random=random, ndim_supp=ndim_supp, ndims_params=ndims_params, signature=signature, @@ -236,6 +235,7 @@ def xrv_op( class_name=class_name, core_dims=core_dims, extra_dims=extra_dims or {}, + dim_lengths=dim_lengths, rng=rng, return_next_rng=return_next_rng, ) @@ -252,6 +252,7 @@ def _symbolic_xrv_op( class_name: str, core_dims: str | Sequence[str] | None, extra_dims: dict[str, int], + dim_lengths: dict | None, rng, return_next_rng: bool, ): @@ -267,11 +268,126 @@ def _symbolic_xrv_op( if missing_extra_dims: rv = expand_dist_dims(rv, missing_extra_dims) else: - output_dims = cls._infer_output_dims(xtensor_params, extra_dims, core_dims) + output_dims = cls._infer_output_dims(xtensor_params, extra_dims, core_dims, dim_lengths) rv = xtensor_from_tensor(rv, dims=output_dims) + # If no user-provided functions to override, return the symbolic RV as-is + if logp is None and logcdf is None and support_point is None: + if return_next_rng: + return xtensor_shared_rng(seed=None), rv + return rv + + # Hybrid: use dist for sampling but user functions for logp/logcdf/support_point + # Walk the owner chain to find the underlying RandomVariable op for its rng_fn + tensor_rv = rv.values if isinstance(rv, XTensorVariable) else rv + current = tensor_rv + while current.owner is not None: + op = current.owner.op + if hasattr(op, "rng_fn"): + orig_op = op + break + if hasattr(op, "core_op") and hasattr(op.core_op, "rng_fn"): + orig_op = op.core_op + break + current = current.owner.inputs[0] + else: + raise ValueError( + "Could not find a RandomVariable op in the dist graph. " + "The dist function must return a distribution with a sampler." + ) + + # Compile the full dist output graph for sampling. + # This is essential for compound processes (Poisson→Gamma→...) — the + # graph contains multiple RandomVariables whose chained evaluation + # produces correct draws from the compound, bypassing a single op's rng_fn. + from pytensor.graph.basic import Constant + from pytensor.graph.traversal import graph_inputs + + # Only pass variables that are actual graph inputs (skip constants/unused params) + graph_deps = set(graph_inputs([rv.values])) + _input_indices = [ + i + for i, p in enumerate(xtensor_params) + if p in graph_deps and not isinstance(p, Constant) + ] + sample_inputs = [xtensor_params[i] for i in _input_indices] + _sample_fn = pytensor.function( + inputs=sample_inputs, + outputs=rv.values, + ) + + def random_fn(*args, rng=None, size=None): + fn_args = [args[i] for i in _input_indices] + result = _sample_fn(*fn_args) + if size is not None and result.shape != tuple(size): + result = np.broadcast_to(result, tuple(size)).copy() + return result + + ndim_supp = getattr(orig_op, "ndim_supp", 0) + ndims_params = [0] * len(xtensor_params) + hybrid_signature = safe_signature( + core_inputs_ndim=ndims_params, + core_outputs_ndim=[ndim_supp], + ) + + output_dims = rv.type.dims if isinstance(rv, XTensorVariable) else () + param_dims = [] + for p in xtensor_params: + try: + param_dims.append(p.dims) + except AttributeError: + param_dims.append(None) + + rv_type = type( + class_name, + (_DimCustomDistRV,), + { + "signature": hybrid_signature, + "dtype": str(tensor_rv.dtype), + "_print_name": (class_name, f"\\operatorname{{{class_name}}}"), + "_random_fn": random_fn, + "_param_dims": tuple(param_dims), + "_output_dims": output_dims, + }, + ) + + if logp is not None: + + @_logprob.register(rv_type) + def _custom_dist_logp(op, values, rng, size, *dist_params, **kwargs): + value_xt = xtensor_from_tensor(values[0], dims=op._output_dims) + xtensor_params = _prep_logp_params(dist_params, op._param_dims, size) + result = logp(value_xt, *xtensor_params) + return result.values if isinstance(result, XTensorVariable) else result + + if logcdf is not None: + + @_logcdf.register(rv_type) + def _custom_dist_logcdf(op, value, rng, size, *dist_params, **kwargs): + value_xt = xtensor_from_tensor(value, dims=op._output_dims) + xtensor_params = _prep_logp_params(dist_params, op._param_dims, size) + result = logcdf(value_xt, *xtensor_params) + return result.values if isinstance(result, XTensorVariable) else result + + _support_point_fn = support_point if support_point is not None else _default_support_point + + @_support_point.register(rv_type) + def _custom_dist_support_point(op, rv, rng, size, *dist_params): + return _support_point_fn(rv, size, *dist_params) + + size = None + if output_dims and dim_lengths: + size = tuple(dim_lengths[d] for d in output_dims if d in dim_lengths) + if not size: + size = tuple(extra_dims.values()) if extra_dims else None + rv_op = rv_type() + _, new_tensor_rv = _call_rv_op( + rv_op, *[p.values for p in xtensor_params], size=size, rng=rng + ) + rv = xtensor_from_tensor(new_tensor_rv, dims=output_dims) + if return_next_rng: - return xtensor_shared_rng(seed=None), rv + return rng, rv return rv @classmethod @@ -282,7 +398,6 @@ def _blackbox_xrv_op( logp: Callable | None, logcdf: Callable | None, support_point: Callable | None, - random: Callable | None, ndim_supp: int | None, ndims_params: Sequence[int] | None, signature: str | None, @@ -290,6 +405,7 @@ def _blackbox_xrv_op( class_name: str, core_dims: str | Sequence[str] | None, extra_dims: dict[str, int], + dim_lengths: dict | None, rng, return_next_rng: bool, ): @@ -317,7 +433,7 @@ def _blackbox_xrv_op( ) # Infer output dims for the XTensor wrapping - output_dims = cls._infer_output_dims(dist_params, extra_dims, core_dims) + output_dims = cls._infer_output_dims(dist_params, extra_dims, core_dims, dim_lengths) # Dynamically create a RandomVariable subclass with ONLY signature # (no ndim_supp/ndims_params class attributes) to avoid deprecation warnings. @@ -333,7 +449,7 @@ def _blackbox_xrv_op( "signature": signature, "dtype": dtype, "_print_name": (class_name, f"\\operatorname{{{class_name}}}"), - "_random_fn": random, + "_random_fn": _default_not_implemented(class_name, "random"), "_param_dims": tuple(param_dims), "_output_dims": output_dims, }, @@ -345,10 +461,7 @@ def _blackbox_xrv_op( @_logprob.register(rv_type) def _custom_dist_logp(op, values, rng, size, *dist_params, **kwargs): value_xt = xtensor_from_tensor(values[0], dims=op._output_dims) - xtensor_params = [ - xtensor_from_tensor(p, dims=dims) if dims else p - for p, dims in zip(dist_params, op._param_dims) - ] + xtensor_params = _prep_logp_params(dist_params, op._param_dims, size) result = _logp_fn(value_xt, *xtensor_params) return result.values if isinstance(result, XTensorVariable) else result @@ -358,23 +471,12 @@ def _custom_dist_logp(op, values, rng, size, *dist_params, **kwargs): @_logcdf.register(rv_type) def _custom_dist_logcdf(op, value, rng, size, *dist_params, **kwargs): value_xt = xtensor_from_tensor(value, dims=op._output_dims) - xtensor_params = [ - xtensor_from_tensor(p, dims=dims) if dims else p - for p, dims in zip(dist_params, op._param_dims) - ] + xtensor_params = _prep_logp_params(dist_params, op._param_dims, size) result = logcdf(value_xt, *xtensor_params) return result.values if isinstance(result, XTensorVariable) else result # Dispatch support_point - _support_point_fn = ( - support_point - if support_point is not None - else functools.partial( - _default_support_point, - rv_name=class_name, - has_fallback=random is not None, - ) - ) + _support_point_fn = support_point if support_point is not None else _default_support_point @_support_point.register(rv_type) def _custom_dist_support_point(op, rv, rng, size, *dist_params): diff --git a/tests/dims/distributions/test_custom.py b/tests/dims/distributions/test_custom.py index d0f21b2d10..8a75efff7f 100644 --- a/tests/dims/distributions/test_custom.py +++ b/tests/dims/distributions/test_custom.py @@ -19,7 +19,7 @@ import pymc.distributions as regular_distributions -from pymc.dims import CustomDist, Normal +from pymc.dims import CustomDist, Normal, Poisson from pymc.model.core import Model from tests.dims.utils import assert_equivalent_logp_graph, assert_equivalent_random_graph @@ -94,38 +94,148 @@ def normal_logp(value, mu, sigma): logp_val = model.compile_logp()(ip) assert np.isfinite(logp_val) - def test_random_logp(self): - """Black-box path with both random and logp.""" + def test_hybrid_dist_logp(self): + """Hybrid path: dist for sampling + logp override.""" + + def normal_dist(mu, sigma): + return Normal.dist(mu, sigma) def normal_logp(value, mu, sigma): v = value.values return pt.sum(-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi))) - def normal_random(mu, sigma, rng=None, size=None): - return rng.normal(loc=mu, scale=sigma, size=size) - coords = {"city": range(5)} with Model(coords=coords) as model: CustomDist( "x", 0, 1, + dist=normal_dist, logp=normal_logp, - random=normal_random, dims="city", ) - # Verify shape via draw + # Verify sampling works (via draw) + from pymc import draw as pm_draw + + draws = pm_draw(model["x"], draws=3) + assert draws.shape == (3, 5) + + # Verify logp evaluates + ip = model.initial_point() + logp_val = model.compile_logp()(ip) + assert np.isfinite(logp_val) + + def test_hybrid_derived_params(self): + """Hybrid path: dist uses more params than the internal RV expects. + + The dist function derives a single param from multiple inputs. + This was broken before because ``_symbolic_xrv_op`` passed all + ``xtensor_params`` to ``_call_rv_op``, whose signature used the + discovered op's ``ndims_params`` — causing a ``zip()`` mismatch. + """ + + def poisson_dist(a, b, c): + lam = a + b + c + return Poisson.dist(mu=lam) + + def poisson_logp(value, a, b, c): + v = value.values + lam = a + b + c + return pt.sum(v * pt.log(lam) - lam - pt.gammaln(v + 1)) + + coords = {"city": range(5)} + with Model(coords=coords) as model: + CustomDist( + "x", + 0.5, + 0.3, + 0.2, + dist=poisson_dist, + logp=poisson_logp, + dims="city", + ) + + # pm.draw — evaluates the full dist graph (compound inference) from pymc import draw as pm_draw draws = pm_draw(model["x"], draws=3) assert draws.shape == (3, 5) + assert draws.dtype.kind == "i" + + # logp evaluates + ip = model.initial_point() + logp_val = model.compile_logp()(ip) + assert np.isfinite(logp_val) + + def test_hybrid_logp_override(self): + """Hybrid path: verify user logp overrides auto-derived logp.""" + + def normal_dist(mu, sigma): + return Normal.dist(mu, sigma) + + def scaled_logp(value, mu, sigma): + """Custom logp that multiplies normal logp by 2.""" + v = value.values + normal_logp = -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi)) + return 2.0 * pt.sum(normal_logp) + + coords = {"city": range(5)} + with Model(coords=coords) as model_hybrid: + CustomDist( + "x", + 0, + 1, + dist=normal_dist, + logp=scaled_logp, + dims="city", + ) + + # Auto-derived logp (no logp override) + with Model(coords=coords) as model_auto: + CustomDist( + "x", + 0, + 1, + dist=normal_dist, + dims="city", + ) + + ip = model_hybrid.initial_point() + hybrid_logp = model_hybrid.compile_logp()(ip) + auto_logp = model_auto.compile_logp()(ip) + # Hybrid logp should be 2x auto logp + np.testing.assert_allclose(hybrid_logp, 2.0 * auto_logp) - # Verify logp + def test_hybrid_basic_dims(self): + """Hybrid path with dims on params.""" + + def normal_dist(mu, sigma): + return Normal.dist(mu, sigma) + + def normal_logp(value, mu, sigma): + v = value.values + m = mu.values if hasattr(mu, "values") else mu + s = sigma.values if hasattr(sigma, "values") else sigma + return pt.sum(-0.5 * ((v - m) / s) ** 2 - pt.log(s * pt.sqrt(2 * np.pi))) + + coords = {"city": range(5)} + mu = as_xtensor(np.array([0.0, 0.5, 1.0, 1.5, 2.0]), dims=("city",)) + sigma = as_xtensor(np.array([1.0, 1.1, 1.2, 1.3, 1.4]), dims=("city",)) + + with Model(coords=coords) as model: + x = CustomDist("x", mu, sigma, dist=normal_dist, logp=normal_logp) + + assert set(x.dims) == {"city"} ip = model.initial_point() logp_val = model.compile_logp()(ip) assert np.isfinite(logp_val) + from pymc import draw as pm_draw + + draws = pm_draw(model["x"], draws=3) + assert draws.shape == (3, 5) + def test_logcdf(self): """Black-box path with logcdf function.""" From 79dcc0c87948bc1816d915f0994bb01913d39f7a Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 20 May 2026 11:45:35 -0400 Subject: [PATCH 6/7] docs: trim verbose test docstring --- tests/dims/distributions/test_custom.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/dims/distributions/test_custom.py b/tests/dims/distributions/test_custom.py index 8a75efff7f..6c3aefb29d 100644 --- a/tests/dims/distributions/test_custom.py +++ b/tests/dims/distributions/test_custom.py @@ -127,13 +127,7 @@ def normal_logp(value, mu, sigma): assert np.isfinite(logp_val) def test_hybrid_derived_params(self): - """Hybrid path: dist uses more params than the internal RV expects. - - The dist function derives a single param from multiple inputs. - This was broken before because ``_symbolic_xrv_op`` passed all - ``xtensor_params`` to ``_call_rv_op``, whose signature used the - discovered op's ``ndims_params`` — causing a ``zip()`` mismatch. - """ + """Hybrid path: dist derives params.""" def poisson_dist(a, b, c): lam = a + b + c From c57045495f8471fd2c458d8847fc27e363a88fc8 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Mon, 25 May 2026 15:51:19 -0400 Subject: [PATCH 7/7] refactor: use DimSymbolicRandomVariable for hybrid path, deduplicate signature inference, fix compound dists - Replace compiled-function + graph-walking hybrid path with DimSymbolicRandomVariable(SymbolicRandomVariable) + OpFromGraph - Deduplicate _infer_dims_signature / _infer_final_signature - Add XElemwise support to expand_dist_dims for compound dists - Drop _forward_dim_lengths, enforce strict XTensorVariable output - Add tests: compound non-XRV output, hybrid support_point --- pymc/dims/distributions/core.py | 13 +- pymc/dims/distributions/custom.py | 261 +++++++++++++----------- pymc/distributions/custom.py | 66 +++--- tests/dims/distributions/test_custom.py | 68 +++++- 4 files changed, 246 insertions(+), 162 deletions(-) diff --git a/pymc/dims/distributions/core.py b/pymc/dims/distributions/core.py index fa614dbbbf..76b8bce3b7 100644 --- a/pymc/dims/distributions/core.py +++ b/pymc/dims/distributions/core.py @@ -28,7 +28,7 @@ from pytensor.xtensor.basic import XTensorFromTensor, xtensor_from_tensor from pytensor.xtensor.shape import Transpose from pytensor.xtensor.type import XTensorVariable -from pytensor.xtensor.vectorization import XRV +from pytensor.xtensor.vectorization import XRV, XElemwise from pymc import SymbolicRandomVariable, modelcontext from pymc.dims.distributions.transforms import DimTransform, log_odds_transform, log_transform @@ -191,7 +191,6 @@ class DimDistribution: xrv_op: Callable default_transform: DimTransform | None = None - _forward_dim_lengths: bool = False @staticmethod def _as_xtensor(x): @@ -326,8 +325,6 @@ def dist( } if kwargs.get("rng") is None: kwargs["rng"] = pt.random.shared_rng(seed=None) - if cls._forward_dim_lengths and dim_lengths is not None: - kwargs["dim_lengths"] = dim_lengths _, rv = cls.xrv_op( *dist_params, extra_dims=extra_dims, @@ -377,6 +374,14 @@ def expand_dist_dims(dist: XTensorVariable, extra_dims: dict[str, Any]) -> XTens # We don't propagate the old RNG, because we don't want the new and old dists to be correlated new_rng = pt.random.shared_rng(seed=None) return new_dist_op(new_rng, *extra_dims.values(), *params_and_dim_lengths) + case XElemwise(): + expanded_inputs = [ + expand_dist_dims(inp, extra_dims=extra_dims) + if isinstance(inp, XTensorVariable) + else inp + for inp in dist.owner.inputs + ] + return dist.owner.op.make_node(*expanded_inputs).outputs[0] case Transpose(): return expand_dist_dims(dist.owner.inputs[0], extra_dims=extra_dims).transpose( ..., *dist.dims diff --git a/pymc/dims/distributions/custom.py b/pymc/dims/distributions/custom.py index 6700b6f0be..97cde88917 100644 --- a/pymc/dims/distributions/custom.py +++ b/pymc/dims/distributions/custom.py @@ -13,8 +13,6 @@ # limitations under the License. from collections.abc import Callable, Sequence -import numpy as np -import pytensor import pytensor.tensor as pt from pytensor.tensor.random.op import RandomVariable @@ -23,14 +21,29 @@ from pytensor.xtensor.random.variable import shared_rng as xtensor_shared_rng from pytensor.xtensor.type import XTensorVariable +from pymc import SymbolicRandomVariable from pymc.dims.distributions.core import DimDistribution, expand_dist_dims +from pymc.distributions.custom import _infer_final_signature from pymc.distributions.distribution import _call_rv_op, _support_point from pymc.logprob.abstract import _logcdf, _logprob from pymc.model.core import new_or_existing_block_model_access +from pymc.pytensorf import collect_default_updates + + +class DimSymbolicRandomVariable(SymbolicRandomVariable): + """XTensor-aware SymbolicRandomVariable for dims-supporting CustomDist. + + Stores output and param dims so that ``_logprob`` dispatch can + reconstruct ``XTensorVariable`` inputs from the tensor-level graph. + """ + + default_output = 0 + _output_dims: tuple[str, ...] = () + _param_dims: tuple[tuple[str, ...] | None, ...] = () class _DimCustomDistRV(RandomVariable): - """Minimal RandomVariable base for the black-box path. + """Minimal RandomVariable base for the arbitrarily-defined path. Only ``signature`` is set on dynamic subclasses to avoid ``FutureWarning`` from ``ndim_supp``/``ndims_params`` class attributes. @@ -70,8 +83,7 @@ def _prep_logp_params(dist_params, param_dims, size): Params with non-empty dims are wrapped as XTensorVariable for dim-aware ops. Params without dims (None) or with empty dims (scalar) stay as plain tensors. """ - del size # Unused; params may have extra batch dims from explicit_expand_dims, - # but broadcasting handles that at the tensor level. + del size result = [] for p, dims in zip(dist_params, param_dims): if dims: @@ -96,14 +108,16 @@ def logitnormal_dist(mu, sigma): return ptxm.sigmoid(pmd.Normal.dist(mu=mu, sigma=sigma)) - pmd.CustomDist("x", mu, sigma, dist=logitnormal_dist, dims="city") + with pm.Model(coords={"city": range(5)}): + pmd.CustomDist("x", mu, sigma, dist=logitnormal_dist, dims="city") When ``dist`` is provided without ``logp``, PyMC auto-derives the logp - from the inner graph. When ``logp`` is also given, it overrides the + from the inner graph (via ``DimSymbolicRandomVariable`` with + ``inline_logprob=True``). When ``logp`` is also given, it overrides the auto-derived logp while ``dist`` still drives the random path. - Logp path (``logp`` function receives the ``value`` and all params as - ``XTensorVariable`` — use ``ptx.*`` for dim-aware operations):: + Arbitrarily-defined logp path (``logp`` function receives the ``value`` + and all params as ``XTensorVariable``):: import pytensor.xtensor.math as ptx @@ -114,28 +128,14 @@ def normal_logp(value, mu, sigma): ) - pmd.CustomDist("y", mu, sigma, logp=normal_logp, observed=y, dims="city") + with pm.Model(coords={"city": range(5)}): + pmd.CustomDist("y", mu, sigma, logp=normal_logp, dims="city") When ``logp`` is provided without ``dist``, prior/posterior predictive sampling is not available — only MCMC (logp evaluation). For both, provide ``dist`` + ``logp`` together. - - For tensor-level operations use ``value.values`` to access the - underlying tensor:: - - import pytensor.tensor as pt - - - def normal_logp(value, mu, sigma): - v = value.values - return pt.sum(-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi))) - - - pmd.CustomDist("y", mu, sigma, logp=normal_logp, observed=y, dims="city") """ - _forward_dim_lengths = True - @classmethod def dist( cls, @@ -178,9 +178,7 @@ def _infer_output_dims(cls, params, extra_dims, core_dims, dim_lengths=None): param_dims |= set(p.dims) except AttributeError: pass - if dim_lengths: - batch_dims = tuple(dim_lengths.keys()) - elif extra_dims: + if extra_dims: batch_dims = tuple(d for d in extra_dims if d in param_dims) + tuple( d for d in extra_dims if d not in param_dims ) @@ -207,7 +205,7 @@ def xrv_op( return_next_rng: bool = False, **kwargs, ): - dim_lengths = kwargs.pop("dim_lengths", None) + kwargs.pop("dim_lengths", None) if dist is not None: return cls._symbolic_xrv_op( list(dist_params), @@ -215,15 +213,17 @@ def xrv_op( logp=logp, logcdf=logcdf, support_point=support_point, + ndim_supp=ndim_supp, + ndims_params=ndims_params, + signature=signature, class_name=class_name, core_dims=core_dims, extra_dims=extra_dims or {}, - dim_lengths=dim_lengths, rng=rng, return_next_rng=return_next_rng, ) else: - return cls._blackbox_xrv_op( + return cls._arbitrary_xrv_op( list(dist_params), logp=logp, logcdf=logcdf, @@ -235,7 +235,6 @@ def xrv_op( class_name=class_name, core_dims=core_dims, extra_dims=extra_dims or {}, - dim_lengths=dim_lengths, rng=rng, return_next_rng=return_next_rng, ) @@ -249,10 +248,12 @@ def _symbolic_xrv_op( logp: Callable | None, logcdf: Callable | None, support_point: Callable | None, + ndim_supp: int | None, + ndims_params: Sequence[int] | None, + signature: str | None, class_name: str, core_dims: str | Sequence[str] | None, extra_dims: dict[str, int], - dim_lengths: dict | None, rng, return_next_rng: bool, ): @@ -263,135 +264,154 @@ def _symbolic_xrv_op( ): rv = dist(*xtensor_params) - if isinstance(rv, XTensorVariable): - missing_extra_dims = {d: s for d, s in extra_dims.items() if d not in rv.dims} - if missing_extra_dims: - rv = expand_dist_dims(rv, missing_extra_dims) - else: - output_dims = cls._infer_output_dims(xtensor_params, extra_dims, core_dims, dim_lengths) - rv = xtensor_from_tensor(rv, dims=output_dims) + if not isinstance(rv, XTensorVariable): + raise TypeError( + "The `dist` function must return an XTensorVariable. " + "Use `pmd.Normal.dist(...)` or `xtensor_from_tensor(rv, dims=...)` " + "to ensure dims are attached to the output." + ) - # If no user-provided functions to override, return the symbolic RV as-is + missing_extra_dims = {d: s for d, s in extra_dims.items() if d not in rv.dims} + if missing_extra_dims: + rv = expand_dist_dims(rv, missing_extra_dims) + + # If no user-provided functions to override, return the symbolic RV as-is. + # This avoids the DimSymbolicRandomVariable OpFromGraph wrapper for the + # common case where logp is auto-derived from the inner graph. if logp is None and logcdf is None and support_point is None: if return_next_rng: return xtensor_shared_rng(seed=None), rv return rv - # Hybrid: use dist for sampling but user functions for logp/logcdf/support_point - # Walk the owner chain to find the underlying RandomVariable op for its rng_fn - tensor_rv = rv.values if isinstance(rv, XTensorVariable) else rv - current = tensor_rv - while current.owner is not None: - op = current.owner.op - if hasattr(op, "rng_fn"): - orig_op = op - break - if hasattr(op, "core_op") and hasattr(op.core_op, "rng_fn"): - orig_op = op.core_op - break - current = current.owner.inputs[0] - else: - raise ValueError( - "Could not find a RandomVariable op in the dist graph. " - "The dist function must return a distribution with a sampler." - ) + output_dims = rv.type.dims + param_dims = tuple(p.dims for p in xtensor_params) - # Compile the full dist output graph for sampling. - # This is essential for compound processes (Poisson→Gamma→...) — the - # graph contains multiple RandomVariables whose chained evaluation - # produces correct draws from the compound, bypassing a single op's rng_fn. - from pytensor.graph.basic import Constant - from pytensor.graph.traversal import graph_inputs - - # Only pass variables that are actual graph inputs (skip constants/unused params) - graph_deps = set(graph_inputs([rv.values])) - _input_indices = [ - i - for i, p in enumerate(xtensor_params) - if p in graph_deps and not isinstance(p, Constant) + # Build dummy inner graph (same tensor types as actual params) + dummy_tensor_params = [p.values.type() for p in xtensor_params] + dummy_xtensor_params = [ + xtensor_from_tensor(p, dims=pd) for p, pd in zip(dummy_tensor_params, param_dims) ] - sample_inputs = [xtensor_params[i] for i in _input_indices] - _sample_fn = pytensor.function( - inputs=sample_inputs, - outputs=rv.values, - ) + with new_or_existing_block_model_access( + error_msg_on_access="Model variables cannot be created in the dist function. Use the `.dist` API" + ): + dummy_rv_xt = dist(*dummy_xtensor_params) - def random_fn(*args, rng=None, size=None): - fn_args = [args[i] for i in _input_indices] - result = _sample_fn(*fn_args) - if size is not None and result.shape != tuple(size): - result = np.broadcast_to(result, tuple(size)).copy() - return result + if not isinstance(dummy_rv_xt, XTensorVariable): + raise TypeError( + "The `dist` function must return an XTensorVariable. " + "Use `pmd.Normal.dist(...)` or `xtensor_from_tensor(rv, dims=...)` " + "to ensure dims are attached to the output." + ) + + if missing_extra_dims: + dummy_rv_xt = expand_dist_dims(dummy_rv_xt, missing_extra_dims) - ndim_supp = getattr(orig_op, "ndim_supp", 0) - ndims_params = [0] * len(xtensor_params) - hybrid_signature = safe_signature( + tensor_dummy_rv = dummy_rv_xt.values + + # Find RNG updates from inner graph + updates = collect_default_updates(inputs=dummy_tensor_params, outputs=(tensor_dummy_rv,)) + if updates: + rngs, rngs_updates = zip(*updates.items()) + else: + rngs, rngs_updates = (), () + + # Build extended signature + if ndims_params is None: + ndims_params = [0] * len(xtensor_params) + if ndim_supp is None: + ndim_supp = 0 + sig = safe_signature( core_inputs_ndim=ndims_params, core_outputs_ndim=[ndim_supp], ) + n_inputs = len(dummy_tensor_params) + len(rngs) + n_outputs = 1 + len(rngs_updates) + n_rngs = len(rngs) + extended_sig = _infer_final_signature(sig, n_inputs, n_outputs, n_rngs, add_size=False) - output_dims = rv.type.dims if isinstance(rv, XTensorVariable) else () - param_dims = [] - for p in xtensor_params: - try: - param_dims.append(p.dims) - except AttributeError: - param_dims.append(None) - + # Create DimSymbolicRandomVariable subclass rv_type = type( class_name, - (_DimCustomDistRV,), + (DimSymbolicRandomVariable,), { - "signature": hybrid_signature, - "dtype": str(tensor_rv.dtype), + "inline_logprob": False, "_print_name": (class_name, f"\\operatorname{{{class_name}}}"), - "_random_fn": random_fn, - "_param_dims": tuple(param_dims), "_output_dims": output_dims, + "_param_dims": param_dims, }, ) + n_params = len(param_dims) + if logp is not None: @_logprob.register(rv_type) - def _custom_dist_logp(op, values, rng, size, *dist_params, **kwargs): - value_xt = xtensor_from_tensor(values[0], dims=op._output_dims) - xtensor_params = _prep_logp_params(dist_params, op._param_dims, size) + def _custom_dist_logp(op, values, *inputs, **kwargs): + [value] = values + value_xt = xtensor_from_tensor(value, dims=op._output_dims) + xtensor_params = _prep_logp_params( + list(inputs[:n_params]), op._param_dims, size=None + ) result = logp(value_xt, *xtensor_params) return result.values if isinstance(result, XTensorVariable) else result if logcdf is not None: @_logcdf.register(rv_type) - def _custom_dist_logcdf(op, value, rng, size, *dist_params, **kwargs): + def _custom_dist_logcdf(op, value, *inputs, **kwargs): value_xt = xtensor_from_tensor(value, dims=op._output_dims) - xtensor_params = _prep_logp_params(dist_params, op._param_dims, size) + xtensor_params = _prep_logp_params( + list(inputs[:n_params]), op._param_dims, size=None + ) result = logcdf(value_xt, *xtensor_params) return result.values if isinstance(result, XTensorVariable) else result _support_point_fn = support_point if support_point is not None else _default_support_point @_support_point.register(rv_type) - def _custom_dist_support_point(op, rv, rng, size, *dist_params): - return _support_point_fn(rv, size, *dist_params) - - size = None - if output_dims and dim_lengths: - size = tuple(dim_lengths[d] for d in output_dims if d in dim_lengths) - if not size: - size = tuple(extra_dims.values()) if extra_dims else None - rv_op = rv_type() - _, new_tensor_rv = _call_rv_op( - rv_op, *[p.values for p in xtensor_params], size=size, rng=rng + def _custom_dist_support_point(op, rv, *inputs): + return _support_point_fn(rv, None, *inputs[:n_params]) + + # Build OpFromGraph + # strict=False because the inner graph may reference shared + # dim_lengths variables (extra_dims from XRV nodes). + ofg_inputs = [*dummy_tensor_params, *rngs] + ofg_outputs = [tensor_dummy_rv, *rngs_updates] + rv_op = rv_type( + inputs=ofg_inputs, + outputs=ofg_outputs, + extended_signature=extended_sig, + strict=False, ) - rv = xtensor_from_tensor(new_tensor_rv, dims=output_dims) + + # Call with concrete inputs + tensor_params = [p.values for p in xtensor_params] + if rng is not None: + if len(rngs) != 1: + raise ValueError( + f"CustomDist received an explicit rng but it requires {len(rngs)} rngs." + ) + actual_rngs = (rng,) + else: + actual_rngs = rngs + + result = rv_op(*tensor_params, *actual_rngs) + + # result is the sample (default_output=0), RNG update is first output + # of the Apply node. + # Wrap as XTensorVariable with output dims. + rv_out = xtensor_from_tensor(result, dims=output_dims) if return_next_rng: - return rng, rv - return rv + if actual_rngs: + next_rng = rv_out.owner.outputs[0] + else: + next_rng = xtensor_shared_rng(seed=None) + return next_rng, rv_out + return rv_out @classmethod - def _blackbox_xrv_op( + def _arbitrary_xrv_op( cls, dist_params: list, *, @@ -405,7 +425,6 @@ def _blackbox_xrv_op( class_name: str, core_dims: str | Sequence[str] | None, extra_dims: dict[str, int], - dim_lengths: dict | None, rng, return_next_rng: bool, ): @@ -433,7 +452,7 @@ def _blackbox_xrv_op( ) # Infer output dims for the XTensor wrapping - output_dims = cls._infer_output_dims(dist_params, extra_dims, core_dims, dim_lengths) + output_dims = cls._infer_output_dims(dist_params, extra_dims, core_dims) # Dynamically create a RandomVariable subclass with ONLY signature # (no ndim_supp/ndims_params class attributes) to avoid deprecation warnings. diff --git a/pymc/distributions/custom.py b/pymc/distributions/custom.py index d7b8e14a89..62b3228c2b 100644 --- a/pymc/distributions/custom.py +++ b/pymc/distributions/custom.py @@ -354,7 +354,7 @@ def change_custom_dist_size(op, rv, new_size, expand): inputs = [*dummy_params, *rngs] outputs = [dummy_rv, *rngs_updates] - extended_signature = cls._infer_final_signature( + extended_signature = _infer_final_signature( signature, n_inputs=len(inputs), n_outputs=len(outputs), n_rngs=len(rngs) ) rv_op = rv_type( @@ -372,36 +372,44 @@ def change_custom_dist_size(op, rv, new_size, expand): rngs = (rng,) return rv_op(size, *dist_params, *rngs) - @staticmethod - def _infer_final_signature(signature: str, n_inputs, n_outputs, n_rngs) -> str: - """Add size and updates to user provided gufunc signature if they are missing.""" - # Regex to split across outer commas - # Copied from https://stackoverflow.com/a/26634150 - outer_commas = re.compile(r",\s*(?![^()]*\))") - - input_sig, output_sig = signature.split("->") - # It's valid to have a signature without params inputs, as in a Flat RV - n_inputs_sig = len(outer_commas.split(input_sig)) if input_sig.strip() else 0 - n_outputs_sig = len(outer_commas.split(output_sig)) - - if n_inputs_sig == n_inputs and n_outputs_sig == n_outputs: - # User provided a signature with no missing parts - return signature - - size_sig = "[size]" - rngs_sig = ("[rng]",) * n_rngs - if n_inputs_sig == (n_inputs - n_rngs - 1): - # Assume size and rngs are missing - if input_sig.strip(): - input_sig = ",".join((size_sig, input_sig, *rngs_sig)) - else: - input_sig = ",".join((size_sig, *rngs_sig)) - if n_outputs_sig == (n_outputs - n_rngs): - # Assume updates are missing - output_sig = ",".join((output_sig, *rngs_sig)) - signature = "->".join((input_sig, output_sig)) + +def _infer_final_signature(signature: str, n_inputs, n_outputs, n_rngs, *, add_size=True) -> str: + """Add size and updates to user provided gufunc signature if they are missing. + + Parameters + ---------- + add_size : bool + Whether to include ``[size]`` in the input signature. Set to ``False`` + when batch dimensions are baked into params (e.g. dims path). + """ + # Regex to split across outer commas + # Copied from https://stackoverflow.com/a/26634150 + outer_commas = re.compile(r",\s*(?![^()]*\))") + + input_sig, output_sig = signature.split("->") + # It's valid to have a signature without params inputs, as in a Flat RV + n_inputs_sig = len(outer_commas.split(input_sig)) if input_sig.strip() else 0 + n_outputs_sig = len(outer_commas.split(output_sig)) + + if n_inputs_sig == n_inputs and n_outputs_sig == n_outputs: + # User provided a signature with no missing parts return signature + rngs_sig = ("[rng]",) * n_rngs + n_extra_inputs = (1 if add_size else 0) + n_rngs + if n_inputs_sig == (n_inputs - n_extra_inputs): + parts = [] + if add_size: + parts.append("[size]") + if input_sig.strip(): + parts.append(input_sig) + parts.extend(rngs_sig) + input_sig = ",".join(parts) + if n_outputs_sig == (n_outputs - n_rngs): + output_sig = ",".join((output_sig, *rngs_sig)) + signature = "->".join((input_sig, output_sig)) + return signature + class SupportPointRewrite(GraphRewriter): def rewrite_support_point_scan_node(self, node): diff --git a/tests/dims/distributions/test_custom.py b/tests/dims/distributions/test_custom.py index 6c3aefb29d..5aeaa9850e 100644 --- a/tests/dims/distributions/test_custom.py +++ b/tests/dims/distributions/test_custom.py @@ -32,6 +32,26 @@ class TestCustomDistSymbolic: """Tests for the symbolic (dist=) path of pmd.CustomDist.""" + def test_compound_non_xrv_output(self): + """Compound dist with non-XRV output gets dims via expand_dist_dims.""" + + def logitnormal_dist(mu, sigma): + import pytensor.xtensor.math as ptxm + + return ptxm.sigmoid(Normal.dist(mu=mu, sigma=sigma)) + + coords = {"city": range(5)} + with Model(coords=coords) as model: + x = CustomDist("x", 0, 1, dist=logitnormal_dist, dims="city") + + assert set(x.dims) == {"city"} + + from pymc import draw as pm_draw + + draws = pm_draw(model["x"], draws=5) + assert draws.shape == (5, 5) + assert (draws > 0).all() and (draws < 1).all() + def test_basic(self): """Symbolic path: dist function wrapping Normal.dist, compared against regular Normal.""" @@ -65,11 +85,11 @@ def normal_dist(mu, sigma): assert x.type.shape == (5,) -class TestCustomDistBlackbox: - """Tests for the black-box (logp=/random=) path of pmd.CustomDist.""" +class TestCustomDistArbitrary: + """Tests for the arbitrarily-defined (logp=) path of pmd.CustomDist.""" def test_logp_basic(self): - """Black-box path with logp function and dims on output.""" + """Arbitrary path with logp function and dims on output.""" def normal_logp(value, mu, sigma): v = value.values @@ -171,8 +191,10 @@ def normal_dist(mu, sigma): def scaled_logp(value, mu, sigma): """Custom logp that multiplies normal logp by 2.""" v = value.values - normal_logp = -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi)) - return 2.0 * pt.sum(normal_logp) + normal_logp = pt.sum( + -0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi)) + ) + return 2.0 * normal_logp coords = {"city": range(5)} with Model(coords=coords) as model_hybrid: @@ -231,7 +253,7 @@ def normal_logp(value, mu, sigma): assert draws.shape == (3, 5) def test_logcdf(self): - """Black-box path with logcdf function.""" + """Arbitrary path with logcdf function.""" def normal_logp(value, mu, sigma): v = value.values @@ -263,7 +285,7 @@ def normal_logcdf(value, mu, sigma): assert np.isfinite(logp_val) def test_mu_as_model_var(self): - """Black-box path with mu as a model variable (no dims on mu).""" + """Arbitrary path with mu as a model variable (no dims on mu).""" def normal_logp(value, mu, sigma): v = value.values @@ -289,7 +311,36 @@ def normal_logp(value, mu, sigma): assert np.isfinite(logp_val) def test_support_point(self): - """Black-box path with custom support_point.""" + """Arbitrary path with custom support_point.""" + + def normal_logp(value, mu, sigma): + v = value.values + return pt.sum(-0.5 * ((v - mu) / sigma) ** 2 - pt.log(sigma * pt.sqrt(2 * np.pi))) + + def custom_support_point(rv, size, mu, sigma): + return pt.full_like(rv, mu) + + coords = {"city": range(5)} + with Model(coords=coords) as model: + CustomDist( + "x", + 0, + 1, + logp=normal_logp, + support_point=custom_support_point, + dims="city", + ) + + from pymc.distributions.distribution import support_point + + sp = support_point(model["x"]) + np.testing.assert_allclose(sp.eval(), np.zeros(5)) + + def test_hybrid_support_point(self): + """Hybrid path with custom support_point.""" + + def normal_dist(mu, sigma): + return Normal.dist(mu, sigma) def normal_logp(value, mu, sigma): v = value.values @@ -304,6 +355,7 @@ def custom_support_point(rv, size, mu, sigma): "x", 0, 1, + dist=normal_dist, logp=normal_logp, support_point=custom_support_point, dims="city",