Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/4184.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix single-observation (or `n_obs == dof`) aggregated variance in {func}`scanpy.get.aggregate` {smaller}`Z Boldyga`
17 changes: 11 additions & 6 deletions src/scanpy/get/_aggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from scipy import sparse
from sklearn.utils.sparsefuncs import csc_median_axis_0

from scanpy._compat import CSBase, CSRBase, DaskArray
from scanpy._compat import CSBase, CSRBase, DaskArray, warn

from .._utils import _resolve_axis, get_literal_vals
from ._kernels import (
Expand Down Expand Up @@ -148,7 +148,12 @@ def mean_var(self, dof: int = 1) -> tuple[np.ndarray, np.ndarray]:
mean_var_csr if isinstance(self.data, CSRBase) else mean_var_csc
)(self.indicator_matrix, self.data)
if dof != 0:
var_ *= (group_counts / np.maximum(group_counts - dof, 1))[:, np.newaxis]
denom = np.where(group_counts > dof, group_counts - dof, np.nan)
which_nan = np.isnan(denom)
if which_nan.any():
msg = f"Group counts matches dof, resulting var for groups {self.groupby.categories[which_nan].to_list()} will be nan"
warn(msg, RuntimeWarning)
var_ *= (group_counts / denom)[:, np.newaxis]
return mean_, var_

def median(self) -> Array:
Expand Down Expand Up @@ -427,7 +432,7 @@ def per_block_row(
counts = combined[0]
mean_ = combined[1]
m2 = combined[2]
denom = counts - dof if dof > 0 else counts
denom = da.where(counts > dof, counts - dof, np.nan) if dof > 0 else counts
return MeanVarDict(mean=mean_, var=m2 / denom)


Expand All @@ -452,9 +457,9 @@ def _block_moments(
return out

agg = Aggregate(groupby=by, data=data, mask=mask)
mean_, var_ = agg.mean_var()
# M2 is the variance times the counts directly minus correction.
m2 = var_ * (counts - 1)[:, np.newaxis]
mean_, var_ = agg.mean_var(dof=0)
# M2 (sum of squared deviations) is the population variance times the count.
m2 = var_ * counts[:, np.newaxis]
out[1, :] = mean_
out[2, :] = m2
return out
Expand Down
53 changes: 53 additions & 0 deletions tests/test_aggregated.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import nullcontext
from typing import TYPE_CHECKING

import anndata as ad
Expand Down Expand Up @@ -630,3 +631,55 @@ def test_var_no_catastrophic_cancellation(array_type) -> None:
if isinstance(result, DaskArray):
result = result.compute()
np.testing.assert_allclose(result, expected, rtol=1e-4)


@needs.dask
@pytest.mark.parametrize("dof", [0, 1, 4])
def test_aggregate_var_group_matches_dof(dof: int) -> None:
# Guards that a one-observation (or n_obs==dof) group's variance is nan (not 0), and that a
# group split into single-observation chunks keeps its correct variance
# rather than being corrupted to nan by the dask per-chunk combine.
import dask.array as da

run_size = max(1, dof)
x = np.array(([1.0] * run_size) + ([2.0] * run_size) + ([3.0] * run_size)).reshape(
3 * run_size, 1
)
obs = pd.DataFrame(
{
"group": pd.Categorical(
(["a"] * run_size) + (["b"] * run_size) + (["a"] * run_size)
)
},
index=[f"cell_{i}" for i in range(x.shape[0])],
)
with (
pytest.warns(RuntimeWarning, match=r".*groups \['b'\] will be nan.*")
if dof > 0
else nullcontext()
):
in_memory = sc.get.aggregate(
ad.AnnData(X=x, obs=obs), "group", "var", dof=dof
).layers["var"]
# tests chunks that contain run_size item i.e., chunk matches dof
dask_x = da.from_array(x, chunks=(run_size, -1))
dask = (
sc.get
.aggregate(ad.AnnData(X=dask_x, obs=obs), "group", "var", dof=dof)
.layers["var"]
.compute()
)
# equal_nan=True by default
np.testing.assert_allclose(in_memory, dask)
var = dict(zip(["a", "b"], in_memory[:, 0], strict=True))
for cat in ["a", "b"]:
with (
pytest.warns(
RuntimeWarning, match=r"((Degrees of freedom.*)|(.*invalid value.*))"
)
if dof > 0 and cat == "b"
else nullcontext()
):
np.testing.assert_equal(
var[cat], np.var(x[(obs["group"] == cat).to_numpy()], ddof=dof)
)
Loading