Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/4076.feat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add preset support to {func}`scanpy.tl.diffmap`’s and {func}`scanpy.tl.draw_graph`’s `key_added` parameter {smaller}`P Angerer`
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ filterwarnings = [
"ignore:FNV hashing is not implemented in Numba.*:UserWarning",
# we want to see and eventually fix these
"default::numba.core.errors.NumbaPerformanceWarning",
"default:.*TSNE.*random.*to.*pca:FutureWarning", # we should set init=obsm["X_pca"] or so
# we should set init=obsm["X_pca"] or so
"default:.*TSNE.*random.*to.*pca:FutureWarning",
# igraph vs leidenalg warning
"ignore:The `igraph` implementation of leiden clustering:UserWarning",
# everybody uses this zarr 3 feature, including us, XArray, lots of data out there …
Expand Down
40 changes: 39 additions & 1 deletion src/scanpy/_settings/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,16 @@ class HVGPreset(NamedTuple):
return_df: bool


class PcaPreset(NamedTuple):
class BasicEmbeddingPreset(NamedTuple):
key_added: str | None


# replace once they diverge
PcaPreset = UmapPreset = TsnePreset = DiffmapPreset = DrawGraphPreset = (
BasicEmbeddingPreset
)


class RankGenesGroupsPreset(NamedTuple):
method: DETest
mask_var: str | None
Expand Down Expand Up @@ -193,6 +199,38 @@ def pca() -> Mapping[Preset, PcaPreset]:
Preset.ScanpyV2Preview: PcaPreset(key_added="pca"),
}

@preset_property
def umap() -> Mapping[Preset, UmapPreset]:
"""Settings for :func:`~scanpy.tl.umap`.""" # noqa: D401
return {
Preset.ScanpyV1: DiffmapPreset(key_added=None),
Preset.ScanpyV2Preview: DiffmapPreset(key_added="umap"),
}

@preset_property
def tsne() -> Mapping[Preset, UmapPreset]:
"""Settings for :func:`~scanpy.tl.tsne`.""" # noqa: D401
return {
Preset.ScanpyV1: DiffmapPreset(key_added=None),
Preset.ScanpyV2Preview: DiffmapPreset(key_added="tsne"),
}

@preset_property
def diffmap() -> Mapping[Preset, DiffmapPreset]:
"""Settings for :func:`~scanpy.tl.diffmap`.""" # noqa: D401
return {
Preset.ScanpyV1: DiffmapPreset(key_added=None),
Preset.ScanpyV2Preview: DiffmapPreset(key_added="diffmap"),
}

@preset_property
def draw_graph() -> Mapping[Preset, DrawGraphPreset]:
"""Settings for :func:`~scanpy.tl.draw_graph`.""" # noqa: D401
return {
Preset.ScanpyV1: DrawGraphPreset(key_added=None),
Preset.ScanpyV2Preview: DrawGraphPreset(key_added="graph_{layout}"),
}

@preset_property
def rank_genes_groups() -> Mapping[Preset, RankGenesGroupsPreset]:
"""Correlation method for :func:`~scanpy.tl.rank_genes_groups`."""
Expand Down
28 changes: 27 additions & 1 deletion src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, KeysView, Mapping
from pathlib import Path
from typing import Any
from typing import Any, Concatenate

from anndata import AnnData
from igraph import Graph
from numpy.typing import ArrayLike, NDArray
from pandas._typing import Dtype as PdDtype

from .._compat import CSRBase
from .._settings import Preset
from ..neighbors import NeighborsParams, RPForestDict

type _MemoryArray = NDArray | CSBase
Expand All @@ -58,6 +59,8 @@
"NeighborsView",
"_choose_graph",
"_doc_params",
"_existing_preset_keys",
"_get_basis",
"_numba_thread_limit",
"_resolve_axis",
"annotate_doc_types",
Expand Down Expand Up @@ -573,6 +576,29 @@ def get_literal_vals(typ: UnionType | TypeAliasType | Any) -> KeysView[Any]:
raise TypeError(msg)


def _get_basis_key(adata: AnnData, basis: str) -> str | None:
if basis in adata.obsm:
return basis
if f"X_{basis}" in adata.obsm:
return f"X_{basis}"
return None


def _existing_preset_keys[**P, T: tuple[str, ...]](
adata: AnnData,
keys: Callable[Concatenate[Preset, P], T],
*args: P.args,
**kw: P.kwargs,
) -> T | None:
from .._settings import Preset

for preset in (Preset.ScanpyV1, Preset.ScanpyV2Preview):
obsm_key, *rest = keys(preset, *args, **kw)
if obsm_key in adata.obsm:
return (obsm_key, *rest)
return None


# --------------------------------------------------------------------------------
# Others
# --------------------------------------------------------------------------------
Expand Down
25 changes: 14 additions & 11 deletions src/scanpy/experimental/pp/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from anndata import AnnData

from ... import logging as logg
from ... import settings
from ..._compat import CSBase, warn
from ..._settings import Default
from ..._utils import _doc_params, check_nonnegative_integers, view_to_actual
Expand All @@ -22,7 +23,7 @@
)
from ...get import _check_mask, _get_obs_rep, _set_obs_rep
from ...preprocessing._docs import doc_mask_var
from ...preprocessing._pca import pca
from ...preprocessing._pca import _pca_keys, pca

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -207,19 +208,21 @@ def normalize_pearson_residuals_pca(
`.uns['pearson_residuals_normalization']['clip']`
The used value of the clipping parameter.

`.obsm['X_pca']`
`.obsm[kwargs_pca.get('key_added', 'X_pca')]`
PCA representation of data after gene selection (if applicable) and Pearson
residual normalization.
`.varm['PCs']`
`.varm[kwargs_pca.get('key_added', 'PCs')]`
The principal components containing the loadings. When `inplace=True` and
`mask_var is not None`, this will contain empty rows for the genes not
selected.
`.uns['pca']['variance_ratio']`
`.uns[kwargs_pca.get('key_added', 'pca')]['variance_ratio']`
Ratio of explained variance.
`.uns['pca']['variance']`
`.uns[kwargs_pca.get('key_added', 'pca')]['variance']`
Explained variance, equivalent to the eigenvalues of the covariance matrix.

"""
key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added)
key_obsm, key_varm, key_uns = _pca_keys(key_added)
if isinstance(mask_var, Default):
mask_var = "highly_variable" if "highly_variable" in adata.var else None
mask_var = _check_mask(adata, mask_var, "var")
Expand All @@ -236,19 +239,19 @@ def normalize_pearson_residuals_pca(
adata_pca, theta=theta, clip=clip, check_values=check_values
)
pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca)
n_comps = adata_pca.obsm["X_pca"].shape[1] # might be None
n_comps = adata_pca.obsm[key_obsm].shape[1] # might be None

if inplace:
norm_settings = adata_pca.uns["pearson_residuals_normalization"]
norm_dict = dict(**norm_settings, pearson_residuals_df=adata_pca.to_df())
if mask_var is not None:
adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, n_comps))
adata.varm["PCs"][mask_var] = adata_pca.varm["PCs"]
adata.varm[key_varm] = np.zeros(shape=(adata.n_vars, n_comps))
adata.varm[key_varm][mask_var] = adata_pca.varm[key_varm]
else:
adata.varm["PCs"] = adata_pca.varm["PCs"]
adata.uns["pca"] = adata_pca.uns["pca"]
adata.varm[key_varm] = adata_pca.varm[key_varm]
adata.uns[key_uns] = adata_pca.uns[key_uns]
adata.uns["pearson_residuals_normalization"] = norm_dict
adata.obsm["X_pca"] = adata_pca.obsm["X_pca"]
adata.obsm[key_obsm] = adata_pca.obsm[key_obsm]
return None
else:
return adata_pca
21 changes: 12 additions & 9 deletions src/scanpy/experimental/pp/_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from ... import experimental
from ... import experimental, settings
from ..._utils import _doc_params
from ..._utils.random import _accepts_legacy_random_state
from ...experimental._docs import (
Expand All @@ -17,6 +17,7 @@
doc_pca_chunk,
)
from ...preprocessing import pca
from ...preprocessing._pca import _pca_keys

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down Expand Up @@ -103,18 +104,20 @@ def recipe_pearson_residuals( # noqa: PLR0913
`.uns['pearson_residuals_normalization']['clip']`
The used value of the clipping parameter.

`.obsm['X_pca']`
`.obsm[kwargs_pca.get('key_added', 'X_pca')]`
PCA representation of data after gene selection and Pearson residual
normalization.
`.varm['PCs']`
`.varm[kwargs_pca.get('key_added', 'PCs')]`
The principal components containing the loadings. When `inplace=True` this
will contain empty rows for the genes not selected during HVG selection.
`.uns['pca']['variance_ratio']`
`.uns[kwargs_pca.get('key_added', 'pca')]['variance_ratio']`
Ratio of explained variance.
`.uns['pca']['variance']`
`.uns[kwargs_pca.get('key_added', 'pca')]['variance']`
Explained variance, equivalent to the eigenvalues of the covariance matrix.

"""
key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added)
key_obsm, key_varm, key_uns = _pca_keys(key_added)
hvg_args = dict(
flavor="pearson_residuals",
n_top_genes=n_top_genes,
Expand Down Expand Up @@ -145,11 +148,11 @@ def recipe_pearson_residuals( # noqa: PLR0913
**normalization_param, pearson_residuals_df=adata_pca.to_df()
)

adata.uns["pca"] = adata_pca.uns["pca"]
adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, n_comps))
adata.varm["PCs"][adata.var["highly_variable"]] = adata_pca.varm["PCs"]
adata.uns[key_uns] = adata_pca.uns[key_uns]
adata.varm[key_varm] = np.zeros(shape=(adata.n_vars, n_comps))
adata.varm[key_varm][adata.var["highly_variable"]] = adata_pca.varm[key_varm]
adata.uns["pearson_residuals_normalization"] = normalization_dict
adata.obsm["X_pca"] = adata_pca.obsm["X_pca"]
adata.obsm[key_obsm] = adata_pca.obsm[key_obsm]
return None
else:
return adata_pca, hvg
16 changes: 11 additions & 5 deletions src/scanpy/external/exporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from fast_array_utils.stats import mean_var
from pandas.api.types import CategoricalDtype

