diff --git a/docs/release-notes/4076.feat.md b/docs/release-notes/4076.feat.md new file mode 100644 index 0000000000..9a4938e420 --- /dev/null +++ b/docs/release-notes/4076.feat.md @@ -0,0 +1 @@ +Add preset support to {func}`scanpy.tl.diffmap`’s and {func}`scanpy.tl.draw_graph`’s `key_added` parameter {smaller}`P Angerer` diff --git a/pyproject.toml b/pyproject.toml index 16cf136060..bfcfd5555a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -273,7 +273,8 @@ filterwarnings = [ "ignore:FNV hashing is not implemented in Numba.*:UserWarning", # we want to see and eventually fix these "default::numba.core.errors.NumbaPerformanceWarning", - "default:.*TSNE.*random.*to.*pca:FutureWarning", # we should set init=obsm["X_pca"] or so + # we should set init=obsm["X_pca"] or so + "default:.*TSNE.*random.*to.*pca:FutureWarning", # igraph vs leidenalg warning "ignore:The `igraph` implementation of leiden clustering:UserWarning", # everybody uses this zarr 3 feature, including us, XArray, lots of data out there … diff --git a/src/scanpy/_settings/presets.py b/src/scanpy/_settings/presets.py index 0119e821c1..c82d256147 100644 --- a/src/scanpy/_settings/presets.py +++ b/src/scanpy/_settings/presets.py @@ -76,10 +76,16 @@ class HVGPreset(NamedTuple): return_df: bool -class PcaPreset(NamedTuple): +class BasicEmbeddingPreset(NamedTuple): key_added: str | None +# replace once they diverge +PcaPreset = UmapPreset = TsnePreset = DiffmapPreset = DrawGraphPreset = ( + BasicEmbeddingPreset +) + + class RankGenesGroupsPreset(NamedTuple): method: DETest mask_var: str | None @@ -193,6 +199,38 @@ def pca() -> Mapping[Preset, PcaPreset]: Preset.ScanpyV2Preview: PcaPreset(key_added="pca"), } + @preset_property + def umap() -> Mapping[Preset, UmapPreset]: + """Settings for :func:`~scanpy.tl.umap`.""" # noqa: D401 + return { + Preset.ScanpyV1: DiffmapPreset(key_added=None), + Preset.ScanpyV2Preview: DiffmapPreset(key_added="umap"), + } + + @preset_property + def tsne() -> Mapping[Preset, UmapPreset]: + """Settings for :func:`~scanpy.tl.tsne`.""" # noqa: D401 + return { + Preset.ScanpyV1: DiffmapPreset(key_added=None), + Preset.ScanpyV2Preview: DiffmapPreset(key_added="tsne"), + } + + @preset_property + def diffmap() -> Mapping[Preset, DiffmapPreset]: + """Settings for :func:`~scanpy.tl.diffmap`.""" # noqa: D401 + return { + Preset.ScanpyV1: DiffmapPreset(key_added=None), + Preset.ScanpyV2Preview: DiffmapPreset(key_added="diffmap"), + } + + @preset_property + def draw_graph() -> Mapping[Preset, DrawGraphPreset]: + """Settings for :func:`~scanpy.tl.draw_graph`.""" # noqa: D401 + return { + Preset.ScanpyV1: DrawGraphPreset(key_added=None), + Preset.ScanpyV2Preview: DrawGraphPreset(key_added="graph_{layout}"), + } + @preset_property def rank_genes_groups() -> Mapping[Preset, RankGenesGroupsPreset]: """Correlation method for :func:`~scanpy.tl.rank_genes_groups`.""" diff --git a/src/scanpy/_utils/__init__.py b/src/scanpy/_utils/__init__.py index f0d67b550d..68033aec71 100644 --- a/src/scanpy/_utils/__init__.py +++ b/src/scanpy/_utils/__init__.py @@ -39,7 +39,7 @@ if TYPE_CHECKING: from collections.abc import Callable, Iterable, KeysView, Mapping from pathlib import Path - from typing import Any + from typing import Any, Concatenate from anndata import AnnData from igraph import Graph @@ -47,6 +47,7 @@ from pandas._typing import Dtype as PdDtype from .._compat import CSRBase + from .._settings import Preset from ..neighbors import NeighborsParams, RPForestDict type _MemoryArray = NDArray | CSBase @@ -58,6 +59,8 @@ "NeighborsView", "_choose_graph", "_doc_params", + "_existing_preset_keys", + "_get_basis", "_numba_thread_limit", "_resolve_axis", "annotate_doc_types", @@ -573,6 +576,29 @@ def get_literal_vals(typ: UnionType | TypeAliasType | Any) -> KeysView[Any]: raise TypeError(msg) +def _get_basis_key(adata: AnnData, basis: str) -> str | None: + if basis in adata.obsm: + return basis + if f"X_{basis}" in adata.obsm: + return f"X_{basis}" + return None + + +def _existing_preset_keys[**P, T: tuple[str, ...]]( + adata: AnnData, + keys: Callable[Concatenate[Preset, P], T], + *args: P.args, + **kw: P.kwargs, +) -> T | None: + from .._settings import Preset + + for preset in (Preset.ScanpyV1, Preset.ScanpyV2Preview): + obsm_key, *rest = keys(preset, *args, **kw) + if obsm_key in adata.obsm: + return (obsm_key, *rest) + return None + + # -------------------------------------------------------------------------------- # Others # -------------------------------------------------------------------------------- diff --git a/src/scanpy/experimental/pp/_normalization.py b/src/scanpy/experimental/pp/_normalization.py index 551df22be4..49c516eaba 100644 --- a/src/scanpy/experimental/pp/_normalization.py +++ b/src/scanpy/experimental/pp/_normalization.py @@ -7,6 +7,7 @@ from anndata import AnnData from ... import logging as logg +from ... import settings from ..._compat import CSBase, warn from ..._settings import Default from ..._utils import _doc_params, check_nonnegative_integers, view_to_actual @@ -22,7 +23,7 @@ ) from ...get import _check_mask, _get_obs_rep, _set_obs_rep from ...preprocessing._docs import doc_mask_var -from ...preprocessing._pca import pca +from ...preprocessing._pca import _pca_keys, pca if TYPE_CHECKING: from collections.abc import Mapping @@ -207,19 +208,21 @@ def normalize_pearson_residuals_pca( `.uns['pearson_residuals_normalization']['clip']` The used value of the clipping parameter. - `.obsm['X_pca']` + `.obsm[kwargs_pca.get('key_added', 'X_pca')]` PCA representation of data after gene selection (if applicable) and Pearson residual normalization. - `.varm['PCs']` + `.varm[kwargs_pca.get('key_added', 'PCs')]` The principal components containing the loadings. When `inplace=True` and `mask_var is not None`, this will contain empty rows for the genes not selected. - `.uns['pca']['variance_ratio']` + `.uns[kwargs_pca.get('key_added', 'pca')]['variance_ratio']` Ratio of explained variance. - `.uns['pca']['variance']` + `.uns[kwargs_pca.get('key_added', 'pca')]['variance']` Explained variance, equivalent to the eigenvalues of the covariance matrix. """ + key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added) + key_obsm, key_varm, key_uns = _pca_keys(key_added) if isinstance(mask_var, Default): mask_var = "highly_variable" if "highly_variable" in adata.var else None mask_var = _check_mask(adata, mask_var, "var") @@ -236,19 +239,19 @@ def normalize_pearson_residuals_pca( adata_pca, theta=theta, clip=clip, check_values=check_values ) pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca) - n_comps = adata_pca.obsm["X_pca"].shape[1] # might be None + n_comps = adata_pca.obsm[key_obsm].shape[1] # might be None if inplace: norm_settings = adata_pca.uns["pearson_residuals_normalization"] norm_dict = dict(**norm_settings, pearson_residuals_df=adata_pca.to_df()) if mask_var is not None: - adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, n_comps)) - adata.varm["PCs"][mask_var] = adata_pca.varm["PCs"] + adata.varm[key_varm] = np.zeros(shape=(adata.n_vars, n_comps)) + adata.varm[key_varm][mask_var] = adata_pca.varm[key_varm] else: - adata.varm["PCs"] = adata_pca.varm["PCs"] - adata.uns["pca"] = adata_pca.uns["pca"] + adata.varm[key_varm] = adata_pca.varm[key_varm] + adata.uns[key_uns] = adata_pca.uns[key_uns] adata.uns["pearson_residuals_normalization"] = norm_dict - adata.obsm["X_pca"] = adata_pca.obsm["X_pca"] + adata.obsm[key_obsm] = adata_pca.obsm[key_obsm] return None else: return adata_pca diff --git a/src/scanpy/experimental/pp/_recipes.py b/src/scanpy/experimental/pp/_recipes.py index 61db8e8a5e..f0f63d0069 100644 --- a/src/scanpy/experimental/pp/_recipes.py +++ b/src/scanpy/experimental/pp/_recipes.py @@ -5,7 +5,7 @@ import numpy as np -from ... import experimental +from ... import experimental, settings from ..._utils import _doc_params from ..._utils.random import _accepts_legacy_random_state from ...experimental._docs import ( @@ -17,6 +17,7 @@ doc_pca_chunk, ) from ...preprocessing import pca +from ...preprocessing._pca import _pca_keys if TYPE_CHECKING: from collections.abc import Mapping @@ -103,18 +104,20 @@ def recipe_pearson_residuals( # noqa: PLR0913 `.uns['pearson_residuals_normalization']['clip']` The used value of the clipping parameter. - `.obsm['X_pca']` + `.obsm[kwargs_pca.get('key_added', 'X_pca')]` PCA representation of data after gene selection and Pearson residual normalization. - `.varm['PCs']` + `.varm[kwargs_pca.get('key_added', 'PCs')]` The principal components containing the loadings. When `inplace=True` this will contain empty rows for the genes not selected during HVG selection. - `.uns['pca']['variance_ratio']` + `.uns[kwargs_pca.get('key_added', 'pca')]['variance_ratio']` Ratio of explained variance. - `.uns['pca']['variance']` + `.uns[kwargs_pca.get('key_added', 'pca')]['variance']` Explained variance, equivalent to the eigenvalues of the covariance matrix. """ + key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added) + key_obsm, key_varm, key_uns = _pca_keys(key_added) hvg_args = dict( flavor="pearson_residuals", n_top_genes=n_top_genes, @@ -145,11 +148,11 @@ def recipe_pearson_residuals( # noqa: PLR0913 **normalization_param, pearson_residuals_df=adata_pca.to_df() ) - adata.uns["pca"] = adata_pca.uns["pca"] - adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, n_comps)) - adata.varm["PCs"][adata.var["highly_variable"]] = adata_pca.varm["PCs"] + adata.uns[key_uns] = adata_pca.uns[key_uns] + adata.varm[key_varm] = np.zeros(shape=(adata.n_vars, n_comps)) + adata.varm[key_varm][adata.var["highly_variable"]] = adata_pca.varm[key_varm] adata.uns["pearson_residuals_normalization"] = normalization_dict - adata.obsm["X_pca"] = adata_pca.obsm["X_pca"] + adata.obsm[key_obsm] = adata_pca.obsm[key_obsm] return None else: return adata_pca, hvg diff --git a/src/scanpy/external/exporting.py b/src/scanpy/external/exporting.py index 55d39eb2b2..076153b822 100644 --- a/src/scanpy/external/exporting.py +++ b/src/scanpy/external/exporting.py @@ -14,7 +14,9 @@ from fast_array_utils.stats import mean_var from pandas.api.types import CategoricalDtype -from .._utils import NeighborsView +from .._utils import NeighborsView, _existing_preset_keys +from ..preprocessing._pca import _pca_keys +from ..tools._draw_graph import _draw_graph_keys if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -81,8 +83,12 @@ def spring_project( # noqa: PLR0912, PLR0915 if embedding_method not in adata.obsm: if f"X_{embedding_method}" in adata.obsm: embedding_method = f"X_{embedding_method}" - elif embedding_method in adata.uns: - embedding_method = f"X_{embedding_method}_{adata.uns[embedding_method]['params']['layout']}" + elif embedding_method in {"graph", "draw_graph"} and ( + keys := _existing_preset_keys( + adata, _draw_graph_keys, adata.uns[embedding_method]["params"]["layout"] + ) + ): + embedding_method = keys[0] else: msg = f"Run the specified embedding method `{embedding_method}` first." raise ValueError(msg) @@ -222,10 +228,10 @@ def spring_project( # noqa: PLR0912, PLR0915 ) # Write some useful intermediates, if they exist - if "X_pca" in adata.obsm: + if keys := _existing_preset_keys(adata, _pca_keys): np.savez_compressed( subplot_dir / "intermediates.npz", - Epca=adata.obsm["X_pca"], + Epca=adata.obsm[keys[0]], total_counts=total_counts, ) diff --git a/src/scanpy/external/tl/_trimap.py b/src/scanpy/external/tl/_trimap.py index 8ccd08f1f3..922c792a8d 100644 --- a/src/scanpy/external/tl/_trimap.py +++ b/src/scanpy/external/tl/_trimap.py @@ -7,7 +7,9 @@ from ... import logging as logg from ..._compat import CSBase from ..._settings import settings +from ..._utils import _existing_preset_keys from ..._utils._doctests import doctest_needs +from ...preprocessing._pca import _pca_keys if TYPE_CHECKING: from typing import Literal @@ -100,9 +102,9 @@ def trimap( # noqa: PLR0913 verbosity = settings.verbosity if verbose is None else verbose verbose = verbosity if isinstance(verbosity, bool) else verbosity > 0 - if "X_pca" in adata.obsm: - n_dim_pca = adata.obsm["X_pca"].shape[1] - x = adata.obsm["X_pca"][:, : min(n_dim_pca, 100)] + if keys := _existing_preset_keys(adata, _pca_keys): + n_dim_pca = adata.obsm[keys[0]].shape[1] + x = adata.obsm[keys[0]][:, : min(n_dim_pca, 100)] else: x = adata.X if isinstance(x, CSBase): @@ -111,7 +113,7 @@ def trimap( # noqa: PLR0913 "use a dense matrix or apply pca first." ) raise ValueError(msg) - logg.warning("`X_pca` not found. Run `sc.pp.pca` first for speedup.") + logg.warning("`pca`/`X_pca` not found. Run `sc.pp.pca` first for speedup.") x_trimap = TRIMAP( n_dims=n_components, n_inliers=n_inliers, diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 7bc2470df3..a9c838b949 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -18,7 +18,7 @@ from .._compat import CSBase, CSRBase, SpBase, pkg_version, warn from .._docs import doc_rng from .._settings import settings -from .._utils import NeighborsView, _doc_params, get_literal_vals +from .._utils import NeighborsView, _doc_params, _existing_preset_keys, get_literal_vals from .._utils.random import ( _accepts_legacy_random_state, _legacy_random_state, @@ -322,20 +322,6 @@ class FlatTree(NamedTuple): # noqa: D101 indices: None -def _backwards_compat_get_full_x_diffmap(adata: AnnData) -> np.ndarray: - if "X_diffmap0" in adata.obs: - return np.c_[adata.obs["X_diffmap0"].values[:, None], adata.obsm["X_diffmap"]] - else: - return adata.obsm["X_diffmap"] - - -def _backwards_compat_get_full_eval(adata: AnnData): - if "X_diffmap0" in adata.obs: - return np.r_[1, adata.uns["diffmap_evals"]] - else: - return adata.uns["diffmap_evals"] - - def _make_forest_dict(forest): d = {} props = ("hyperplanes", "offsets", "children", "indices") @@ -432,6 +418,7 @@ def __init__( # noqa: PLR0912, PLR0915 *, n_dcs: int | None = None, neighbors_key: str | None = None, + diffmap_key: str | None = None, ) -> None: self._adata = adata self._init_iroot() @@ -484,9 +471,16 @@ def count_nonzero(a: np.ndarray | CSRBase) -> int: self._connected_components = connected_components(self._connectivities) self._number_connected_components = self._connected_components[0] - if "X_diffmap" in adata.obsm: - self._eigen_values = _backwards_compat_get_full_eval(adata) - self._eigen_basis = _backwards_compat_get_full_x_diffmap(adata) + + from ..tools._dpt import _diffmap_keys + + if keys := ( + (diffmap_key, diffmap_key) + if diffmap_key + else _existing_preset_keys(adata, _diffmap_keys) + ): + self._eigen_values = adata.uns[keys[1]] + self._eigen_basis = adata.obsm[keys[0]] if n_dcs is not None: if n_dcs > len(self._eigen_values): msg = ( diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index e41f4ea908..bba09e85ce 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -13,7 +13,7 @@ from ... import logging as logg from ..._settings import Default, settings -from ..._utils import _doc_params, sanitize_anndata, with_cat_dtype +from ..._utils import _doc_params, _get_basis_key, sanitize_anndata, with_cat_dtype from ..._utils.random import _LegacyRng from ...get import obs_df, rank_genes_groups_df from .._anndata import ranking @@ -1495,9 +1495,9 @@ def embedding_density( # noqa: PLR0912, PLR0913, PLR0915 if groupby is not None: key += f"_{groupby}" - if f"X_{basis}" not in adata.obsm: + if _get_basis_key(adata, basis) is None: msg = ( - f"Cannot find the embedded representation `adata.obsm['X_{basis}']`. " + f"Cannot find the embedded representation `adata.obsm[{basis!r} | 'X_{basis}']`. " "Compute the embedding first." ) raise ValueError(msg) diff --git a/src/scanpy/plotting/_tools/paga.py b/src/scanpy/plotting/_tools/paga.py index d5100f1338..f8cfb6bd25 100644 --- a/src/scanpy/plotting/_tools/paga.py +++ b/src/scanpy/plotting/_tools/paga.py @@ -129,7 +129,7 @@ def paga_compare( # noqa: PLR0912, PLR0913 else: basis = "umap" - from .scatterplots import _components_to_dimensions, _get_basis, embedding + from .scatterplots import _components_to_dimensions, _get_basis_arr, embedding embedding( adata, @@ -156,7 +156,7 @@ def paga_compare( # noqa: PLR0912, PLR0913 if pos is None: if color == adata.uns["paga"]["groups"]: # TODO: Use dimensions here - _basis = _get_basis(adata, basis) + _basis = _get_basis_arr(adata, basis) dims = _components_to_dimensions( components=components, dimensions=None, total_dims=_basis.shape[1] )[0] diff --git a/src/scanpy/plotting/_tools/scatterplots.py b/src/scanpy/plotting/_tools/scatterplots.py index e18305140f..a7a83a517e 100644 --- a/src/scanpy/plotting/_tools/scatterplots.py +++ b/src/scanpy/plotting/_tools/scatterplots.py @@ -19,9 +19,16 @@ from ... import logging as logg from ..._settings import Default, settings -from ..._utils import _doc_params, sanitize_anndata +from ..._utils import ( + _doc_params, + _existing_preset_keys, + _get_basis_key, + sanitize_anndata, +) from ..._utils._doctests import doctest_internet from ...get import _check_mask +from ...preprocessing._pca import _pca_keys +from ...tools._draw_graph import _draw_graph_keys from .. import _utils from .._docs import ( doc_adata_color_etc, @@ -147,7 +154,7 @@ def embedding( # noqa: PLR0912, PLR0913, PLR0915 check_projection(projection) sanitize_anndata(adata) - basis_values = _get_basis(adata, basis) + basis_values = _get_basis_arr(adata, basis) dimensions = _components_to_dimensions( components, dimensions, projection=projection, total_dims=basis_values.shape[1] ) @@ -852,12 +859,10 @@ def draw_graph( """ if layout is None: layout = str(adata.uns["draw_graph"]["params"]["layout"]) - basis = f"draw_graph_{layout}" - if f"X_{basis}" not in adata.obsm: - msg = f"Did not find {basis} in adata.obs. Did you compute layout {layout}?" - raise ValueError(msg) - - return embedding(adata, basis, **kwargs) + if keys := _existing_preset_keys(adata, _draw_graph_keys, layout): + return embedding(adata, keys[0], **kwargs) + msg = f"Did not find `adata.obsm['draw_graph_{layout}']`. Did you compute layout {layout}?" + raise ValueError(msg) @_wraps_plot_scatter @@ -925,7 +930,7 @@ def pca( return embedding( adata, "pca", show=show, return_fig=return_fig, save=save, **kwargs ) - if "pca" not in adata.obsm and "X_pca" not in adata.obsm: + if not _existing_preset_keys(adata, _pca_keys): msg = ( f"Could not find entry in `obsm` for 'pca'.\n" f"Available keys are: {list(adata.obsm.keys())}." @@ -1202,15 +1207,12 @@ def _add_categorical_legend( # noqa: PLR0913 ax.legend(loc=legend_loc, fontsize=legend_fontsize) -def _get_basis(adata: AnnData, basis: str) -> np.ndarray: +def _get_basis_arr(adata: AnnData, basis: str) -> np.ndarray: """Get array for basis from anndata. Just tries to add 'X_'.""" - if basis in adata.obsm: - return adata.obsm[basis] - elif f"X_{basis}" in adata.obsm: - return adata.obsm[f"X_{basis}"] - else: - msg = f"Could not find {basis!r} or 'X_{basis}' in .obsm" - raise KeyError(msg) + if basis_key := _get_basis_key(adata, basis): + return adata.obsm[basis_key] + msg = f"Could not find {basis!r} or 'X_{basis}' in .obsm" + raise KeyError(msg) def _get_color_source_vector( diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index 23cad1c111..c43c3fea97 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -18,7 +18,7 @@ from .. import logging as logg from .._compat import warn from .._settings import Default, settings -from .._utils import NeighborsView +from .._utils import NeighborsView, _get_basis_key from . import palettes if TYPE_CHECKING: @@ -567,7 +567,7 @@ def plot_edges(axs, adata, basis, edges_width, edges_color, *, neighbors_key=Non raise ValueError(msg) neighbors = NeighborsView(adata, neighbors_key) g = nx.Graph(neighbors["connectivities"]) - basis_key = _get_basis(adata, basis) + basis_key = _get_basis_key(adata, basis) with warnings.catch_warnings(): warnings.simplefilter("ignore") @@ -602,7 +602,7 @@ def plot_arrows(axs, adata, basis, arrows_kwds=None): "Prefer using `scv.pl.velocity_embedding` to `arrows=True`." ) - basis_key = _get_basis(adata, basis) + basis_key = _get_basis_key(adata, basis) x = adata.obsm[basis_key] v = adata.obsm[f"{v_prefix}_{basis}"] for ax in axs: @@ -1027,16 +1027,6 @@ def fix_kwds(kwds_dict, **kwargs): return kwargs -def _get_basis(adata: AnnData, basis: str): - if basis in adata.obsm: - basis_key = basis - - elif f"X_{basis}" in adata.obsm: - basis_key = f"X_{basis}" - - return basis_key - - def check_colornorm(vmin=None, vmax=None, vcenter=None, norm=None): from matplotlib.colors import Normalize diff --git a/src/scanpy/preprocessing/_pca/__init__.py b/src/scanpy/preprocessing/_pca/__init__.py index 4f7a267e15..52741a2c3a 100644 --- a/src/scanpy/preprocessing/_pca/__init__.py +++ b/src/scanpy/preprocessing/_pca/__init__.py @@ -8,7 +8,7 @@ from ... import logging as logg from ..._compat import CSBase, DaskArray, warn from ..._docs import doc_rng -from ..._settings import Default, settings +from ..._settings import Default, Preset, settings from ..._utils import _doc_params, get_literal_vals, is_backed_type from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state from ...get import _check_mask, _get_obs_rep @@ -47,6 +47,20 @@ type SvdSolver = SvdSolvDaskML | SvdSolvSkearn | SvdSolvPCACustom +def _pca_keys( + key_added: str | None | Default | Preset = Default(), +) -> tuple[str, str, str]: + if isinstance(key_added, Default): + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.pca.key_added + return ( + ("X_pca", "PCs", "pca") + if key_added is None + else (key_added, key_added, key_added) + ) + + @_doc_params(mask_var=doc_mask_var, rng=doc_rng) @_accepts_legacy_random_state(0) def pca( # noqa: PLR0912, PLR0913, PLR0915 @@ -203,8 +217,6 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 # Current chunking implementation relies on pca being called on X msg = "Cannot use `layer`/`obsm` and `chunked` at the same time." raise NotImplementedError(msg) - if isinstance(key_added, Default): - key_added = settings.preset.pca.key_added # chunked calculation is not randomized, anyways if svd_solver in {"auto", "randomized"} and not chunked: @@ -337,9 +349,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915 x_pca = x_pca.astype(dtype) if return_anndata: - key_obsm, key_varm, key_uns = ( - ("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3 - ) + key_obsm, key_varm, key_uns = _pca_keys(key_added) adata.obsm[key_obsm] = x_pca if obsm: diff --git a/src/scanpy/tools/_diffmap.py b/src/scanpy/tools/_diffmap.py index 90f2976837..0d03ee9bdc 100644 --- a/src/scanpy/tools/_diffmap.py +++ b/src/scanpy/tools/_diffmap.py @@ -5,6 +5,7 @@ import numpy as np from .._docs import doc_rng +from .._settings import Default from .._utils import _doc_params from .._utils.random import _accepts_legacy_random_state from ._dpt import _diffmap @@ -22,6 +23,7 @@ def diffmap( n_comps: int = 15, *, neighbors_key: str | None = None, + key_added: str | None | Default = Default(preset=("diffmap", "key_added")), rng: SeedLike | RNGLike | None = None, copy: bool = False, ) -> AnnData | None: @@ -55,6 +57,8 @@ def diffmap( .obsp[.uns[neighbors_key]['connectivities_key']] and .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. + key_added + Control where the embedding and eigenvalues are stored. {rng} copy Return a copy instead of writing to adata. @@ -63,11 +67,11 @@ def diffmap( ------- Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields: - `adata.obsm['X_diffmap']` : :class:`numpy.ndarray` (dtype `float`) + `adata.obsm['X_diffmap' | key_added]` : :class:`numpy.ndarray` (dtype `float`) Diffusion map representation of data, which is the right eigen basis of the transition matrix with eigenvectors as columns. - `adata.uns['diffmap_evals']` : :class:`numpy.ndarray` (dtype `float`) + `adata.uns['diffmap_evals' | key_added]` : :class:`numpy.ndarray` (dtype `float`) Array of size (number of eigen vectors). Eigenvalues of transition matrix. @@ -90,5 +94,11 @@ def diffmap( msg = "Provide any value greater than 2 for `n_comps`. " raise ValueError(msg) adata = adata.copy() if copy else adata - _diffmap(adata, n_comps=n_comps, neighbors_key=neighbors_key, rng=rng) + _diffmap( + adata, + n_comps=n_comps, + neighbors_key=neighbors_key, + key_added=key_added, + rng=rng, + ) return adata if copy else None diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 74d7fe66ae..f97cc98eef 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -7,8 +7,11 @@ import scipy as sp from natsort import natsorted +from scanpy._utils.random import _LegacyRng +from scanpy.tools._utils import _existing_preset_keys + from .. import logging as logg -from .._utils.random import _LegacyRng +from .._settings import Default, Preset, settings from ..neighbors import Neighbors, OnFlySymMatrix if TYPE_CHECKING: @@ -17,26 +20,38 @@ from anndata import AnnData +def _diffmap_keys(key_added: str | None | Default | Preset) -> tuple[str, str]: + if isinstance(key_added, Default): + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.diffmap.key_added + return ( + ("X_diffmap", "diffmap_evals") if key_added is None else (key_added, key_added) + ) + + def _diffmap( adata: AnnData, n_comps: int = 15, *, neighbors_key: str | None, + key_added: str | None | Default, rng: np.random.Generator, ) -> None: + obsm_key, uns_key = _diffmap_keys(key_added) start = logg.info(f"computing Diffusion Maps using {n_comps=}(=n_dcs)") dpt = DPT(adata, neighbors_key=neighbors_key) dpt.compute_transitions() dpt.compute_eigen(n_comps=n_comps, rng=rng) - adata.obsm["X_diffmap"] = dpt.eigen_basis - adata.uns["diffmap_evals"] = dpt.eigen_values + adata.obsm[obsm_key] = dpt.eigen_basis + adata.uns[uns_key] = dpt.eigen_values logg.info( " finished", time=start, deep=( "added\n" - " 'X_diffmap', diffmap coordinates (adata.obsm)\n" - " 'diffmap_evals', eigenvalues of transition matrix (adata.uns)" + f" {obsm_key!r}, diffmap coordinates (adata.obsm)\n" + f" {uns_key!r}, eigenvalues of transition matrix (adata.uns)" ), ) @@ -49,6 +64,7 @@ def dpt( min_group_size: float = 0.01, allow_kendall_tau_shift: bool = True, neighbors_key: str | None = None, + diffmap_key: str | None = None, copy: bool = False, ) -> AnnData | None: """Infer progression of cells through geodesic distance along the graph :cite:p:`Haghverdi2016,Wolf2019`. @@ -106,6 +122,9 @@ def dpt( .obsp[.uns[neighbors_key]['connectivities_key']] and .obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances, respectively. + diffmap_key + If specified, dpt looks in .obsm[diffmap_key] for diffmap coordinates, + otherwise int the default place. copy Copy instance before computation and return a copy. Otherwise, perform computation inplace and return `None`. @@ -142,12 +161,14 @@ def dpt( " adata.uns['iroot'] = root_cell_index\n" " adata.var['xroot'] = adata[root_cell_name, :].X" ) - if "X_diffmap" not in adata.obsm: + if not diffmap_key and not _existing_preset_keys(adata, _diffmap_keys): logg.warning( "Trying to run `tl.dpt` without prior call of `tl.diffmap`. " "Falling back to `tl.diffmap` with default parameters." ) - _diffmap(adata, neighbors_key=neighbors_key, rng=_LegacyRng(0)) + from ._diffmap import diffmap + + diffmap(adata, neighbors_key=neighbors_key, rng=_LegacyRng(0)) # start with the actual computation dpt = DPT( adata, @@ -156,6 +177,7 @@ def dpt( n_branchings=n_branchings, allow_kendall_tau_shift=allow_kendall_tau_shift, neighbors_key=neighbors_key, + diffmap_key=diffmap_key, ) start = logg.info(f"computing Diffusion Pseudotime using {n_dcs=}") if n_branchings > 1: @@ -216,8 +238,11 @@ def __init__( n_branchings: int = 0, allow_kendall_tau_shift: bool = False, neighbors_key: str | None = None, + diffmap_key: str | None = None, ): - super().__init__(adata, n_dcs=n_dcs, neighbors_key=neighbors_key) + super().__init__( + adata, n_dcs=n_dcs, neighbors_key=neighbors_key, diffmap_key=diffmap_key + ) self.flavor = "haghverdi16" self.n_branchings = n_branchings self.min_group_size = ( diff --git a/src/scanpy/tools/_draw_graph.py b/src/scanpy/tools/_draw_graph.py index ae209599d1..ace03145a0 100644 --- a/src/scanpy/tools/_draw_graph.py +++ b/src/scanpy/tools/_draw_graph.py @@ -7,7 +7,9 @@ from .. import _utils from .. import logging as logg +from .._compat import warn from .._docs import doc_rng +from .._settings import Default, Preset, settings from .._utils import _choose_graph, _doc_params, get_literal_vals from .._utils.random import ( _accepts_legacy_random_state, @@ -42,10 +44,12 @@ def draw_graph( # noqa: PLR0913 rng: SeedLike | RNGLike | None = None, n_jobs: int | None = None, adjacency: CSBase | None = None, - key_added_ext: str | None = None, + key_added: str | None | Default = Default(preset=("draw_graph", "key_added")), neighbors_key: str | None = None, obsp: str | None = None, copy: bool = False, + # deprecated + key_added_ext: str | None = None, **kwds, ) -> AnnData | None: """Force-directed graph drawing :cite:p:`Islam2011,Jacomy2014,Chippada2018`. @@ -86,21 +90,21 @@ def draw_graph( # noqa: PLR0913 Applies to layouts with random initialization like `'fr'`. adjacency Sparse adjacency matrix of the graph, defaults to neighbors connectivities. - key_added_ext - By default, append `layout`. + key_added + Template for the key. If `None`, uses `f'X_draw_graph_{{layout}}'` for `obsm`. proceed - Continue computation, starting off with 'X_draw_graph_`layout`'. + Continue computation, starting off with `f'X_draw_graph_{{layout}}'`. init_pos `'paga'`/`True`, `None`/`False`, or any valid 2d-`.obsm` key. Use precomputed coordinates for initialization. If `False`/`None` (the default), initialize randomly. neighbors_key - If not specified, draw_graph looks at .obsp['connectivities'] for connectivities + If not specified, draw_graph looks at `.obsp['connectivities']` for connectivities (default storage place for pp.neighbors). If specified, draw_graph looks at - .obsp[.uns[neighbors_key]['connectivities_key']] for connectivities. + `.obsp[.uns[neighbors_key]['connectivities_key']]` for connectivities. obsp - Use .obsp[obsp] as adjacency. You can't specify both + Use `.obsp[obsp]` as adjacency. You can't specify both `obsp` and `neighbors_key` at the same time. copy Return a copy instead of writing to adata. @@ -113,7 +117,7 @@ def draw_graph( # noqa: PLR0913 ------- Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields: - `adata.obsm['X_draw_graph_[layout | key_added_ext]']` : :class:`numpy.ndarray` (dtype `float`) + `adata.obsm[(f'X_draw_graph_{{layout}}' | key_added).format(layout=layout)]` : :class:`numpy.ndarray` (dtype `float`) Coordinates of graph layout. E.g. for `layout='fa'` (the default), the field is called `'X_draw_graph_fa'`. `key_added_ext` overwrites `layout`. `adata.uns['draw_graph']`: :class:`dict` @@ -121,6 +125,8 @@ def draw_graph( # noqa: PLR0913 """ start = logg.info(f"drawing single-cell graph using layout {layout!r}") + layout = coerce_fa2_layout(layout) + key_obsm, key_uns = _draw_graph_keys(key_added, layout, key_added_ext) rng = np.random.default_rng(rng) meta_random_state = ( dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} @@ -145,7 +151,6 @@ def draw_graph( # noqa: PLR0913 else: _if_legacy_apply_global(rng) init_coords = rng.random((adjacency.shape[0], 2)) - layout = coerce_fa2_layout(layout) # actual drawing if layout == "fa": positions = np.array(fa2_positions(adjacency, init_coords, **kwds)) @@ -161,18 +166,40 @@ def draw_graph( # noqa: PLR0913 else: ig_layout = g.layout(layout, **kwds) positions = np.array(ig_layout.coords) - adata.uns["draw_graph"] = {} - adata.uns["draw_graph"]["params"] = dict(layout=layout, **meta_random_state) - key_added = f"X_draw_graph_{key_added_ext or layout}" - adata.obsm[key_added] = positions + adata.uns[key_uns] = {} + adata.uns[key_uns]["params"] = dict(layout=layout, **meta_random_state) + adata.obsm[key_obsm] = positions logg.info( " finished", time=start, - deep=f"added\n {key_added!r}, graph_drawing coordinates (adata.obsm)", + deep="added" + f"\n {key_obsm!r}, draw_graph coordinates (adata.obsm)" + f"\n {key_uns!r}, draw_graph parameters (adata.uns)", ) return adata if copy else None +def _draw_graph_keys( + key_added: str | None | Default | Preset, + layout: str, + key_added_ext: str | None = None, +) -> tuple[str, str]: + if key_added_ext is not None: + msg = "Passing `key_added_ext` is deprecated, use `key_added`’s template functionality instead." + warn(msg, category=FutureWarning) + suffix = key_added_ext + else: + suffix = layout + if isinstance(key_added, Default): + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.draw_graph.key_added + if key_added is None: + return f"X_draw_graph_{suffix}", "draw_graph" + key_added = key_added.format(layout=suffix) + return key_added, key_added + + def fa2_positions( adjacency: CSBase | np.ndarray, init_coords: np.ndarray, **kwds ) -> list[tuple[float, float]]: diff --git a/src/scanpy/tools/_embedding_density.py b/src/scanpy/tools/_embedding_density.py index 37e2bcf606..ff823becb4 100644 --- a/src/scanpy/tools/_embedding_density.py +++ b/src/scanpy/tools/_embedding_density.py @@ -7,7 +7,7 @@ import numpy as np from .. import logging as logg -from .._utils import sanitize_anndata +from .._utils import _get_basis_key, sanitize_anndata if TYPE_CHECKING: from collections.abc import Sequence @@ -125,7 +125,7 @@ def embedding_density( # noqa: PLR0912 if basis == "fa": basis = "draw_graph_fa" - if f"X_{basis}" not in adata.obsm: + if _get_basis_key(adata, basis) is None: msg = ( "Cannot find the embedded representation " f"`adata.obsm['X_{basis}']`. Compute the embedding first." diff --git a/src/scanpy/tools/_ingest.py b/src/scanpy/tools/_ingest.py index d85d6b89bb..f231093f1e 100644 --- a/src/scanpy/tools/_ingest.py +++ b/src/scanpy/tools/_ingest.py @@ -14,11 +14,14 @@ from .._utils import ( NeighborsView, _doc_params, + _existing_preset_keys, raise_not_implemented_error_if_backed_type, ) from .._utils._doctests import doctest_skipif from .._utils.random import _legacy_random_state, _LegacyRng from ..neighbors import FlatTree +from ..preprocessing._pca import _pca_keys +from ..tools._umap import _umap_keys if TYPE_CHECKING: from collections.abc import Generator, Iterable @@ -248,10 +251,10 @@ def _get_rng(self, params: dict[str, object]) -> np.random.Generator | None: random_state = params.get("random_state", 0) return _LegacyRng(random_state) - def _init_umap(self, adata: AnnData) -> None: + def _init_umap(self, adata: AnnData, obsm_key: str, uns_key: str) -> None: from umap import UMAP - rng = self._get_rng(adata.uns["umap"]["params"]) + rng = self._get_rng(adata.uns[uns_key]["params"]) self._umap = UMAP( metric=self._metric, random_state=_legacy_random_state(rng), @@ -265,7 +268,7 @@ def _init_umap(self, adata: AnnData) -> None: self._umap._validate_parameters() - self._umap.embedding_ = adata.obsm["X_umap"] + self._umap.embedding_ = adata.obsm[obsm_key] self._umap._sparse_data = isinstance(self._rep, CSBase) self._umap._small_data = self._rep.shape[0] < 4096 self._umap._metric_kwds = self._metric_kwds @@ -275,8 +278,8 @@ def _init_umap(self, adata: AnnData) -> None: self._umap._knn_search_index = self._nnd_idx - self._umap._a = adata.uns["umap"]["params"]["a"] - self._umap._b = adata.uns["umap"]["params"]["b"] + self._umap._a = adata.uns[uns_key]["params"]["a"] + self._umap._b = adata.uns[uns_key]["params"]["b"] self._umap._input_hash = None @@ -388,8 +391,8 @@ def __init__( ) raise ValueError(msg) - if "X_umap" in adata.obsm: - self._init_umap(adata) + if keys := _existing_preset_keys(adata, _umap_keys): + self._init_umap(adata, *keys) self._obsm = None self._obs = None @@ -554,10 +557,10 @@ def to_adata_joint( self._obsm["rep"], )) - if "X_umap" in self._obsm: - adata.uns["umap"] = self._adata_ref.uns["umap"] - if "X_pca" in self._obsm: - adata.uns["pca"] = self._adata_ref.uns["pca"] - adata.varm["PCs"] = self._adata_ref.varm["PCs"] + if keys := _existing_preset_keys(self._adata_ref, _umap_keys): + adata.uns[keys[1]] = self._adata_ref.uns[keys[1]] + if keys := _existing_preset_keys(self._adata_ref, _pca_keys): + adata.varm[keys[1]] = self._adata_ref.varm[keys[1]] + adata.uns[keys[2]] = self._adata_ref.uns[keys[2]] return adata diff --git a/src/scanpy/tools/_tsne.py b/src/scanpy/tools/_tsne.py index 54fb28f436..81fcfe314e 100644 --- a/src/scanpy/tools/_tsne.py +++ b/src/scanpy/tools/_tsne.py @@ -5,7 +5,7 @@ from .. import logging as logg from .._compat import warn from .._docs import doc_rng -from .._settings import settings +from .._settings import Default, Preset, settings from .._utils import _doc_params, raise_not_implemented_error_if_backed_type from .._utils.random import _accepts_legacy_random_state, _legacy_random_state from ..neighbors._doc import doc_n_pcs, doc_use_rep @@ -17,6 +17,14 @@ from .._utils.random import RNGLike, SeedLike +def _tsne_keys(key_added: str | None | Default | Preset) -> tuple[str, str]: + if isinstance(key_added, Default): + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.tsne.key_added + return ("X_tsne", "tsne") if key_added is None else (key_added, key_added) + + @_accepts_legacy_random_state(0) @_doc_params(doc_n_pcs=doc_n_pcs, use_rep=doc_use_rep, rng=doc_rng) def tsne( # noqa: PLR0913 @@ -32,7 +40,7 @@ def tsne( # noqa: PLR0913 rng: SeedLike | RNGLike | None = None, use_fast_tsne: bool = False, n_jobs: int | None = None, - key_added: str | None = None, + key_added: str | None | Default = Default(preset=("tsne", "key_added")), copy: bool = False, ) -> AnnData | None: r"""t-SNE :cite:p:`vanDerMaaten2008,Amir2013,Pedregosa2011`. @@ -102,6 +110,7 @@ def tsne( # noqa: PLR0913 """ start = logg.info("computing tSNE") + key_obsm, key_uns = _tsne_keys(key_added) adata = adata.copy() if copy else adata x = _choose_representation(adata, use_rep=use_rep, n_pcs=n_pcs) raise_not_implemented_error_if_backed_type(x, "tsne") @@ -157,7 +166,6 @@ def tsne( # noqa: PLR0913 use_rep=use_rep, n_components=n_components, ) - key_uns, key_obsm = ("tsne", "X_tsne") if key_added is None else [key_added] * 2 adata.obsm[key_obsm] = x_tsne # annotate samples with tSNE coordinates adata.uns[key_uns] = dict(params={k: v for k, v in params.items() if v is not None}) diff --git a/src/scanpy/tools/_umap.py b/src/scanpy/tools/_umap.py index 316c6c8aed..cd17ebd8be 100644 --- a/src/scanpy/tools/_umap.py +++ b/src/scanpy/tools/_umap.py @@ -7,7 +7,7 @@ from .. import logging as logg from .._docs import doc_rng -from .._settings import settings +from .._settings import Default, Preset, settings from .._utils import NeighborsView, _doc_params from .._utils.random import ( _accepts_legacy_random_state, @@ -27,6 +27,16 @@ type _InitPos = Literal["paga", "spectral", "random"] +def _umap_keys( + key_added: str | None | Default | Preset = Default(), +) -> tuple[str, str]: + if isinstance(key_added, Default): + key_added = settings.preset + if isinstance(key_added, Preset): + key_added = key_added.umap.key_added + return ("X_umap", "umap") if key_added is None else (key_added, key_added) + + @_accepts_legacy_random_state(0) @_doc_params(rng=doc_rng) def umap( # noqa: PLR0913 @@ -44,7 +54,7 @@ def umap( # noqa: PLR0913 a: float | None = None, b: float | None = None, method: Literal["umap"] = "umap", - key_added: str | None = None, + key_added: str | None | Default = Default(preset=("umap", "key_added")), neighbors_key: str = "neighbors", copy: bool = False, ) -> AnnData | None: @@ -144,7 +154,7 @@ def umap( # noqa: PLR0913 rng = np.random.default_rng(rng) adata = adata.copy() if copy else adata - key_obsm, key_uns = ("X_umap", "umap") if key_added is None else [key_added] * 2 + key_obsm, key_uns = _umap_keys(key_added) if neighbors_key is None: # backwards compat neighbors_key = "neighbors" diff --git a/src/scanpy/tools/_utils.py b/src/scanpy/tools/_utils.py index 80083a7023..b705cd0f3c 100644 --- a/src/scanpy/tools/_utils.py +++ b/src/scanpy/tools/_utils.py @@ -7,7 +7,7 @@ from .. import logging as logg from .._compat import warn from .._settings import settings -from .._utils import _choose_graph +from .._utils import _choose_graph, _existing_preset_keys if TYPE_CHECKING: from anndata import AnnData @@ -19,8 +19,8 @@ def _choose_representation( adata: AnnData, *, - use_rep: str | None = None, - n_pcs: int | None = None, + use_rep: str | None, + n_pcs: int | None, silent: bool = False, ) -> np.ndarray | CSRBase: # TODO: what else? verbosity = settings.verbosity @@ -51,30 +51,30 @@ def _choose_representation( def _get_pca_or_small_x(adata: AnnData, n_pcs: int | None) -> np.ndarray | CSRBase: + from ..preprocessing._pca import _pca_keys, pca + if adata.n_vars <= settings.N_PCS: logg.info(" using data matrix X directly") return adata.X - pca_key = next((k for k in ("pca", "X_pca") if k in adata.obsm), None) - if pca_key is not None: + if keys := _existing_preset_keys(adata, _pca_keys): + pca_key, *_ = keys if n_pcs is not None and n_pcs > adata.obsm[pca_key].shape[1]: - msg = f"`{pca_key}` does not have enough PCs. Rerun `sc.pp.pca` with adjusted `n_comps`." + msg = f"`adata.obsm[{pca_key!r}]` does not have enough PCs. Rerun `sc.pp.pca` with adjusted `n_comps`." raise ValueError(msg) x = adata.obsm[pca_key][:, :n_pcs] logg.info(f" using {pca_key!r} with n_pcs = {x.shape[1]}") return x - from ..preprocessing import pca - msg = ( f"You’re trying to run this on {adata.n_vars} dimensions of `.X`, " - "if you really want this, set `use_rep='X'`.\n " + "if you really want this, set `use_rep=’X’`.\n " "Falling back to preprocessing with `sc.pp.pca` and default params." ) warn(msg, UserWarning) n_pcs_pca = n_pcs if n_pcs is not None else settings.N_PCS pca(adata, n_comps=n_pcs_pca) - return adata.obsm[settings.preset.pca.key_added or "X_pca"] + return adata.obsm[_pca_keys()[0]] def get_init_pos_from_paga( diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 436a72db0e..f1830aaedb 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -7,6 +7,7 @@ from numpy.testing import assert_array_almost_equal, assert_array_equal, assert_raises import scanpy as sc +from scanpy._settings import Default from testing.scanpy._helpers.data import pbmc68k_reduced from testing.scanpy._pytest.marks import needs @@ -109,3 +110,54 @@ def test_diffmap( assert_array_equal(d1, d2) with subtests.test("different embedding"): assert_raises(AssertionError, assert_array_equal, d1, d3) + + +@pytest.mark.parametrize( + ("key_added", "key_obsm", "key_uns"), + [ + pytest.param(None, "X_diffmap", "diffmap_evals", id="None"), + pytest.param("custom_key", "custom_key", "custom_key", id="custom_key"), + pytest.param(sc.Preset.ScanpyV1, "X_diffmap", "diffmap_evals", id="v1"), + pytest.param( + *(sc.Preset.ScanpyV2Preview, "diffmap", "diffmap"), + marks=[needs.igraph, needs.skmisc], + id="v2", + ), + ], +) +def test_diffmap_key_added( + key_added: str | None | Default | sc.Preset, key_obsm: str, key_uns: str +) -> None: + pbmc = pbmc68k_reduced()[:300, :100].copy() + if isinstance(key_added, sc.Preset): + sc.settings.preset = key_added + key_added = Default() + adata = sc.tl.diffmap(pbmc, key_added=key_added, copy=True) + assert key_obsm in adata.obsm + assert key_uns in adata.uns + + +@needs.igraph +@pytest.mark.parametrize( + ("key_added", "key_obsm", "key_uns"), + [ + pytest.param(None, "X_draw_graph_fr", "draw_graph", id="None"), + pytest.param("custom_{layout}", "custom_fr", "custom_fr", id="custom_template"), + pytest.param(sc.Preset.ScanpyV1, "X_draw_graph_fr", "draw_graph", id="v1"), + pytest.param( + *(sc.Preset.ScanpyV2Preview, "graph_fr", "graph_fr"), + marks=needs.skmisc, + id="v2", + ), + ], +) +def test_draw_graph_key_added( + key_added: str | None | Default | sc.Preset, key_obsm: str, key_uns: str +) -> None: + pbmc = pbmc68k_reduced()[:100, :100].copy() + if isinstance(key_added, sc.Preset): + sc.settings.preset = key_added + key_added = Default() + adata = sc.tl.draw_graph(pbmc, layout="fr", key_added=key_added, copy=True) + assert key_obsm in adata.obsm + assert key_uns in adata.uns diff --git a/tests/test_sim.py b/tests/test_sim.py index ac96966cb1..fb1fe9b8e7 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -6,9 +6,8 @@ import scanpy as sc +@pytest.mark.filterwarnings("ignore:.*Observation names are not unique:UserWarning") def test_sim_toggleswitch() -> None: - with pytest.warns(UserWarning, match=r"Observation names are not unique"): - adata_sim = sc.tl.sim("toggleswitch") - with pytest.warns(UserWarning, match=r"Observation names are not unique"): - adata_ds = sc.datasets.toggleswitch() + adata_sim = sc.tl.sim("toggleswitch") + adata_ds = sc.datasets.toggleswitch() np.allclose(adata_sim.X, adata_ds.X, np.finfo(np.float32).eps)