Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/scanpy/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
"DaskArray",
"SpBase",
"fullname",
"get_namespace",
"is_array_api",
"pkg_metadata",
"pkg_version",
"set_module",
Expand Down Expand Up @@ -66,6 +68,29 @@ def pkg_metadata(package: str) -> PackageMetadata:
return metadata(package)


def is_array_api(x: object) -> bool:
# returns true if x is array api compatible
# exclusing the ones that are already handled by the script
import numpy as np
from array_api_compat import is_array_api_obj

# excluding packages that are handled by both array-api-compat and script
if isinstance(x, np.ndarray):
return False
if isinstance(x, DaskArray):
return False
if isinstance(x, SpBase):
return False
return is_array_api_obj(x)


def get_namespace(x):
# get array-api namespace for x
from array_api_compat import get_namespace

return get_namespace(x)


@cache
def pkg_version(package: str) -> Version:
from importlib.metadata import version
Expand Down
29 changes: 29 additions & 0 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,22 @@ def axis_mul_or_truediv(
allow_divide_by_zero: bool = True,
out: ArrayLike | None = None,
) -> np.ndarray:
from .._compat import get_namespace, is_array_api

_check_op(op)
scaling_array = _broadcast_axis(scaling_array, axis)
# array api version
if is_array_api(x): ### double check if numpy skips this
xp = get_namespace(x)
scaling_array = xp.asarray(scaling_array)
if op is mul:
return x * scaling_array
if not allow_divide_by_zero:
scaling_array = xp.where(
scaling_array == 0, xp.ones_like(scaling_array), scaling_array
)
return x / scaling_array
# numpy version
if op is mul:
return np.multiply(x, scaling_array, out=out)
if not allow_divide_by_zero:
Expand Down Expand Up @@ -725,6 +739,12 @@ def _[T: (DaskArray, np.ndarray)](

@singledispatch
def axis_nnz(x: ArrayLike, /, axis: Literal[0, 1]) -> np.ndarray:
from .._compat import get_namespace, is_array_api

if is_array_api(x):
xp = get_namespace(x)
return xp.count_nonzero(x, axis=axis)

return np.count_nonzero(x, axis=axis)


Expand Down Expand Up @@ -758,6 +778,15 @@ def _(x: DaskArray, /, axis: Literal[0, 1]) -> DaskArray:
@singledispatch
def check_nonnegative_integers(x: _SupportedArray, /) -> bool | DaskArray:
"""Check values of X to ensure it is count data."""
from .._compat import get_namespace, is_array_api

if is_array_api(x):
xp = get_namespace(x)
if bool(xp.any(x < 0)):
return False
if xp.isdtype(x.dtype, "integral"):
return True
return not bool(xp.any((x % 1) != 0)) ### double check
raise NotImplementedError


Expand Down
9 changes: 6 additions & 3 deletions src/scanpy/metrics/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from .._utils import NeighborsView

if TYPE_CHECKING:
from typing import NoReturn

from anndata import AnnData
from numpy.typing import NDArray

Expand Down Expand Up @@ -93,7 +91,12 @@ def _resolve_vals(val: pd.DataFrame | pd.Series) -> NDArray: ...


@singledispatch
def _resolve_vals(val: object) -> NoReturn:
def _resolve_vals(val: object): ### double check
from .._compat import is_array_api

if is_array_api(val):
# Moran's I / Geary's C use numba kernels, so need to convert at boundary
return np.asarray(val)
msg = f"Unsupported type {type(val)}"
raise TypeError(msg)

Expand Down
5 changes: 5 additions & 0 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,12 @@ def compute_neighbors(
self._rp_forest = None
self.n_neighbors = n_neighbors
self.knn = knn
from .._compat import is_array_api

x = _choose_representation(self._adata, use_rep=use_rep, n_pcs=n_pcs)
if is_array_api(x):
# sklearn transformers require numpy, so need to convert at boundary
x = np.asarray(x)
self._distances = transformer.fit_transform(x)
knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix(
self._distances, n_neighbors
Expand Down
10 changes: 9 additions & 1 deletion src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,21 @@ def _highly_variable_genes_single_batch(
if n_removed:
x = x[:, filt].copy()

from .._compat import get_namespace, is_array_api ### double check

if flavor == "seurat":
x = x.copy()
if (base := adata.uns.get("log1p", {}).get("base")) is not None:
x *= np.log(base)
if is_array_api(x):
x = x * float(np.log(base))
else:
x *= np.log(base)
# use out if possible. only possible since we copy the data matrix
if isinstance(x, np.ndarray):
np.expm1(x, out=x)
elif is_array_api(x):
xp = get_namespace(x)
x = xp.expm1(x)
else:
x = np.expm1(x)

Expand Down
7 changes: 7 additions & 0 deletions src/scanpy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,15 @@

def _compute_nnz_median(counts: np.ndarray | DaskArray) -> np.floating:
"""Given a 1D array of counts, compute the median of the non-zero counts."""
from .._compat import is_array_api

if isinstance(counts, DaskArray):
counts = counts.compute()

if is_array_api(counts):
# there is no xp.median? ### double check
counts = np.asarray(counts)

counts_greater_than_zero = counts[counts > 0]
median = np.median(counts_greater_than_zero)
return median
Expand Down
48 changes: 37 additions & 11 deletions src/scanpy/preprocessing/_scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@
def clip[A: _Array](
x: ArrayLike | A, *, max_value: float, zero_center: bool = True
) -> A:
# clip_array cannot trace JAX arrays = example
from .._compat import get_namespace, is_array_api

if is_array_api(x):
xp = get_namespace(x)
if zero_center:
return xp.clip(x, -max_value, max_value) ### double check
return xp.clip(x, None, max_value)

return clip_array(x, max_value=max_value, zero_center=zero_center)


Expand Down Expand Up @@ -169,8 +178,15 @@ def scale_array[A: _Array](
logg.info( # Be careful of what? This should be more specific
"... be careful when using `max_value` without `zero_center`."
)
from .._compat import get_namespace, is_array_api

if np.issubdtype(x.dtype, np.integer):
if is_array_api(x):
xp = get_namespace(x)
if xp.isdtype(x.dtype, "integral"): ### double check if integral is needed
logg.info("...")
x = xp.astype(x, xp.float64)

elif np.issubdtype(x.dtype, np.integer):
logg.info(
"... as scaling leads to float results, integer "
"input is cast to float, returning copy."
Expand All @@ -192,18 +208,28 @@ def scale_array[A: _Array](
max_value=max_value,
return_mean_std=return_mean_std,
)
from .._compat import get_namespace, is_array_api

mean, var = mean_var(x, axis=0, correction=1)
std = np.sqrt(var)
std[std == 0] = 1
if zero_center:
if isinstance(x, CSBase) or (
isinstance(x, DaskArray) and isinstance(x._meta, CSBase)
):
msg = "zero-centering a sparse array/matrix densifies it."
warn(msg, UserWarning)
x -= mean
x = dematrix(x)

if is_array_api(x):
xp = get_namespace(x)
std = xp.sqrt(var)
std = xp.where(std == 0, xp.ones_like(std), std)

if zero_center:
x = x - mean ### double check this formula
else:
std = np.sqrt(var)
std[std == 0] = 1
if zero_center:
if isinstance(x, CSBase) or (
isinstance(x, DaskArray) and isinstance(x._meta, CSBase)
):
msg = "zero-centering a sparse array/matrix densifies it."
warn(msg, UserWarning)
x -= mean
x = dematrix(x)

x = axis_mul_or_truediv(
x,
Expand Down
8 changes: 8 additions & 0 deletions src/scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,9 +349,17 @@ def log1p(
Returns or updates `data`, depending on `copy`.

"""
from .._compat import get_namespace, is_array_api

check_array_function_arguments(
chunked=chunked, chunk_size=chunk_size, layer=layer, obsm=obsm
)
if is_array_api(data):
xp = get_namespace(data)
result = xp.log1p(data)
if base is not None:
result = result / float(np.log(base))
return result
return log1p_array(data, copy=copy, base=base)


Expand Down
5 changes: 5 additions & 0 deletions src/scanpy/tools/_rank_genes_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,11 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915
"""
from scanpy import settings

from .._compat import is_array_api

# rank_genes_groups uses numba kernels internally, so need convert at entry.
if is_array_api(adata.X): ### double check
adata.X = np.asarray(adata.X)
if isinstance(mask_var, Default):
mask_var = settings.preset.rank_genes_groups.mask_var
if isinstance(mean_in_log_space, Default):
Expand Down
Loading