Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 46 additions & 245 deletions pymc/distributions/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,287 +138,88 @@ def log_jac_det(self, value, *inputs):


class CholeskyCorrTransform(Transform):
"""
Map an unconstrained real vector the Cholesky factor of a correlation matrix.
r"""Map an unconstrained real vector to the Cholesky factor of a correlation matrix.

For detailed description of the transform, [1]_ and [2]_.
Constrained space: ``(n, n)`` lower-triangular Cholesky factor ``L`` of a
correlation matrix, with unit-norm rows (so ``L @ L.T`` has ones on the diagonal).
Unconstrained space: ``(n*(n-1)/2,)`` flat real vector packed in row-major
strictly-lower-triangular order.

This is typically used with :class:`~pymc.distributions.LKJCholeskyCov` to place priors on correlation structures.
For a related transform that additionally rescales diagonal elements (working on covariance factors), see
:class:`~pymc.distributions.transforms.CholeskyCovPacked`.
The transform composes three steps:

Adapted from the implementation in TensorFlow Probability [3]_:
https://github.com/tensorflow/probability/blob/94f592af363e13391858b48f785eb4c250912904/tensorflow_probability/python/bijectors/correlation_cholesky.py#L31
1. Scatter the flat vector into the strictly lower-triangular positions of an
``(n, n)`` matrix and set the diagonal to 1.
2. Normalize each row to unit L2 norm, producing the Cholesky factor ``L``.

Examples
--------
The Jacobian of the composite map uses the diagonal of the normalized factor:

.. code-block:: python
.. math::

import numpy as np
import pytensor.tensor as pt
from pymc.distributions.transforms import CholeskyCorr
\log |J| = \sum_{k=0}^{n-1} (k+2)\,\log L_{kk}
= -\tfrac12 \sum_{k=0}^{n-1} (k+2)\,\log\!\bigl(1 + \|r_k\|^2\bigr)

unconstrained_vector = pt.as_tensor(np.array([2.0, 2.0, 1.0]))
n = unconstrained_vector.shape[0]
tr = CholeskyCorr(n)
constrained_matrix = tr.forward(unconstrained_vector)
y.eval()
array(
[[1.0, 0.0, 0.0], [0.70710678, 0.70710678, 0.0], [0.66666667, 0.66666667, 0.33333333]]
)
where :math:`r_k` are the off-diagonal elements placed in row *k*.

References
----------
.. [1] Lewandowski, D., Kurowicka, D., & Joe, H. (2009).
Generating random correlation matrices based on vines and extended onion method.
Journal of Multivariate Analysis, 100(9), 1989–2001.
.. [2] Stan Development Team. Stan Functions Reference. Section on LKJ / Cholesky correlation.
.. [3] TensorFlow Probability. Correlation Cholesky bijector implementation.
https://github.com/tensorflow/probability/
"""

name = "cholesky_corr"

def __init__(self, n, upper: bool = False):
"""
Initialize the CholeskyCorr transform.

Parameters
----------
n : int
Size of the correlation matrix.
upper: bool, default False
If True, transform to an upper triangular matrix. If False, transform to a lower triangular matrix.
"""
if upper:
raise NotImplementedError("upper=True is not supported")
self.n = n
self.m = (n * (n + 1)) // 2 # Number of triangular elements
self.upper = upper
self.tril_idxs = pt.tril_indices(n, -1)

super().__init__()

def _fill_triangular_spiral(
self, x_raveled: TensorLike, unit_diag: bool = True
) -> TensorVariable:
"""
Create a triangular matrix from a vector by filling it in a spiral order.

This code is adapted from the `fill_triangular` function in TensorFlow Probability:
https://github.com/tensorflow/probability/blob/a26f4cbe5ce1549767e13798d9bf5032dac4257b/tensorflow_probability/python/math/linalg.py#L925

Parameters
----------
x_raveled: TensorLike
The input vector to be reshaped into a triangular matrix.
unit_diag: bool, default False
If True, the diagonal elements are assumed to be 1 and are not filled from the input vector. The input
vector is expected to have length m = n * (n - 1) / 2 in this case, containing only the off-diagonal
elements.

Returns
-------
triangular_matrix: TensorVariable
The resulting triangular matrix.

