From 6523b379e3cdfcc49d85103493dc98392eda080e Mon Sep 17 00:00:00 2001 From: ricardoV94 Date: Fri, 15 May 2026 15:25:54 +0200 Subject: [PATCH] Simpler LKJ Corr unconstraining transform --- pymc/distributions/transforms.py | 291 ++++---------------------- pymc/logprob/utils.py | 93 ++++++++ tests/distributions/test_transform.py | 39 +--- 3 files changed, 144 insertions(+), 279 deletions(-) diff --git a/pymc/distributions/transforms.py b/pymc/distributions/transforms.py index 4561152d50..7e7ede0e3b 100644 --- a/pymc/distributions/transforms.py +++ b/pymc/distributions/transforms.py @@ -138,35 +138,27 @@ def log_jac_det(self, value, *inputs): class CholeskyCorrTransform(Transform): - """ - Map an unconstrained real vector the Cholesky factor of a correlation matrix. + r"""Map an unconstrained real vector to the Cholesky factor of a correlation matrix. - For detailed description of the transform, [1]_ and [2]_. + Constrained space: ``(n, n)`` lower-triangular Cholesky factor ``L`` of a + correlation matrix, with unit-norm rows (so ``L @ L.T`` has ones on the diagonal). + Unconstrained space: ``(n*(n-1)/2,)`` flat real vector packed in row-major + strictly-lower-triangular order. - This is typically used with :class:`~pymc.distributions.LKJCholeskyCov` to place priors on correlation structures. - For a related transform that additionally rescales diagonal elements (working on covariance factors), see - :class:`~pymc.distributions.transforms.CholeskyCovPacked`. + The transform composes three steps: - Adapted from the implementation in TensorFlow Probability [3]_: - https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31 + 1. Scatter the flat vector into the strictly lower-triangular positions of an + ``(n, n)`` matrix and set the diagonal to 1. + 2. Normalize each row to unit L2 norm, producing the Cholesky factor ``L``. - Examples - -------- + The Jacobian of the composite map uses the diagonal of the normalized factor: - .. code-block:: python + .. math:: - import numpy as np - import pytensor.tensor as pt - from pymc.distributions.transforms import CholeskyCorr + \log |J| = \sum_{k=0}^{n-1} (k+2)\,\log L_{kk} + = -\tfrac12 \sum_{k=0}^{n-1} (k+2)\,\log\!\bigl(1 + \|r_k\|^2\bigr) - unconstrained_vector = pt.as_tensor(np.array([2.0, 2.0, 1.0])) - n = unconstrained_vector.shape[0] - tr = CholeskyCorr(n) - constrained_matrix = tr.forward(unconstrained_vector) - y.eval() - array( - [[1.0, 0.0, 0.0], [0.70710678, 0.70710678, 0.0], [0.66666667, 0.66666667, 0.33333333]] - ) + where :math:`r_k` are the off-diagonal elements placed in row *k*. References ---------- @@ -174,251 +166,60 @@ class CholeskyCorrTransform(Transform): Generating random correlation matrices based on vines and extended onion method. Journal of Multivariate Analysis, 100(9), 1989–2001. .. [2] Stan Development Team. Stan Functions Reference. Section on LKJ / Cholesky correlation. - .. [3] TensorFlow Probability. Correlation Cholesky bijector implementation. - https://github.com/tensorflow/probability/ """ name = "cholesky_corr" def __init__(self, n, upper: bool = False): - """ - Initialize the CholeskyCorr transform. - - Parameters - ---------- - n : int - Size of the correlation matrix. - upper: bool, default False - If True, transform to an upper triangular matrix. If False, transform to a lower triangular matrix. - """ + if upper: + raise NotImplementedError("upper=True is not supported") self.n = n - self.m = (n * (n + 1)) // 2 # Number of triangular elements self.upper = upper + self.tril_idxs = pt.tril_indices(n, -1) super().__init__() - def _fill_triangular_spiral( - self, x_raveled: TensorLike, unit_diag: bool = True - ) -> TensorVariable: - """ - Create a triangular matrix from a vector by filling it in a spiral order. - - This code is adapted from the `fill_triangular` function in TensorFlow Probability: - https://github.com/tensorflow/probability/blob/a26f4cbe5ce1549767e13798d9bf5032dac4257b/tensorflow_probability/python/math/linalg.py#L925 - - Parameters - ---------- - x_raveled: TensorLike - The input vector to be reshaped into a triangular matrix. - unit_diag: bool, default False - If True, the diagonal elements are assumed to be 1 and are not filled from the input vector. The input - vector is expected to have length m = n * (n - 1) / 2 in this case, containing only the off-diagonal - elements. - - Returns - ------- - triangular_matrix: TensorVariable - The resulting triangular matrix. - - Notes - ----- - By "spiral order", it is meant that the matrix is filled by jumping between the top and bottom rows, flipping - the fill order from left-to-right to right-to-left on each jump. For example, to fill a 4x4 matrix with - `order=True`, the matrix is filled in the following order: - - - Row 0, left to right - - Row 3, right to left - - Row 1, left to right - - Row 2, right to left - - When `upper` if False, everything is reversed: - - - Row 3, right to left - - Row 0, left to right - - Row 2, right to left - - Row 1, left to right - - After filling, entries not part of the triangular matrix are set to zero. - - Examples - -------- - - .. code-block:: python - - import numpy as np - from pymc.distributions.transforms import CholeskyCorr - - tr = CholeskyCorr(n=4) - x_unconstrained = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) - tr._fill_triangular_spiral(x_unconstrained, upper=False).eval() - - # Out: - # array([[ 5, 0, 0, 0], - # [ 9, 10, 0, 0], - # [ 8, 7, 6, 0], - # [ 4, 3, 2, 1]]) - """ - x_raveled = pt.as_tensor(x_raveled) - *batch_shape, _ = x_raveled.shape - n, m = self.n, self.m - upper = self.upper - - if unit_diag: - n = n - 1 - - tail = x_raveled[..., n:] - - if upper: - xc = pt.concatenate([x_raveled, pt.flip(tail, -1)], axis=-1) - else: - xc = pt.concatenate([tail, pt.flip(x_raveled, -1)], axis=-1) - - y = pt.reshape(xc, (*batch_shape, n, n)) - return pt.triu(y) if upper else pt.tril(y) - - def _inverse_fill_triangular_spiral( - self, x: TensorLike, unit_diag: bool = True - ) -> TensorVariable: - """ - Inverse operation of `_fill_triangular_spiral`. - - Extracts the elements of a triangular matrix in spiral order and returns them as a vector. For details about - what is meant by "spiral order", see the docstring of `_fill_triangular_spiral`. - - Parameters - ---------- - x: TensorVariable - The input triangular matrix. - unit_diag: bool - If True, the diagonal elements are assumed to be 1 and are not included in the output vector. - - Returns - ------- - x_raveled: TensorVariable - The resulting vector containing the elements of the triangular matrix in spiral order. - """ - x = pt.as_tensor(x) - *batch_shape, _, _ = x.shape - n, m = self.n, self.m - - if unit_diag: - m = m - n - n = n - 1 - - upper = self.upper - - if upper: - initial_elements = x[..., 0, :] - triangular_portion = x[..., 1:, :] - else: - initial_elements = pt.flip(x[..., -1, :], axis=-1) - triangular_portion = x[..., :-1, :] - - rotated_triangular_portion = pt.flip(triangular_portion, axis=(-1, -2)) # type: ignore[arg-type] - consolidated_matrix = triangular_portion + rotated_triangular_portion - end_sequence = pt.reshape( - consolidated_matrix, - (*batch_shape, pt.cast(n * (n - 1), "int64")), - ) - y = pt.concatenate([initial_elements, end_sequence[..., : m - n]], axis=-1) - - return y - def forward(self, chol_corr_matrix: TensorLike, *inputs): - """ - Transform the Cholesky factor of a correlation matrix into a real-valued vector. - - Parameters - ---------- - chol_corr_matrix : TensorVariable - Cholesky factor of a correlation matrix R = L @ L.T of shape (n,n). - inputs: - Additional input values. Not used; included for signature compatibility with other transformations. - - Returns - ------- - unconstrained_vector: TensorVariable - Real-valued vector of length m = n * (n - 1) / 2. - """ - chol_corr_matrix = pt.as_tensor(chol_corr_matrix) - n = self.n + chol_corr_matrix = pt.as_tensor_variable(chol_corr_matrix) - # Extract the reciprocal of the row norms from the diagonal. + # Divide each row by its diagonal element to undo the normalization. diag = pt.diagonal(chol_corr_matrix, axis1=-2, axis2=-1)[..., None] + unconstrained = chol_corr_matrix / diag - # Set the diagonal to 0s. - diag_idx = pt.arange(n) - chol_corr_matrix = chol_corr_matrix[..., diag_idx, diag_idx].set(0) - - # Multiply with the norm (or divide by its reciprocal) to recover the - # unconstrained reals in the (strictly) lower triangular part. - unconstrained_matrix = chol_corr_matrix / diag - - # Remove the first row and last column before inverting the fill_triangular_spiral - # transformation. - return self._inverse_fill_triangular_spiral( - unconstrained_matrix[..., 1:, :-1], unit_diag=True - ) + # The strictly lower-triangular elements (row-major) are the free parameters. + return unconstrained[..., self.tril_idxs[0], self.tril_idxs[1]] def backward(self, unconstrained_vector: TensorLike, *inputs): - """ - Transform a real-valued vector of length m = n * (n - 1) / 2 into the Cholesky factor of a correlation matrix. - - Parameters - ---------- - unconstrained_vector : TensorLike - Real-valued vector of length m = n * (n - 1) / 2. - inputs: - Additional input values. Not used; included for signature compatibility with other transformations. - - Returns - ------- - unconstrained_vector: TensorVariable - Unconstrained real numbers. - """ - unconstrained_vector = pt.as_tensor(unconstrained_vector) - chol_corr_matrix = self._fill_triangular_spiral(unconstrained_vector, unit_diag=True) - - # Pad zeros on the top row and right column. - ndim = chol_corr_matrix.ndim - paddings = [*([(0, 0)] * (ndim - 2)), [1, 0], [0, 1]] - chol_corr_matrix = pt.pad(chol_corr_matrix, paddings) + unconstrained_vector = pt.as_tensor_variable(unconstrained_vector) + n = self.n - diag_idx = pt.arange(self.n) - chol_corr_matrix = chol_corr_matrix[..., diag_idx, diag_idx].set(1) + # Scatter into the strictly-lower-triangular positions of an (n, n) matrix. + L = pt.zeros((*unconstrained_vector.shape[:-1], n, n), dtype=unconstrained_vector.dtype) + L = L[..., self.tril_idxs[0], self.tril_idxs[1]].set(unconstrained_vector) - # Normalize each row to have Euclidean (L2) norm 1. - chol_corr_matrix /= pt.linalg.norm(chol_corr_matrix, axis=-1, ord=2)[..., None] + # Set diagonal to 1 (before normalization). + diag_idx = pt.arange(n) + L = L[..., diag_idx, diag_idx].set(1) - return chol_corr_matrix + # Normalize each row to unit L2 norm. + L /= pt.linalg.norm(L, axis=-1, ord=2)[..., None] + return L def log_jac_det(self, unconstrained_vector: TensorLike, *inputs) -> TensorVariable: - """ - Compute the log determinant of the Jacobian. - - Parameters - ---------- - unconstrained_vector : TensorLike - Real-valued vector of length m = n * (n - 1) / 2. - inputs: - Additional input values. Not used; included for signature compatibility with other transformations. - - Returns - ------- - log_jac_det: TensorVariable - Log determinant of the Jacobian of the transformation. - """ - unconstrained_vector = pt.as_tensor(unconstrained_vector) - chol_corr_matrix = self.backward(unconstrained_vector, *inputs) + unconstrained_vector = pt.as_tensor_variable(unconstrained_vector) n = self.n - input_dtype = unconstrained_vector.dtype - - # TODO: tfp has a negative sign here; verify if it is needed - return pt.sum( - pt.arange(2, 2 + n, dtype=input_dtype) - * pt.log(pt.diagonal(chol_corr_matrix, axis1=-2, axis2=-1)), - axis=-1, - ) + dtype = unconstrained_vector.dtype + + # Compute per-row sum of squares of the off-diagonal elements directly, + # without constructing the full normalized matrix. + sq = unconstrained_vector**2 + row_sums_sq = pt.zeros((*unconstrained_vector.shape[:-1], n), dtype=dtype) + row_sums_sq = pt.inc_subtensor(row_sums_sq[..., self.tril_idxs[0]], sq) + + # After setting diagonal to 1 and normalizing, diag_k = 1/sqrt(1 + row_sums_sq_k). + log_diag = pt.cast(-0.5, dtype) * pt.log1p(row_sums_sq) + coeffs = pt.arange(2, 2 + n, dtype=dtype) + return pt.sum(coeffs * log_diag, axis=-1) class CholeskyCovPacked(Transform): diff --git a/pymc/logprob/utils.py b/pymc/logprob/utils.py index bbb30456d1..f116fd1246 100644 --- a/pymc/logprob/utils.py +++ b/pymc/logprob/utils.py @@ -52,6 +52,7 @@ from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import MeasurableOp, ValuedRV, _logprob @@ -238,6 +239,98 @@ def local_check_parameter_to_ninf_switch(fgraph, node): ) +@register_canonicalize +@register_specialize +@node_rewriter([pt.math.sqr]) +def local_sqr_of_sqrt_div(fgraph, node): + """Sqr(Sqrt(A) / B) -> A / Sqr(B) and Sqr(A / Sqrt(B)) -> Sqr(A) / B. + + Cancels Sqr with Sqrt when they are separated by a Mul or True_div. + """ + import pytensor.scalar.basic as ps + + [x] = node.inputs + if not (x.owner and isinstance(x.owner.op, Elemwise)): + return None + + inner_scalar_op = x.owner.op.scalar_op + if not isinstance(inner_scalar_op, (ps.TrueDiv, ps.Mul)): + return None + + terms = x.owner.inputs + is_div = isinstance(inner_scalar_op, ps.TrueDiv) + + sqrt_indices = [] + for i, term in enumerate(terms): + if ( + term.owner + and isinstance(term.owner.op, Elemwise) + and isinstance(term.owner.op.scalar_op, ps.Sqrt) + ): + sqrt_indices.append(i) + + if not sqrt_indices: + return None + + if is_div: + numerator, denominator = terms + if ( + numerator.owner + and isinstance(numerator.owner.op, Elemwise) + and isinstance(numerator.owner.op.scalar_op, ps.Sqrt) + ): + # Sqr(Sqrt(A) / B) -> A / Sqr(B) + return [numerator.owner.inputs[0] / pt.sqr(denominator)] + elif ( + denominator.owner + and isinstance(denominator.owner.op, Elemwise) + and isinstance(denominator.owner.op.scalar_op, ps.Sqrt) + ): + # Sqr(A / Sqrt(B)) -> Sqr(A) / B + return [pt.sqr(numerator) / denominator.owner.inputs[0]] + else: + # Mul: cancel one Sqrt factor, square the rest + # Sqr(Sqrt(A) * B * ...) -> A * Sqr(B) * Sqr(...) + idx = sqrt_indices[0] + new_terms = [] + for i, term in enumerate(terms): + if i == idx: + new_terms.append(term.owner.inputs[0]) + else: + new_terms.append(pt.sqr(term)) + result = new_terms[0] + for t in new_terms[1:]: + result = result * t + return [result] + + +@register_canonicalize +@register_specialize +@node_rewriter([pt.math.sqr]) +def local_sqr_of_abs(fgraph, node): + """Sqr(Abs(x)) -> Sqr(x), since squaring already eliminates sign.""" + import pytensor.scalar.basic as ps + + [x] = node.inputs + if ( + x.owner + and isinstance(x.owner.op, Elemwise) + and isinstance(x.owner.op.scalar_op, ps.Abs) + ): + return [pt.sqr(x.owner.inputs[0])] + + +@register_specialize +@node_rewriter([pt.true_div]) +def local_x_div_x(fgraph, node): + """x / x -> ones_like(x), using structural equality.""" + from pytensor.graph.basic import equal_computations + + x, y = node.inputs + if x is y or equal_computations([x], [y]): + return [pt.ones_like(x)] + + class DiracDelta(MeasurableOp, Op): """An `Op` that represents a Dirac-delta distribution.""" diff --git a/tests/distributions/test_transform.py b/tests/distributions/test_transform.py index 036f2e46f7..abad0fae16 100644 --- a/tests/distributions/test_transform.py +++ b/tests/distributions/test_transform.py @@ -688,44 +688,15 @@ class TestLJKCholeskyCorrTransform: def _get_test_values(self): x_unconstrained = np.array([2.0, 2.0, 1.0], dtype=config.floatX) x_constrained = np.array( - [[1.0, 0.0, 0.0], [0.70710678, 0.70710678, 0.0], [0.66666667, 0.66666667, 0.33333333]], + [ + [1.0, 0.0, 0.0], + [2.0 / np.sqrt(5), 1.0 / np.sqrt(5), 0.0], + [2.0 / np.sqrt(6), 1.0 / np.sqrt(6), 1.0 / np.sqrt(6)], + ], dtype=config.floatX, ) return x_unconstrained, x_constrained - @pytest.mark.parametrize("upper", [True, False], ids=["upper", "lower"]) - def test_fill_triangular_spiral(self, upper): - x_unconstrained = np.array([1, 2, 3, 4, 5, 6]) - - if upper: - x_constrained = np.array( - [ - [1, 2, 3], - [0, 5, 6], - [0, 0, 4], - ] - ) - else: - x_constrained = np.array( - [ - [4, 0, 0], - [6, 5, 0], - [3, 2, 1], - ] - ) - - transform = tr.CholeskyCorrTransform(n=3, upper=upper) - - np.testing.assert_allclose( - transform._fill_triangular_spiral(x_unconstrained, unit_diag=False).eval(), - x_constrained, - ) - - np.testing.assert_allclose( - transform._inverse_fill_triangular_spiral(x_constrained, unit_diag=False).eval(), - x_unconstrained, - ) - def test_forward(self): transform = tr.CholeskyCorrTransform(n=3, upper=False) x_unconstrained, x_constrained = self._get_test_values()