Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,13 +1247,31 @@ def _reduce(

The numeric_only, min_count
"""
if min_count != 0:
raise NotImplementedError(
"min_count parameter is not implemented yet"
)
if numeric_only:
return self._reduce_numeric_only(op)
return self.agg(op)
result = self.agg(op)
if min_count and min_count > 0:
# Mask result rows where the per-group non-null count is less
# than ``min_count``.
counts = self.agg("count")
from cudf.core.dataframe import DataFrame
from cudf.core.series import Series

if isinstance(result, DataFrame):
for col_name in result._column_names:
if col_name not in counts._column_names:
continue
count_col = counts._data[col_name]
mask = count_col < min_count
result[col_name] = result[col_name].where(
~Series._from_column(mask), None
)
Comment thread
galipremsagar marked this conversation as resolved.
Outdated
elif isinstance(result, Series):
count_series = (
counts if isinstance(counts, Series) else counts.iloc[:, 0]
)
result = result.where(count_series >= min_count, None)
return result

def _scan(self, op: str, *args, **kwargs):
"""{op_name} for each group."""
Expand Down
25 changes: 25 additions & 0 deletions python/cudf/cudf/tests/groupby/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,3 +1189,28 @@ def test_string_groupby_key_index():
got = gdf.groupby("a", sort=True).count()

assert_eq(expect, got, check_dtype=False)


@pytest.mark.parametrize("op", ["sum", "min", "max", "first", "last"])
@pytest.mark.parametrize("min_count", [0, 1, 2, 3, 5])
def test_groupby_reduce_min_count(op, min_count):
pdf = pd.DataFrame(
{"a": [1, 1, 2, 2, 3], "b": [1.0, 2.0, 3.0, np.nan, 5.0]}
)
gdf = cudf.from_pandas(pdf)
with cudf.option_context("mode.pandas_compatible", True):
got = getattr(gdf.groupby("a"), op)(min_count=min_count)
expect = getattr(pdf.groupby("a"), op)(min_count=min_count)
assert_eq(expect, got)


@pytest.mark.parametrize("min_count", [0, 2, 3])
def test_groupby_series_reduce_min_count(min_count):
psr = pd.Series([1.0, 2.0, 3.0, 4.0, 5.0])
pkeys = pd.Series([1, 1, 2, 2, 3])
gsr = cudf.from_pandas(psr)
gkeys = cudf.from_pandas(pkeys)
with cudf.option_context("mode.pandas_compatible", True):
got = gsr.groupby(gkeys).sum(min_count=min_count)
expect = psr.groupby(pkeys).sum(min_count=min_count)
assert_eq(expect, got)
Loading