Notes
-----
By "spiral order", it is meant that the matrix is filled by jumping between the top and bottom rows, flipping
the fill order from left-to-right to right-to-left on each jump. For example, to fill a 4x4 matrix with
`order=True`, the matrix is filled in the following order:

- Row 0, left to right
- Row 3, right to left
- Row 1, left to right
- Row 2, right to left

When `upper` if False, everything is reversed:

- Row 3, right to left
- Row 0, left to right
- Row 2, right to left
- Row 1, left to right

After filling, entries not part of the triangular matrix are set to zero.

Examples
--------

.. code-block:: python

import numpy as np
from pymc.distributions.transforms import CholeskyCorr

tr = CholeskyCorr(n=4)
x_unconstrained = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
tr._fill_triangular_spiral(x_unconstrained, upper=False).eval()

# Out:
# array([[ 5, 0, 0, 0],
# [ 9, 10, 0, 0],
# [ 8, 7, 6, 0],
# [ 4, 3, 2, 1]])
"""
x_raveled = pt.as_tensor(x_raveled)
*batch_shape, _ = x_raveled.shape
n, m = self.n, self.m
upper = self.upper

if unit_diag:
n = n - 1

tail = x_raveled[..., n:]

if upper:
xc = pt.concatenate([x_raveled, pt.flip(tail, -1)], axis=-1)
else:
xc = pt.concatenate([tail, pt.flip(x_raveled, -1)], axis=-1)

y = pt.reshape(xc, (*batch_shape, n, n))
return pt.triu(y) if upper else pt.tril(y)

def _inverse_fill_triangular_spiral(
self, x: TensorLike, unit_diag: bool = True
) -> TensorVariable:
"""
Inverse operation of `_fill_triangular_spiral`.

Extracts the elements of a triangular matrix in spiral order and returns them as a vector. For details about
what is meant by "spiral order", see the docstring of `_fill_triangular_spiral`.

Parameters
----------
x: TensorVariable
The input triangular matrix.
unit_diag: bool
If True, the diagonal elements are assumed to be 1 and are not included in the output vector.

Returns
-------
x_raveled: TensorVariable
The resulting vector containing the elements of the triangular matrix in spiral order.
"""
x = pt.as_tensor(x)
*batch_shape, _, _ = x.shape
n, m = self.n, self.m

if unit_diag:
m = m - n
n = n - 1

upper = self.upper

if upper:
initial_elements = x[..., 0, :]
triangular_portion = x[..., 1:, :]
else:
initial_elements = pt.flip(x[..., -1, :], axis=-1)
triangular_portion = x[..., :-1, :]

rotated_triangular_portion = pt.flip(triangular_portion, axis=(-1, -2)) # type: ignore[arg-type]
consolidated_matrix = triangular_portion + rotated_triangular_portion
end_sequence = pt.reshape(
consolidated_matrix,
(*batch_shape, pt.cast(n * (n - 1), "int64")),
)
y = pt.concatenate([initial_elements, end_sequence[..., : m - n]], axis=-1)

return y

def forward(self, chol_corr_matrix: TensorLike, *inputs):
"""
Transform the Cholesky factor of a correlation matrix into a real-valued vector.

Parameters
----------
chol_corr_matrix : TensorVariable
Cholesky factor of a correlation matrix R = L @ L.T of shape (n,n).
inputs:
Additional input values. Not used; included for signature compatibility with other transformations.

Returns
-------
unconstrained_vector: TensorVariable
Real-valued vector of length m = n * (n - 1) / 2.
"""
chol_corr_matrix = pt.as_tensor(chol_corr_matrix)
n = self.n
chol_corr_matrix = pt.as_tensor_variable(chol_corr_matrix)

# Extract the reciprocal of the row norms from the diagonal.
# Divide each row by its diagonal element to undo the normalization.
diag = pt.diagonal(chol_corr_matrix, axis1=-2, axis2=-1)[..., None]
unconstrained = chol_corr_matrix / diag

# Set the diagonal to 0s.
diag_idx = pt.arange(n)
chol_corr_matrix = chol_corr_matrix[..., diag_idx, diag_idx].set(0)

# Multiply with the norm (or divide by its reciprocal) to recover the
# unconstrained reals in the (strictly) lower triangular part.
unconstrained_matrix = chol_corr_matrix / diag

# Remove the first row and last column before inverting the fill_triangular_spiral
# transformation.
return self._inverse_fill_triangular_spiral(
unconstrained_matrix[..., 1:, :-1], unit_diag=True
)
# The strictly lower-triangular elements (row-major) are the free parameters.
return unconstrained[..., self.tril_idxs[0], self.tril_idxs[1]]

def backward(self, unconstrained_vector: TensorLike, *inputs):
"""
Transform a real-valued vector of length m = n * (n - 1) / 2 into the Cholesky factor of a correlation matrix.

