Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
atleast_nd,
cov,
create_diagonal,
diag_indices,
expand_dims,
isclose,
isin,
Expand All @@ -15,6 +16,8 @@
searchsorted,
setdiff1d,
sinc,
tril_indices,
triu_indices,
union1d,
)
from ._lib._at import at
Expand All @@ -40,6 +43,7 @@
"cov",
"create_diagonal",
"default_dtype",
"diag_indices",
"expand_dims",
"isclose",
"isin",
Expand All @@ -53,5 +57,7 @@
"searchsorted",
"setdiff1d",
"sinc",
"tril_indices",
"triu_indices",
"union1d",
]
162 changes: 162 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@
"atleast_nd",
"cov",
"create_diagonal",
"diag_indices",
"expand_dims",
"isclose",
"nan_to_num",
"one_hot",
"pad",
"searchsorted",
"sinc",
"tril_indices",
"triu_indices",
]


Expand Down Expand Up @@ -238,6 +241,49 @@ def create_diagonal(
return _funcs.create_diagonal(x, offset=offset, xp=xp)


def diag_indices(n: int, /, *, ndim: int = 2, xp: ModuleType) -> tuple[Array, ...]:
"""
Return the indices to access the main diagonal of an array.

Equivalent to ``numpy.diag_indices``.

Parameters
----------
n : int
The size of each dimension of the (hyper-)cube ``(n, n, ..., n)``
that the returned indices index into.
ndim : int, optional
The number of dimensions. Default: ``2``.
xp : array_namespace
The standard-compatible namespace to create the indices in.

Returns
-------
tuple of array
``ndim`` 1-D integer arrays of length ``n`` that together index
Comment thread
bruAristimunha marked this conversation as resolved.
Outdated
the main diagonal of an array of shape ``(n,) * ndim``.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> rows, cols = xpx.diag_indices(3, xp=xp)
>>> rows
Array([0, 1, 2], dtype=array_api_strict.int64)
>>> cols
Array([0, 1, 2], dtype=array_api_strict.int64)
"""
if n < 0:
msg = f"`n` must be non-negative, got {n}"
raise ValueError(msg)
if ndim < 1:
msg = f"`ndim` must be >= 1, got {ndim}"
raise ValueError(msg)
if is_numpy_namespace(xp) or is_cupy_namespace(xp) or is_jax_namespace(xp):
return xp.diag_indices(n, ndim=ndim)
return _funcs.diag_indices(n, ndim=ndim, xp=xp)


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
) -> Array:
Expand Down Expand Up @@ -1150,3 +1196,119 @@ def union1d(a: Array, b: Array, /, *, xp: ModuleType | None = None) -> Array:
return xp.union1d(a, b)

return _funcs.union1d(a, b, xp=xp)


def tril_indices(
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
) -> tuple[Array, Array]:
"""
Return the indices of the lower triangle of an ``(n, m)`` array.

Equivalent to ``numpy.tril_indices`` with parameter ``k`` renamed to
``offset`` to match ``xp.linalg.diagonal``'s naming.

Parameters
----------
n : int
The row dimension of the array.
offset : int, optional
Diagonal offset; ``0`` (default) is the main diagonal. Corresponds
to ``k`` in ``numpy.tril_indices``.
m : int, optional
The column dimension. If ``None`` (default), assumed equal to `n`.
xp : array_namespace
The standard-compatible namespace to create the indices in.

Returns
-------
tuple of array
Row and column indices ``(rows, cols)`` of the lower triangle of
the ``(n, m)`` matrix, shifted by `offset`.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> rows, cols = xpx.tril_indices(3, xp=xp)
>>> rows
Array([0, 1, 1, 2, 2, 2], dtype=array_api_strict.int64)
>>> cols
Array([0, 0, 1, 0, 1, 2], dtype=array_api_strict.int64)
"""
if n < 0:
msg = f"`n` must be non-negative, got {n}"
raise ValueError(msg)
if m is not None and m < 0:
msg = f"`m` must be non-negative, got {m}"
raise ValueError(msg)
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
):
return xp.tril_indices(n, k=offset, m=m)
if is_torch_namespace(xp):
# `torch.tril_indices` returns a 2xN tensor, not a tuple, and
# takes (row, col) rather than (n, *, m=None).
cols = n if m is None else m
idx = xp.tril_indices(n, cols, offset=offset)
return (idx[0], idx[1])
return _funcs.tril_indices(n, offset=offset, m=m, xp=xp)


