Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
a625c55
perf: "two-pass" seurat hvg3 via `scanpy.get.aggregate`
ilan-gold Mar 26, 2026
d839e98
chore: hvg v3 benchmark
ilan-gold Mar 26, 2026
86db499
fix: use counts
ilan-gold Mar 26, 2026
d5a6a78
fix: use a batch key
ilan-gold Mar 26, 2026
fdc5653
fix: not again
ilan-gold Mar 26, 2026
8f0e426
fix: `compute` single pass!
ilan-gold Apr 8, 2026
8ad893d
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold Apr 8, 2026
7e0390e
fix: unique
ilan-gold Apr 9, 2026
17be530
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold Apr 10, 2026
cc0d67e
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold Apr 16, 2026
96c16e9
chore: add new `dask` benchmark
ilan-gold May 4, 2026
db4bc2c
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold May 4, 2026
478af4a
fix: actually use dask lol
ilan-gold May 4, 2026
54db31b
chore: really do dask
ilan-gold May 4, 2026
4fe84c5
fix: layers support
ilan-gold May 4, 2026
35590a4
fix: no view check needed
ilan-gold May 4, 2026
db81d6e
fix: no layers eeded
ilan-gold May 4, 2026
b37444e
fix: reduce number of batches
ilan-gold May 5, 2026
cf65665
fix: a little bit more
ilan-gold May 5, 2026
8f4ef78
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold May 15, 2026
a7b067d
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold May 16, 2026
6f7ad6a
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold May 18, 2026
e624939
perf: chan's parallel mean-var algorithm for dask
ilan-gold Jun 5, 2026
61332fd
fix: params
ilan-gold Jun 5, 2026
1df5fda
fix: iteration
ilan-gold Jun 5, 2026
9a70581
fix: zarr link
ilan-gold Jun 5, 2026
5313ea2
fix: `median` calculation skipped
ilan-gold Jun 5, 2026
e19a7d8
fix: no-batch-key accel
ilan-gold Jun 5, 2026
8482561
fix: don't run all benchmarks with dask
ilan-gold Jun 5, 2026
44606f0
Merge branch 'ig/chan_mean_var_main' into ig/two_pass_hvg_v3
ilan-gold Jun 5, 2026
17f706e
Merge branch 'main' into ig/chan_mean_var_main
ilan-gold Jun 5, 2026
a2b390b
chore: relnote
ilan-gold Jun 8, 2026
21f5ddc
perf: welford's algorithm for mean-var
ilan-gold Jun 8, 2026
514bd17
chore: relnote
ilan-gold Jun 8, 2026
48230af
Merge branch 'main' into ig/welford
ilan-gold Jun 9, 2026
c2cd368
Merge branch 'main' into ig/chan_mean_var_main
ilan-gold Jun 9, 2026
b71eb68
njit support for chan algorithm (#4153)
zboldyga Jun 11, 2026
471b989
Merge branch 'main' into ig/chan_mean_var_main
ilan-gold Jun 11, 2026
afc24a1
Merge branch 'main' into ig/welford
ilan-gold Jun 12, 2026
35b0ff6
Merge branch 'main' into ig/welford
ilan-gold Jun 15, 2026
f793415
Merge branch 'main' into ig/chan_mean_var_main
ilan-gold Jun 15, 2026
8a57e1c
Merge branch 'ig/welford' into ig/chan_mean_var_main
ilan-gold Jun 15, 2026
1f43088
chore: integrate welford's directly into chans
ilan-gold Jun 15, 2026
80a4a94
Apply suggestion from @ilan-gold
ilan-gold Jun 16, 2026
8405ff1
Update benchmarks/benchmarks/preprocessing_counts.py
ilan-gold Jun 17, 2026
cd73fe2
fix: correct chan unstable step usable
ilan-gold Jun 17, 2026
05daadb
chore: add cancelling test
ilan-gold Jun 17, 2026
0d25b48
Merge branch 'ig/welford' into ig/chan_mean_var_main
ilan-gold Jun 17, 2026
e437714
chore: finish sentence
ilan-gold Jun 17, 2026
ec72679
chore: relnote
ilan-gold Jun 17, 2026
9dcfcc7
chore: add context
ilan-gold Jun 17, 2026
4f074ed
Merge branch 'ig/welford' into ig/chan_mean_var_main
ilan-gold Jun 17, 2026
3fa0884
Merge branch 'main' into ig/welford
ilan-gold Jun 17, 2026
cd781f0
chore: bring in dask
ilan-gold Jun 17, 2026
9a6275b
Merge branch 'ig/welford' into ig/chan_mean_var_main
ilan-gold Jun 17, 2026
bd85e03
perf: less memory touches
ilan-gold Jun 18, 2026
4e0ff1a
Merge branch 'ig/welford' into ig/chan_mean_var_main
ilan-gold Jun 18, 2026
497809e
refactor: cleanup
ilan-gold Jun 22, 2026
25e6bfc
chore: csc benchmarks
ilan-gold Jun 22, 2026
3f4b1b4
Merge branch 'ig/welford' into ig/chan_mean_var_main
ilan-gold Jun 22, 2026
cd9ad03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 22, 2026
1a19312
Merge branch 'main' into ig/welford
ilan-gold Jun 23, 2026
c649555
Merge branch 'ig/welford' into ig/chan_mean_var_main
ilan-gold Jun 23, 2026
81ae72b
Merge branch 'main' into ig/welford
ilan-gold Jun 23, 2026
c7d4166
Merge branch 'ig/welford' into ig/chan_mean_var_main
flying-sheep Jun 23, 2026
9004cc0
fix: tests
ilan-gold Jun 23, 2026
cc5ac95
Merge branch 'ig/welford' of github.com:scverse/scanpy into ig/welford
ilan-gold Jun 23, 2026
eb03735
Merge branch 'main' into ig/welford
ilan-gold Jun 23, 2026
d1ad434
Merge branch 'ig/welford' into ig/chan_mean_var_main
ilan-gold Jun 23, 2026
ff0ac25
chore: spelling
ilan-gold Jun 23, 2026
6ddd745
Merge branch 'ig/welford' into ig/chan_mean_var_main
ilan-gold Jun 23, 2026
6ebc4b3
Merge branch 'main' into ig/chan_mean_var_main
ilan-gold Jun 23, 2026
3c3c5b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2026
65a46ea
chore: clean up `counts` key
ilan-gold Jun 23, 2026
88a0150
Merge branch 'ig/chan_mean_var_main' of github.com:scverse/scanpy int…
ilan-gold Jun 23, 2026
196e443
Merge branch 'main' into ig/chan_mean_var_main
ilan-gold Jun 23, 2026
c99d04d
fix: try no dask
ilan-gold Jun 24, 2026
31d42ba
fix: back to dask
ilan-gold Jun 24, 2026
83d8db7
Merge branch 'ig/chan_mean_var_main' into ig/two_pass_hvg_v3
ilan-gold Jun 24, 2026
7bf2db4
fix: no defaults
ilan-gold Jun 25, 2026
added47
Merge branch 'ig/chan_mean_var_main' into ig/two_pass_hvg_v3
ilan-gold Jun 25, 2026
06ecaa2
fix: var space
ilan-gold Jun 25, 2026
1302d26
chore: relnote
ilan-gold Jun 25, 2026
761f054
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold Jun 25, 2026
3c87db4
Merge branch 'main' into ig/two_pass_hvg_v3
ilan-gold Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/asv.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@
// "psutil": [""]
"pooch": [""],
"scikit-image": [""], // https://github.com/conda-forge/scikit-misc-feedstock/pull/29
// "scikit-misc": [""],
"scikit-misc": [""],
"dask": [""],
},

