From f51a30f0aa9f7236e6051c1cef04427f8771cf7f Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 20 Apr 2026 17:20:16 +0200 Subject: [PATCH 01/15] WIP base without `X_` --- pyproject.toml | 3 +- src/scanpy/_settings/presets.py | 22 +++++++++- src/scanpy/experimental/pp/_normalization.py | 25 +++++++----- src/scanpy/experimental/pp/_recipes.py | 22 +++++----- src/scanpy/neighbors/__init__.py | 20 ++-------- src/scanpy/tools/_diffmap.py | 18 +++++++-- src/scanpy/tools/_dpt.py | 12 ++++-- src/scanpy/tools/_draw_graph.py | 42 +++++++++++++++----- 8 files changed, 109 insertions(+), 55 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b1e586eab7..8450b7984a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -272,7 +272,8 @@ filterwarnings = [ "ignore:FNV hashing is not implemented in Numba.*:UserWarning", # we want to see and eventually fix these "default::numba.core.errors.NumbaPerformanceWarning", - "default:.*TSNE.*random.*to.*pca:FutureWarning", # we should set init=obsm["X_pca"] or so + # we should set init=obsm["X_pca"] or so + "default:.*TSNE.*random.*to.*pca:FutureWarning", # matplotlib <3.11 uses old pyparsing APIs "ignore::pyparsing.warnings.PyparsingDeprecationWarning", # igraph vs leidenalg warning diff --git a/src/scanpy/_settings/presets.py b/src/scanpy/_settings/presets.py index bef0280b39..5b2d828fdd 100644 --- a/src/scanpy/_settings/presets.py +++ b/src/scanpy/_settings/presets.py @@ -74,10 +74,14 @@ class HVGPreset(NamedTuple): return_df: bool -class PcaPreset(NamedTuple): +class BasicEmbeddingPreset(NamedTuple): key_added: str | None +# replace once they diverge +PcaPreset = DiffmapPreset = DrawGraphPreset = BasicEmbeddingPreset + + class RankGenesGroupsPreset(NamedTuple): method: DETest mask_var: str | None @@ -181,6 +185,22 @@ def pca() -> Mapping[Preset, PcaPreset]: Preset.ScanpyV2Preview: PcaPreset(key_added="pca"), } + @preset_property + def diffmap() -> Mapping[Preset, DiffmapPreset]: + """Settings for :func:`~scanpy.tl.diffmap`.""" # noqa: D401 + return { + Preset.ScanpyV1: DiffmapPreset(key_added=None), + Preset.ScanpyV2Preview: DiffmapPreset(key_added="diffmap"), + } + + @preset_property + def draw_graph() -> Mapping[Preset, DrawGraphPreset]: + """Settings for :func:`~scanpy.tl.draw_graph`.""" # noqa: D401 + return { + Preset.ScanpyV1: DrawGraphPreset(key_added=None), + Preset.ScanpyV2Preview: DrawGraphPreset(key_added="graph_{layout}"), + } + @preset_property def rank_genes_groups() -> Mapping[Preset, RankGenesGroupsPreset]: """Correlation method for :func:`~scanpy.tl.rank_genes_groups`.""" diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index 551df22be4..142dd6620d 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -7,6 +7,7 @@ from anndata import AnnData from ... import logging as logg +from ... import settings from ..._compat import CSBase, warn from ..._settings import Default from ..._utils import _doc_params, check_nonnegative_integers, view_to_actual @@ -207,19 +208,23 @@ def normalize_pearson_residuals_pca( `.uns['pearson_residuals_normalization']['clip']` The used value of the clipping parameter. - `.obsm['X_pca']` + `.obsm[kwargs_pca.get('key_added', 'X_pca')]` PCA representation of data after gene selection (if applicable) and Pearson residual normalization. - `.varm['PCs']` + `.varm[kwargs_pca.get('key_added', 'PCs')]` The principal components containing the loadings. When `inplace=True` and `mask_var is not None`, this will contain empty rows for the genes not selected. - `.uns['pca']['variance_ratio']` + `.uns[kwargs_pca.get('key_added', 'pca')]['variance_ratio']` Ratio of explained variance. - `.uns['pca']['variance']` + `.uns[kwargs_pca.get('key_added', 'pca')]['variance']` Explained variance, equivalent to the eigenvalues of the covariance matrix. """ + key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added) + key_obsm, key_varm, key_uns = ( + ("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3 + ) if isinstance(mask_var, Default): mask_var = "highly_variable" if "highly_variable" in adata.var else None mask_var = _check_mask(adata, mask_var, "var") @@ -236,19 +241,19 @@ def normalize_pearson_residuals_pca( adata_pca, theta=theta, clip=clip, check_values=check_values ) pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca) - n_comps = adata_pca.obsm["X_pca"].shape[1] # might be None + n_comps = adata_pca.obsm[key_obsm].shape[1] # might be None if inplace: norm_settings = adata_pca.uns["pearson_residuals_normalization"] norm_dict = dict(**norm_settings, pearson_residuals_df=adata_pca.to_df()) if mask_var is not None: - adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, n_comps)) - adata.varm["PCs"][mask_var] = adata_pca.varm["PCs"] + adata.varm[key_varm] = np.zeros(shape=(adata.n_vars, n_comps)) + adata.varm[key_varm][mask_var] = adata_pca.varm[key_varm] else: - adata.varm["PCs"] = adata_pca.varm["PCs"] - adata.uns["pca"] = adata_pca.uns["pca"] + adata.varm[key_varm] = adata_pca.varm[key_varm] + adata.uns[key_uns] = adata_pca.uns[key_uns] adata.uns["pearson_residuals_normalization"] = norm_dict - adata.obsm["X_pca"] = adata_pca.obsm["X_pca"] + adata.obsm[key_obsm] = adata_pca.obsm[key_obsm] return None else: return adata_pca diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index 61db8e8a5e..a1b5634b0e 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -5,7 +5,7 @@ import numpy as np -from ... import experimental +from ... import experimental, settings from ..._utils import _doc_params from ..._utils.random import _accepts_legacy_random_state from ...experimental._docs import ( @@ -103,18 +103,22 @@ def recipe_pearson_residuals( # noqa: PLR0913 `.uns['pearson_residuals_normalization']['clip']` The used value of the clipping parameter. - `.obsm['X_pca']` + `.obsm[kwargs_pca.get('key_added', 'X_pca')]` PCA representation of data after gene selection and Pearson residual normalization. - `.varm['PCs']` + `.varm[kwargs_pca.get('key_added', 'PCs')]` The principal components containing the loadings. When `inplace=True` this will contain empty rows for the genes not selected during HVG selection. - `.uns['pca']['variance_ratio']` + `.uns[kwargs_pca.get('key_added', 'pca')]['variance_ratio']` Ratio of explained variance. - `.uns['pca']['variance']` + `.uns[kwargs_pca.get('key_added', 'pca')]['variance']` Explained variance, equivalent to the eigenvalues of the covariance matrix. """ + key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added) + key_obsm, key_varm, key_uns = ( + ("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3 + ) hvg_args = dict( flavor="pearson_residuals", n_top_genes=n_top_genes, @@ -145,11 +149,11 @@ def recipe_pearson_residuals( # noqa: PLR0913 **normalization_param, pearson_residuals_df=adata_pca.to_df() ) - adata.uns["pca"] = adata_pca.uns["pca"] - adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, n_comps)) - adata.varm["PCs"][adata.var["highly_variable"]] = adata_pca.varm["PCs"] + adata.uns[key_uns] = adata_pca.uns[key_uns] + adata.varm[key_varm] = np.zeros(shape=(adata.n_vars, n_comps)) + adata.varm[key_varm][adata.var["highly_variable"]] = adata_pca.varm[key_varm] adata.uns["pearson_residuals_normalization"] = normalization_dict - adata.obsm["X_pca"] = adata_pca.obsm["X_pca"] + adata.obsm[key_obsm] = adata_pca.obsm[key_obsm] return None else: return adata_pca, hvg diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 239822e3f2..9c584373fb 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -320,20 +320,6 @@ class FlatTree(NamedTuple): # noqa: D101 indices: None -def _backwards_compat_get_full_x_diffmap(adata: AnnData) -> np.ndarray: - if "X_diffmap0" in adata.obs: - return np.c_[adata.obs["X_diffmap0"].values[:, None], adata.obsm["X_diffmap"]] - else: - return adata.obsm["X_diffmap"] - - -def _backwards_compat_get_full_eval(adata: AnnData): - if "X_diffmap0" in adata.obs: - return np.r_[1, adata.uns["diffmap_evals"]] - else: - return adata.uns["diffmap_evals"] - - def _make_forest_dict(forest): d = {} props = ("hyperplanes", "offsets", "children", "indices") @@ -482,9 +468,9 @@ def count_nonzero(a: np.ndarray | CSRBase) -> int: self._connected_components = connected_components(self._connectivities) self._number_connected_components = self._connected_components[0] - if "X_diffmap" in adata.obsm: - self._eigen_values = _backwards_compat_get_full_eval(adata) - self._eigen_basis = _backwards_compat_get_full_x_diffmap(adata) + if dm := (adata.obsm.get("diffmap") or adata.obsm.get("X_diffmap")): + self._eigen_values = adata.uns["diffmap_evals"] + self._eigen_basis = dm if n_dcs is not None: if n_dcs > len(self._eigen_values): msg = ( diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index 90f2976837..0f61348d5c 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -5,6 +5,7 @@ import numpy as np from .._docs import doc_rng +from .._settings import Default, settings from .._utils import _doc_params from .._utils.random import _accepts_legacy_random_state from ._dpt import _diffmap @@ -22,6 +23,7 @@ def diffmap( n_comps: int = 15, *, neighbors_key: str | None = None, + key_added: str | None | Default = Default(preset=("diffmap", "key_added")), rng: SeedLike | RNGLike | None = None, copy: bool = False, ) -> AnnData | None: @@ -55,6 +57,8 @@ def diffmap( .obsp[.uns[neighbors_key]['connectivities_key']] and .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. + key_added + Control where the embedding and eigenvalues are stored. {rng} copy Return a copy instead of writing to adata. @@ -63,11 +67,11 @@ def diffmap( ------- Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields: - `adata.obsm['X_diffmap']` : :class:`numpy.ndarray` (dtype `float`) + `adata.obsm['X_diffmap' | key_added]` : :class:`numpy.ndarray` (dtype `float`) Diffusion map representation of data, which is the right eigen basis of the transition matrix with eigenvectors as columns. - `adata.uns['diffmap_evals']` : :class:`numpy.ndarray` (dtype `float`) + `adata.uns['diffmap_evals' | key_added]` : :class:`numpy.ndarray` (dtype `float`) Array of size (number of eigen vectors). Eigenvalues of transition matrix. @@ -82,6 +86,8 @@ def diffmap( rng = np.random.default_rng(rng) if neighbors_key is None: neighbors_key = "neighbors" + if isinstance(key_added, Default): + key_added = settings.preset.diffmap.key_added if neighbors_key not in adata.uns: msg = "You need to run `pp.neighbors` first to compute a neighborhood graph." @@ -90,5 +96,11 @@ def diffmap( msg = "Provide any value greater than 2 for `n_comps`. " raise ValueError(msg) adata = adata.copy() if copy else adata - _diffmap(adata, n_comps=n_comps, neighbors_key=neighbors_key, rng=rng) + _diffmap( + adata, + n_comps=n_comps, + neighbors_key=neighbors_key, + key_added=key_added, + rng=rng, + ) return adata if copy else None diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 74d7fe66ae..a53b984d60 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -22,21 +22,25 @@ def _diffmap( n_comps: int = 15, *, neighbors_key: str | None, + key_added: str | None, rng: np.random.Generator, ) -> None: + obsm_key, uns_key = ( + ("X_diffmap", "diffmap_evals") if key_added is None else ((key_added,) * 2) + ) start = logg.info(f"computing Diffusion Maps using {n_comps=}(=n_dcs)") dpt = DPT(adata, neighbors_key=neighbors_key) dpt.compute_transitions() dpt.compute_eigen(n_comps=n_comps, rng=rng) - adata.obsm["X_diffmap"] = dpt.eigen_basis - adata.uns["diffmap_evals"] = dpt.eigen_values + adata.obsm[obsm_key] = dpt.eigen_basis + adata.uns[uns_key] = dpt.eigen_values logg.info( " finished", time=start, deep=( "added\n" - " 'X_diffmap', diffmap coordinates (adata.obsm)\n" - " 'diffmap_evals', eigenvalues of transition matrix (adata.uns)" + f" {obsm_key!r}, diffmap coordinates (adata.obsm)\n" + f" {uns_key!r}, eigenvalues of transition matrix (adata.uns)" ), ) diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index bdbb37d0be..339668a00c 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -5,6 +5,9 @@ import numpy as np +from scanpy._compat import warn +from scanpy._settings import Default + from .. import _utils from .. import logging as logg from .._docs import doc_rng @@ -42,10 +45,12 @@ def draw_graph( # noqa: PLR0913 rng: SeedLike | RNGLike | None = None, n_jobs: int | None = None, adjacency: CSBase | None = None, - key_added_ext: str | None = None, + key_added: str | Default = Default(preset=("draw_graph", "key_added")), neighbors_key: str | None = None, obsp: str | None = None, copy: bool = False, + # deprecated + key_added_ext: str | None = None, **kwds, ) -> AnnData | None: """Force-directed graph drawing :cite:p:`Islam2011,Jacomy2014,Chippada2018`. @@ -86,10 +91,10 @@ def draw_graph( # noqa: PLR0913 Applies to layouts with random initialization like `'fr'`. adjacency Sparse adjacency matrix of the graph, defaults to neighbors connectivities. - key_added_ext - By default, append `layout`. + key_added + Template for the key. If `None`, use `'X_draw_graph_{layout}'` for `obsm` (replacing `'{layout}'` with the passed `layout`). proceed - Continue computation, starting off with 'X_draw_graph_`layout`'. + Continue computation, starting off with `f'X_draw_graph_{layout}'`. init_pos `'paga'`/`True`, `None`/`False`, or any valid 2d-`.obsm` key. Use precomputed coordinates for initialization. @@ -113,7 +118,7 @@ def draw_graph( # noqa: PLR0913 ------- Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields: - `adata.obsm['X_draw_graph_[layout | key_added_ext]']` : :class:`numpy.ndarray` (dtype `float`) + `adata.obsm[('X_draw_graph_{layout}' | key_added).format(layout=layout)]` : :class:`numpy.ndarray` (dtype `float`) Coordinates of graph layout. E.g. for `layout='fa'` (the default), the field is called `'X_draw_graph_fa'`. `key_added_ext` overwrites `layout`. `adata.uns['draw_graph']`: :class:`dict` @@ -121,6 +126,7 @@ def draw_graph( # noqa: PLR0913 """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") + key_obsm, key_uns = _get_keys_added(key_added, layout, key_added_ext) rng = np.random.default_rng(rng) meta_random_state = ( dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} @@ -161,18 +167,34 @@ def draw_graph( # noqa: PLR0913 else: ig_layout = g.layout(layout, **kwds) positions = np.array(ig_layout.coords) - adata.uns["draw_graph"] = {} - adata.uns["draw_graph"]["params"] = dict(layout=layout, **meta_random_state) - key_added = f"X_draw_graph_{key_added_ext or layout}" - adata.obsm[key_added] = positions + adata.uns[key_uns] = {} + adata.uns[key_uns]["params"] = dict(layout=layout, **meta_random_state) + adata.obsm[key_obsm] = positions logg.info( " finished", time=start, - deep=f"added\n {key_added!r}, graph_drawing coordinates (adata.obsm)", + deep="added" + f"\n {key_obsm!r}, draw_graph coordinates (adata.obsm)" + f"\n {key_uns!r}, draw_graph parameters (adata.uns)", ) return adata if copy else None +def _get_keys_added( + key_added: str | Default, layout: str, key_added_ext: str | None +) -> tuple[str, str]: + if key_added_ext is not None: + msg = "Passing `key_added_ext` is deprecated, use `key_added`’s template functionality instead." + warn(msg, category=FutureWarning) + suffix = key_added_ext + else: + suffix = layout + if isinstance(key_added, Default): + return f"X_draw_graph_{suffix}", "draw_graph" + key_added = key_added.format(layout=suffix) + return key_added, key_added + + def fa2_positions( adjacency: CSBase | np.ndarray, init_coords: np.ndarray, **kwds ) -> list[tuple[float, float]]: From b7f6bbaa7fa68db2fdd9a91c414cca44ef51e50b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 11 May 2026 16:26:25 +0200 Subject: [PATCH 02/15] fix draw_graph docstring --- src/scanpy/tools/_draw_graph.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index 339668a00c..be5321e06d 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -92,20 +92,20 @@ def draw_graph( # noqa: PLR0913 adjacency Sparse adjacency matrix of the graph, defaults to neighbors connectivities. key_added - Template for the key. If `None`, use `'X_draw_graph_{layout}'` for `obsm` (replacing `'{layout}'` with the passed `layout`). + Template for the key. If `None`, uses `f'X_draw_graph_{{layout}}'` for `obsm`. proceed - Continue computation, starting off with `f'X_draw_graph_{layout}'`. + Continue computation, starting off with `f'X_draw_graph_{{layout}}'`. init_pos `'paga'`/`True`, `None`/`False`, or any valid 2d-`.obsm` key. Use precomputed coordinates for initialization. If `False`/`None` (the default), initialize randomly. neighbors_key - If not specified, draw_graph looks at .obsp['connectivities'] for connectivities + If not specified, draw_graph looks at `.obsp['connectivities']` for connectivities (default storage place for pp.neighbors). If specified, draw_graph looks at - .obsp[.uns[neighbors_key]['connectivities_key']] for connectivities. + `.obsp[.uns[neighbors_key]['connectivities_key']]` for connectivities. obsp - Use .obsp[obsp] as adjacency. You can't specify both + Use `.obsp[obsp]` as adjacency. You can't specify both `obsp` and `neighbors_key` at the same time. copy Return a copy instead of writing to adata. @@ -118,7 +118,7 @@ def draw_graph( # noqa: PLR0913 ------- Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields: - `adata.obsm[('X_draw_graph_{layout}' | key_added).format(layout=layout)]` : :class:`numpy.ndarray` (dtype `float`) + `adata.obsm[(f'X_draw_graph_{{layout}}' | key_added).format(layout=layout)]` : :class:`numpy.ndarray` (dtype `float`) Coordinates of graph layout. E.g. for `layout='fa'` (the default), the field is called `'X_draw_graph_fa'`. `key_added_ext` overwrites `layout`. `adata.uns['draw_graph']`: :class:`dict` From 9e0464aa439ff045a8afd95d0ae5e88eab99796d Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Mon, 11 May 2026 16:47:27 +0200 Subject: [PATCH 03/15] _get_pca_or_small_x --- src/scanpy/tools/_utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index 4190907e27..8761c6d15d 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -19,8 +19,8 @@ def _choose_representation( adata: AnnData, *, - use_rep: str | None = None, - n_pcs: int | None = None, + use_rep: str | None, + n_pcs: int | None, silent: bool = False, ) -> np.ndarray | CSRBase: # TODO: what else? verbosity = settings.verbosity @@ -55,12 +55,12 @@ def _get_pca_or_small_x(adata: AnnData, n_pcs: int | None) -> np.ndarray | CSRBa logg.info(" using data matrix X directly") return adata.X - if "X_pca" in adata.obsm: - if n_pcs is not None and n_pcs > adata.obsm["X_pca"].shape[1]: - msg = "`X_pca` does not have enough PCs. Rerun `sc.pp.pca` with adjusted `n_comps`." + if key := next((b for b in ["X_pca", "pca"] if b in adata.obsm), None): + if n_pcs is not None and n_pcs > adata.obsm[key].shape[1]: + msg = f"adata.obsm[{key!r}] does not have enough PCs. Rerun `sc.pp.pca` with adjusted `n_comps`." raise ValueError(msg) - x = adata.obsm["X_pca"][:, :n_pcs] - logg.info(f" using 'X_pca' with n_pcs = {x.shape[1]}") + x = adata.obsm[key][:, :n_pcs] + logg.info(f" using {key!r} with n_pcs = {x.shape[1]}") return x from ..preprocessing import pca @@ -73,7 +73,7 @@ def _get_pca_or_small_x(adata: AnnData, n_pcs: int | None) -> np.ndarray | CSRBa warn(msg, UserWarning) n_pcs_pca = n_pcs if n_pcs is not None else settings.N_PCS pca(adata, n_comps=n_pcs_pca) - return adata.obsm["X_pca"] + return adata.obsm[settings.preset.pca.key_added or "X_pca"] def get_init_pos_from_paga( From 22f537900d74678e5d756ad366ec6c8c9fdce6a8 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jun 2026 15:31:48 +0200 Subject: [PATCH 04/15] add tests --- src/scanpy/tools/_draw_graph.py | 8 +++--- tests/test_embedding.py | 44 +++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index e4553bb70f..dcfe9cb3cf 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -6,7 +6,7 @@ import numpy as np from scanpy._compat import warn -from scanpy._settings import Default +from scanpy._settings import Default, settings from .. import _utils from .. import logging as logg @@ -45,7 +45,7 @@ def draw_graph( # noqa: PLR0913 rng: SeedLike | RNGLike | None = None, n_jobs: int | None = None, adjacency: CSBase | None = None, - key_added: str | Default = Default(preset=("draw_graph", "key_added")), + key_added: str | None | Default = Default(preset=("draw_graph", "key_added")), neighbors_key: str | None = None, obsp: str | None = None, copy: bool = False, @@ -181,7 +181,7 @@ def draw_graph( # noqa: PLR0913 def _get_keys_added( - key_added: str | Default, layout: str, key_added_ext: str | None + key_added: str | None | Default, layout: str, key_added_ext: str | None ) -> tuple[str, str]: if key_added_ext is not None: msg = "Passing `key_added_ext` is deprecated, use `key_added`’s template functionality instead." @@ -190,6 +190,8 @@ def _get_keys_added( else: suffix = layout if isinstance(key_added, Default): + key_added = settings.preset.draw_graph.key_added + if key_added is None: return f"X_draw_graph_{suffix}", "draw_graph" key_added = key_added.format(layout=suffix) return key_added, key_added diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 436a72db0e..0b5b09ff2e 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -7,6 +7,7 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_raises import scanpy as sc +from scanpy._settings import Default from testing.scanpy._helpers.data import pbmc68k_reduced from testing.scanpy._pytest.marks import needs @@ -109,3 +110,46 @@ def test_diffmap( assert_array_equal(d1, d2) with subtests.test("different embedding"): assert_raises(AssertionError, assert_array_equal, d1, d3) + + +@pytest.mark.parametrize( + ("key_added", "key_obsm", "key_uns"), + [ + pytest.param(None, "X_diffmap", "diffmap_evals", id="None"), + pytest.param("custom_key", "custom_key", "custom_key", id="custom_key"), + pytest.param(sc.Preset.ScanpyV1, "X_diffmap", "diffmap_evals", id="v1"), + pytest.param(sc.Preset.ScanpyV2Preview, "diffmap", "diffmap", id="v2"), + ], +) +def test_diffmap_key_added( + key_added: str | None | Default | sc.Preset, key_obsm: str, key_uns: str +) -> None: + pbmc = pbmc68k_reduced()[:300, :100].copy() + if isinstance(key_added, sc.Preset): + sc.settings.preset = key_added + key_added = Default() + adata = sc.tl.diffmap(pbmc, key_added=key_added, copy=True) + assert key_obsm in adata.obsm + assert key_uns in adata.uns + + +@needs.igraph +@pytest.mark.parametrize( + ("key_added", "key_obsm", "key_uns"), + [ + pytest.param(None, "X_draw_graph_fr", "draw_graph", id="None"), + pytest.param("custom_{layout}", "custom_fr", "custom_fr", id="custom_template"), + pytest.param(sc.Preset.ScanpyV1, "X_draw_graph_fr", "draw_graph", id="v1"), + pytest.param(sc.Preset.ScanpyV2Preview, "graph_fr", "graph_fr", id="v2"), + ], +) +def test_draw_graph_key_added( + key_added: str | None | Default | sc.Preset, key_obsm: str, key_uns: str +) -> None: + pbmc = pbmc68k_reduced()[:100, :100].copy() + if isinstance(key_added, sc.Preset): + sc.settings.preset = key_added + key_added = Default() + adata = sc.tl.draw_graph(pbmc, layout="fr", key_added=key_added, copy=True) + assert key_obsm in adata.obsm + assert key_uns in adata.uns From 0bfcbe19bd0a953b3ca551d1a856fb07969cbf41 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jun 2026 15:39:56 +0200 Subject: [PATCH 05/15] relnote --- docs/release-notes/4076.feat.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 docs/release-notes/4076.feat.md diff --git a/docs/release-notes/4076.feat.md b/docs/release-notes/4076.feat.md new file mode 100644 index 0000000000..9a4938e420 --- /dev/null +++ b/docs/release-notes/4076.feat.md @@ -0,0 +1 @@ +Add preset support to {func}`scanpy.tl.diffmap`’s and {func}`scanpy.tl.draw_graph`’s `key_added` parameter {smaller}`P Angerer` From a390628a46a9c0bf7015904e5e3c2e742657f3a8 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jun 2026 15:47:20 +0200 Subject: [PATCH 06/15] fix tests --- src/scanpy/neighbors/__init__.py | 4 +++- src/scanpy/tools/_dpt.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index da80150e88..3464ed21c6 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -470,7 +470,9 @@ def count_nonzero(a: np.ndarray | CSRBase) -> int: self._connected_components = connected_components(self._connectivities) self._number_connected_components = self._connected_components[0] - if dm := (adata.obsm.get("diffmap") or adata.obsm.get("X_diffmap")): + if ( + dm := (adata.obsm.get("diffmap") or adata.obsm.get("X_diffmap")) + ) is not None: self._eigen_values = adata.uns["diffmap_evals"] self._eigen_basis = dm if n_dcs is not None: diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index a53b984d60..893aaad259 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -151,7 +151,9 @@ def dpt( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." ) - _diffmap(adata, neighbors_key=neighbors_key, rng=_LegacyRng(0)) + from ._diffmap import diffmap + + diffmap(adata, neighbors_key=neighbors_key, rng=_LegacyRng(0)) # start with the actual computation dpt = DPT( adata, From 900a4d7e4373f01a8b91a975290e9158f770fff2 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jun 2026 16:23:52 +0200 Subject: [PATCH 07/15] fix tests for min job --- tests/test_embedding.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 0b5b09ff2e..f1830aaedb 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -118,7 +118,11 @@ def test_diffmap( pytest.param(None, "X_diffmap", "diffmap_evals", id="None"), pytest.param("custom_key", "custom_key", "custom_key", id="custom_key"), pytest.param(sc.Preset.ScanpyV1, "X_diffmap", "diffmap_evals", id="v1"), - pytest.param(sc.Preset.ScanpyV2Preview, "diffmap", "diffmap", id="v2"), + pytest.param( + *(sc.Preset.ScanpyV2Preview, "diffmap", "diffmap"), + marks=[needs.igraph, needs.skmisc], + id="v2", + ), ], ) def test_diffmap_key_added( @@ -140,7 +144,11 @@ def test_diffmap_key_added( pytest.param(None, "X_draw_graph_fr", "draw_graph", id="None"), pytest.param("custom_{layout}", "custom_fr", "custom_fr", id="custom_template"), pytest.param(sc.Preset.ScanpyV1, "X_draw_graph_fr", "draw_graph", id="v1"), - pytest.param(sc.Preset.ScanpyV2Preview, "graph_fr", "graph_fr", id="v2"), + pytest.param( + *(sc.Preset.ScanpyV2Preview, "graph_fr", "graph_fr"), + marks=needs.skmisc, + id="v2", + ), ], ) def test_draw_graph_key_added( From e0773012a7e38f8ea35241eb403b2e10c9e3d108 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jun 2026 18:10:51 +0200 Subject: [PATCH 08/15] deal with defaults better --- src/scanpy/_utils/__init__.py | 12 ++++++++ src/scanpy/neighbors/__init__.py | 18 ++++++++---- src/scanpy/plotting/_tools/scatterplots.py | 6 ++-- src/scanpy/preprocessing/_pca/__init__.py | 22 +++++++++++---- src/scanpy/tools/_diffmap.py | 4 +-- src/scanpy/tools/_dpt.py | 33 +++++++++++++++++----- src/scanpy/tools/_utils.py | 14 ++++----- 7 files changed, 78 insertions(+), 31 deletions(-) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index f0d67b550d..f7faa0a1fa 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -34,6 +34,7 @@ from .. import logging as logg from .._compat import CSBase, DaskArray, SpBase, _CSArray, pkg_version, warn +from .._settings import Preset from ._numba import _numba_thread_limit if TYPE_CHECKING: @@ -58,6 +59,7 @@ "NeighborsView", "_choose_graph", "_doc_params", + "_existing_preset_keys", "_numba_thread_limit", "_resolve_axis", "annotate_doc_types", @@ -573,6 +575,16 @@ def get_literal_vals(typ: UnionType | TypeAliasType | Any) -> KeysView[Any]: raise TypeError(msg) +def _existing_preset_keys[T: tuple[str, ...]]( + adata: AnnData, keys: Callable[..., T] +) -> T | None: + for preset in (Preset.ScanpyV1, Preset.ScanpyV2Preview): + obsm_key, *rest = keys(preset) + if obsm_key in adata.obsm: + return (obsm_key, *rest) + return None + + # -------------------------------------------------------------------------------- # Others # -------------------------------------------------------------------------------- diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 3464ed21c6..a9c838b949 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -18,7 +18,7 @@ from .._compat import CSBase, CSRBase, SpBase, pkg_version, warn from .._docs import doc_rng from .._settings import settings -from .._utils import NeighborsView, _doc_params, get_literal_vals +from .._utils import NeighborsView, _doc_params, _existing_preset_keys, get_literal_vals from .._utils.random import ( _accepts_legacy_random_state, _legacy_random_state, @@ -418,6 +418,7 @@ def __init__( # noqa: PLR0912, PLR0915 *, n_dcs: int | None = None, neighbors_key: str | None = None, + diffmap_key: str | None = None, ) -> None: self._adata = adata self._init_iroot() @@ -470,11 +471,16 @@ def count_nonzero(a: np.ndarray | CSRBase) -> int: self._connected_components = connected_components(self._connectivities) self._number_connected_components = self._connected_components[0] - if ( - dm := (adata.obsm.get("diffmap") or adata.obsm.get("X_diffmap")) - ) is not None: - self._eigen_values = adata.uns["diffmap_evals"] - self._eigen_basis = dm + + from ..tools._dpt import _diffmap_keys + + if keys := ( + (diffmap_key, diffmap_key) + if diffmap_key + else _existing_preset_keys(adata, _diffmap_keys) + ): + self._eigen_values = adata.uns[keys[1]] + self._eigen_basis = adata.obsm[keys[0]] if n_dcs is not None: if n_dcs > len(self._eigen_values): msg = ( diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index e18305140f..2fc04d190c 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -17,9 +17,11 @@ from matplotlib.markers import MarkerStyle from scverse_misc import Deprecation, deprecated +from scanpy.preprocessing._pca import _pca_keys + from ... import logging as logg from ..._settings import Default, settings -from ..._utils import _doc_params, sanitize_anndata +from ..._utils import _doc_params, _existing_preset_keys, sanitize_anndata from ..._utils._doctests import doctest_internet from ...get import _check_mask from .. import _utils @@ -925,7 +927,7 @@ def pca( return embedding( adata, "pca", show=show, return_fig=return_fig, save=save, **kwargs ) - if "pca" not in adata.obsm and "X_pca" not in adata.obsm: + if not _existing_preset_keys(adata, _pca_keys): msg = ( f"Could not find entry in `obsm` for 'pca'.\n" f"Available keys are: {list(adata.obsm.keys())}." diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 4f7a267e15..52741a2c3a 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -8,7 +8,7 @@ from ... import logging as logg from ..._compat import CSBase, DaskArray, warn from ..._docs import doc_rng -from ..._settings import Default, settings +from ..._settings import Default, Preset, settings from ..._utils import _doc_params, get_literal_vals, is_backed_type from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state from ...get import _check_mask, _get_obs_rep @@ -47,6 +47,20 @@ type SvdSolver = SvdSolvDaskML | SvdSolvSkearn | SvdSolvPCACustom +def _pca_keys( + key_added: str | None | Default | Preset = Default(), +) -> tuple[str, str, str]: + if isinstance(key_added, Default): + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.pca.key_added + return ( + ("X_pca", "PCs", "pca") + if key_added is None + else (key_added, key_added, key_added) + ) + + @_doc_params(mask_var=doc_mask_var, rng=doc_rng) @_accepts_legacy_random_state(0) def pca( # noqa: PLR0912, PLR0913, PLR0915 @@ -203,8 +217,6 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 # Current chunking implementation relies on pca being called on X msg = "Cannot use `layer`/`obsm` and `chunked` at the same time." raise NotImplementedError(msg) - if isinstance(key_added, Default): - key_added = settings.preset.pca.key_added # chunked calculation is not randomized, anyways if svd_solver in {"auto", "randomized"} and not chunked: @@ -337,9 +349,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 x_pca = x_pca.astype(dtype) if return_anndata: - key_obsm, key_varm, key_uns = ( - ("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3 - ) + key_obsm, key_varm, key_uns = _pca_keys(key_added) adata.obsm[key_obsm] = x_pca if obsm: diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index 0f61348d5c..0d03ee9bdc 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -5,7 +5,7 @@ import numpy as np from .._docs import doc_rng -from .._settings import Default, settings +from .._settings import Default from .._utils import _doc_params from .._utils.random import _accepts_legacy_random_state from ._dpt import _diffmap @@ -86,8 +86,6 @@ def diffmap( rng = np.random.default_rng(rng) if neighbors_key is None: neighbors_key = "neighbors" - if isinstance(key_added, Default): - key_added = settings.preset.diffmap.key_added if neighbors_key not in adata.uns: msg = "You need to run `pp.neighbors` first to compute a neighborhood graph." diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 893aaad259..f97cc98eef 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -7,8 +7,11 @@ import scipy as sp from natsort import natsorted +from scanpy._utils.random import _LegacyRng +from scanpy.tools._utils import _existing_preset_keys + from .. import logging as logg -from .._utils.random import _LegacyRng +from .._settings import Default, Preset, settings from ..neighbors import Neighbors, OnFlySymMatrix if TYPE_CHECKING: @@ -17,17 +20,25 @@ from anndata import AnnData +def _diffmap_keys(key_added: str | None | Default | Preset) -> tuple[str, str]: + if isinstance(key_added, Default): + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.diffmap.key_added + return ( + ("X_diffmap", "diffmap_evals") if key_added is None else (key_added, key_added) + ) + + def _diffmap( adata: AnnData, n_comps: int = 15, *, neighbors_key: str | None, - key_added: str | None, + key_added: str | None | Default, rng: np.random.Generator, ) -> None: - obsm_key, uns_key = ( - ("X_diffmap", "diffmap_evals") if key_added is None else ((key_added,) * 2) - ) + obsm_key, uns_key = _diffmap_keys(key_added) start = logg.info(f"computing Diffusion Maps using {n_comps=}(=n_dcs)") dpt = DPT(adata, neighbors_key=neighbors_key) dpt.compute_transitions() @@ -53,6 +64,7 @@ def dpt( min_group_size: float = 0.01, allow_kendall_tau_shift: bool = True, neighbors_key: str | None = None, + diffmap_key: str | None = None, copy: bool = False, ) -> AnnData | None: """Infer progression of cells through geodesic distance along the graph :cite:p:`Haghverdi2016,Wolf2019`. @@ -110,6 +122,9 @@ def dpt( .obsp[.uns[neighbors_key]['connectivities_key']] and .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. + diffmap_key + If specified, dpt looks in .obsm[diffmap_key] for diffmap coordinates, + otherwise int the default place. copy Copy instance before computation and return a copy. Otherwise, perform computation inplace and return `None`. @@ -146,7 +161,7 @@ def dpt( " adata.uns['iroot'] = root_cell_index\n" " adata.var['xroot'] = adata[root_cell_name, :].X" ) - if "X_diffmap" not in adata.obsm: + if not diffmap_key and not _existing_preset_keys(adata, _diffmap_keys): logg.warning( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." @@ -162,6 +177,7 @@ def dpt( n_branchings=n_branchings, allow_kendall_tau_shift=allow_kendall_tau_shift, neighbors_key=neighbors_key, + diffmap_key=diffmap_key, ) start = logg.info(f"computing Diffusion Pseudotime using {n_dcs=}") if n_branchings > 1: @@ -222,8 +238,11 @@ def __init__( n_branchings: int = 0, allow_kendall_tau_shift: bool = False, neighbors_key: str | None = None, + diffmap_key: str | None = None, ): - super().__init__(adata, n_dcs=n_dcs, neighbors_key=neighbors_key) + super().__init__( + adata, n_dcs=n_dcs, neighbors_key=neighbors_key, diffmap_key=diffmap_key + ) self.flavor = "haghverdi16" self.n_branchings = n_branchings self.min_group_size = ( diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index b85a654587..a50b1c945e 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -7,7 +7,7 @@ from .. import logging as logg from .._compat import warn from .._settings import settings -from .._utils import _choose_graph +from .._utils import _choose_graph, _existing_preset_keys if TYPE_CHECKING: from anndata import AnnData @@ -51,12 +51,14 @@ def _choose_representation( def _get_pca_or_small_x(adata: AnnData, n_pcs: int | None) -> np.ndarray | CSRBase: + from ..preprocessing._pca import _pca_keys, pca + if adata.n_vars <= settings.N_PCS: logg.info(" using data matrix X directly") return adata.X - pca_key = next((k for k in ("pca", "X_pca") if k in adata.obsm), None) - if pca_key is not None: + if keys := _existing_preset_keys(adata, _pca_keys): + pca_key, *_ = keys if n_pcs is not None and n_pcs > adata.obsm[pca_key].shape[1]: msg = f"`adata.obsm[{pca_key!r}]` does not have enough PCs. Rerun `sc.pp.pca` with adjusted `n_comps`." raise ValueError(msg) @@ -64,17 +66,15 @@ def _get_pca_or_small_x(adata: AnnData, n_pcs: int | None) -> np.ndarray | CSRBa logg.info(f" using {pca_key!r} with n_pcs = {x.shape[1]}") return x - from ..preprocessing import pca - msg = ( f"You’re trying to run this on {adata.n_vars} dimensions of `.X`, " - "if you really want this, set `use_rep='X'`.\n " + "if you really want this, set `use_rep=’X’`.\n " "Falling back to preprocessing with `sc.pp.pca` and default params." ) warn(msg, UserWarning) n_pcs_pca = n_pcs if n_pcs is not None else settings.N_PCS pca(adata, n_comps=n_pcs_pca) - return adata.obsm[settings.preset.pca.key_added or "X_pca"] + return adata.obsm[_pca_keys(adata)[0]] def get_init_pos_from_paga( From e6e9647386ad8eed937997cdf4f1a34bfb784a72 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 23 Jun 2026 18:11:51 +0200 Subject: [PATCH 09/15] fix circular dep --- src/scanpy/_utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index f7faa0a1fa..fbef37afa4 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -34,7 +34,6 @@ from .. import logging as logg from .._compat import CSBase, DaskArray, SpBase, _CSArray, pkg_version, warn -from .._settings import Preset from ._numba import _numba_thread_limit if TYPE_CHECKING: @@ -578,6 +577,8 @@ def get_literal_vals(typ: UnionType | TypeAliasType | Any) -> KeysView[Any]: def _existing_preset_keys[T: tuple[str, ...]]( adata: AnnData, keys: Callable[..., T] ) -> T | None: + from .._settings import Preset + for preset in (Preset.ScanpyV1, Preset.ScanpyV2Preview): obsm_key, *rest = keys(preset) if obsm_key in adata.obsm: From d38194cb510e99ea3ca90d687798bc9eed80974c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 26 Jun 2026 09:00:42 +0200 Subject: [PATCH 10/15] fix _get_pca_or_small_x --- src/scanpy/tools/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index a50b1c945e..b705cd0f3c 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -74,7 +74,7 @@ def _get_pca_or_small_x(adata: AnnData, n_pcs: int | None) -> np.ndarray | CSRBa warn(msg, UserWarning) n_pcs_pca = n_pcs if n_pcs is not None else settings.N_PCS pca(adata, n_comps=n_pcs_pca) - return adata.obsm[_pca_keys(adata)[0]] + return adata.obsm[_pca_keys()[0]] def get_init_pos_from_paga( From df597e363ab80224924615c8f434357378c54728 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 26 Jun 2026 10:30:04 +0200 Subject: [PATCH 11/15] f,ix draw _graph --- src/scanpy/_utils/__init__.py | 12 ++++++++---- src/scanpy/plotting/_tools/scatterplots.py | 11 +++++------ src/scanpy/tools/_draw_graph.py | 19 +++++++++++-------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index fbef37afa4..a8f29e38c4 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable, KeysView, Mapping from pathlib import Path - from typing import Any + from typing import Any, Concatenate from anndata import AnnData from igraph import Graph @@ -47,6 +47,7 @@ from pandas._typing import Dtype as PdDtype from .._compat import CSRBase + from .._settings import Preset from ..neighbors import NeighborsParams, RPForestDict type _MemoryArray = NDArray | CSBase @@ -574,13 +575,16 @@ def get_literal_vals(typ: UnionType | TypeAliasType | Any) -> KeysView[Any]: raise TypeError(msg) -def _existing_preset_keys[T: tuple[str, ...]]( - adata: AnnData, keys: Callable[..., T] +def _existing_preset_keys[**P, T: tuple[str, ...]]( + adata: AnnData, + keys: Callable[Concatenate[Preset, P], T], + *args: P.args, + **kw: P.kwargs, ) -> T | None: from .._settings import Preset for preset in (Preset.ScanpyV1, Preset.ScanpyV2Preview): - obsm_key, *rest = keys(preset) + obsm_key, *rest = keys(preset, *args, **kw) if obsm_key in adata.obsm: return (obsm_key, *rest) return None diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 2fc04d190c..8924d96918 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -18,6 +18,7 @@ from scverse_misc import Deprecation, deprecated from scanpy.preprocessing._pca import _pca_keys +from scanpy.tools._draw_graph import _draw_graph_keys from ... import logging as logg from ..._settings import Default, settings @@ -854,12 +855,10 @@ def draw_graph( """ if layout is None: layout = str(adata.uns["draw_graph"]["params"]["layout"]) - basis = f"draw_graph_{layout}" - if f"X_{basis}" not in adata.obsm: - msg = f"Did not find {basis} in adata.obs. Did you compute layout {layout}?" - raise ValueError(msg) - - return embedding(adata, basis, **kwargs) + if keys := _existing_preset_keys(adata, _draw_graph_keys, layout): + return embedding(adata, keys[0], **kwargs) + msg = f"Did not find `adata.obsm['draw_graph_{layout}']`. Did you compute layout {layout}?" + raise ValueError(msg) @_wraps_plot_scatter diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index dcfe9cb3cf..ace03145a0 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -5,12 +5,11 @@ import numpy as np -from scanpy._compat import warn -from scanpy._settings import Default, settings - from .. import _utils from .. import logging as logg +from .._compat import warn from .._docs import doc_rng +from .._settings import Default, Preset, settings from .._utils import _choose_graph, _doc_params, get_literal_vals from .._utils.random import ( _accepts_legacy_random_state, @@ -126,7 +125,8 @@ def draw_graph( # noqa: PLR0913 """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") - key_obsm, key_uns = _get_keys_added(key_added, layout, key_added_ext) + layout = coerce_fa2_layout(layout) + key_obsm, key_uns = _draw_graph_keys(key_added, layout, key_added_ext) rng = np.random.default_rng(rng) meta_random_state = ( dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} @@ -151,7 +151,6 @@ def draw_graph( # noqa: PLR0913 else: _if_legacy_apply_global(rng) init_coords = rng.random((adjacency.shape[0], 2)) - layout = coerce_fa2_layout(layout) # actual drawing if layout == "fa": positions = np.array(fa2_positions(adjacency, init_coords, **kwds)) @@ -180,8 +179,10 @@ def draw_graph( # noqa: PLR0913 return adata if copy else None -def _get_keys_added( - key_added: str | None | Default, layout: str, key_added_ext: str | None +def _draw_graph_keys( + key_added: str | None | Default | Preset, + layout: str, + key_added_ext: str | None = None, ) -> tuple[str, str]: if key_added_ext is not None: msg = "Passing `key_added_ext` is deprecated, use `key_added`’s template functionality instead." @@ -190,7 +191,9 @@ def _get_keys_added( else: suffix = layout if isinstance(key_added, Default): - key_added = settings.preset.draw_graph.key_added + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.draw_graph.key_added if key_added is None: return f"X_draw_graph_{suffix}", "draw_graph" key_added = key_added.format(layout=suffix) From 10b74801aebf52f28063fa70aaf04d9ff7b77ae5 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 26 Jun 2026 10:44:40 +0200 Subject: [PATCH 12/15] ignore instead --- tests/test_sim.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/test_sim.py b/tests/test_sim.py index ac96966cb1..fb1fe9b8e7 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -6,9 +6,8 @@ import scanpy as sc +@pytest.mark.filterwarnings("ignore:.*Observation names are not unique:UserWarning") def test_sim_toggleswitch() -> None: - with pytest.warns(UserWarning, match=r"Observation names are not unique"): - adata_sim = sc.tl.sim("toggleswitch") - with pytest.warns(UserWarning, match=r"Observation names are not unique"): - adata_ds = sc.datasets.toggleswitch() + adata_sim = sc.tl.sim("toggleswitch") + adata_ds = sc.datasets.toggleswitch() np.allclose(adata_sim.X, adata_ds.X, np.finfo(np.float32).eps) From ca3126ff27f8f89d5793d933a024c5712111a0f7 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 26 Jun 2026 11:44:39 +0200 Subject: [PATCH 13/15] fix a few more --- src/scanpy/_settings/presets.py | 10 +++++++- src/scanpy/external/exporting.py | 16 +++++++++---- src/scanpy/external/tl/_trimap.py | 10 ++++---- src/scanpy/plotting/_tools/__init__.py | 3 ++- src/scanpy/plotting/_tools/paga.py | 4 ++-- src/scanpy/plotting/_tools/scatterplots.py | 28 ++++++++++++---------- src/scanpy/plotting/_utils.py | 12 ++++------ src/scanpy/tools/_ingest.py | 16 +++++++------ src/scanpy/tools/_umap.py | 16 ++++++++++--- 9 files changed, 72 insertions(+), 43 deletions(-) diff --git a/src/scanpy/_settings/presets.py b/src/scanpy/_settings/presets.py index 54f5f34d9d..bd6a38cb3a 100644 --- a/src/scanpy/_settings/presets.py +++ b/src/scanpy/_settings/presets.py @@ -81,7 +81,7 @@ class BasicEmbeddingPreset(NamedTuple): # replace once they diverge -PcaPreset = DiffmapPreset = DrawGraphPreset = BasicEmbeddingPreset +PcaPreset = UmapPreset = DiffmapPreset = DrawGraphPreset = BasicEmbeddingPreset class RankGenesGroupsPreset(NamedTuple): @@ -197,6 +197,14 @@ def pca() -> Mapping[Preset, PcaPreset]: Preset.ScanpyV2Preview: PcaPreset(key_added="pca"), } + @preset_property + def umap() -> Mapping[Preset, UmapPreset]: + """Settings for :func:`~scanpy.tl.umap`.""" # noqa: D401 + return { + Preset.ScanpyV1: DiffmapPreset(key_added=None), + Preset.ScanpyV2Preview: DiffmapPreset(key_added="umap"), + } + @preset_property def diffmap() -> Mapping[Preset, DiffmapPreset]: """Settings for :func:`~scanpy.tl.diffmap`.""" # noqa: D401 diff --git a/src/scanpy/external/exporting.py b/src/scanpy/external/exporting.py index 55d39eb2b2..076153b822 100644 --- a/src/scanpy/external/exporting.py +++ b/src/scanpy/external/exporting.py @@ -14,7 +14,9 @@ from fast_array_utils.stats import mean_var from pandas.api.types import CategoricalDtype -from .._utils import NeighborsView +from .._utils import NeighborsView, _existing_preset_keys +from ..preprocessing._pca import _pca_keys +from ..tools._draw_graph import _draw_graph_keys if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -81,8 +83,12 @@ def spring_project( # noqa: PLR0912, PLR0915 if embedding_method not in adata.obsm: if f"X_{embedding_method}" in adata.obsm: embedding_method = f"X_{embedding_method}" - elif embedding_method in adata.uns: - embedding_method = f"X_{embedding_method}_{adata.uns[embedding_method]['params']['layout']}" + elif embedding_method in {"graph", "draw_graph"} and ( + keys := _existing_preset_keys( + adata, _draw_graph_keys, adata.uns[embedding_method]["params"]["layout"] + ) + ): + embedding_method = keys[0] else: msg = f"Run the specified embedding method `{embedding_method}` first." raise ValueError(msg) @@ -222,10 +228,10 @@ def spring_project( # noqa: PLR0912, PLR0915 ) # Write some useful intermediates, if they exist - if "X_pca" in adata.obsm: + if keys := _existing_preset_keys(adata, _pca_keys): np.savez_compressed( subplot_dir / "intermediates.npz", - Epca=adata.obsm["X_pca"], + Epca=adata.obsm[keys[0]], total_counts=total_counts, ) diff --git a/src/scanpy/external/tl/_trimap.py b/src/scanpy/external/tl/_trimap.py index 8ccd08f1f3..922c792a8d 100644 --- a/src/scanpy/external/tl/_trimap.py +++ b/src/scanpy/external/tl/_trimap.py @@ -7,7 +7,9 @@ from ... import logging as logg from ..._compat import CSBase from ..._settings import settings +from ..._utils import _existing_preset_keys from ..._utils._doctests import doctest_needs +from ...preprocessing._pca import _pca_keys if TYPE_CHECKING: from typing import Literal @@ -100,9 +102,9 @@ def trimap( # noqa: PLR0913 verbosity = settings.verbosity if verbose is None else verbose verbose = verbosity if isinstance(verbosity, bool) else verbosity > 0 - if "X_pca" in adata.obsm: - n_dim_pca = adata.obsm["X_pca"].shape[1] - x = adata.obsm["X_pca"][:, : min(n_dim_pca, 100)] + if keys := _existing_preset_keys(adata, _pca_keys): + n_dim_pca = adata.obsm[keys[0]].shape[1] + x = adata.obsm[keys[0]][:, : min(n_dim_pca, 100)] else: x = adata.X if isinstance(x, CSBase): @@ -111,7 +113,7 @@ def trimap( # noqa: PLR0913 "use a dense matrix or apply pca first." ) raise ValueError(msg) - logg.warning("`X_pca` not found. Run `sc.pp.pca` first for speedup.") + logg.warning("`pca`/`X_pca` not found. Run `sc.pp.pca` first for speedup.") x_trimap = TRIMAP( n_dims=n_components, n_inliers=n_inliers, diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index e41f4ea908..3052589b56 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -29,6 +29,7 @@ ) from .._utils import ( _deprecated_scale, + _get_basis, savefig_or_show, timeseries, timeseries_as_heatmap, @@ -1495,7 +1496,7 @@ def embedding_density( # noqa: PLR0912, PLR0913, PLR0915 if groupby is not None: key += f"_{groupby}" - if f"X_{basis}" not in adata.obsm: + if _get_basis(adata, basis) is None: msg = ( f"Cannot find the embedded representation `adata.obsm['X_{basis}']`. " "Compute the embedding first." diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index d5100f1338..f8cfb6bd25 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -129,7 +129,7 @@ def paga_compare( # noqa: PLR0912, PLR0913 else: basis = "umap" - from .scatterplots import _components_to_dimensions, _get_basis, embedding + from .scatterplots import _components_to_dimensions, _get_basis_arr, embedding embedding( adata, @@ -156,7 +156,7 @@ def paga_compare( # noqa: PLR0912, PLR0913 if pos is None: if color == adata.uns["paga"]["groups"]: # TODO: Use dimensions here - _basis = _get_basis(adata, basis) + _basis = _get_basis_arr(adata, basis) dims = _components_to_dimensions( components=components, dimensions=None, total_dims=_basis.shape[1] )[0] diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 8924d96918..007f112283 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -17,14 +17,13 @@ from matplotlib.markers import MarkerStyle from scverse_misc import Deprecation, deprecated -from scanpy.preprocessing._pca import _pca_keys -from scanpy.tools._draw_graph import _draw_graph_keys - from ... import logging as logg from ..._settings import Default, settings from ..._utils import _doc_params, _existing_preset_keys, sanitize_anndata from ..._utils._doctests import doctest_internet from ...get import _check_mask +from ...preprocessing._pca import _pca_keys +from ...tools._draw_graph import _draw_graph_keys from .. import _utils from .._docs import ( doc_adata_color_etc, @@ -33,7 +32,13 @@ doc_scatter_spatial, doc_show_save_ax, ) -from .._utils import _obs_vector_compat, check_colornorm, check_projection, circles +from .._utils import ( + _get_basis, + _obs_vector_compat, + check_colornorm, + check_projection, + circles, +) if TYPE_CHECKING: from collections.abc import Callable, Collection, Mapping @@ -150,7 +155,7 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915 check_projection(projection) sanitize_anndata(adata) - basis_values = _get_basis(adata, basis) + basis_values = _get_basis_arr(adata, basis) dimensions = _components_to_dimensions( components, dimensions, projection=projection, total_dims=basis_values.shape[1] ) @@ -1203,15 +1208,12 @@ def _add_categorical_legend( # noqa: PLR0913 ax.legend(loc=legend_loc, fontsize=legend_fontsize) -def _get_basis(adata: AnnData, basis: str) -> np.ndarray: +def _get_basis_arr(adata: AnnData, basis: str) -> np.ndarray: """Get array for basis from anndata. Just tries to add 'X_'.""" - if basis in adata.obsm: - return adata.obsm[basis] - elif f"X_{basis}" in adata.obsm: - return adata.obsm[f"X_{basis}"] - else: - msg = f"Could not find {basis!r} or 'X_{basis}' in .obsm" - raise KeyError(msg) + if basis_key := _get_basis(adata, basis): + return adata.obsm[basis_key] + msg = f"Could not find {basis!r} or 'X_{basis}' in .obsm" + raise KeyError(msg) def _get_color_source_vector( diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index 23cad1c111..1c4aba91ee 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -1027,14 +1027,12 @@ def fix_kwds(kwds_dict, **kwargs): return kwargs -def _get_basis(adata: AnnData, basis: str): +def _get_basis(adata: AnnData, basis: str) -> str | None: if basis in adata.obsm: - basis_key = basis - - elif f"X_{basis}" in adata.obsm: - basis_key = f"X_{basis}" - - return basis_key + return basis + if f"X_{basis}" in adata.obsm: + return f"X_{basis}" + return None def check_colornorm(vmin=None, vmax=None, vcenter=None, norm=None): diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index d85d6b89bb..3b9a035d8c 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -14,11 +14,13 @@ from .._utils import ( NeighborsView, _doc_params, + _existing_preset_keys, raise_not_implemented_error_if_backed_type, ) from .._utils._doctests import doctest_skipif from .._utils.random import _legacy_random_state, _LegacyRng from ..neighbors import FlatTree +from ..tools._umap import _umap_keys if TYPE_CHECKING: from collections.abc import Generator, Iterable @@ -248,10 +250,10 @@ def _get_rng(self, params: dict[str, object]) -> np.random.Generator | None: random_state = params.get("random_state", 0) return _LegacyRng(random_state) - def _init_umap(self, adata: AnnData) -> None: + def _init_umap(self, adata: AnnData, obsm_key: str, uns_key: str) -> None: from umap import UMAP - rng = self._get_rng(adata.uns["umap"]["params"]) + rng = self._get_rng(adata.uns[uns_key]["params"]) self._umap = UMAP( metric=self._metric, random_state=_legacy_random_state(rng), @@ -265,7 +267,7 @@ def _init_umap(self, adata: AnnData) -> None: self._umap._validate_parameters() - self._umap.embedding_ = adata.obsm["X_umap"] + self._umap.embedding_ = adata.obsm[obsm_key] self._umap._sparse_data = isinstance(self._rep, CSBase) self._umap._small_data = self._rep.shape[0] < 4096 self._umap._metric_kwds = self._metric_kwds @@ -275,8 +277,8 @@ def _init_umap(self, adata: AnnData) -> None: self._umap._knn_search_index = self._nnd_idx - self._umap._a = adata.uns["umap"]["params"]["a"] - self._umap._b = adata.uns["umap"]["params"]["b"] + self._umap._a = adata.uns[uns_key]["params"]["a"] + self._umap._b = adata.uns[uns_key]["params"]["b"] self._umap._input_hash = None @@ -388,8 +390,8 @@ def __init__( ) raise ValueError(msg) - if "X_umap" in adata.obsm: - self._init_umap(adata) + if keys := _existing_preset_keys(adata, _umap_keys): + self._init_umap(adata, *keys) self._obsm = None self._obs = None diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 316c6c8aed..cd17ebd8be 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -7,7 +7,7 @@ from .. import logging as logg from .._docs import doc_rng -from .._settings import settings +from .._settings import Default, Preset, settings from .._utils import NeighborsView, _doc_params from .._utils.random import ( _accepts_legacy_random_state, @@ -27,6 +27,16 @@ type _InitPos = Literal["paga", "spectral", "random"] +def _umap_keys( + key_added: str | None | Default | Preset = Default(), +) -> tuple[str, str]: + if isinstance(key_added, Default): + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.umap.key_added + return ("X_umap", "umap") if key_added is None else (key_added, key_added) + + @_accepts_legacy_random_state(0) @_doc_params(rng=doc_rng) def umap( # noqa: PLR0913 @@ -44,7 +54,7 @@ def umap( # noqa: PLR0913 a: float | None = None, b: float | None = None, method: Literal["umap"] = "umap", - key_added: str | None = None, + key_added: str | None | Default = Default(preset=("umap", "key_added")), neighbors_key: str = "neighbors", copy: bool = False, ) -> AnnData | None: @@ -144,7 +154,7 @@ def umap( # noqa: PLR0913 rng = np.random.default_rng(rng) adata = adata.copy() if copy else adata - key_obsm, key_uns = ("X_umap", "umap") if key_added is None else [key_added] * 2 + key_obsm, key_uns = _umap_keys(key_added) if neighbors_key is None: # backwards compat neighbors_key = "neighbors" From 75939dfaa4bceb62c1f54790641147be1a152f5e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 26 Jun 2026 11:58:15 +0200 Subject: [PATCH 14/15] refactor basis key --- src/scanpy/_utils/__init__.py | 9 +++++++++ src/scanpy/plotting/_tools/__init__.py | 7 +++---- src/scanpy/plotting/_tools/scatterplots.py | 17 ++++++++--------- src/scanpy/plotting/_utils.py | 14 +++----------- src/scanpy/tools/_embedding_density.py | 4 ++-- src/scanpy/tools/_ingest.py | 11 ++++++----- 6 files changed, 31 insertions(+), 31 deletions(-) diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index a8f29e38c4..68033aec71 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -60,6 +60,7 @@ "_choose_graph", "_doc_params", "_existing_preset_keys", + "_get_basis", "_numba_thread_limit", "_resolve_axis", "annotate_doc_types", @@ -575,6 +576,14 @@ def get_literal_vals(typ: UnionType | TypeAliasType | Any) -> KeysView[Any]: raise TypeError(msg) +def _get_basis_key(adata: AnnData, basis: str) -> str | None: + if basis in adata.obsm: + return basis + if f"X_{basis}" in adata.obsm: + return f"X_{basis}" + return None + + def _existing_preset_keys[**P, T: tuple[str, ...]]( adata: AnnData, keys: Callable[Concatenate[Preset, P], T], diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index 3052589b56..bba09e85ce 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -13,7 +13,7 @@ from ... import logging as logg from ..._settings import Default, settings -from ..._utils import _doc_params, sanitize_anndata, with_cat_dtype +from ..._utils import _doc_params, _get_basis_key, sanitize_anndata, with_cat_dtype from ..._utils.random import _LegacyRng from ...get import obs_df, rank_genes_groups_df from .._anndata import ranking @@ -29,7 +29,6 @@ ) from .._utils import ( _deprecated_scale, - _get_basis, savefig_or_show, timeseries, timeseries_as_heatmap, @@ -1496,9 +1495,9 @@ def embedding_density( # noqa: PLR0912, PLR0913, PLR0915 if groupby is not None: key += f"_{groupby}" - if _get_basis(adata, basis) is None: + if _get_basis_key(adata, basis) is None: msg = ( - f"Cannot find the embedded representation `adata.obsm['X_{basis}']`. " + f"Cannot find the embedded representation `adata.obsm[{basis!r} | 'X_{basis}']`. " "Compute the embedding first." ) raise ValueError(msg) diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index 007f112283..a7a83a517e 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -19,7 +19,12 @@ from ... import logging as logg from ..._settings import Default, settings -from ..._utils import _doc_params, _existing_preset_keys, sanitize_anndata +from ..._utils import ( + _doc_params, + _existing_preset_keys, + _get_basis_key, + sanitize_anndata, +) from ..._utils._doctests import doctest_internet from ...get import _check_mask from ...preprocessing._pca import _pca_keys @@ -32,13 +37,7 @@ doc_scatter_spatial, doc_show_save_ax, ) -from .._utils import ( - _get_basis, - _obs_vector_compat, - check_colornorm, - check_projection, - circles, -) +from .._utils import _obs_vector_compat, check_colornorm, check_projection, circles if TYPE_CHECKING: from collections.abc import Callable, Collection, Mapping @@ -1210,7 +1209,7 @@ def _add_categorical_legend( # noqa: PLR0913 def _get_basis_arr(adata: AnnData, basis: str) -> np.ndarray: """Get array for basis from anndata. Just tries to add 'X_'.""" - if basis_key := _get_basis(adata, basis): + if basis_key := _get_basis_key(adata, basis): return adata.obsm[basis_key] msg = f"Could not find {basis!r} or 'X_{basis}' in .obsm" raise KeyError(msg) diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index 1c4aba91ee..c43c3fea97 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -18,7 +18,7 @@ from .. import logging as logg from .._compat import warn from .._settings import Default, settings -from .._utils import NeighborsView +from .._utils import NeighborsView, _get_basis_key from . import palettes if TYPE_CHECKING: @@ -567,7 +567,7 @@ def plot_edges(axs, adata, basis, edges_width, edges_color, *, neighbors_key=Non raise ValueError(msg) neighbors = NeighborsView(adata, neighbors_key) g = nx.Graph(neighbors["connectivities"]) - basis_key = _get_basis(adata, basis) + basis_key = _get_basis_key(adata, basis) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -602,7 +602,7 @@ def plot_arrows(axs, adata, basis, arrows_kwds=None): "Prefer using `scv.pl.velocity_embedding` to `arrows=True`." ) - basis_key = _get_basis(adata, basis) + basis_key = _get_basis_key(adata, basis) x = adata.obsm[basis_key] v = adata.obsm[f"{v_prefix}_{basis}"] for ax in axs: @@ -1027,14 +1027,6 @@ def fix_kwds(kwds_dict, **kwargs): return kwargs -def _get_basis(adata: AnnData, basis: str) -> str | None: - if basis in adata.obsm: - return basis - if f"X_{basis}" in adata.obsm: - return f"X_{basis}" - return None - - def check_colornorm(vmin=None, vmax=None, vcenter=None, norm=None): from matplotlib.colors import Normalize diff --git a/src/scanpy/tools/_embedding_density.py b/src/scanpy/tools/_embedding_density.py index 37e2bcf606..ff823becb4 100644 --- a/src/scanpy/tools/_embedding_density.py +++ b/src/scanpy/tools/_embedding_density.py @@ -7,7 +7,7 @@ import numpy as np from .. import logging as logg -from .._utils import sanitize_anndata +from .._utils import _get_basis_key, sanitize_anndata if TYPE_CHECKING: from collections.abc import Sequence @@ -125,7 +125,7 @@ def embedding_density( # noqa: PLR0912 if basis == "fa": basis = "draw_graph_fa" - if f"X_{basis}" not in adata.obsm: + if _get_basis_key(adata, basis) is None: msg = ( "Cannot find the embedded representation " f"`adata.obsm['X_{basis}']`. Compute the embedding first." diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index 3b9a035d8c..f231093f1e 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -20,6 +20,7 @@ from .._utils._doctests import doctest_skipif from .._utils.random import _legacy_random_state, _LegacyRng from ..neighbors import FlatTree +from ..preprocessing._pca import _pca_keys from ..tools._umap import _umap_keys if TYPE_CHECKING: @@ -556,10 +557,10 @@ def to_adata_joint( self._obsm["rep"], )) - if "X_umap" in self._obsm: - adata.uns["umap"] = self._adata_ref.uns["umap"] - if "X_pca" in self._obsm: - adata.uns["pca"] = self._adata_ref.uns["pca"] - adata.varm["PCs"] = self._adata_ref.varm["PCs"] + if keys := _existing_preset_keys(self._adata_ref, _umap_keys): + adata.uns[keys[1]] = self._adata_ref.uns[keys[1]] + if keys := _existing_preset_keys(self._adata_ref, _pca_keys): + adata.varm[keys[1]] = self._adata_ref.varm[keys[1]] + adata.uns[keys[2]] = self._adata_ref.uns[keys[2]] return adata From 8dee00eedb4b5790d01a35b8003b0280604c99e3 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Fri, 26 Jun 2026 12:06:22 +0200 Subject: [PATCH 15/15] rest --- src/scanpy/_settings/presets.py | 12 +++++++++++- src/scanpy/experimental/pp/_normalization.py | 6 ++---- src/scanpy/experimental/pp/_recipes.py | 5 ++--- src/scanpy/tools/_tsne.py | 14 +++++++++++--- 4 files changed, 26 insertions(+), 11 deletions(-) diff --git a/src/scanpy/_settings/presets.py b/src/scanpy/_settings/presets.py index bd6a38cb3a..c82d256147 100644 --- a/src/scanpy/_settings/presets.py +++ b/src/scanpy/_settings/presets.py @@ -81,7 +81,9 @@ class BasicEmbeddingPreset(NamedTuple): # replace once they diverge -PcaPreset = UmapPreset = DiffmapPreset = DrawGraphPreset = BasicEmbeddingPreset +PcaPreset = UmapPreset = TsnePreset = DiffmapPreset = DrawGraphPreset = ( + BasicEmbeddingPreset +) class RankGenesGroupsPreset(NamedTuple): @@ -205,6 +207,14 @@ def umap() -> Mapping[Preset, UmapPreset]: Preset.ScanpyV2Preview: DiffmapPreset(key_added="umap"), } + @preset_property + def tsne() -> Mapping[Preset, UmapPreset]: + """Settings for :func:`~scanpy.tl.tsne`.""" # noqa: D401 + return { + Preset.ScanpyV1: DiffmapPreset(key_added=None), + Preset.ScanpyV2Preview: DiffmapPreset(key_added="tsne"), + } + @preset_property def diffmap() -> Mapping[Preset, DiffmapPreset]: """Settings for :func:`~scanpy.tl.diffmap`.""" # noqa: D401 diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index 142dd6620d..49c516eaba 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -23,7 +23,7 @@ ) from ...get import _check_mask, _get_obs_rep, _set_obs_rep from ...preprocessing._docs import doc_mask_var -from ...preprocessing._pca import pca +from ...preprocessing._pca import _pca_keys, pca if TYPE_CHECKING: from collections.abc import Mapping @@ -222,9 +222,7 @@ def normalize_pearson_residuals_pca( """ key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added) - key_obsm, key_varm, key_uns = ( - ("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3 - ) + key_obsm, key_varm, key_uns = _pca_keys(key_added) if isinstance(mask_var, Default): mask_var = "highly_variable" if "highly_variable" in adata.var else None mask_var = _check_mask(adata, mask_var, "var") diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index a1b5634b0e..f0f63d0069 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -17,6 +17,7 @@ doc_pca_chunk, ) from ...preprocessing import pca +from ...preprocessing._pca import _pca_keys if TYPE_CHECKING: from collections.abc import Mapping @@ -116,9 +117,7 @@ def recipe_pearson_residuals( # noqa: PLR0913 """ key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added) - key_obsm, key_varm, key_uns = ( - ("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3 - ) + key_obsm, key_varm, key_uns = _pca_keys(key_added) hvg_args = dict( flavor="pearson_residuals", n_top_genes=n_top_genes, diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 54fb28f436..81fcfe314e 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -5,7 +5,7 @@ from .. import logging as logg from .._compat import warn from .._docs import doc_rng -from .._settings import settings +from .._settings import Default, Preset, settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type from .._utils.random import _accepts_legacy_random_state, _legacy_random_state from ..neighbors._doc import doc_n_pcs, doc_use_rep @@ -17,6 +17,14 @@ from .._utils.random import RNGLike, SeedLike +def _tsne_keys(key_added: str | None | Default | Preset) -> tuple[str, str]: + if isinstance(key_added, Default): + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.tsne.key_added + return ("X_tsne", "tsne") if key_added is None else (key_added, key_added) + + @_accepts_legacy_random_state(0) @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep, rng=doc_rng) def tsne( # noqa: PLR0913 @@ -32,7 +40,7 @@ def tsne( # noqa: PLR0913 rng: SeedLike | RNGLike | None = None, use_fast_tsne: bool = False, n_jobs: int | None = None, - key_added: str | None = None, + key_added: str | None | Default = Default(preset=("tsne", "key_added")), copy: bool = False, ) -> AnnData | None: r"""t-SNE :cite:p:`vanDerMaaten2008,Amir2013,Pedregosa2011`. @@ -102,6 +110,7 @@ def tsne( # noqa: PLR0913 """ start = logg.info("computing tSNE") + key_obsm, key_uns = _tsne_keys(key_added) adata = adata.copy() if copy else adata x = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) raise_not_implemented_error_if_backed_type(x, "tsne") @@ -157,7 +166,6 @@ def tsne( # noqa: PLR0913 use_rep=use_rep, n_components=n_components, ) - key_uns, key_obsm = ("tsne", "X_tsne") if key_added is None else [key_added] * 2 adata.obsm[key_obsm] = x_tsne # annotate samples with tSNE coordinates adata.uns[key_uns] = dict(params={k: v for k, v in params.items() if v is not None})