def triu_indices(
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
) -> tuple[Array, Array]:
"""
Return the indices of the upper triangle of an ``(n, m)`` array.

Equivalent to ``numpy.triu_indices`` with parameter ``k`` renamed to
``offset`` to match ``xp.linalg.diagonal``'s naming.

Parameters
----------
n : int
The row dimension of the array.
offset : int, optional
Diagonal offset; ``0`` (default) is the main diagonal. Corresponds
to ``k`` in ``numpy.triu_indices``.
m : int, optional
The column dimension. If ``None`` (default), assumed equal to `n`.
xp : array_namespace
The standard-compatible namespace to create the indices in.

Returns
-------
tuple of array
Row and column indices ``(rows, cols)`` of the upper triangle of
the ``(n, m)`` matrix, shifted by `offset`.

Examples
--------
>>> import array_api_strict as xp
>>> import array_api_extra as xpx
>>> rows, cols = xpx.triu_indices(3, xp=xp)
>>> rows
Array([0, 0, 0, 1, 1, 2], dtype=array_api_strict.int64)
>>> cols
Array([0, 1, 2, 1, 2, 2], dtype=array_api_strict.int64)
"""
if n < 0:
msg = f"`n` must be non-negative, got {n}"
raise ValueError(msg)
if m is not None and m < 0:
msg = f"`m` must be non-negative, got {m}"
raise ValueError(msg)
if (
is_numpy_namespace(xp)
or is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_dask_namespace(xp)
):
return xp.triu_indices(n, k=offset, m=m)
if is_torch_namespace(xp):
cols = n if m is None else m
idx = xp.triu_indices(n, cols, offset=offset)
return (idx[0], idx[1])
return _funcs.triu_indices(n, offset=offset, m=m, xp=xp)
38 changes: 38 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
"broadcast_shapes",
"cov",
"create_diagonal",
"diag_indices",
"expand_dims",
"kron",
"nunique",
"pad",
"searchsorted",
"setdiff1d",
"sinc",
"tril_indices",
"triu_indices",
]


Expand Down Expand Up @@ -346,6 +349,41 @@ def create_diagonal(
return xp.reshape(diag, (*batch_dims, n, n))


def diag_indices(
n: int, /, *, ndim: int = 2, xp: ModuleType
) -> tuple[Array, ...]: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
idx = xp.arange(n)
return (idx,) * ndim


def _tri_indices(
n: int, *, offset: int, m: int | None, upper: bool, xp: ModuleType
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
"""Shared implementation for `tril_indices` and `triu_indices`."""
cols = n if m is None else m
rows = xp.arange(n)[:, None]
cols_a = xp.arange(cols)[None, :]
delta = cols_a - rows
mask = delta >= offset if upper else delta <= offset
r, c = xp.nonzero(mask)
return (r, c)


def tril_indices(
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
return _tri_indices(n, offset=offset, m=m, upper=False, xp=xp)


def triu_indices(
n: int, /, *, offset: int = 0, m: int | None = None, xp: ModuleType
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
"""See docstring in array_api_extra._delegation."""
return _tri_indices(n, offset=offset, m=m, upper=True, xp=xp)
Comment thread
lucascolley marked this conversation as resolved.
Outdated


def default_dtype(
xp: ModuleType,
kind: Literal[
Expand Down
7 changes: 5 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,12 @@ def as_readonly(o: T) -> T: # numpydoc ignore=PR01,RT01
# Cannot interpret as a data type
return o

# This works with namedtuples too
if isinstance(o, tuple | list):
return type(o)(*(as_readonly(i) for i in o)) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType]
# namedtuple wants positional args; plain tuple/list wants an iterable.
items = (as_readonly(i) for i in o)
if hasattr(o, "_fields"):
return type(o)(*items) # type: ignore[arg-type,return-value] # pyright: ignore[reportArgumentType]
return type(o)(items) # type: ignore[return-value]

return o

Expand Down
Loading