Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 76 additions & 6 deletions python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,35 @@ def __exit__(self, exc_type, exc_val, exc_tb):
return False


def _collect_series_key_column_names(obj, by) -> list[Hashable]:
"""Identify, for each Series grouping key in ``by``, the name of the
corresponding column in ``obj`` whose underlying column object is
identical to the Series' column. Returns a list (one entry per Series
key, in order) of column names or ``None``. Non-Series keys produce no
entry. The check uses object identity to mirror pandas' behavior of
excluding such columns from aggregation values.

Only applies when ``obj`` is a DataFrame: for Series inputs, the single
column *is* the value column, so identity-based exclusion would empty
the aggregation result.
"""
import cudf

result: list[Hashable] = []
if not isinstance(obj, cudf.DataFrame):
return result
by_list = by if isinstance(by, list) else [by]
for key in by_list:
if isinstance(key, cudf.Series):
matched = None
for col_name, col in obj._column_labels_and_values:
if col is key._column:
matched = col_name
break
result.append(matched)
return result


class GroupBy(Serializable, Reducible, Scannable):
obj: Series | DataFrame

Expand Down Expand Up @@ -523,6 +552,11 @@ def __init__(
dropna : bool, optional
If True (default), do not include the "null" group.
"""
# Determine which column names in `obj` correspond to the grouping
# key Series by column identity (mirrors pandas' behavior).
# Must be done before ``nans_to_nulls`` which breaks identity.
by_series_col_names = _collect_series_key_column_names(obj, by)

if get_option("mode.pandas_compatible"):
obj = obj.nans_to_nulls()
self.obj = obj
Expand All @@ -537,7 +571,9 @@ def __init__(
self._by._obj = self.obj
self.grouping = self._by
else:
self.grouping = _Grouping(obj, self._by, level)
self.grouping = _Grouping(
obj, self._by, level, by_series_col_names
)

self._groupby_manager = _GroupByContextManager(
self.grouping, self._dropna
Expand Down Expand Up @@ -702,7 +738,8 @@ def size(self) -> Series:
.groupby(self.grouping, sort=self._sort, dropna=self._dropna)
.agg("size")
)
if isinstance(getattr(self.obj, "dtype", None), pd.ArrowDtype):
obj_dtype = getattr(self.obj, "dtype", None)
if isinstance(obj_dtype, pd.ArrowDtype):
# TODO: Remove once groupby.agg preserves pandas extension dtypes.
arrow_dtype = pd.ArrowDtype(pa.int64())
if isinstance(result, Series):
Expand All @@ -713,6 +750,23 @@ def size(self) -> Series:
result._data["size"] = ColumnBase.create(
result._data["size"].plc_column, arrow_dtype
)
elif (
isinstance(obj_dtype, pd.StringDtype)
and obj_dtype.storage == "pyarrow"
and obj_dtype.na_value is pd.NA
):
# Series.groupby.size() on ``string[pyarrow]`` returns Int64.
int64_dtype = pd.Int64Dtype()
if isinstance(result, Series):
result = Series._from_column(
ColumnBase.create(result._column.plc_column, int64_dtype),
name=result.name,
index=result.index,
)
elif "size" in result._column_names:
result._data["size"] = ColumnBase.create(
result._data["size"].plc_column, int64_dtype
)
if not self._as_index:
result = result.rename("size").reset_index()
return result
Expand Down Expand Up @@ -1083,9 +1137,14 @@ def agg(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
)
# Override for specific aggregation types that need dtype adjustments
if agg_kind in {"COUNT", "SIZE", "ARGMIN", "ARGMAX"}:
cast_dtype = get_dtype_of_same_kind(
orig_dtype, np.dtype(np.int64)
)
if isinstance(orig_dtype, pd.StringDtype):
cast_dtype = np.dtype(np.int64)
else:
cast_dtype = get_dtype_of_same_kind(
orig_dtype, np.dtype(np.int64)
)
elif agg_kind == "NUNIQUE":
cast_dtype = np.dtype(np.int64)
elif (
(
isinstance(agg_name, str)
Expand Down Expand Up @@ -3464,14 +3523,18 @@ def __init__(


class _Grouping(Serializable):
def __init__(self, obj, by=None, level=None):
def __init__(self, obj, by=None, level=None, series_key_column_names=None):
self._obj = obj
self._key_columns = []
self.names = []

# Need to keep track of named key columns
# to support `as_index=False` correctly
self._named_columns = []
# For each Series-typed grouping key (in order), the name of the
# ``obj`` column that the Series' underlying column is identical
# to (or ``None`` if the Series is unrelated to any column).
self._series_key_column_names = list(series_key_column_names or [])
self._handle_by_or_level(by, level)

if len(obj) and not len(self._key_columns):
Expand Down Expand Up @@ -3553,6 +3616,13 @@ def _handle_series(self, by):
by = by._align_to_index(self._obj.index, how="right")
self._key_columns.append(by._column)
self.names.append(by.name)
# Mirror pandas: if the grouping Series' underlying column was one
# of the obj's columns (identity checked before any transformation),
# exclude that column name from value columns during aggregation.
if self._series_key_column_names:
col_name = self._series_key_column_names.pop(0)
if col_name is not None:
self._named_columns.append(col_name)

def _handle_index(self, by):
self._key_columns.extend(by._columns)
Expand Down
Loading
Loading