Parameters
----------
unconstrained_vector : TensorLike
Real-valued vector of length m = n * (n - 1) / 2.
inputs:
Additional input values. Not used; included for signature compatibility with other transformations.

Returns
-------
unconstrained_vector: TensorVariable
Unconstrained real numbers.
"""
unconstrained_vector = pt.as_tensor(unconstrained_vector)
chol_corr_matrix = self._fill_triangular_spiral(unconstrained_vector, unit_diag=True)

# Pad zeros on the top row and right column.
ndim = chol_corr_matrix.ndim
paddings = [*([(0, 0)] * (ndim - 2)), [1, 0], [0, 1]]
chol_corr_matrix = pt.pad(chol_corr_matrix, paddings)
unconstrained_vector = pt.as_tensor_variable(unconstrained_vector)
n = self.n

diag_idx = pt.arange(self.n)
chol_corr_matrix = chol_corr_matrix[..., diag_idx, diag_idx].set(1)
# Scatter into the strictly-lower-triangular positions of an (n, n) matrix.
L = pt.zeros((*unconstrained_vector.shape[:-1], n, n), dtype=unconstrained_vector.dtype)
L = L[..., self.tril_idxs[0], self.tril_idxs[1]].set(unconstrained_vector)

# Normalize each row to have Euclidean (L2) norm 1.
chol_corr_matrix /= pt.linalg.norm(chol_corr_matrix, axis=-1, ord=2)[..., None]
# Set diagonal to 1 (before normalization).
diag_idx = pt.arange(n)
L = L[..., diag_idx, diag_idx].set(1)

return chol_corr_matrix
# Normalize each row to unit L2 norm.
L /= pt.linalg.norm(L, axis=-1, ord=2)[..., None]
return L

def log_jac_det(self, unconstrained_vector: TensorLike, *inputs) -> TensorVariable:
"""
Compute the log determinant of the Jacobian.

Parameters
----------
unconstrained_vector : TensorLike
Real-valued vector of length m = n * (n - 1) / 2.
inputs:
Additional input values. Not used; included for signature compatibility with other transformations.

Returns
-------
log_jac_det: TensorVariable
Log determinant of the Jacobian of the transformation.
"""
unconstrained_vector = pt.as_tensor(unconstrained_vector)
chol_corr_matrix = self.backward(unconstrained_vector, *inputs)
unconstrained_vector = pt.as_tensor_variable(unconstrained_vector)
n = self.n
input_dtype = unconstrained_vector.dtype

# TODO: tfp has a negative sign here; verify if it is needed
return pt.sum(
pt.arange(2, 2 + n, dtype=input_dtype)
* pt.log(pt.diagonal(chol_corr_matrix, axis1=-2, axis2=-1)),
axis=-1,
)
dtype = unconstrained_vector.dtype

# Compute per-row sum of squares of the off-diagonal elements directly,
# without constructing the full normalized matrix.
sq = unconstrained_vector**2
row_sums_sq = pt.zeros((*unconstrained_vector.shape[:-1], n), dtype=dtype)
row_sums_sq = pt.inc_subtensor(row_sums_sq[..., self.tril_idxs[0]], sq)

# After setting diagonal to 1 and normalizing, diag_k = 1/sqrt(1 + row_sums_sq_k).
log_diag = pt.cast(-0.5, dtype) * pt.log1p(row_sums_sq)
coeffs = pt.arange(2, 2 + n, dtype=dtype)
return pt.sum(coeffs * log_diag, axis=-1)


class CholeskyCovPacked(Transform):
Expand Down
Loading
Loading