From 286925fb83961c1cfee39e1804a045bd11180549 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 12 Jun 2026 16:42:27 +0200 Subject: [PATCH] Declare ndim_supp on value transforms --- pymc/distributions/transforms.py | 7 +++++++ pymc/logprob/transforms.py | 12 ++++++++++++ 2 files changed, 19 insertions(+) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 4561152d50..b7e9a6718c 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -57,6 +57,7 @@ def _default_transform(op: Op, rv: TensorVariable): class LogExpM1(Transform): name = "log_exp_m1" + ndim_supp = 0 def backward(self, value, *inputs): return pt.softplus(value) @@ -84,6 +85,7 @@ class Ordered(Transform): """ name = "ordered" + ndim_supp = 1 def __init__(self, positive=False, ascending=True): self.positive = positive @@ -124,6 +126,7 @@ class SumTo1(Transform): """ name = "sumto1" + ndim_supp = 1 def backward(self, value, *inputs): remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True) @@ -179,6 +182,7 @@ class CholeskyCorrTransform(Transform): """ name = "cholesky_corr" + ndim_supp = 1 def __init__(self, n, upper: bool = False): """ @@ -425,6 +429,7 @@ class CholeskyCovPacked(Transform): """Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale.""" name = "cholesky-cov-packed" + ndim_supp = 1 def __init__(self, n): """Create a CholeskyCovPack object. @@ -494,6 +499,7 @@ class CholeskyCovTransform(Transform): """ name = "cholesky-cov" + ndim_supp = 1 def __init__(self, n): """Create a CholeskyCovTransform. @@ -649,6 +655,7 @@ class ZeroSumTransform(Transform): def __init__(self, zerosum_axes): self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes) + self.ndim_supp = len(self.zerosum_axes) @staticmethod def extend_axis(array, axis): diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 3b551f6660..19c1e5bb12 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -744,6 +744,7 @@ def calc_delta_x(value, prior_result): class LocTransform(Transform): name = "loc" + ndim_supp = 0 def __init__(self, transform_args_fn): self.transform_args_fn = transform_args_fn @@ -762,6 +763,7 @@ def log_jac_det(self, value, *inputs): class ScaleTransform(Transform): name = "scale" + ndim_supp = 0 def __init__(self, transform_args_fn): self.transform_args_fn = transform_args_fn @@ -781,6 +783,7 @@ def log_jac_det(self, value, *inputs): class LogTransform(Transform): name = "log" + ndim_supp = 0 def forward(self, value, *inputs): return pt.log(value) @@ -794,6 +797,7 @@ def log_jac_det(self, value, *inputs): class ExpTransform(Transform): name = "exp" + ndim_supp = 0 def forward(self, value, *inputs): return pt.exp(value) @@ -807,6 +811,7 @@ def log_jac_det(self, value, *inputs): class AbsTransform(Transform): name = "abs" + ndim_supp = 0 def forward(self, value, *inputs): return pt.abs(value) @@ -821,6 +826,7 @@ def log_jac_det(self, value, *inputs): class PowerTransform(Transform): name = "power" + ndim_supp = 0 def __init__(self, power=None): if not isinstance(power, int | float): @@ -864,6 +870,7 @@ def log_jac_det(self, value, *inputs): class IntervalTransform(Transform): name = "interval" + ndim_supp = 0 def __init__(self, args_fn: Callable[..., tuple[Variable | None, Variable | None]]): """Create the IntervalTransform object. @@ -972,6 +979,7 @@ def log_jac_det(self, value, *inputs): class LogOddsTransform(Transform): name = "logodds" + ndim_supp = 0 def backward(self, value, *inputs): return pt.expit(value) @@ -986,6 +994,7 @@ def log_jac_det(self, value, *inputs): class SimplexTransform(Transform): name = "simplex" + ndim_supp = 1 def forward(self, value, *inputs): value = pt.as_tensor(value) @@ -1013,6 +1022,7 @@ def log_jac_det(self, value, *inputs): class CircularTransform(Transform): name = "circular" + ndim_supp = 0 def backward(self, value, *inputs): return pt.arctan2(pt.sin(value), pt.cos(value)) @@ -1029,6 +1039,8 @@ class ChainedTransform(Transform): def __init__(self, transform_list): self.transform_list = transform_list + ndims_supp = [transform.ndim_supp for transform in transform_list] + self.ndim_supp = max(ndims_supp) if None not in ndims_supp else None def forward(self, value, *inputs): for transform in self.transform_list: