Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/4076.feat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add preset support to {func}`scanpy.tl.diffmap`’s and {func}`scanpy.tl.draw_graph`’s `key_added` parameter {smaller}`P Angerer`
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ filterwarnings = [
"ignore:FNV hashing is not implemented in Numba.*:UserWarning",
# we want to see and eventually fix these
"default::numba.core.errors.NumbaPerformanceWarning",
"default:.*TSNE.*random.*to.*pca:FutureWarning", # we should set init=obsm["X_pca"] or so
# we should set init=obsm["X_pca"] or so
"default:.*TSNE.*random.*to.*pca:FutureWarning",
# igraph vs leidenalg warning
"ignore:The `igraph` implementation of leiden clustering:UserWarning",
# everybody uses this zarr 3 feature, including us, XArray, lots of data out there …
Expand Down
22 changes: 21 additions & 1 deletion src/scanpy/_settings/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,14 @@ class HVGPreset(NamedTuple):
return_df: bool


class PcaPreset(NamedTuple):
class BasicEmbeddingPreset(NamedTuple):
key_added: str | None


# replace once they diverge
PcaPreset = DiffmapPreset = DrawGraphPreset = BasicEmbeddingPreset


class RankGenesGroupsPreset(NamedTuple):
method: DETest
mask_var: str | None
Expand Down Expand Up @@ -193,6 +197,22 @@ def pca() -> Mapping[Preset, PcaPreset]:
Preset.ScanpyV2Preview: PcaPreset(key_added="pca"),
}

@preset_property
def diffmap() -> Mapping[Preset, DiffmapPreset]:
"""Settings for :func:`~scanpy.tl.diffmap`.""" # noqa: D401
return {
Preset.ScanpyV1: DiffmapPreset(key_added=None),
Preset.ScanpyV2Preview: DiffmapPreset(key_added="diffmap"),
}

@preset_property
def draw_graph() -> Mapping[Preset, DrawGraphPreset]:
"""Settings for :func:`~scanpy.tl.draw_graph`.""" # noqa: D401
return {
Preset.ScanpyV1: DrawGraphPreset(key_added=None),
Preset.ScanpyV2Preview: DrawGraphPreset(key_added="graph_{layout}"),
}

@preset_property
def rank_genes_groups() -> Mapping[Preset, RankGenesGroupsPreset]:
"""Correlation method for :func:`~scanpy.tl.rank_genes_groups`."""
Expand Down
13 changes: 13 additions & 0 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
"NeighborsView",
"_choose_graph",
"_doc_params",
"_existing_preset_keys",
"_numba_thread_limit",
"_resolve_axis",
"annotate_doc_types",
Expand Down Expand Up @@ -573,6 +574,18 @@ def get_literal_vals(typ: UnionType | TypeAliasType | Any) -> KeysView[Any]:
raise TypeError(msg)


def _existing_preset_keys[T: tuple[str, ...]](
adata: AnnData, keys: Callable[..., T]
) -> T | None:
from .._settings import Preset

for preset in (Preset.ScanpyV1, Preset.ScanpyV2Preview):
obsm_key, *rest = keys(preset)
if obsm_key in adata.obsm:
return (obsm_key, *rest)
return None


# --------------------------------------------------------------------------------
# Others
# --------------------------------------------------------------------------------
Expand Down
25 changes: 15 additions & 10 deletions src/scanpy/experimental/pp/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from anndata import AnnData