Expand Down
78 changes: 67 additions & 11 deletions benchmarks/benchmarks/preprocessing_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,16 @@
from typing import TYPE_CHECKING

import anndata as ad
import numpy as np
import zarr

import scanpy as sc

from ._utils import get_dataset, param_skipper

if TYPE_CHECKING:
from typing import Literal

from ._utils import Dataset, KeyX


Expand Down Expand Up @@ -47,17 +51,6 @@ def time_pca(self, *_) -> None:
def peakmem_pca(self, *_) -> None:
sc.pp.pca(self.adata, svd_solver="arpack")

def time_highly_variable_genes(self, *_) -> None:
# the default flavor runs on log-transformed data
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
)

def peakmem_highly_variable_genes(self, *_) -> None:
sc.pp.highly_variable_genes(
self.adata, min_mean=0.0125, max_mean=3, min_disp=0.5
)

# regress_out is very slow for this dataset
@skip_when(dataset={"pbmc3k"})
def time_regress_out(self, *_) -> None:
Expand All @@ -72,3 +65,66 @@ def time_scale(self, *_) -> None:

def peakmem_scale(self, *_) -> None:
sc.pp.scale(self.adata, max_value=10)


class HVGSuite: # noqa: D101
params = (["seurat_v3", "cell_ranger", "seurat"], [True, False])
param_names = ("flavor", "use_dask")

