From a625c5594e13ff368fb857f40cdd14920096c0d9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 14:44:35 +0100 Subject: [PATCH 01/48] perf: "two-pass" seurat hvg3 via `scanpy.get.aggregate` --- .../preprocessing/_highly_variable_genes.py | 76 +++++++++++++------ 1 file changed, 52 insertions(+), 24 deletions(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 10dbec1117..7224e87ebd 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -20,7 +20,7 @@ raise_if_dask_feature_axis_chunked, sanitize_anndata, ) -from ..get import _get_obs_rep +from ..get import _get_obs_rep, aggregate from ._distributed import materialize_as_ndarray from ._simple import filter_genes @@ -36,7 +36,7 @@ @singledispatch def clip_square_sum( data_batch: np.ndarray, clip_val: np.ndarray -) -> tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray, np.ndarray] | tuple[DaskArray, DaskArray]: """Clip data_batch by clip_val. Parameters @@ -64,24 +64,19 @@ def clip_square_sum( @clip_square_sum.register(DaskArray) -def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]: +def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[DaskArray, DaskArray]: n_blocks = data_batch.blocks.size def sum_and_sum_squares_clipped_from_block(block): return np.vstack(clip_square_sum(block, clip_val))[None, ...] - squared_batch_counts_sum, batch_counts_sum = ( - data_batch - .map_blocks( - sum_and_sum_squares_clipped_from_block, - new_axis=(1,), - chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), - meta=np.array([]), - dtype=np.float64, - ) - .sum(axis=0) - .compute() - ) + squared_batch_counts_sum, batch_counts_sum = data_batch.map_blocks( + sum_and_sum_squares_clipped_from_block, + new_axis=(1,), + chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)), + meta=np.array([]), + dtype=np.float64, + ).sum(axis=0) return squared_batch_counts_sum, batch_counts_sum @@ -172,17 +167,43 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 batch_info = ( pd.Categorical(np.zeros(adata.shape[0], dtype=int)) if batch_key is None - else adata.obs[batch_key].to_numpy() + else adata.obs[batch_key] ) - norm_gene_vars = [] - for b in np.unique(batch_info): - data_batch = data[batch_info == b] - mean, var = stats.mean_var(data_batch, axis=0, correction=1) - # These get computed anyway for loess - if isinstance(mean, DaskArray): - mean, var = mean.compute(), var.compute() + if can_aggregate := (inplace or not adata.is_view): + adata.obs["__hvg_v3_batch_info__"] = batch_info + aggregated_mean_var = aggregate( + adata, by="__hvg_v3_batch_info__", func=["mean", "var"] + ) + mean_global, var_global = ( + aggregated_mean_var.layers[l] for l in ["mean", "var"] + ) + if isinstance(mean_global, DaskArray): + mean_global, var_global = mean_global.compute(), var_global.compute() + aggregated_mean_var.layers["mean"] = mean_global + aggregated_mean_var.layers["var"] = var_global + batch_info = batch_info.to_numpy() + for b in batch_info: + data_batch = data[batch_info == b] + if can_aggregate: + mean, var = ( + aggregated_mean_var[ + aggregated_mean_var.obs["__hvg_v3_batch_info__"] == b + ].layers[l] + for l in ["mean", "var"] + ) + if isinstance(mean, CSBase): + mean = mean.toarray() + mean = mean.ravel() + if isinstance(var, CSBase): + var = var.toarray() + var = var.ravel() + else: + mean, var = stats.mean_var(data_batch, axis=0, correction=1) + # These get computed anyway for loess + if isinstance(mean, DaskArray): + mean, var = mean.compute(), var.compute() estimat_var = np.zeros(data.shape[1], dtype=np.float64) if (not_const := var > 0).any(): y = np.log10(var[not_const]) @@ -204,8 +225,15 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 + squared_batch_counts_sum - 2 * batch_counts_sum * mean ) - norm_gene_vars.append(norm_gene_var.reshape(1, -1)) + norm_gene_vars.append(norm_gene_var) + if can_aggregate: + del adata.obs["__hvg_v3_batch_info__"] + if any(isinstance(e, DaskArray) for e in norm_gene_vars): + import dask.array as da + + norm_gene_vars = da.compute(*norm_gene_vars) + norm_gene_vars = [ngv.reshape(1, -1) for ngv in norm_gene_vars] norm_gene_vars = np.concatenate(norm_gene_vars, axis=0) # argsort twice gives ranks, small rank means most variable ranked_norm_gene_vars = np.argsort(np.argsort(-norm_gene_vars, axis=1), axis=1) From d839e98dd66733f63c647cea747c5fdbf5ac4e0a Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 14:48:51 +0100 Subject: [PATCH 02/48] chore: hvg v3 benchmark --- benchmarks/asv.conf.json | 2 +- benchmarks/benchmarks/preprocessing_log.py | 37 +++++++++++++++------- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..2921d82f40 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -83,7 +83,7 @@ // "psutil": [""] "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 - // "scikit-misc": [""], + "scikit-misc": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 9633c8e208..84a356548c 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -47,17 +47,6 @@ def time_pca(self, *_) -> None: def peakmem_pca(self, *_) -> None: sc.pp.pca(self.adata, svd_solver="arpack") - def time_highly_variable_genes(self, *_) -> None: - # the default flavor runs on log-transformed data - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - - def peakmem_highly_variable_genes(self, *_) -> None: - sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5 - ) - # regress_out is very slow for this dataset @skip_when(dataset={"pbmc3k"}) def time_regress_out(self, *_) -> None: @@ -72,3 +61,29 @@ def time_scale(self, *_) -> None: def peakmem_scale(self, *_) -> None: sc.pp.scale(self.adata, max_value=10) + + +class HVGSuite: # noqa: D101 + params = (["seurat_v3", "cell_ranger", "seurat"],) + param_names = ("flavor",) + + def setup_cache(self) -> None: + """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" + adata, _ = get_dataset("pbmc3k") + adata.write_h5ad("pbmc3k.h5ad") + + def setup(self, flavor) -> None: + self.adata = ad.read_h5ad("pbmc3k.h5ad") + sc.pp.filter_genes(self.adata, min_cells=3) + self.flavor = flavor + + def time_highly_variable_genes(self, *_) -> None: + # the default flavor runs on log-transformed data + sc.pp.highly_variable_genes( + self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + ) + + def peakmem_highly_variable_genes(self, *_) -> None: + sc.pp.highly_variable_genes( + self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + ) From 86db4990934de4e1b1751512e05c6faca159b376 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 14:56:52 +0100 Subject: [PATCH 03/48] fix: use counts --- benchmarks/benchmarks/preprocessing_log.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 84a356548c..cea301dc1d 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -80,10 +80,20 @@ def setup(self, flavor) -> None: def time_highly_variable_genes(self, *_) -> None: # the default flavor runs on log-transformed data sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + self.adata, + min_mean=0.0125, + max_mean=3, + min_disp=0.5, + flavor=self.flavor, + **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), ) def peakmem_highly_variable_genes(self, *_) -> None: sc.pp.highly_variable_genes( - self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5, flavor=self.flavor + self.adata, + min_mean=0.0125, + max_mean=3, + min_disp=0.5, + flavor=self.flavor, + **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), ) From d5a6a7833f738ab1500f31a0d601eba51991487f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 15:07:56 +0100 Subject: [PATCH 04/48] fix: use a batch key --- benchmarks/benchmarks/preprocessing_log.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index cea301dc1d..3ffc5560c1 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -69,11 +69,11 @@ class HVGSuite: # noqa: D101 def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" - adata, _ = get_dataset("pbmc3k") - adata.write_h5ad("pbmc3k.h5ad") + adata, _ = get_dataset("lung93k") + adata.write_h5ad("lung93k.h5ad") def setup(self, flavor) -> None: - self.adata = ad.read_h5ad("pbmc3k.h5ad") + self.adata = ad.read_h5ad("lung93k.h5ad") sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor @@ -85,7 +85,8 @@ def time_highly_variable_genes(self, *_) -> None: max_mean=3, min_disp=0.5, flavor=self.flavor, - **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), + batch_key="PatientNumber", + **({"layer": "counts"} if self.flavor == "seurat_v3" else {}), ) def peakmem_highly_variable_genes(self, *_) -> None: @@ -95,5 +96,6 @@ def peakmem_highly_variable_genes(self, *_) -> None: max_mean=3, min_disp=0.5, flavor=self.flavor, + batch_key="PatientNumber", **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), ) From fdc5653bb78715b8ba085a43c51ef102e9e66a4f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 26 Mar 2026 15:12:32 +0100 Subject: [PATCH 05/48] fix: not again --- benchmarks/benchmarks/preprocessing_log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 3ffc5560c1..6f597e7a66 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -97,5 +97,5 @@ def peakmem_highly_variable_genes(self, *_) -> None: min_disp=0.5, flavor=self.flavor, batch_key="PatientNumber", - **({"layer": "counts"} if self.self.flavor == "seurat_v3" else {}), + **({"layer": "counts"} if self.flavor == "seurat_v3" else {}), ) From 8f0e426cc95c6168c3c4704c36c6012580ee8dc5 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 8 Apr 2026 15:54:29 +0200 Subject: [PATCH 06/48] fix: `compute` single pass! --- src/scanpy/preprocessing/_highly_variable_genes.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 7224e87ebd..8a8ef4cee4 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -180,7 +180,9 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 aggregated_mean_var.layers[l] for l in ["mean", "var"] ) if isinstance(mean_global, DaskArray): - mean_global, var_global = mean_global.compute(), var_global.compute() + import dask.array as da + + mean_global, var_global = da.compute(mean_global, var_global) aggregated_mean_var.layers["mean"] = mean_global aggregated_mean_var.layers["var"] = var_global batch_info = batch_info.to_numpy() From 7e0390ee10fe2c38d382b97ad2ff0bf8e6280e6b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 9 Apr 2026 09:31:03 +0200 Subject: [PATCH 07/48] fix: unique --- src/scanpy/preprocessing/_highly_variable_genes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 10d59e1af9..9d06ef11dc 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -186,7 +186,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 aggregated_mean_var.layers["mean"] = mean_global aggregated_mean_var.layers["var"] = var_global batch_info = batch_info.to_numpy() - for b in batch_info: + for b in np.unique(batch_info): data_batch = data[batch_info == b] if can_aggregate: mean, var = ( From 96c16e91841bb80cda49564c5927a77a8351293c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 12:04:57 +0200 Subject: [PATCH 08/48] chore: add new `dask` benchmark --- benchmarks/benchmarks/preprocessing_log.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 6f597e7a66..3f10728cee 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -9,12 +9,15 @@ from typing import TYPE_CHECKING import anndata as ad +import numpy as np import scanpy as sc from ._utils import get_dataset, param_skipper if TYPE_CHECKING: + from typing import Literal + from ._utils import Dataset, KeyX @@ -64,16 +67,25 @@ def peakmem_scale(self, *_) -> None: class HVGSuite: # noqa: D101 - params = (["seurat_v3", "cell_ranger", "seurat"],) - param_names = ("flavor",) + params = (["seurat_v3", "cell_ranger", "seurat"], [True, False]) + param_names = ("flavor", "use_dask") def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" adata, _ = get_dataset("lung93k") adata.write_h5ad("lung93k.h5ad") - - def setup(self, flavor) -> None: - self.adata = ad.read_h5ad("lung93k.h5ad") + obs = np.arange(adata.shape[0]) + np.random.default_rng().shuffle(obs) + adata[obs].write_h5ad("lung93k_shuffled.h5ad") + + def setup( + self, + flavor: Literal["seurat_v3", "cell_ranger", "seurat"], + use_dask: bool, # noqa: FBT001 + ) -> None: + self.adata = ad.read_h5ad( + "lung93k_shuffled.h5ad" if use_dask else "lung93k.h5ad" + ) sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor From 478af4af39b46fb2e94ee7e9a0fbf10900681dc8 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 13:08:31 +0200 Subject: [PATCH 09/48] fix: actually use dask lol --- benchmarks/benchmarks/preprocessing_log.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 3f10728cee..e36aecfbdc 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -83,9 +83,14 @@ def setup( flavor: Literal["seurat_v3", "cell_ranger", "seurat"], use_dask: bool, # noqa: FBT001 ) -> None: - self.adata = ad.read_h5ad( - "lung93k_shuffled.h5ad" if use_dask else "lung93k.h5ad" - ) + if use_dask: + self.adata = ad.experimental.read_lazy("lung93k_shuffled.h5ad") + self.adata.obs = self.adata.obs.to_memory() + self.adata.var = self.adata.var.to_memory() + else: + self.adata = ad.read_h5ad( + "lung93k_shuffled.h5ad" if use_dask else "lung93k.h5ad" + ) sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor From 54db31b20ad894c7d7cab8d4585b3b3d9cac6bfa Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 13:46:24 +0200 Subject: [PATCH 10/48] chore: really do dask --- benchmarks/asv.conf.json | 1 + benchmarks/benchmarks/preprocessing_log.py | 21 ++++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index 2921d82f40..655b33d9a8 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -84,6 +84,7 @@ "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 "scikit-misc": [""], + "dask": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index e36aecfbdc..4a13febb12 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -10,6 +10,7 @@ import anndata as ad import numpy as np +import zarr import scanpy as sc @@ -73,10 +74,10 @@ class HVGSuite: # noqa: D101 def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" adata, _ = get_dataset("lung93k") - adata.write_h5ad("lung93k.h5ad") + adata.write_zarr("lung93k.zarr") obs = np.arange(adata.shape[0]) np.random.default_rng().shuffle(obs) - adata[obs].write_h5ad("lung93k_shuffled.h5ad") + adata[obs].write_zarr("lung93k_shuffled.zarr") def setup( self, @@ -84,12 +85,18 @@ def setup( use_dask: bool, # noqa: FBT001 ) -> None: if use_dask: - self.adata = ad.experimental.read_lazy("lung93k_shuffled.h5ad") - self.adata.obs = self.adata.obs.to_memory() - self.adata.var = self.adata.var.to_memory() + z = zarr.open("lung93k_shuffled.zarr") + self.adata = ad.AnnData( + obs=ad.io.read_elem(z["obs"]), + var=ad.io.read_elem(z["var"]), + layers={ + "counts": ad.experimental.read_elem_lazy(z["layers"]["counts"]) + }, + X=ad.experimental.read_elem_lazy(z["X"]), + ) else: - self.adata = ad.read_h5ad( - "lung93k_shuffled.h5ad" if use_dask else "lung93k.h5ad" + self.adata = ad.read_zarr( + "lung93k_shuffled.zarr" if use_dask else "lung93k.zarr" ) sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor From 4fe84c524380b15670bdcab8c7501d113beba9cc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 14:13:20 +0200 Subject: [PATCH 11/48] fix: layers support --- src/scanpy/preprocessing/_highly_variable_genes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index d79e1c2f43..375e2cdacf 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -174,7 +174,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 if can_aggregate := (inplace or not adata.is_view): adata.obs["__hvg_v3_batch_info__"] = batch_info aggregated_mean_var = aggregate( - adata, by="__hvg_v3_batch_info__", func=["mean", "var"] + adata, by="__hvg_v3_batch_info__", func=["mean", "var"], layer=layer ) mean_global, var_global = ( aggregated_mean_var.layers[l] for l in ["mean", "var"] From 35590a4a90e7dfab222c937b211238975b1519cc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 15:17:13 +0200 Subject: [PATCH 12/48] fix: no view check needed --- .../preprocessing/_highly_variable_genes.py | 61 +++++++++---------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index 375e2cdacf..fc156e5ed3 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -171,41 +171,38 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 ) norm_gene_vars = [] - if can_aggregate := (inplace or not adata.is_view): - adata.obs["__hvg_v3_batch_info__"] = batch_info - aggregated_mean_var = aggregate( - adata, by="__hvg_v3_batch_info__", func=["mean", "var"], layer=layer - ) - mean_global, var_global = ( - aggregated_mean_var.layers[l] for l in ["mean", "var"] - ) - if isinstance(mean_global, DaskArray): - import dask.array as da + adata_agg = AnnData( + X=data, + var=pd.DataFrame(index=adata.var_names), + obs=pd.DataFrame( + index=adata.obs_names, data={"__hvg_v3_batch_info__": batch_info} + ), + ) + aggregated_mean_var = aggregate( + adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"], layer=layer + ) + mean_global, var_global = (aggregated_mean_var.layers[l] for l in ["mean", "var"]) + if isinstance(mean_global, DaskArray): + import dask.array as da - mean_global, var_global = da.compute(mean_global, var_global) - aggregated_mean_var.layers["mean"] = mean_global - aggregated_mean_var.layers["var"] = var_global + mean_global, var_global = da.compute(mean_global, var_global) + aggregated_mean_var.layers["mean"] = mean_global + aggregated_mean_var.layers["var"] = var_global batch_info = batch_info.to_numpy() for b in np.unique(batch_info): data_batch = data[batch_info == b] - if can_aggregate: - mean, var = ( - aggregated_mean_var[ - aggregated_mean_var.obs["__hvg_v3_batch_info__"] == b - ].layers[l] - for l in ["mean", "var"] - ) - if isinstance(mean, CSBase): - mean = mean.toarray() - mean = mean.ravel() - if isinstance(var, CSBase): - var = var.toarray() - var = var.ravel() - else: - mean, var = stats.mean_var(data_batch, axis=0, correction=1) - # These get computed anyway for loess - if isinstance(mean, DaskArray): - mean, var = mean.compute(), var.compute() + mean, var = ( + aggregated_mean_var[ + aggregated_mean_var.obs["__hvg_v3_batch_info__"] == b + ].layers[l] + for l in ["mean", "var"] + ) + if isinstance(mean, CSBase): + mean = mean.toarray() + mean = mean.ravel() + if isinstance(var, CSBase): + var = var.toarray() + var = var.ravel() estimat_var = np.zeros(data.shape[1], dtype=np.float64) if (not_const := var > 0).any(): y = np.log10(var[not_const]) @@ -228,8 +225,6 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 - 2 * batch_counts_sum * mean ) norm_gene_vars.append(norm_gene_var) - if can_aggregate: - del adata.obs["__hvg_v3_batch_info__"] if any(isinstance(e, DaskArray) for e in norm_gene_vars): import dask.array as da From db81d6eb60e7172743728aa613ea6fb1a75687cc Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 4 May 2026 16:23:22 +0200 Subject: [PATCH 13/48] fix: no layers eeded --- benchmarks/benchmarks/preprocessing_log.py | 4 +--- src/scanpy/preprocessing/_highly_variable_genes.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 4a13febb12..6d6f6eef55 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -95,9 +95,7 @@ def setup( X=ad.experimental.read_elem_lazy(z["X"]), ) else: - self.adata = ad.read_zarr( - "lung93k_shuffled.zarr" if use_dask else "lung93k.zarr" - ) + self.adata = ad.read_zarr("lung93k.zarr") sc.pp.filter_genes(self.adata, min_cells=3) self.flavor = flavor diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index fc156e5ed3..4f7c38ccf2 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -179,7 +179,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 ), ) aggregated_mean_var = aggregate( - adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"], layer=layer + adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"] ) mean_global, var_global = (aggregated_mean_var.layers[l] for l in ["mean", "var"]) if isinstance(mean_global, DaskArray): From b37444e388367ea482f0f30b3d2a24c84f71d3e7 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 5 May 2026 12:06:32 +0200 Subject: [PATCH 14/48] fix: reduce number of batches --- benchmarks/benchmarks/preprocessing_log.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 6d6f6eef55..5ea2fa58fd 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -94,6 +94,10 @@ def setup( }, X=ad.experimental.read_elem_lazy(z["X"]), ) + # Times out on the benchmark machine with full dataset + self.adata = self.adata[ + self.adata.obs["PatientNumber"].isin(["1", "2"]) + ].copy() else: self.adata = ad.read_zarr("lung93k.zarr") sc.pp.filter_genes(self.adata, min_cells=3) From cf65665a6591de5158c94013474ac051358f065b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 5 May 2026 14:53:15 +0200 Subject: [PATCH 15/48] fix: a little bit more --- benchmarks/benchmarks/preprocessing_log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 5ea2fa58fd..70e658ffbf 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -96,7 +96,7 @@ def setup( ) # Times out on the benchmark machine with full dataset self.adata = self.adata[ - self.adata.obs["PatientNumber"].isin(["1", "2"]) + self.adata.obs["PatientNumber"].isin(["1", "2", "3"]) ].copy() else: self.adata = ad.read_zarr("lung93k.zarr") From e62493948004da5214c5b35ab8fa995c52d838bb Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 14:02:32 +0200 Subject: [PATCH 16/48] perf: chan's parallel mean-var algorithm for dask --- benchmarks/asv.conf.json | 1 + benchmarks/benchmarks/preprocessing_counts.py | 34 ++++- src/scanpy/get/_aggregated.py | 134 ++++++++++++++++-- 3 files changed, 151 insertions(+), 18 deletions(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..a1a8d31a42 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -84,6 +84,7 @@ "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 // "scikit-misc": [""], + "dask": [""], }, // Combinations of libraries/python versions can be excluded/included diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 9a20e7eda3..672a7df5fc 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING import anndata as ad +import zarr import scanpy as sc from scanpy._utils import get_literal_vals @@ -18,6 +19,7 @@ from ._utils import get_count_dataset, get_dataset if TYPE_CHECKING: + from collections.abc import KeysView from typing import Any from ._utils import Dataset, KeyCount @@ -151,17 +153,35 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[AggType] = tuple(get_literal_vals(AggType)) - param_names = ("agg_name",) + params: tuple[KeysView[AggType], tuple[bool]] = ( + get_literal_vals(AggType), + (True, False), + ) + param_names = ("agg_name", "use_dask") def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" adata, _ = get_dataset("lung93k") - adata.write_h5ad("lung93k.h5ad") - - def setup(self, agg_name: AggType) -> None: - self.adata = ad.read_h5ad("lung93k.h5ad") - self.agg_name = agg_name + adata.write_zarr("lung93k.zarr") + + def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 + if use_dask: + z = zarr.open("lung93k_shuffled.zarr") + self.adata = ad.AnnData( + obs=ad.io.read_elem(z["obs"]), + var=ad.io.read_elem(z["var"]), + layers={ + "counts": ad.experimental.read_elem_lazy(z["layers"]["counts"]) + }, + X=ad.experimental.read_elem_lazy(z["X"]), + ) + # Times out on the benchmark machine with full dataset + self.adata = self.adata[ + self.adata.obs["PatientNumber"].isin(["1", "2", "3"]) + ].copy() + else: + self.adata = ad.read_zarr("lung93k.zarr") + self.agg_name: AggType = agg_name def time_agg(self, *_) -> None: sc.get.aggregate( diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..66f251c316 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -6,7 +6,6 @@ import numpy as np import pandas as pd from anndata import AnnData -from fast_array_utils.stats._power import power as fau_power # TODO: upstream from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 @@ -371,16 +370,129 @@ def aggregate_dask_mean_var( mask: NDArray[np.bool] | None = None, dof: int = 1, ) -> MeanVarDict: - mean = aggregate_dask(data, by, "mean", mask=mask, dof=dof)["mean"] - sq_mean = aggregate_dask(fau_power(data, 2), by, "mean", mask=mask, dof=dof)["mean"] - # TODO: If we don't compute here, the results are not deterministic under the process cluster for sparse. - if isinstance(data._meta, CSRBase): - sq_mean = sq_mean.compute() - var = sq_mean - fau_power(mean, 2) - if dof != 0: - group_counts = np.bincount(by.codes) - var *= (group_counts / (group_counts - dof))[:, np.newaxis] - return MeanVarDict(mean=mean, var=var) + """Compute group-wise mean and variance for a dask array. + + Per chunk we compute ``(count, mean, M2)`` (where ``M2 = sum((x - mean)**2)``), + then combine across chunks with the pairwise parallel algorithm from + Chan et al. (https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm) + so the across-chunk reduction avoids the catastrophic cancellation of + ``E[X**2] - E[X]**2``. + """ + import dask.array as da + + n_categories = len(by.categories) + n_features = data.shape[1] + chunked_axis = 0 if isinstance(data._meta, CSRBase | np.ndarray) else 1 + + if chunked_axis == 1: + # Each block already sees every observation, so mean/var per chunk is final. + def per_block_col(chunk: Array) -> NDArray[np.float64]: + mean_, var_ = Aggregate(groupby=by, data=chunk, mask=mask).mean_var(dof=dof) + return np.concatenate([mean_, var_], axis=0) + + combined = data.map_blocks( + per_block_col, + chunks=((2 * n_categories,), data.chunks[1]), + meta=np.array([], dtype=np.float64), + ) + return MeanVarDict(mean=combined[:n_categories], var=combined[n_categories:]) + + n_blocks = data.numblocks[0] + + def per_block_row( + chunk: Array, block_info: dict | None = None + ) -> NDArray[np.float64]: + row_subset = slice(*block_info[0]["array-location"][0]) + by_sub = by[row_subset] + mask_sub = mask[row_subset] if mask is not None else None + return _block_moments(chunk, by_sub, mask=mask_sub, n_categories=n_categories)[ + None + ] + + per_block_stats = data.map_blocks( + per_block_row, + chunks=((1,) * n_blocks, (3,), (n_categories,), (n_features,)), + new_axis=(1, 2), + meta=np.array([], dtype=np.float64), + ) + + combined = da.reduction( + per_block_stats, + chunk=lambda x, axis=None, keepdims=False: x, + aggregate=_chan_reduce_axis_0, + axis=0, + keepdims=False, + concatenate=True, + dtype=np.float64, + meta=np.array([], dtype=np.float64), + ) + counts = combined[0] + mean_ = combined[1] + m2 = combined[2] + denom = counts - dof if dof > 0 else counts + return MeanVarDict(mean=mean_, var=m2 / denom) + + +def _block_moments( + data: np.ndarray | CSBase, + by: pd.Categorical, + *, + mask: NDArray[np.bool] | None, + n_categories: int, +) -> NDArray[np.float64]: + """Per-chunk ``(count, mean, M2)`` array of shape ``(3, n_categories, n_features)``. + + Groups with no observations in the chunk get zeros for mean and M2 so + they combine cleanly under ``_chan_combine``. + """ + codes = by.codes + valid = codes >= 0 + if mask is not None: + valid = valid & mask + counts = np.bincount(codes[valid], minlength=n_categories).astype(np.float64) + + out = np.zeros((3, n_categories, data.shape[1]), dtype=np.float64) + out[0] = counts[:, None] + nonempty = counts > 0 + if not nonempty.any(): + return out + + agg = Aggregate(groupby=by, data=data, mask=mask) + sum_ = agg.sum() + sum_sq = agg._sum(_power(data, 2)) + safe_counts = np.where(nonempty, counts, 1)[:, None] + mean_ = sum_ / safe_counts + # M2 = sum((x - mean)**2) = sum_sq - count * mean**2; clip cancellation noise to 0. + m2 = np.maximum(sum_sq - sum_ * mean_, 0) + out[1, nonempty] = mean_[nonempty] + out[2, nonempty] = m2[nonempty] + return out + + +def _chan_combine( + a: NDArray[np.float64], b: NDArray[np.float64] +) -> NDArray[np.float64]: + """Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise.""" + n_a, mean_a, m2_a = a[0], a[1], a[2] + n_b, mean_b, m2_b = b[0], b[1], b[2] + n = n_a + n_b + safe_n = np.where(n > 0, n, 1) + delta = mean_b - mean_a + new_mean = mean_a + delta * n_b / safe_n + new_m2 = m2_a + m2_b + delta * delta * n_a * n_b / safe_n + return np.stack([n, new_mean, new_m2]) + + +def _chan_reduce_axis_0( + stats: NDArray[np.float64], + axis: int | None, + keepdims: bool, # noqa: FBT001 +) -> NDArray[np.float64]: + """Aggregate per-block stats along axis 0 with the parallel variance algorithm.""" + result = stats[0] + for i in range(1, stats.shape[0]): + result = _chan_combine(result, stats[i]) + return result[None] if keepdims else result @_aggregate.register(DaskArray) From 61332fd103d7d904136e701b9f831eca6bbe041f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 15:08:26 +0200 Subject: [PATCH 17/48] fix: params --- benchmarks/benchmarks/preprocessing_counts.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 672a7df5fc..db6c23f0f0 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -153,9 +153,9 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[KeysView[AggType], tuple[bool]] = ( + params: tuple[KeysView[AggType], list[bool]] = ( get_literal_vals(AggType), - (True, False), + [True, False], ) param_names = ("agg_name", "use_dask") From 1df5fdaed65ce603511c7be00cacc35126f61d3b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 15:30:21 +0200 Subject: [PATCH 18/48] fix: iteration --- benchmarks/benchmarks/preprocessing_counts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index db6c23f0f0..1809c6d3bb 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -19,7 +19,6 @@ from ._utils import get_count_dataset, get_dataset if TYPE_CHECKING: - from collections.abc import KeysView from typing import Any from ._utils import Dataset, KeyCount @@ -153,8 +152,8 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[KeysView[AggType], list[bool]] = ( - get_literal_vals(AggType), + params: tuple[list[str], list[bool]] = ( + list(get_literal_vals(AggType)), [True, False], ) param_names = ("agg_name", "use_dask") From 9a705812bce5d556678a7717004380a89454dc54 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 15:50:18 +0200 Subject: [PATCH 19/48] fix: zarr link --- benchmarks/benchmarks/preprocessing_counts.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 1809c6d3bb..a17d96e9c9 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -165,7 +165,7 @@ def setup_cache(self) -> None: def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 if use_dask: - z = zarr.open("lung93k_shuffled.zarr") + z = zarr.open("lung93k.zarr") self.adata = ad.AnnData( obs=ad.io.read_elem(z["obs"]), var=ad.io.read_elem(z["var"]), @@ -174,10 +174,6 @@ def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 }, X=ad.experimental.read_elem_lazy(z["X"]), ) - # Times out on the benchmark machine with full dataset - self.adata = self.adata[ - self.adata.obs["PatientNumber"].isin(["1", "2", "3"]) - ].copy() else: self.adata = ad.read_zarr("lung93k.zarr") self.agg_name: AggType = agg_name From 5313ea24ab162fb994cf2a16b5dfe34151c0b411 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 16:30:59 +0200 Subject: [PATCH 20/48] fix: `median` calculation skipped --- benchmarks/benchmarks/preprocessing_counts.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index a17d96e9c9..bc3bedce63 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -165,6 +165,9 @@ def setup_cache(self) -> None: def setup(self, agg_name: AggType, use_dask: bool) -> None: # noqa: FBT001 if use_dask: + if agg_name == "median": + # Skip this one: https://asv.readthedocs.io/en/stable/writing_benchmarks.html#setup-and-teardown-functions + raise NotImplementedError() z = zarr.open("lung93k.zarr") self.adata = ad.AnnData( obs=ad.io.read_elem(z["obs"]), From e19a7d88d057a81f139e8de688ff71a4c5f90ade Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 16:59:21 +0200 Subject: [PATCH 21/48] fix: no-batch-key accel --- .../preprocessing/_highly_variable_genes.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index f588aa019b..cad3ff470d 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -178,16 +178,29 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 index=adata.obs_names, data={"__hvg_v3_batch_info__": batch_info} ), ) - aggregated_mean_var = aggregate( - adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"] - ) - mean_global, var_global = (aggregated_mean_var.layers[l] for l in ["mean", "var"]) - if isinstance(mean_global, DaskArray): - import dask.array as da + if batch_key is not None: + aggregated_mean_var = aggregate( + adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"] + ) + mean_global, var_global = ( + aggregated_mean_var.layers[l] for l in ["mean", "var"] + ) + if isinstance(mean_global, DaskArray): + import dask.array as da - mean_global, var_global = da.compute(mean_global, var_global) - aggregated_mean_var.layers["mean"] = mean_global - aggregated_mean_var.layers["var"] = var_global + mean_global, var_global = da.compute(mean_global, var_global) + aggregated_mean_var.layers["mean"] = mean_global + aggregated_mean_var.layers["var"] = var_global + else: + aggregated_mean_var = AnnData( + obs=pd.DataFrame( + index=np.array(["one"]), data={"__hvg_v3_batch_info__": np.array([0])} + ), + layers={ + "mean": df["means"].to_numpy().reshape((1, -1)), + "var": df["variances"].to_numpy().reshape((1, -1)), + }, + ) batch_info = batch_info.to_numpy() for b in np.unique(batch_info): data_batch = data[batch_info == b] From 8482561ecbcdcee2b68667ab79acbe52df12e17b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 5 Jun 2026 17:05:40 +0200 Subject: [PATCH 22/48] fix: don't run all benchmarks with dask --- benchmarks/benchmarks/preprocessing_log.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/benchmarks/benchmarks/preprocessing_log.py b/benchmarks/benchmarks/preprocessing_log.py index 70e658ffbf..350bb66883 100644 --- a/benchmarks/benchmarks/preprocessing_log.py +++ b/benchmarks/benchmarks/preprocessing_log.py @@ -85,6 +85,9 @@ def setup( use_dask: bool, # noqa: FBT001 ) -> None: if use_dask: + if flavor != "seurat_v3": + # This benchmark only really makes sense for seurat v3 as that has been optimized. + raise NotImplementedError() z = zarr.open("lung93k_shuffled.zarr") self.adata = ad.AnnData( obs=ad.io.read_elem(z["obs"]), From a2b390b88e83112cb981570afa08200b72687e48 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 8 Jun 2026 13:13:17 +0200 Subject: [PATCH 23/48] chore: relnote --- docs/release-notes/4143.perf.md | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 docs/release-notes/4143.perf.md diff --git a/docs/release-notes/4143.perf.md b/docs/release-notes/4143.perf.md new file mode 100644 index 0000000000..2ffd9c00d6 --- /dev/null +++ b/docs/release-notes/4143.perf.md @@ -0,0 +1,3 @@ +Use Chan's mean-var algorithm for acceleration of dask-backed {func}`scanpy.get.aggregate` {smaller}`I Gold` + +[Chan's mean-var]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm From 21f5ddc79f55f08c26c3b3d34de1508fee214d9b Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 8 Jun 2026 14:09:13 +0200 Subject: [PATCH 24/48] perf: welford's algorithm for mean-var --- src/scanpy/get/_aggregated.py | 27 +++---- src/scanpy/get/_kernels.py | 131 +++++++++++++++++++++++++--------- 2 files changed, 108 insertions(+), 50 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index fe5140171b..3fad260f7d 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -13,7 +13,13 @@ from scanpy._compat import CSBase, CSRBase, DaskArray from .._utils import _resolve_axis, get_literal_vals -from ._kernels import agg_sum_csc, agg_sum_csr, mean_var_csc, mean_var_csr +from ._kernels import ( + agg_sum_csc, + agg_sum_csr, + mean_var_csc, + mean_var_csr, + mean_var_dense, +) from .get import _check_mask if TYPE_CHECKING: @@ -117,11 +123,8 @@ def mean(self) -> Array: def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: """Compute the count, as well as mean and variance per feature, per group of observations. - The formula `Var(X) = E(X^2) - E(X)^2` suffers loss of precision when the variance is a - very small fraction of the squared mean. In particular, when X is constant, the formula may - nonetheless be non-zero. By default, our implementation resets the variance to exactly zero - when the computed variance, relative to the squared mean, nears limit of precision of the - floating-point significand. + Mean and variance are computed with Welford's online algorithm, which is + numerically stable for constant or near-constant inputs. Params ------ @@ -137,21 +140,11 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: group_counts = np.bincount(self.groupby.codes) if isinstance(self.data, np.ndarray): - mean_ = self.mean() - # sparse matrices do not support ** for elementwise power. - mean_sq = self._sum(_power(self.data, 2)) / group_counts[:, None] - sq_mean = mean_**2 - var_ = mean_sq - sq_mean + mean_, var_ = mean_var_dense(self.indicator_matrix.tocsr(), self.data) else: mean_, var_ = ( mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc )(self.indicator_matrix, self.data) - sq_mean = mean_**2 - # TODO: Why these values exactly? Because they are high relative to the datatype? - # (unchanged from original code: https://github.com/scverse/anndata/pull/564) - precision = 2 << (42 if self.data.dtype == np.float64 else 20) - # detects loss of precision in mean_sq - sq_mean, which suggests variance is 0 - var_[precision * var_ < sq_mean] = 0 if dof != 0: var_ *= (group_counts / (group_counts - dof))[:, np.newaxis] return mean_, var_ diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py index 4d25bd06be..a0ecd7f35f 100644 --- a/src/scanpy/get/_kernels.py +++ b/src/scanpy/get/_kernels.py @@ -48,34 +48,84 @@ def agg_sum_csc(indicator: CSRBase, data: CSCBase, out: np.ndarray) -> None: out[cat, col] += data.data[j] +@njit +def mean_var_dense( + indicator: CSRBase, data: NDArray +) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + # Welford's online algorithm, parallelized over categories. The indicator + # CSR lists which observations belong to each category, allowing mask + # handling to be folded in naturally. + n_cats = indicator.shape[0] + n_features = data.shape[1] + mean = np.zeros((n_cats, n_features), dtype="float64") + var = np.zeros((n_cats, n_features), dtype="float64") + + for cat in numba.prange(n_cats): + start = indicator.indptr[cat] + stop = indicator.indptr[cat + 1] + n = 0 + for row_num in range(start, stop): + obs = indicator.indices[row_num] + n += 1 + for col in range(n_features): + value = np.float64(data[obs, col]) + delta = value - mean[cat, col] + mean[cat, col] += delta / n + delta2 = value - mean[cat, col] + var[cat, col] += delta * delta2 + if n > 0: + for col in range(n_features): + var[cat, col] /= n + return mean, var + + @njit def mean_var_csr( indicator: CSRBase, data: CSCBase, ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: - mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") - var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") - - for cat_num in numba.prange(indicator.shape[0]): + # Welford's online algorithm over nonzeros, then merge with the block of + # implicit zeros per (category, feature). Merging a Welford accumulator + # (n_A, mean_A, M2_A) with k zeros gives: + # mean = mean_A * n_A / (n_A + k) + # M2_new = M2_A + mean_A^2 * n_A * k / (n_A + k) + n_cats = indicator.shape[0] + n_features = data.shape[1] + mean = np.zeros((n_cats, n_features), dtype="float64") + var = np.zeros((n_cats, n_features), dtype="float64") + + for cat_num in numba.prange(n_cats): start_cat_idx = indicator.indptr[cat_num] stop_cat_idx = indicator.indptr[cat_num + 1] + n_obs = stop_cat_idx - start_cat_idx + if n_obs == 0: + continue + + n_nonzero = np.zeros(n_features, dtype=np.int64) + for row_num in range(start_cat_idx, stop_cat_idx): obs_per_cat = indicator.indices[row_num] - start_obs = data.indptr[obs_per_cat] end_obs = data.indptr[obs_per_cat + 1] for j in range(start_obs, end_obs): col = data.indices[j] value = np.float64(data.data[j]) - value = data.data[j] - mean[cat_num, col] += value - var[cat_num, col] += value * value - - n_obs = stop_cat_idx - start_cat_idx - mean_cat = mean[cat_num, :] / n_obs - mean[cat_num, :] = mean_cat - var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) + n_nonzero[col] += 1 + n = n_nonzero[col] + delta = value - mean[cat_num, col] + mean[cat_num, col] += delta / n + delta2 = value - mean[cat_num, col] + var[cat_num, col] += delta * delta2 + + for col in range(n_features): + n_nz = n_nonzero[col] + k = n_obs - n_nz + if k > 0 and n_nz > 0: + mean_a = mean[cat_num, col] + mean[cat_num, col] = mean_a * n_nz / n_obs + var[cat_num, col] += mean_a * mean_a * n_nz * k / n_obs + var[cat_num, col] /= n_obs return mean, var @@ -83,34 +133,49 @@ def mean_var_csr( def mean_var_csc( indicator: CSRBase, data: CSCBase ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: + # Welford's online algorithm, parallelized over columns. For each column + # we accumulate per-category over the explicit nonzeros, then merge each + # category's accumulator with its block of implicit zeros (see merge + # formula in `mean_var_csr`). + n_cats = indicator.shape[0] + n_features = data.shape[1] obs_to_cat = np.full(data.shape[0], -1, dtype=np.int64) - - mean = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") - var = np.zeros((indicator.shape[0], data.shape[1]), dtype="float64") - - for cat in range(indicator.shape[0]): + n_obs_per_cat = np.zeros(n_cats, dtype=np.int64) + for cat in range(n_cats): + n_obs_per_cat[cat] = indicator.indptr[cat + 1] - indicator.indptr[cat] for k in range(indicator.indptr[cat], indicator.indptr[cat + 1]): obs_to_cat[indicator.indices[k]] = cat - for col in numba.prange(data.shape[1]): + mean = np.zeros((n_cats, n_features), dtype="float64") + var = np.zeros((n_cats, n_features), dtype="float64") + + for col in numba.prange(n_features): + n_nonzero = np.zeros(n_cats, dtype=np.int64) start = data.indptr[col] end = data.indptr[col + 1] for j in range(start, end): obs = data.indices[j] cat = obs_to_cat[obs] - - if cat != -1: - value = np.float64(data.data[j]) - value = data.data[j] - mean[cat, col] += value - var[cat, col] += value * value - - for cat_num in numba.prange(indicator.shape[0]): - start_cat_idx = indicator.indptr[cat_num] - stop_cat_idx = indicator.indptr[cat_num + 1] - n_obs = stop_cat_idx - start_cat_idx - mean_cat = mean[cat_num, :] / n_obs - mean[cat_num, :] = mean_cat - var[cat_num, :] = (var[cat_num, :] / n_obs) - (mean_cat * mean_cat) + if cat == -1: + continue + value = np.float64(data.data[j]) + n_nonzero[cat] += 1 + n = n_nonzero[cat] + delta = value - mean[cat, col] + mean[cat, col] += delta / n + delta2 = value - mean[cat, col] + var[cat, col] += delta * delta2 + + for cat in range(n_cats): + n_obs = n_obs_per_cat[cat] + if n_obs == 0: + continue + n_nz = n_nonzero[cat] + k = n_obs - n_nz + if k > 0 and n_nz > 0: + mean_a = mean[cat, col] + mean[cat, col] = mean_a * n_nz / n_obs + var[cat, col] += mean_a * mean_a * n_nz * k / n_obs + var[cat, col] /= n_obs return mean, var From 514bd1771a65459da39b45c1304d36bf82a0e2a0 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 8 Jun 2026 14:15:16 +0200 Subject: [PATCH 25/48] chore: relnote --- docs/release-notes/4147.perf.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/4147.perf.md diff --git a/docs/release-notes/4147.perf.md b/docs/release-notes/4147.perf.md new file mode 100644 index 0000000000..f6044655c8 --- /dev/null +++ b/docs/release-notes/4147.perf.md @@ -0,0 +1 @@ +Use [Welford's algorithm][] for mean-var calculation in {func}`scanpy.get.aggregate` for in-memory (i.e., non-dask) arrays {smaller}`I Gold` From b71eb686174bd998ddeefd4e0034bcc186c0a154 Mon Sep 17 00:00:00 2001 From: Zach Boldyga Date: Thu, 11 Jun 2026 12:00:14 -0700 Subject: [PATCH 26/48] njit support for chan algorithm (#4153) Co-authored-by: Ilan Gold Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/scanpy/get/_aggregated.py | 41 ++++++++++++++++++++++++++--------- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 66f251c316..b05b80e435 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -3,9 +3,11 @@ from functools import partial, singledispatch from typing import TYPE_CHECKING, Literal, TypedDict, get_args +import numba import numpy as np import pandas as pd from anndata import AnnData +from fast_array_utils.numba import njit from scipy import sparse from sklearn.utils.sparsefuncs import csc_median_axis_0 @@ -469,18 +471,37 @@ def _block_moments( return out -def _chan_combine( +@numba.njit(inline="always") # noqa: TID251 +def _chan_combine( # noqa: PLR0917 + n_a: float, mean_a: float, m2_a: float, n_b: float, mean_b: float, m2_b: float +) -> tuple[float, float, float]: + """Combine two ``(count, mean, M2)`` groups pairwise.""" + if n_a == 0.0: + return n_b, mean_b, m2_b + if n_b == 0.0: + return n_a, mean_a, m2_a + n = n_a + n_b + delta = mean_b - mean_a + return n, mean_a + delta * n_b / n, m2_a + m2_b + delta * delta * n_a * n_b / n + + +@njit +def _chan_combine_blocks( a: NDArray[np.float64], b: NDArray[np.float64] ) -> NDArray[np.float64]: """Combine two ``(3, K, F)`` ``(count, mean, M2)`` stat blocks pairwise.""" - n_a, mean_a, m2_a = a[0], a[1], a[2] - n_b, mean_b, m2_b = b[0], b[1], b[2] - n = n_a + n_b - safe_n = np.where(n > 0, n, 1) - delta = mean_b - mean_a - new_mean = mean_a + delta * n_b / safe_n - new_m2 = m2_a + m2_b + delta * delta * n_a * n_b / safe_n - return np.stack([n, new_mean, new_m2]) + out = np.empty_like(a) + for i in numba.prange(a.shape[1]): + for j in range(a.shape[2]): + out[0, i, j], out[1, i, j], out[2, i, j] = _chan_combine( + a[0, i, j], + a[1, i, j], + a[2, i, j], + b[0, i, j], + b[1, i, j], + b[2, i, j], + ) + return out def _chan_reduce_axis_0( @@ -491,7 +512,7 @@ def _chan_reduce_axis_0( """Aggregate per-block stats along axis 0 with the parallel variance algorithm.""" result = stats[0] for i in range(1, stats.shape[0]): - result = _chan_combine(result, stats[i]) + result = _chan_combine_blocks(result, stats[i]) return result[None] if keepdims else result From 1f430883795ee7c1efb17c6b23221016b874a226 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 15 Jun 2026 14:58:23 +0200 Subject: [PATCH 27/48] chore: integrate welford's directly into chans --- src/scanpy/get/_aggregated.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 9a5079bc1a..c2fb6e33ce 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -147,7 +147,7 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc )(self.indicator_matrix, self.data) if dof != 0: - var_ *= (group_counts / (group_counts - dof))[:, np.newaxis] + var_ *= (group_counts / np.maximum(group_counts - dof, 1))[:, np.newaxis] return mean_, var_ def median(self) -> Array: @@ -453,14 +453,11 @@ def _block_moments( return out agg = Aggregate(groupby=by, data=data, mask=mask) - sum_ = agg.sum() - sum_sq = agg._sum(_power(data, 2)) - safe_counts = np.where(nonempty, counts, 1)[:, None] - mean_ = sum_ / safe_counts - # M2 = sum((x - mean)**2) = sum_sq - count * mean**2; clip cancellation noise to 0. - m2 = np.maximum(sum_sq - sum_ * mean_, 0) - out[1, nonempty] = mean_[nonempty] - out[2, nonempty] = m2[nonempty] + mean_, var_ = agg.mean_var() + # M2 is the variance times the counts directly minus correction. + m2 = var_ * (counts - 1)[:, np.newaxis] + out[1, :] = mean_ + out[2, :] = m2 return out From 80a4a94286f11de74281f06381e6b1d190f2e43a Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Tue, 16 Jun 2026 15:12:47 +0200 Subject: [PATCH 28/48] Apply suggestion from @ilan-gold --- docs/release-notes/4143.perf.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/release-notes/4143.perf.md b/docs/release-notes/4143.perf.md index 2ffd9c00d6..e5644ca038 100644 --- a/docs/release-notes/4143.perf.md +++ b/docs/release-notes/4143.perf.md @@ -1,3 +1,3 @@ -Use Chan's mean-var algorithm for acceleration of dask-backed {func}`scanpy.get.aggregate` {smaller}`I Gold` +Use [Chan's mean-var][] algorithm for acceleration of dask-backed {func}`scanpy.get.aggregate` {smaller}`I Gold` [Chan's mean-var]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm From 8405ff12f8838dc89f6ca21ce42ee6a68bab68c9 Mon Sep 17 00:00:00 2001 From: Ilan Gold Date: Wed, 17 Jun 2026 13:41:10 +0200 Subject: [PATCH 29/48] Update benchmarks/benchmarks/preprocessing_counts.py Co-authored-by: Philipp A. --- benchmarks/benchmarks/preprocessing_counts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index bc3bedce63..65b7502044 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -152,7 +152,7 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[list[str], list[bool]] = ( + params: tuple[list[AggType], list[bool]] = ( list(get_literal_vals(AggType)), [True, False], ) From cd73fe221007b870b79e5ceafdaa294ac2171401 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 17 Jun 2026 13:47:35 +0200 Subject: [PATCH 30/48] fix: correct chan unstable step usable --- src/scanpy/get/_aggregated.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index c2fb6e33ce..8eb02d5730 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -472,7 +472,11 @@ def _chan_combine( # noqa: PLR0917 return n_a, mean_a, m2_a n = n_a + n_b delta = mean_b - mean_a - return n, mean_a + delta * n_b / n, m2_a + m2_b + delta * delta * n_a * n_b / n + return ( + n, + (n_a * mean_a + n_b * mean_b) / n, + m2_a + m2_b + delta * delta * n_a * n_b / n, + ) @njit From 05daadbe8af805ad875bc657db71f3dcbe97f36c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 17 Jun 2026 14:33:19 +0200 Subject: [PATCH 31/48] chore: add cancelling test --- tests/test_aggregated.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index e9caef4f4e..e66e77cf23 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -16,6 +16,7 @@ from testing.scanpy._helpers.data import pbmc3k_processed from testing.scanpy._pytest.marks import needs from testing.scanpy._pytest.params import ARRAY_TYPES as ARRAY_TYPES_ALL +from testing.scanpy._pytest.params import ARRAY_TYPES_MEM if TYPE_CHECKING: from collections.abc import Callable @@ -544,3 +545,41 @@ def test_nan() -> None: "s2_control_C", ] assert adata_agg.obs["n_obs_aggregated"].tolist() == [1, 2, 1] + + +@pytest.mark.parametrize("array_type", ARRAY_TYPES_MEM) +def test_var_no_catastrophic_cancellation(array_type) -> None: + # Values of the form `offset + tiny_noise` make the textbook two-pass + # formula sum(x**2)/n - (sum(x)/n)**2 lose ~all precision: both terms are + # ~n*offset**2 ≈ 1e19 in float64 (precision ~1e3) but their difference is + # the variance ~1e-3, far below the rounding noise. Welford's online + # algorithm. + rng = np.random.default_rng(0) + n_per_group, n_features = 1000, 4 + offset, std = 1e8, 1e-3 + groups = ["a", "b"] + x = np.vstack([ + offset + std * rng.standard_normal((n_per_group, n_features)) for _ in groups + ]).astype(np.float64) + obs = pd.DataFrame( + {"group": pd.Categorical(np.repeat(groups, n_per_group))}, + index=[f"cell_{i}" for i in range(x.shape[0])], + ) + adata = ad.AnnData(X=array_type(x), obs=obs) + + expected = np.vstack([ + np.var(x[i * n_per_group : (i + 1) * n_per_group], axis=0, ddof=0) + for i in range(len(groups)) + ]) + # Sanity: textbook formula on this data is catastrophically wrong + # (off by >1e5x the true variance — proving the scenario actually triggers + # cancellation rather than being a vacuous test). + naive = (x**2).mean(axis=0) - x.mean(axis=0) ** 2 + assert ( + np.abs(naive - np.var(x, axis=0, ddof=0)) / np.var(x, axis=0, ddof=0) > 1e5 + ).all() + + result = sc.get.aggregate(adata, by="group", func="var", dof=0).layers["var"] + if isinstance(result, DaskArray): + result = result.compute() + np.testing.assert_allclose(result, expected, rtol=1e-4) From e4377143000289e154032a664b73ac20e785ed7e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 17 Jun 2026 14:34:25 +0200 Subject: [PATCH 32/48] chore: finish sentence --- tests/test_aggregated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index e66e77cf23..78b8340f67 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -553,7 +553,7 @@ def test_var_no_catastrophic_cancellation(array_type) -> None: # formula sum(x**2)/n - (sum(x)/n)**2 lose ~all precision: both terms are # ~n*offset**2 ≈ 1e19 in float64 (precision ~1e3) but their difference is # the variance ~1e-3, far below the rounding noise. Welford's online - # algorithm. + # algorithm avoids the subtraction entirely. rng = np.random.default_rng(0) n_per_group, n_features = 1000, 4 offset, std = 1e8, 1e-3 From ec7267903e858b6a5f45b331003954dd11b2dffb Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 17 Jun 2026 14:35:18 +0200 Subject: [PATCH 33/48] chore: relnote --- docs/release-notes/4147.perf.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/release-notes/4147.perf.md b/docs/release-notes/4147.perf.md index f6044655c8..2943ad2083 100644 --- a/docs/release-notes/4147.perf.md +++ b/docs/release-notes/4147.perf.md @@ -1 +1,3 @@ Use [Welford's algorithm][] for mean-var calculation in {func}`scanpy.get.aggregate` for in-memory (i.e., non-dask) arrays {smaller}`I Gold` + +[Welford's algorithm]: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm From 9dcfcc7e5f3c2779b069af09f17354137c0baa63 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 17 Jun 2026 14:36:36 +0200 Subject: [PATCH 34/48] chore: add context --- src/scanpy/get/_aggregated.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scanpy/get/_aggregated.py b/src/scanpy/get/_aggregated.py index 3fad260f7d..a21ec7f806 100644 --- a/src/scanpy/get/_aggregated.py +++ b/src/scanpy/get/_aggregated.py @@ -124,7 +124,8 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]: """Compute the count, as well as mean and variance per feature, per group of observations. Mean and variance are computed with Welford's online algorithm, which is - numerically stable for constant or near-constant inputs. + numerically stable for constant or near-constant inputs + compared to subtracting E[X^2] - E[X]^2 since both values will be so close. Params ------ From cd781f005d4abf5cc9b9d5b93723fa91e53305b9 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 17 Jun 2026 14:40:23 +0200 Subject: [PATCH 35/48] chore: bring in dask --- tests/test_aggregated.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index b6c238fd27..78f3ec1a2d 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -16,7 +16,6 @@ from testing.scanpy._helpers.data import pbmc3k_processed from testing.scanpy._pytest.marks import needs from testing.scanpy._pytest.params import ARRAY_TYPES as ARRAY_TYPES_ALL -from testing.scanpy._pytest.params import ARRAY_TYPES_MEM if TYPE_CHECKING: from collections.abc import Callable @@ -547,7 +546,7 @@ def test_nan() -> None: assert adata_agg.obs["n_obs_aggregated"].tolist() == [1, 2, 1] -@pytest.mark.parametrize("array_type", ARRAY_TYPES_MEM) +@pytest.mark.parametrize("array_type", VALID_ARRAY_TYPES) def test_var_no_catastrophic_cancellation(array_type) -> None: # Values of the form `offset + tiny_noise` make the textbook two-pass # formula sum(x**2)/n - (sum(x)/n)**2 lose ~all precision: both terms are From bd85e03c1fbc44b32a572e6772ac0c781d28107f Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 18 Jun 2026 14:28:04 +0200 Subject: [PATCH 36/48] perf: less memory touches --- src/scanpy/get/_kernels.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/src/scanpy/get/_kernels.py b/src/scanpy/get/_kernels.py index a0ecd7f35f..06154ebe03 100644 --- a/src/scanpy/get/_kernels.py +++ b/src/scanpy/get/_kernels.py @@ -111,12 +111,13 @@ def mean_var_csr( for j in range(start_obs, end_obs): col = data.indices[j] value = np.float64(data.data[j]) - n_nonzero[col] += 1 - n = n_nonzero[col] - delta = value - mean[cat_num, col] - mean[cat_num, col] += delta / n - delta2 = value - mean[cat_num, col] - var[cat_num, col] += delta * delta2 + n = n_nonzero[col] + 1 + n_nonzero[col] = n + m = mean[cat_num, col] + delta = value - m + m += delta / n + mean[cat_num, col] = m + var[cat_num, col] += delta * (value - m) for col in range(n_features): n_nz = n_nonzero[col] @@ -160,12 +161,13 @@ def mean_var_csc( if cat == -1: continue value = np.float64(data.data[j]) - n_nonzero[cat] += 1 - n = n_nonzero[cat] - delta = value - mean[cat, col] - mean[cat, col] += delta / n - delta2 = value - mean[cat, col] - var[cat, col] += delta * delta2 + n = n_nonzero[cat] + 1 + n_nonzero[cat] = n + m = mean[cat, col] + delta = value - m + m += delta / n + mean[cat, col] = m + var[cat, col] += delta * (value - m) for cat in range(n_cats): n_obs = n_obs_per_cat[cat] From 497809e96bb14b03e1b6dcbf20eb7d155fc10ed2 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 22 Jun 2026 12:28:57 +0200 Subject: [PATCH 37/48] refactor: cleanup --- tests/test_aggregated.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index 78b8340f67..b9e1750bef 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -554,13 +554,14 @@ def test_var_no_catastrophic_cancellation(array_type) -> None: # ~n*offset**2 ≈ 1e19 in float64 (precision ~1e3) but their difference is # the variance ~1e-3, far below the rounding noise. Welford's online # algorithm avoids the subtraction entirely. - rng = np.random.default_rng(0) n_per_group, n_features = 1000, 4 offset, std = 1e8, 1e-3 groups = ["a", "b"] x = np.vstack([ - offset + std * rng.standard_normal((n_per_group, n_features)) for _ in groups - ]).astype(np.float64) + offset + + std * np.random.default_rng().standard_normal((n_per_group, n_features)) + for _ in groups + ]) obs = pd.DataFrame( {"group": pd.Categorical(np.repeat(groups, n_per_group))}, index=[f"cell_{i}" for i in range(x.shape[0])], From 25e6bfcdbb14be0ca79a555e0939fe81dfe34bdb Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 22 Jun 2026 13:44:19 +0200 Subject: [PATCH 38/48] chore: csc benchmarks --- benchmarks/benchmarks/preprocessing_counts.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 9a20e7eda3..51933d88cf 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -151,16 +151,21 @@ def peakmem_log1p(self, *_) -> None: class Agg: # noqa: D101 - params: tuple[AggType] = tuple(get_literal_vals(AggType)) - param_names = ("agg_name",) + params: tuple[list[AggType], list[bool]] = ( + list(get_literal_vals(AggType)), + [True, False], + ) + param_names = ("agg_name", "use_csc") def setup_cache(self) -> None: """Without this caching, asv was running several processes which meant the data was repeatedly downloaded.""" adata, _ = get_dataset("lung93k") adata.write_h5ad("lung93k.h5ad") - def setup(self, agg_name: AggType) -> None: + def setup(self, agg_name: AggType, use_csc: bool) -> None: # noqa: FBT001 self.adata = ad.read_h5ad("lung93k.h5ad") + if use_csc: + self.adata.layers["counts"] = self.adata.layers["counts"].tocsc() self.agg_name = agg_name def time_agg(self, *_) -> None: From cd9ad03e7777fde11ec2e529f36d210de9b5661b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Jun 2026 13:15:24 +0000 Subject: [PATCH 39/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- benchmarks/benchmarks/preprocessing_counts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index f0ebe34486..08cadd9764 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -176,7 +176,9 @@ def setup(self, agg_name: AggType, use_csc: bool, use_dask: bool) -> None: # no obs=ad.io.read_elem(z["obs"]), var=ad.io.read_elem(z["var"]), layers={ - "counts": ad.experimental.read_elem_lazy(z["layers"][counts_src_key]) + "counts": ad.experimental.read_elem_lazy( + z["layers"][counts_src_key] + ) }, X=ad.experimental.read_elem_lazy(z["X"]), ) From 9004cc016455f770540a37b23afea32ce41b551e Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 23 Jun 2026 14:49:54 +0200 Subject: [PATCH 40/48] fix: tests --- tests/test_aggregated.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index b9e1750bef..afd37d4938 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -572,13 +572,20 @@ def test_var_no_catastrophic_cancellation(array_type) -> None: np.var(x[i * n_per_group : (i + 1) * n_per_group], axis=0, ddof=0) for i in range(len(groups)) ]) - # Sanity: textbook formula on this data is catastrophically wrong - # (off by >1e5x the true variance — proving the scenario actually triggers - # cancellation rather than being a vacuous test). - naive = (x**2).mean(axis=0) - x.mean(axis=0) ** 2 - assert ( - np.abs(naive - np.var(x, axis=0, ddof=0)) / np.var(x, axis=0, ddof=0) > 1e5 - ).all() + # Sanity: textbook formula on this data is either catastrophically wrong by a large magnitude relative to the epected + # or the sum-sq and sq-sum in naive are literally identical due to precision errors at the upper bound of the range. + naive = np.vstack([ + (xg**2).mean(axis=0) - xg.mean(axis=0) ** 2 + for xg in ( + x[i * n_per_group : (i + 1) * n_per_group] for i in range(len(groups)) + ) + ]) + diff_magnitude = np.abs(naive - expected) / expected + all_large = (diff_magnitude > 1e5).all() + if not all_large: + does_naive_fully_cancel = naive == 0 + assert does_naive_fully_cancel.any() + assert (diff_magnitude[does_naive_fully_cancel] == 1).all() result = sc.get.aggregate(adata, by="group", func="var", dof=0).layers["var"] if isinstance(result, DaskArray): From ff0ac25cc4596e011e8759d76e044d3eccc77997 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 23 Jun 2026 14:53:34 +0200 Subject: [PATCH 41/48] chore: spelling --- tests/test_aggregated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index afd37d4938..9d5db64e08 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -572,7 +572,7 @@ def test_var_no_catastrophic_cancellation(array_type) -> None: np.var(x[i * n_per_group : (i + 1) * n_per_group], axis=0, ddof=0) for i in range(len(groups)) ]) - # Sanity: textbook formula on this data is either catastrophically wrong by a large magnitude relative to the epected + # Sanity: textbook formula on this data is either catastrophically wrong by a large magnitude relative to the expected # or the sum-sq and sq-sum in naive are literally identical due to precision errors at the upper bound of the range. naive = np.vstack([ (xg**2).mean(axis=0) - xg.mean(axis=0) ** 2 From 3c3c5b0e3f7dd5cedd1e2050b1ab8395ccec1d9a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Jun 2026 13:19:19 +0000 Subject: [PATCH 42/48] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_aggregated.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_aggregated.py b/tests/test_aggregated.py index 6c9acd9461..6b3ebf409f 100644 --- a/tests/test_aggregated.py +++ b/tests/test_aggregated.py @@ -16,7 +16,6 @@ from testing.scanpy._helpers.data import pbmc3k_processed from testing.scanpy._pytest.marks import needs from testing.scanpy._pytest.params import ARRAY_TYPES as ARRAY_TYPES_ALL -from testing.scanpy._pytest.params import ARRAY_TYPES_MEM if TYPE_CHECKING: from collections.abc import Callable From 65a46ea53088d83164da83667c582a05911438ac Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Tue, 23 Jun 2026 15:39:10 +0200 Subject: [PATCH 43/48] chore: clean up `counts` key --- benchmarks/benchmarks/preprocessing_counts.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/benchmarks/benchmarks/preprocessing_counts.py b/benchmarks/benchmarks/preprocessing_counts.py index 08cadd9764..5c5114d902 100644 --- a/benchmarks/benchmarks/preprocessing_counts.py +++ b/benchmarks/benchmarks/preprocessing_counts.py @@ -184,8 +184,9 @@ def setup(self, agg_name: AggType, use_csc: bool, use_dask: bool) -> None: # no ) else: self.adata = ad.read_zarr("lung93k.zarr") - self.adata.layers["counts"] = self.adata.layers[counts_src_key] - del self.adata.layers[counts_src_key] + if counts_src_key != "counts": + self.adata.layers["counts"] = self.adata.layers[counts_src_key] + del self.adata.layers[counts_src_key] self.agg_name: AggType = agg_name def time_agg(self, *_) -> None: From c99d04d405d188c6912393e455f299803cdeab3d Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 24 Jun 2026 10:07:07 +0200 Subject: [PATCH 44/48] fix: try no dask --- benchmarks/asv.conf.json | 1 - 1 file changed, 1 deletion(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index a1a8d31a42..d19b822178 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -84,7 +84,6 @@ "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 // "scikit-misc": [""], - "dask": [""], }, // Combinations of libraries/python versions can be excluded/included From 31d42baa543b7ac8cb3a85c6be5f5e0486d92537 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Wed, 24 Jun 2026 10:07:42 +0200 Subject: [PATCH 45/48] fix: back to dask --- benchmarks/asv.conf.json | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index d19b822178..a1a8d31a42 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -84,6 +84,7 @@ "pooch": [""], "scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29 // "scikit-misc": [""], + "dask": [""], }, // Combinations of libraries/python versions can be excluded/included From 7bf2db4aa11c28d5e6ed01644453bb742bab6375 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 25 Jun 2026 12:19:06 +0200 Subject: [PATCH 46/48] fix: no defaults --- benchmarks/asv.conf.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/asv.conf.json b/benchmarks/asv.conf.json index a1a8d31a42..777d2bb980 100644 --- a/benchmarks/asv.conf.json +++ b/benchmarks/asv.conf.json @@ -58,7 +58,7 @@ // The list of conda channel names to be searched for benchmark // dependency packages in the specified order - "conda_channels": ["conda-forge", "defaults"], + "conda_channels": ["conda-forge"], // The matrix of dependencies to test. Each key is the name of a // package (in PyPI) and the values are version numbers. An empty From 06ecaa2178931caec6ab99c67edc1ff1f96cf10c Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 25 Jun 2026 13:48:53 +0200 Subject: [PATCH 47/48] fix: var space --- src/scanpy/preprocessing/_highly_variable_genes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/scanpy/preprocessing/_highly_variable_genes.py b/src/scanpy/preprocessing/_highly_variable_genes.py index cad3ff470d..489af1e3f8 100644 --- a/src/scanpy/preprocessing/_highly_variable_genes.py +++ b/src/scanpy/preprocessing/_highly_variable_genes.py @@ -193,6 +193,7 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915 aggregated_mean_var.layers["var"] = var_global else: aggregated_mean_var = AnnData( + var=pd.DataFrame(index=adata.var_names), obs=pd.DataFrame( index=np.array(["one"]), data={"__hvg_v3_batch_info__": np.array([0])} ), From 1302d26bcf01b07fd20f028456b71c78da8757b0 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Thu, 25 Jun 2026 14:20:18 +0200 Subject: [PATCH 48/48] chore: relnote --- docs/release-notes/4013.perf.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/4013.perf.md diff --git a/docs/release-notes/4013.perf.md b/docs/release-notes/4013.perf.md new file mode 100644 index 0000000000..58893247da --- /dev/null +++ b/docs/release-notes/4013.perf.md @@ -0,0 +1 @@ +{func}`scanpy.pp.highly_variable_genes` now does only two passes over the data sequentially for `seurat_v3` flavors, greatly reducing `dask` input usage time {smaller}`I Gold`