Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def _default_transform(op: Op, rv: TensorVariable):

class LogExpM1(Transform):
name = "log_exp_m1"
ndim_supp = 0

def backward(self, value, *inputs):
return pt.softplus(value)
Expand Down Expand Up @@ -84,6 +85,7 @@ class Ordered(Transform):
"""

name = "ordered"
ndim_supp = 1

def __init__(self, positive=False, ascending=True):
self.positive = positive
Expand Down Expand Up @@ -124,6 +126,7 @@ class SumTo1(Transform):
"""

name = "sumto1"
ndim_supp = 1

def backward(self, value, *inputs):
remaining = 1 - pt.sum(value[..., :], axis=-1, keepdims=True)
Expand Down Expand Up @@ -179,6 +182,7 @@ class CholeskyCorrTransform(Transform):
"""

name = "cholesky_corr"
ndim_supp = 1

def __init__(self, n, upper: bool = False):
"""
Expand Down Expand Up @@ -425,6 +429,7 @@ class CholeskyCovPacked(Transform):
"""Transforms the diagonal elements of the LKJCholeskyCov distribution to be on the log scale."""

name = "cholesky-cov-packed"
ndim_supp = 1

def __init__(self, n):
"""Create a CholeskyCovPack object.
Expand Down Expand Up @@ -494,6 +499,7 @@ class CholeskyCovTransform(Transform):
"""

name = "cholesky-cov"
ndim_supp = 1

def __init__(self, n):
"""Create a CholeskyCovTransform.
Expand Down Expand Up @@ -649,6 +655,7 @@ class ZeroSumTransform(Transform):

def __init__(self, zerosum_axes):
self.zerosum_axes = tuple(int(axis) for axis in zerosum_axes)
self.ndim_supp = len(self.zerosum_axes)

@staticmethod
def extend_axis(array, axis):
Expand Down
12 changes: 12 additions & 0 deletions pymc/logprob/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,7 @@ def calc_delta_x(value, prior_result):

class LocTransform(Transform):
name = "loc"
ndim_supp = 0

def __init__(self, transform_args_fn):
self.transform_args_fn = transform_args_fn
Expand All @@ -762,6 +763,7 @@ def log_jac_det(self, value, *inputs):

class ScaleTransform(Transform):
name = "scale"
ndim_supp = 0

def __init__(self, transform_args_fn):
self.transform_args_fn = transform_args_fn
Expand All @@ -781,6 +783,7 @@ def log_jac_det(self, value, *inputs):

class LogTransform(Transform):
name = "log"
ndim_supp = 0

def forward(self, value, *inputs):
return pt.log(value)
Expand All @@ -794,6 +797,7 @@ def log_jac_det(self, value, *inputs):

class ExpTransform(Transform):
name = "exp"
ndim_supp = 0

def forward(self, value, *inputs):
return pt.exp(value)
Expand All @@ -807,6 +811,7 @@ def log_jac_det(self, value, *inputs):

class AbsTransform(Transform):
name = "abs"
ndim_supp = 0

def forward(self, value, *inputs):
return pt.abs(value)
Expand All @@ -821,6 +826,7 @@ def log_jac_det(self, value, *inputs):

class PowerTransform(Transform):
name = "power"
ndim_supp = 0

def __init__(self, power=None):
if not isinstance(power, int | float):
Expand Down Expand Up @@ -864,6 +870,7 @@ def log_jac_det(self, value, *inputs):

class IntervalTransform(Transform):
name = "interval"
ndim_supp = 0

def __init__(self, args_fn: Callable[..., tuple[Variable | None, Variable | None]]):
"""Create the IntervalTransform object.
Expand Down Expand Up @@ -972,6 +979,7 @@ def log_jac_det(self, value, *inputs):

class LogOddsTransform(Transform):
name = "logodds"
ndim_supp = 0

def backward(self, value, *inputs):
return pt.expit(value)
Expand All @@ -986,6 +994,7 @@ def log_jac_det(self, value, *inputs):

class SimplexTransform(Transform):
name = "simplex"
ndim_supp = 1

def forward(self, value, *inputs):
value = pt.as_tensor(value)
Expand Down Expand Up @@ -1013,6 +1022,7 @@ def log_jac_det(self, value, *inputs):

class CircularTransform(Transform):
name = "circular"
ndim_supp = 0

def backward(self, value, *inputs):
return pt.arctan2(pt.sin(value), pt.cos(value))
Expand All @@ -1029,6 +1039,8 @@ class ChainedTransform(Transform):

def __init__(self, transform_list):
self.transform_list = transform_list
ndims_supp = [transform.ndim_supp for transform in transform_list]
self.ndim_supp = max(ndims_supp) if None not in ndims_supp else None

def forward(self, value, *inputs):
for transform in self.transform_list:
Expand Down
Loading