from .._utils import NeighborsView
from .._utils import NeighborsView, _existing_preset_keys
from ..preprocessing._pca import _pca_keys
from ..tools._draw_graph import _draw_graph_keys

if TYPE_CHECKING:
from collections.abc import Iterable, Mapping
Expand Down Expand Up @@ -81,8 +83,12 @@ def spring_project( # noqa: PLR0912, PLR0915
if embedding_method not in adata.obsm:
if f"X_{embedding_method}" in adata.obsm:
embedding_method = f"X_{embedding_method}"
elif embedding_method in adata.uns:
embedding_method = f"X_{embedding_method}_{adata.uns[embedding_method]['params']['layout']}"
elif embedding_method in {"graph", "draw_graph"} and (
keys := _existing_preset_keys(
adata, _draw_graph_keys, adata.uns[embedding_method]["params"]["layout"]
)
):
embedding_method = keys[0]
else:
msg = f"Run the specified embedding method `{embedding_method}` first."
raise ValueError(msg)
Expand Down Expand Up @@ -222,10 +228,10 @@ def spring_project( # noqa: PLR0912, PLR0915
)

# Write some useful intermediates, if they exist
if "X_pca" in adata.obsm:
if keys := _existing_preset_keys(adata, _pca_keys):
np.savez_compressed(
subplot_dir / "intermediates.npz",
Epca=adata.obsm["X_pca"],
Epca=adata.obsm[keys[0]],
total_counts=total_counts,
)

Expand Down
10 changes: 6 additions & 4 deletions src/scanpy/external/tl/_trimap.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from ... import logging as logg
from ..._compat import CSBase
from ..._settings import settings
from ..._utils import _existing_preset_keys
from ..._utils._doctests import doctest_needs
from ...preprocessing._pca import _pca_keys

if TYPE_CHECKING:
from typing import Literal
Expand Down Expand Up @@ -100,9 +102,9 @@ def trimap( # noqa: PLR0913
verbosity = settings.verbosity if verbose is None else verbose
verbose = verbosity if isinstance(verbosity, bool) else verbosity > 0

if "X_pca" in adata.obsm:
n_dim_pca = adata.obsm["X_pca"].shape[1]
x = adata.obsm["X_pca"][:, : min(n_dim_pca, 100)]
if keys := _existing_preset_keys(adata, _pca_keys):
n_dim_pca = adata.obsm[keys[0]].shape[1]
x = adata.obsm[keys[0]][:, : min(n_dim_pca, 100)]
else:
x = adata.X
if isinstance(x, CSBase):
Expand All @@ -111,7 +113,7 @@ def trimap( # noqa: PLR0913
"use a dense matrix or apply pca first."
)
raise ValueError(msg)
logg.warning("`X_pca` not found. Run `sc.pp.pca` first for speedup.")
logg.warning("`pca`/`X_pca` not found. Run `sc.pp.pca` first for speedup.")
x_trimap = TRIMAP(
n_dims=n_components,
n_inliers=n_inliers,
Expand Down
30 changes: 12 additions & 18 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .._compat import CSBase, CSRBase, SpBase, pkg_version, warn
from .._docs import doc_rng
from .._settings import settings
from .._utils import NeighborsView, _doc_params, get_literal_vals
from .._utils import NeighborsView, _doc_params, _existing_preset_keys, get_literal_vals
from .._utils.random import (
_accepts_legacy_random_state,
_legacy_random_state,
Expand Down Expand Up @@ -322,20 +322,6 @@ class FlatTree(NamedTuple): # noqa: D101
indices: None


def _backwards_compat_get_full_x_diffmap(adata: AnnData) -> np.ndarray:
if "X_diffmap0" in adata.obs:
return np.c_[adata.obs["X_diffmap0"].values[:, None], adata.obsm["X_diffmap"]]
else:
return adata.obsm["X_diffmap"]


def _backwards_compat_get_full_eval(adata: AnnData):
if "X_diffmap0" in adata.obs:
return np.r_[1, adata.uns["diffmap_evals"]]
else:
return adata.uns["diffmap_evals"]


def _make_forest_dict(forest):
d = {}
props = ("hyperplanes", "offsets", "children", "indices")
Expand Down Expand Up @@ -432,6 +418,7 @@ def __init__( # noqa: PLR0912, PLR0915
*,
n_dcs: int | None = None,
neighbors_key: str | None = None,
diffmap_key: str | None = None,
) -> None:
self._adata = adata
self._init_iroot()
Expand Down Expand Up @@ -484,9 +471,16 @@ def count_nonzero(a: np.ndarray | CSRBase) -> int:

self._connected_components = connected_components(self._connectivities)
self._number_connected_components = self._connected_components[0]
if "X_diffmap" in adata.obsm:
self._eigen_values = _backwards_compat_get_full_eval(adata)
self._eigen_basis = _backwards_compat_get_full_x_diffmap(adata)

from ..tools._dpt import _diffmap_keys

if keys := (
(diffmap_key, diffmap_key)
if diffmap_key
else _existing_preset_keys(adata, _diffmap_keys)
):
self._eigen_values = adata.uns[keys[1]]
self._eigen_basis = adata.obsm[keys[0]]
if n_dcs is not None:
if n_dcs > len(self._eigen_values):
msg = (
Expand Down
Loading
Loading