diff --git a/python/cudf/cudf/core/groupby/groupby.py b/python/cudf/cudf/core/groupby/groupby.py index 0ad04470a5f..239352bd99b 100644 --- a/python/cudf/cudf/core/groupby/groupby.py +++ b/python/cudf/cudf/core/groupby/groupby.py @@ -1247,13 +1247,13 @@ def _reduce( The numeric_only, min_count """ - if min_count != 0: - raise NotImplementedError( - "min_count parameter is not implemented yet" - ) if numeric_only: return self._reduce_numeric_only(op) - return self.agg(op) + result = self.agg(op) + if min_count and min_count > 0: + counts = self.agg("count") + result = result.where(counts >= min_count, None) + return result def _scan(self, op: str, *args, **kwargs): """{op_name} for each group.""" diff --git a/python/cudf/cudf/tests/groupby/test_reductions.py b/python/cudf/cudf/tests/groupby/test_reductions.py index fc664bae59a..aa82f6ae026 100644 --- a/python/cudf/cudf/tests/groupby/test_reductions.py +++ b/python/cudf/cudf/tests/groupby/test_reductions.py @@ -1189,3 +1189,28 @@ def test_string_groupby_key_index(): got = gdf.groupby("a", sort=True).count() assert_eq(expect, got, check_dtype=False) + + +@pytest.mark.parametrize("op", ["sum", "min", "max", "first", "last"]) +@pytest.mark.parametrize("min_count", [0, 1, 2, 3, 5]) +def test_groupby_reduce_min_count(op, min_count): + pdf = pd.DataFrame( + {"a": [1, 1, 2, 2, 3], "b": [1.0, 2.0, 3.0, np.nan, 5.0]} + ) + gdf = cudf.from_pandas(pdf) + with cudf.option_context("mode.pandas_compatible", True): + got = getattr(gdf.groupby("a"), op)(min_count=min_count) + expect = getattr(pdf.groupby("a"), op)(min_count=min_count) + assert_eq(expect, got) + + +@pytest.mark.parametrize("min_count", [0, 2, 3]) +def test_groupby_series_reduce_min_count(min_count): + psr = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0]) + pkeys = pd.Series([1, 1, 2, 2, 3]) + gsr = cudf.from_pandas(psr) + gkeys = cudf.from_pandas(pkeys) + with cudf.option_context("mode.pandas_compatible", True): + got = gsr.groupby(gkeys).sum(min_count=min_count) + expect = psr.groupby(pkeys).sum(min_count=min_count) + assert_eq(expect, got)