def setup_cache(self) -> None:
"""Without this caching, asv was running several processes which meant the data was repeatedly downloaded."""
adata, _ = get_dataset("lung93k")
adata.write_zarr("lung93k.zarr")
obs = np.arange(adata.shape[0])
np.random.default_rng().shuffle(obs)
adata[obs].write_zarr("lung93k_shuffled.zarr")

def setup(
self,
flavor: Literal["seurat_v3", "cell_ranger", "seurat"],
use_dask: bool, # noqa: FBT001
) -> None:
if use_dask:
if flavor != "seurat_v3":
# This benchmark only really makes sense for seurat v3 as that has been optimized.
raise NotImplementedError()
z = zarr.open("lung93k_shuffled.zarr")
self.adata = ad.AnnData(
obs=ad.io.read_elem(z["obs"]),
var=ad.io.read_elem(z["var"]),
layers={
"counts": ad.experimental.read_elem_lazy(z["layers"]["counts"])
},
X=ad.experimental.read_elem_lazy(z["X"]),
)
# Times out on the benchmark machine with full dataset
self.adata = self.adata[
self.adata.obs["PatientNumber"].isin(["1", "2", "3"])
].copy()
else:
self.adata = ad.read_zarr("lung93k.zarr")
sc.pp.filter_genes(self.adata, min_cells=3)
self.flavor = flavor

def time_highly_variable_genes(self, *_) -> None:
# the default flavor runs on log-transformed data
sc.pp.highly_variable_genes(
self.adata,
min_mean=0.0125,
max_mean=3,
min_disp=0.5,
flavor=self.flavor,
batch_key="PatientNumber",
**({"layer": "counts"} if self.flavor == "seurat_v3" else {}),
)

def peakmem_highly_variable_genes(self, *_) -> None:
sc.pp.highly_variable_genes(
self.adata,
min_mean=0.0125,
max_mean=3,
min_disp=0.5,
flavor=self.flavor,
batch_key="PatientNumber",
**({"layer": "counts"} if self.flavor == "seurat_v3" else {}),
)
1 change: 1 addition & 0 deletions docs/release-notes/4013.perf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{func}`scanpy.pp.highly_variable_genes` now does only two passes over the data sequentially for `seurat_v3` flavors, greatly reducing `dask` input usage time {smaller}`I Gold`
85 changes: 62 additions & 23 deletions src/scanpy/preprocessing/_highly_variable_genes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
raise_if_dask_feature_axis_chunked,
sanitize_anndata,
)
from ..get import _get_obs_rep
from ..get import _get_obs_rep, aggregate
from ._distributed import materialize_as_ndarray
from ._simple import filter_genes

