From a8fc77e74caf914c01651c184d6b264e632cc280 Mon Sep 17 00:00:00 2001 From: amalia-k510 Date: Thu, 25 Jun 2026 10:14:44 +0200 Subject: [PATCH 1/2] first try at array-api integration --- src/scanpy/_compat.py | 22 +++++++++ src/scanpy/_utils/__init__.py | 29 +++++++++++ src/scanpy/metrics/_common.py | 9 ++-- src/scanpy/neighbors/__init__.py | 5 ++ .../preprocessing/_highly_variable_genes.py | 10 +++- src/scanpy/preprocessing/_normalization.py | 7 +++ src/scanpy/preprocessing/_scale.py | 48 ++++++++++++++----- src/scanpy/preprocessing/_simple.py | 8 ++++ src/scanpy/tools/_rank_genes_groups.py | 5 ++ 9 files changed, 128 insertions(+), 15 deletions(-) diff --git a/src/scanpy/_compat.py b/src/scanpy/_compat.py index 3027a81107..575877b9f3 100644 --- a/src/scanpy/_compat.py +++ b/src/scanpy/_compat.py @@ -22,6 +22,8 @@ "DaskArray", "SpBase", "fullname", + "get_namespace", + "is_array_api", "pkg_metadata", "pkg_version", "set_module", @@ -66,6 +68,26 @@ def pkg_metadata(package: str) -> PackageMetadata: return metadata(package) +def is_array_api(x: object) -> bool: + # returns true if x is array api compatible + # exclusing the ones that are already handled by the script + from array_api_compat import is_array_api_obj + + # excluding packages that are handled by both array-api-compat and script + if isinstance(x, DaskArray): + return False + if isinstance(x, DaskArray): + return False + return is_array_api_obj(x) + + +def get_namespace(x): + # get array-api namespace for x + from array_api_compat import get_namespace + + return get_namespace(x) + + @cache def pkg_version(package: str) -> Version: from importlib.metadata import version diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index f0d67b550d..d61c150461 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -602,8 +602,22 @@ def axis_mul_or_truediv( allow_divide_by_zero: bool = True, out: ArrayLike | None = None, ) -> np.ndarray: + from .._compat import get_namespace, is_array_api + _check_op(op) scaling_array = _broadcast_axis(scaling_array, axis) + # array api version + if is_array_api(x): ### double check if numpy skips this + xp = get_namespace(x) + scaling_array = xp.asarray(scaling_array) + if op is mul: + return x * scaling_array + if not allow_divide_by_zero: + scaling_array = xp.where( + scaling_array == 0, xp.ones_like(scaling_array), scaling_array + ) + return x / scaling_array + # numpy version if op is mul: return np.multiply(x, scaling_array, out=out) if not allow_divide_by_zero: @@ -725,6 +739,12 @@ def _[T: (DaskArray, np.ndarray)]( @singledispatch def axis_nnz(x: ArrayLike, /, axis: Literal[0, 1]) -> np.ndarray: + from .._compat import get_namespace, is_array_api + + if is_array_api(x): + xp = get_namespace(x) + return xp.count_nonzero(x, axis=axis) + return np.count_nonzero(x, axis=axis) @@ -758,6 +778,15 @@ def _(x: DaskArray, /, axis: Literal[0, 1]) -> DaskArray: @singledispatch def check_nonnegative_integers(x: _SupportedArray, /) -> bool | DaskArray: """Check values of X to ensure it is count data.""" + from .._compat import get_namespace, is_array_api + + if is_array_api(x): + xp = get_namespace(x) + if bool(xp.any(x < 0)): + return False + if xp.isdtype(x.dtype, "integral"): + return True + return not bool(xp.any((x % 1) != 0)) ### double check raise NotImplementedError diff --git a/src/scanpy/metrics/_common.py b/src/scanpy/metrics/_common.py index c9c90e1dc8..87c2cbd45d 100644 --- a/src/scanpy/metrics/_common.py +++ b/src/scanpy/metrics/_common.py @@ -12,8 +12,6 @@ from .._utils import NeighborsView if TYPE_CHECKING: - from typing import NoReturn - from anndata import AnnData from numpy.typing import NDArray @@ -93,7 +91,12 @@ def _resolve_vals(val: pd.DataFrame | pd.Series) -> NDArray: ... @singledispatch -def _resolve_vals(val: object) -> NoReturn: +def _resolve_vals(val: object): ### double check + from .._compat import is_array_api + + if is_array_api(val): + # Moran's I / Geary's C use numba kernels, so need to convert at boundary + return np.asarray(val) msg = f"Unsupported type {type(val)}" raise TypeError(msg) diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 7bc2470df3..2a4aafbf41 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -641,7 +641,12 @@ def compute_neighbors( self._rp_forest = None self.n_neighbors = n_neighbors self.knn = knn + from .._compat import is_array_api + x = _choose_representation(self._adata, use_rep=use_rep, n_pcs=n_pcs) + if is_array_api(x): + # sklearn transformers require numpy, so need to convert at boundary + x = np.asarray(x) self._distances = transformer.fit_transform(x) knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix( self._distances, n_neighbors diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index d5f3d2cc79..788f4c3e5f 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -363,13 +363,21 @@ def _highly_variable_genes_single_batch( if n_removed: x = x[:, filt].copy() + from .._compat import get_namespace, is_array_api ### double check + if flavor == "seurat": x = x.copy() if (base := adata.uns.get("log1p", {}).get("base")) is not None: - x *= np.log(base) + if is_array_api(x): + x = x * float(np.log(base)) + else: + x *= np.log(base) # use out if possible. only possible since we copy the data matrix if isinstance(x, np.ndarray): np.expm1(x, out=x) + elif is_array_api(x): + xp = get_namespace(x) + x = xp.expm1(x) else: x = np.expm1(x) diff --git a/src/scanpy/preprocessing/_normalization.py b/src/scanpy/preprocessing/_normalization.py index 79eb5cf0d1..4836e231de 100644 --- a/src/scanpy/preprocessing/_normalization.py +++ b/src/scanpy/preprocessing/_normalization.py @@ -19,8 +19,15 @@ def _compute_nnz_median(counts: np.ndarray | DaskArray) -> np.floating: """Given a 1D array of counts, compute the median of the non-zero counts.""" + from .._compat import is_array_api + if isinstance(counts, DaskArray): counts = counts.compute() + + if is_array_api(counts): + # there is no xp.median? ### double check + counts = np.asarray(counts) + counts_greater_than_zero = counts[counts > 0] median = np.median(counts_greater_than_zero) return median diff --git a/src/scanpy/preprocessing/_scale.py b/src/scanpy/preprocessing/_scale.py index 1a1a047544..6bb09d7c3b 100644 --- a/src/scanpy/preprocessing/_scale.py +++ b/src/scanpy/preprocessing/_scale.py @@ -32,6 +32,15 @@ def clip[A: _Array]( x: ArrayLike | A, *, max_value: float, zero_center: bool = True ) -> A: + # clip_array cannot trace JAX arrays = example + from .._compat import get_namespace, is_array_api + + if is_array_api(x): + xp = get_namespace(x) + if zero_center: + return xp.clip(x, -max_value, max_value) ### double check + return xp.clip(x, None, max_value) + return clip_array(x, max_value=max_value, zero_center=zero_center) @@ -169,8 +178,15 @@ def scale_array[A: _Array]( logg.info( # Be careful of what? This should be more specific "... be careful when using `max_value` without `zero_center`." ) + from .._compat import get_namespace, is_array_api - if np.issubdtype(x.dtype, np.integer): + if is_array_api(x): + xp = get_namespace(x) + if xp.isdtype(x.dtype, "integral"): ### double check if integral is needed + logg.info("...") + x = xp.astype(x, xp.float64) + + elif np.issubdtype(x.dtype, np.integer): logg.info( "... as scaling leads to float results, integer " "input is cast to float, returning copy." @@ -192,18 +208,28 @@ def scale_array[A: _Array]( max_value=max_value, return_mean_std=return_mean_std, ) + from .._compat import get_namespace, is_array_api mean, var = mean_var(x, axis=0, correction=1) - std = np.sqrt(var) - std[std == 0] = 1 - if zero_center: - if isinstance(x, CSBase) or ( - isinstance(x, DaskArray) and isinstance(x._meta, CSBase) - ): - msg = "zero-centering a sparse array/matrix densifies it." - warn(msg, UserWarning) - x -= mean - x = dematrix(x) + + if is_array_api(x): + xp = get_namespace(x) + std = xp.sqrt(var) + std = xp.where(std == 0, xp.ones_like(std), std) + + if zero_center: + x = x - mean ### double check this formula + else: + std = np.sqrt(var) + std[std == 0] = 1 + if zero_center: + if isinstance(x, CSBase) or ( + isinstance(x, DaskArray) and isinstance(x._meta, CSBase) + ): + msg = "zero-centering a sparse array/matrix densifies it." + warn(msg, UserWarning) + x -= mean + x = dematrix(x) x = axis_mul_or_truediv( x, diff --git a/src/scanpy/preprocessing/_simple.py b/src/scanpy/preprocessing/_simple.py index fb325dec35..ee42f54267 100644 --- a/src/scanpy/preprocessing/_simple.py +++ b/src/scanpy/preprocessing/_simple.py @@ -349,9 +349,17 @@ def log1p( Returns or updates `data`, depending on `copy`. """ + from .._compat import get_namespace, is_array_api + check_array_function_arguments( chunked=chunked, chunk_size=chunk_size, layer=layer, obsm=obsm ) + if is_array_api(data): + xp = get_namespace(data) + result = xp.log1p(data) + if base is not None: + result = result / float(np.log(base)) + return result return log1p_array(data, copy=copy, base=base) diff --git a/src/scanpy/tools/_rank_genes_groups.py b/src/scanpy/tools/_rank_genes_groups.py index eb32fb4bdb..c91564cb3b 100644 --- a/src/scanpy/tools/_rank_genes_groups.py +++ b/src/scanpy/tools/_rank_genes_groups.py @@ -645,6 +645,11 @@ def rank_genes_groups( # noqa: PLR0912, PLR0913, PLR0915 """ from scanpy import settings + from .._compat import is_array_api + + # rank_genes_groups uses numba kernels internally, so need convert at entry. + if is_array_api(adata.X): ### double check + adata.X = np.asarray(adata.X) if isinstance(mask_var, Default): mask_var = settings.preset.rank_genes_groups.mask_var if isinstance(mean_in_log_space, Default): From 790ddde6b45bee06b3e8cded1ef2398517ba0bb8 Mon Sep 17 00:00:00 2001 From: amalia-k510 Date: Thu, 25 Jun 2026 10:17:50 +0200 Subject: [PATCH 2/2] typo + forgot numpy exclusion --- src/scanpy/_compat.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/scanpy/_compat.py b/src/scanpy/_compat.py index 575877b9f3..c1b2bcece4 100644 --- a/src/scanpy/_compat.py +++ b/src/scanpy/_compat.py @@ -71,13 +71,16 @@ def pkg_metadata(package: str) -> PackageMetadata: def is_array_api(x: object) -> bool: # returns true if x is array api compatible # exclusing the ones that are already handled by the script + import numpy as np from array_api_compat import is_array_api_obj # excluding packages that are handled by both array-api-compat and script - if isinstance(x, DaskArray): + if isinstance(x, np.ndarray): return False if isinstance(x, DaskArray): return False + if isinstance(x, SpBase): + return False return is_array_api_obj(x)