from ... import logging as logg
from ... import settings
from ..._compat import CSBase, warn
from ..._settings import Default
from ..._utils import _doc_params, check_nonnegative_integers, view_to_actual
Expand Down Expand Up @@ -207,19 +208,23 @@ def normalize_pearson_residuals_pca(
`.uns['pearson_residuals_normalization']['clip']`
The used value of the clipping parameter.

`.obsm['X_pca']`
`.obsm[kwargs_pca.get('key_added', 'X_pca')]`
PCA representation of data after gene selection (if applicable) and Pearson
residual normalization.
`.varm['PCs']`
`.varm[kwargs_pca.get('key_added', 'PCs')]`
The principal components containing the loadings. When `inplace=True` and
`mask_var is not None`, this will contain empty rows for the genes not
selected.
`.uns['pca']['variance_ratio']`
`.uns[kwargs_pca.get('key_added', 'pca')]['variance_ratio']`
Ratio of explained variance.
`.uns['pca']['variance']`
`.uns[kwargs_pca.get('key_added', 'pca')]['variance']`
Explained variance, equivalent to the eigenvalues of the covariance matrix.

"""
key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added)
key_obsm, key_varm, key_uns = (
("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3
)
if isinstance(mask_var, Default):
mask_var = "highly_variable" if "highly_variable" in adata.var else None
mask_var = _check_mask(adata, mask_var, "var")
Expand All @@ -236,19 +241,19 @@ def normalize_pearson_residuals_pca(
adata_pca, theta=theta, clip=clip, check_values=check_values
)
pca(adata_pca, n_comps=n_comps, rng=rng, **kwargs_pca)
n_comps = adata_pca.obsm["X_pca"].shape[1] # might be None
n_comps = adata_pca.obsm[key_obsm].shape[1] # might be None

if inplace:
norm_settings = adata_pca.uns["pearson_residuals_normalization"]
norm_dict = dict(**norm_settings, pearson_residuals_df=adata_pca.to_df())
if mask_var is not None:
adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, n_comps))
adata.varm["PCs"][mask_var] = adata_pca.varm["PCs"]
adata.varm[key_varm] = np.zeros(shape=(adata.n_vars, n_comps))
adata.varm[key_varm][mask_var] = adata_pca.varm[key_varm]
else:
adata.varm["PCs"] = adata_pca.varm["PCs"]
adata.uns["pca"] = adata_pca.uns["pca"]
adata.varm[key_varm] = adata_pca.varm[key_varm]
adata.uns[key_uns] = adata_pca.uns[key_uns]
adata.uns["pearson_residuals_normalization"] = norm_dict
adata.obsm["X_pca"] = adata_pca.obsm["X_pca"]
adata.obsm[key_obsm] = adata_pca.obsm[key_obsm]
return None
else:
return adata_pca
22 changes: 13 additions & 9 deletions src/scanpy/experimental/pp/_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import numpy as np

from ... import experimental
from ... import experimental, settings
from ..._utils import _doc_params
from ..._utils.random import _accepts_legacy_random_state
from ...experimental._docs import (
Expand Down Expand Up @@ -103,18 +103,22 @@ def recipe_pearson_residuals( # noqa: PLR0913
`.uns['pearson_residuals_normalization']['clip']`
The used value of the clipping parameter.

`.obsm['X_pca']`
`.obsm[kwargs_pca.get('key_added', 'X_pca')]`
PCA representation of data after gene selection and Pearson residual
normalization.
`.varm['PCs']`
`.varm[kwargs_pca.get('key_added', 'PCs')]`
The principal components containing the loadings. When `inplace=True` this
will contain empty rows for the genes not selected during HVG selection.
`.uns['pca']['variance_ratio']`
`.uns[kwargs_pca.get('key_added', 'pca')]['variance_ratio']`
Ratio of explained variance.
`.uns['pca']['variance']`
`.uns[kwargs_pca.get('key_added', 'pca')]['variance']`
Explained variance, equivalent to the eigenvalues of the covariance matrix.

"""
key_added = kwargs_pca.get("key_added", settings.preset.pca.key_added)
key_obsm, key_varm, key_uns = (
("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3
)
hvg_args = dict(
flavor="pearson_residuals",
n_top_genes=n_top_genes,
Expand Down Expand Up @@ -145,11 +149,11 @@ def recipe_pearson_residuals( # noqa: PLR0913
**normalization_param, pearson_residuals_df=adata_pca.to_df()
)

adata.uns["pca"] = adata_pca.uns["pca"]
adata.varm["PCs"] = np.zeros(shape=(adata.n_vars, n_comps))
adata.varm["PCs"][adata.var["highly_variable"]] = adata_pca.varm["PCs"]
adata.uns[key_uns] = adata_pca.uns[key_uns]
adata.varm[key_varm] = np.zeros(shape=(adata.n_vars, n_comps))
adata.varm[key_varm][adata.var["highly_variable"]] = adata_pca.varm[key_varm]
adata.uns["pearson_residuals_normalization"] = normalization_dict
adata.obsm["X_pca"] = adata_pca.obsm["X_pca"]
adata.obsm[key_obsm] = adata_pca.obsm[key_obsm]
return None
else:
return adata_pca, hvg
30 changes: 12 additions & 18 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .._compat import CSBase, CSRBase, SpBase, pkg_version, warn
from .._docs import doc_rng
from .._settings import settings
from .._utils import NeighborsView, _doc_params, get_literal_vals
from .._utils import NeighborsView, _doc_params, _existing_preset_keys, get_literal_vals
from .._utils.random import (
_accepts_legacy_random_state,
_legacy_random_state,
Expand Down Expand Up @@ -322,20 +322,6 @@ class FlatTree(NamedTuple): # noqa: D101
indices: None


def _backwards_compat_get_full_x_diffmap(adata: AnnData) -> np.ndarray:
if "X_diffmap0" in adata.obs:
return np.c_[adata.obs["X_diffmap0"].values[:, None], adata.obsm["X_diffmap"]]
else:
return adata.obsm["X_diffmap"]


def _backwards_compat_get_full_eval(adata: AnnData):
if "X_diffmap0" in adata.obs:
return np.r_[1, adata.uns["diffmap_evals"]]
else:
return adata.uns["diffmap_evals"]


def _make_forest_dict(forest):
d = {}
props = ("hyperplanes", "offsets", "children", "indices")
Expand Down Expand Up @@ -432,6 +418,7 @@ def __init__( # noqa: PLR0912, PLR0915
*,
n_dcs: int | None = None,
neighbors_key: str | None = None,
diffmap_key: str | None = None,
) -> None:
self._adata = adata
self._init_iroot()
Expand Down Expand Up @@ -484,9 +471,16 @@ def count_nonzero(a: np.ndarray | CSRBase) -> int:

self._connected_components = connected_components(self._connectivities)
self._number_connected_components = self._connected_components[0]
if "X_diffmap" in adata.obsm:
self._eigen_values = _backwards_compat_get_full_eval(adata)
self._eigen_basis = _backwards_compat_get_full_x_diffmap(adata)

from ..tools._dpt import _diffmap_keys

if keys := (
(diffmap_key, diffmap_key)
if diffmap_key
else _existing_preset_keys(adata, _diffmap_keys)
):
self._eigen_values = adata.uns[keys[1]]
self._eigen_basis = adata.obsm[keys[0]]
if n_dcs is not None:
if n_dcs > len(self._eigen_values):
msg = (
Expand Down
6 changes: 4 additions & 2 deletions src/scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from matplotlib.markers import MarkerStyle
from scverse_misc import Deprecation, deprecated

from scanpy.preprocessing._pca import _pca_keys

from ... import logging as logg
from ..._settings import Default, settings
from ..._utils import _doc_params, sanitize_anndata
from ..._utils import _doc_params, _existing_preset_keys, sanitize_anndata
from ..._utils._doctests import doctest_internet
from ...get import _check_mask
from .. import _utils
Expand Down Expand Up @@ -925,7 +927,7 @@ def pca(
return embedding(
adata, "pca", show=show, return_fig=return_fig, save=save, **kwargs
)
if "pca" not in adata.obsm and "X_pca" not in adata.obsm:
if not _existing_preset_keys(adata, _pca_keys):
msg = (
f"Could not find entry in `obsm` for 'pca'.\n"
f"Available keys are: {list(adata.obsm.keys())}."
Expand Down
22 changes: 16 additions & 6 deletions src/scanpy/preprocessing/_pca/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ... import logging as logg
from ..._compat import CSBase, DaskArray, warn
from ..._docs import doc_rng
from ..._settings import Default, settings
from ..._settings import Default, Preset, settings
from ..._utils import _doc_params, get_literal_vals, is_backed_type
from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state
from ...get import _check_mask, _get_obs_rep
Expand Down Expand Up @@ -47,6 +47,20 @@
type SvdSolver = SvdSolvDaskML | SvdSolvSkearn | SvdSolvPCACustom


def _pca_keys(
key_added: str | None | Default | Preset = Default(),
) -> tuple[str, str, str]:
if isinstance(key_added, Default):
key_added = settings.preset
if isinstance(key_added, Preset):
key_added = key_added.pca.key_added
return (
("X_pca", "PCs", "pca")
if key_added is None
else (key_added, key_added, key_added)
)


@_doc_params(mask_var=doc_mask_var, rng=doc_rng)
@_accepts_legacy_random_state(0)
def pca( # noqa: PLR0912, PLR0913, PLR0915
Expand Down Expand Up @@ -203,8 +217,6 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915
# Current chunking implementation relies on pca being called on X
msg = "Cannot use `layer`/`obsm` and `chunked` at the same time."
raise NotImplementedError(msg)
if isinstance(key_added, Default):
key_added = settings.preset.pca.key_added

# chunked calculation is not randomized, anyways
if svd_solver in {"auto", "randomized"} and not chunked:
Expand Down Expand Up @@ -337,9 +349,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915
x_pca = x_pca.astype(dtype)

if return_anndata:
key_obsm, key_varm, key_uns = (
("X_pca", "PCs", "pca") if key_added is None else [key_added] * 3
)
key_obsm, key_varm, key_uns = _pca_keys(key_added)
adata.obsm[key_obsm] = x_pca

if obsm:
Expand Down
16 changes: 13 additions & 3 deletions src/scanpy/tools/_diffmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from .._docs import doc_rng
from .._settings import Default
from .._utils import _doc_params
from .._utils.random import _accepts_legacy_random_state
from ._dpt import _diffmap
Expand All @@ -22,6 +23,7 @@ def diffmap(
n_comps: int = 15,
*,
neighbors_key: str | None = None,
key_added: str | None | Default = Default(preset=("diffmap", "key_added")),
rng: SeedLike | RNGLike | None = None,
copy: bool = False,
) -> AnnData | None:
Expand Down Expand Up @@ -55,6 +57,8 @@ def diffmap(
.obsp[.uns[neighbors_key]['connectivities_key']] and
.obsp[.uns[neighbors_key]['distances_key']] for connectivities and distances,
respectively.
key_added
Control where the embedding and eigenvalues are stored.
{rng}
copy
Return a copy instead of writing to adata.
Expand All @@ -63,11 +67,11 @@ def diffmap(
-------
Returns `None` if `copy=False`, else returns an `AnnData` object. Sets the following fields:

`adata.obsm['X_diffmap']` : :class:`numpy.ndarray` (dtype `float`)
`adata.obsm['X_diffmap' | key_added]` : :class:`numpy.ndarray` (dtype `float`)
Diffusion map representation of data, which is the right eigen basis of
the transition matrix with eigenvectors as columns.

`adata.uns['diffmap_evals']` : :class:`numpy.ndarray` (dtype `float`)
`adata.uns['diffmap_evals' | key_added]` : :class:`numpy.ndarray` (dtype `float`)
Array of size (number of eigen vectors).
Eigenvalues of transition matrix.

Expand All @@ -90,5 +94,11 @@ def diffmap(
msg = "Provide any value greater than 2 for `n_comps`. "
raise ValueError(msg)
adata = adata.copy() if copy else adata
_diffmap(adata, n_comps=n_comps, neighbors_key=neighbors_key, rng=rng)
_diffmap(
adata,
n_comps=n_comps,
neighbors_key=neighbors_key,
key_added=key_added,
rng=rng,
)
return adata if copy else None
Loading
Loading