Expand All @@ -36,7 +36,7 @@
@singledispatch
def clip_square_sum(
data_batch: np.ndarray, clip_val: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
) -> tuple[np.ndarray, np.ndarray] | tuple[DaskArray, DaskArray]:
"""Clip data_batch by clip_val.

Parameters
Expand Down Expand Up @@ -64,24 +64,19 @@ def clip_square_sum(


@clip_square_sum.register(DaskArray)
def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
def _(data_batch: DaskArray, clip_val: np.ndarray) -> tuple[DaskArray, DaskArray]:
n_blocks = data_batch.blocks.size

def sum_and_sum_squares_clipped_from_block(block):
return np.vstack(clip_square_sum(block, clip_val))[None, ...]

squared_batch_counts_sum, batch_counts_sum = (
data_batch
.map_blocks(
sum_and_sum_squares_clipped_from_block,
new_axis=(1,),
chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)),
meta=np.array([]),
dtype=np.float64,
)
.sum(axis=0)
.compute()
)
squared_batch_counts_sum, batch_counts_sum = data_batch.map_blocks(
sum_and_sum_squares_clipped_from_block,
new_axis=(1,),
chunks=((1,) * n_blocks, (2,), (data_batch.shape[1],)),
meta=np.array([]),
dtype=np.float64,
).sum(axis=0)
return squared_batch_counts_sum, batch_counts_sum


Expand Down Expand Up @@ -172,17 +167,56 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
batch_info = (
pd.Categorical(np.zeros(adata.shape[0], dtype=int))
if batch_key is None
else adata.obs[batch_key].to_numpy()
else adata.obs[batch_key]
)

norm_gene_vars = []

adata_agg = AnnData(
X=data,
var=pd.DataFrame(index=adata.var_names),
obs=pd.DataFrame(
index=adata.obs_names, data={"__hvg_v3_batch_info__": batch_info}
),
)
if batch_key is not None:
aggregated_mean_var = aggregate(
adata_agg, by="__hvg_v3_batch_info__", func=["mean", "var"]
)
mean_global, var_global = (
aggregated_mean_var.layers[l] for l in ["mean", "var"]
)
if isinstance(mean_global, DaskArray):
import dask.array as da

mean_global, var_global = da.compute(mean_global, var_global)
aggregated_mean_var.layers["mean"] = mean_global
aggregated_mean_var.layers["var"] = var_global
Comment on lines +185 to +193

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems a bit verbose for what it is, don’t we have a helper for that or am I thinking f-a-u?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What aspect of it is verbose?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating the intermediates. I think I got confused searching for where they are used after, only to realize they aren’t. But maybe that’s just me.

Would this work or can they be non-ndarrays?

Suggested change
mean_global, var_global = (
aggregated_mean_var.layers[l] for l in ["mean", "var"]
)
if isinstance(mean_global, DaskArray):
import dask.array as da
mean_global, var_global = da.compute(mean_global, var_global)
aggregated_mean_var.layers["mean"] = mean_global
aggregated_mean_var.layers["var"] = var_global
aggregated_mean_var.layers["mean"], aggregated_mean_var.layers["var"] = materialize_as_ndarray(
*(aggregated_mean_var.layers[l] for l in ["mean", "var"])
)

else:
aggregated_mean_var = AnnData(
var=pd.DataFrame(index=adata.var_names),
obs=pd.DataFrame(
index=np.array(["one"]), data={"__hvg_v3_batch_info__": np.array([0])}
),
layers={
"mean": df["means"].to_numpy().reshape((1, -1)),
"var": df["variances"].to_numpy().reshape((1, -1)),
},
)
batch_info = batch_info.to_numpy()
for b in np.unique(batch_info):
data_batch = data[batch_info == b]

mean, var = stats.mean_var(data_batch, axis=0, correction=1)
# These get computed anyway for loess
if isinstance(mean, DaskArray):
mean, var = mean.compute(), var.compute()
mean, var = (
aggregated_mean_var[
aggregated_mean_var.obs["__hvg_v3_batch_info__"] == b
].layers[l]
for l in ["mean", "var"]
)
if isinstance(mean, CSBase):
mean = mean.toarray()
mean = mean.ravel()
if isinstance(var, CSBase):
var = var.toarray()
var = var.ravel()
estimat_var = np.zeros(data.shape[1], dtype=np.float64)
if (not_const := var > 0).any():
y = np.log10(var[not_const])
Expand All @@ -204,8 +238,13 @@ def _highly_variable_genes_seurat_v3( # noqa: PLR0912, PLR0915
+ squared_batch_counts_sum
- 2 * batch_counts_sum * mean
)
norm_gene_vars.append(norm_gene_var.reshape(1, -1))
norm_gene_vars.append(norm_gene_var)
if any(isinstance(e, DaskArray) for e in norm_gene_vars):
import dask.array as da

norm_gene_vars = da.compute(*norm_gene_vars)

norm_gene_vars = [ngv.reshape(1, -1) for ngv in norm_gene_vars]
norm_gene_vars = np.concatenate(norm_gene_vars, axis=0)
# argsort twice gives ranks, small rank means most variable
ranked_norm_gene_vars = np.argsort(np.argsort(-norm_gene_vars, axis=1), axis=1)
Expand Down
Loading