Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions cpp/src/groupby/sort/group_merge_m2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,23 @@ struct merge_fn {
if (partial_n == 0) { continue; }
auto const partial_avg = d_means[idx];
auto const partial_m2 = d_M2s[idx];
auto const new_n = n + partial_n;
auto const delta = partial_avg - avg;
m2 += partial_m2 + delta * delta * n * partial_n / new_n;
avg = (avg * n + partial_avg * partial_n) / new_n;
n = new_n;

// Merging an empty accumulator with a non-empty partial is an identity operation. Running
// the generic formula for this case can evaluate inf * 0 and turn extreme finite partials
// into NaN.
if (n == 0) {
Copy link
Copy Markdown
Contributor

@pmattione-nvidia pmattione-nvidia May 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes but what if the input mean is literally infinity? or it's a NaN? then it should return NaN right? You should also check std::isfinite() here. Or am I misunderstanding what merge m2 is trying to do.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Walking through the cases:

partial_avg = +Inf (with partial_n > 0): the identity branch propagates avg = +Inf, m2 = partial_m2 as-is. The old generic path produced NaN here via delta * delta_n * n * partial_n = +Inf * +Inf * 0 * partial_n = inf*0 — same inf*0=NaN side effect this PR is fixing. Propagating +Inf preserves the upstream "overflowed" signal; coercing to NaN would discard it.

partial_avg = NaN: identity sets avg = NaN; any subsequent merge step propagates NaN through the generic formula (NaN ⊕ anything = NaN). Final result is NaN regardless of partial position, as expected.

In practice Spark's CentralMomentAgg doesn't emit (count, +Inf, m2_finite) partials — Welford hits +Inf - +Inf = NaN on the first overflowing row, so the partial becomes (count, NaN, NaN). So the "+Inf avg" case really only shows up for direct callers of MERGE_M2 with hand-crafted partials, and for those propagation is strictly more informative than coercion.

I pushed 5d917711 (now 071266d after rebase) with regression tests pinning these semantics: NanMeanFirstPartial, InfMeanFirstPartial, and NanMeanMergedWithFinite for both INT64 and FLOAT64 count types. Let me know if there's a Spark scenario where NaN coercion is actually wanted — I'm not seeing one.

n = partial_n;
avg = partial_avg;
m2 = partial_m2;
continue;
}

auto const new_n = n + partial_n;
auto const delta = partial_avg - avg;
auto const delta_n = delta / new_n;
m2 += partial_m2 + delta * delta_n * n * partial_n;
avg += delta_n * partial_n;
n = new_n;
}

return {n, avg, m2};
Expand Down
160 changes: 160 additions & 0 deletions cpp/tests/groupby/merge_m2_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include <cudf/table/table_view.hpp>
#include <cudf/unary.hpp>

#include <limits>

using namespace cudf::test::iterators;

namespace {
Expand Down Expand Up @@ -87,11 +89,119 @@ auto merge_M2(vcol_views const& keys_cols, vcol_views const& values_cols)
auto result = gb_obj.aggregate(requests);
return std::pair(std::move(result.first->release()[0]), std::move(result.second[0].results[0]));
}

template <typename CountType>
void test_extreme_finite_first_partial()
{
auto const key = keys_col<int32_t>{1};
auto counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{1}};
auto means = means_col<double>{std::numeric_limits<double>::max()};
auto m2s = M2s_col<double>{0.0};
auto const vals = structs_col{counts, means, m2s};

auto const [out_key, out_vals] = merge_M2({key}, {vals});

auto const expected_keys = keys_col<int32_t>{1};
auto expected_counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{1}};
auto expected_means = means_col<double>{std::numeric_limits<double>::max()};
auto expected_m2s = M2s_col<double>{0.0};
auto const expected_values = structs_col{expected_counts, expected_means, expected_m2s};
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_keys, *out_key, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_values, *out_vals, verbosity);
}

template <typename CountType>
void test_extreme_finite_merged_partials()
{
auto const keys = keys_col<int32_t>{1, 1};
auto counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{1}, CountType{1}};
auto means = means_col<double>{std::numeric_limits<double>::max(), 0.0};
auto m2s = M2s_col<double>{0.0, 0.0};
auto const vals = structs_col{counts, means, m2s};

auto const [out_keys, out_vals] = merge_M2({keys}, {vals});

auto const expected_keys = keys_col<int32_t>{1};
auto expected_counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{2}};
auto expected_means = means_col<double>{std::numeric_limits<double>::max() / 2};
auto expected_m2s = M2s_col<double>{std::numeric_limits<double>::infinity()};
auto const expected_values = structs_col{expected_counts, expected_means, expected_m2s};
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_keys, *out_keys, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_values, *out_vals, verbosity);
}

// A non-finite mean in a partial must propagate, not be coerced. This pins the
// identity-branch behavior: a single non-empty partial with NaN/Inf statistics
// should be preserved as-is, since coercing to NaN would discard upstream
// signal (e.g. an upstream overflow that already produced +Inf).
template <typename CountType>
void test_nan_mean_first_partial()
{
auto const key = keys_col<int32_t>{1};
auto counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{1}};
auto means = means_col<double>{std::numeric_limits<double>::quiet_NaN()};
auto m2s = M2s_col<double>{std::numeric_limits<double>::quiet_NaN()};
auto const vals = structs_col{counts, means, m2s};

auto const [out_key, out_vals] = merge_M2({key}, {vals});

auto const expected_keys = keys_col<int32_t>{1};
auto expected_counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{1}};
auto expected_means = means_col<double>{std::numeric_limits<double>::quiet_NaN()};
auto expected_m2s = M2s_col<double>{std::numeric_limits<double>::quiet_NaN()};
auto const expected_values = structs_col{expected_counts, expected_means, expected_m2s};
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_keys, *out_key, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_values, *out_vals, verbosity);
}

template <typename CountType>
void test_inf_mean_first_partial()
{
auto const key = keys_col<int32_t>{1};
auto counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{1}};
auto means = means_col<double>{std::numeric_limits<double>::infinity()};
auto m2s = M2s_col<double>{0.0};
auto const vals = structs_col{counts, means, m2s};

auto const [out_key, out_vals] = merge_M2({key}, {vals});

auto const expected_keys = keys_col<int32_t>{1};
auto expected_counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{1}};
auto expected_means = means_col<double>{std::numeric_limits<double>::infinity()};
auto expected_m2s = M2s_col<double>{0.0};
auto const expected_values = structs_col{expected_counts, expected_means, expected_m2s};
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_keys, *out_key, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_values, *out_vals, verbosity);
}

// Once the accumulator holds a NaN, any subsequent merge step must keep
// propagating NaN (NaN ⊕ anything = NaN), regardless of partial order.
template <typename CountType>
void test_nan_mean_merged_with_finite()
{
auto const keys = keys_col<int32_t>{1, 1};
auto counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{10}, CountType{10}};
auto means = means_col<double>{std::numeric_limits<double>::quiet_NaN(), 5.0};
auto m2s = M2s_col<double>{std::numeric_limits<double>::quiet_NaN(), 20.0};
auto const vals = structs_col{counts, means, m2s};

auto const [out_keys, out_vals] = merge_M2({keys}, {vals});

auto const expected_keys = keys_col<int32_t>{1};
auto expected_counts = cudf::test::fixed_width_column_wrapper<CountType>{CountType{20}};
auto expected_means = means_col<double>{std::numeric_limits<double>::quiet_NaN()};
auto expected_m2s = M2s_col<double>{std::numeric_limits<double>::quiet_NaN()};
auto const expected_values = structs_col{expected_counts, expected_means, expected_m2s};
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_keys, *out_keys, verbosity);
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(expected_values, *out_vals, verbosity);
}
} // namespace

template <class T>
struct GroupbyMergeM2TypedTest : public cudf::test::BaseFixture {};

struct GroupbyMergeM2ExtremeTest : public cudf::test::BaseFixture {};

using TestTypes = cudf::test::Concat<cudf::test::Types<int8_t, int16_t, int32_t, int64_t>,
cudf::test::FloatingPointTypes>;
TYPED_TEST_SUITE(GroupbyMergeM2TypedTest, TestTypes);
Expand Down Expand Up @@ -145,6 +255,56 @@ TYPED_TEST(GroupbyMergeM2TypedTest, EmptyInput)
CUDF_TEST_EXPECT_COLUMNS_EQUIVALENT(vals, *out_vals, verbosity);
}

TEST_F(GroupbyMergeM2ExtremeTest, ExtremeFiniteFirstPartialInt64Count)
{
test_extreme_finite_first_partial<int64_t>();
}

TEST_F(GroupbyMergeM2ExtremeTest, ExtremeFiniteFirstPartialDoubleCount)
{
test_extreme_finite_first_partial<double>();
}

TEST_F(GroupbyMergeM2ExtremeTest, ExtremeFiniteMergedPartialsInt64Count)
{
test_extreme_finite_merged_partials<int64_t>();
}

TEST_F(GroupbyMergeM2ExtremeTest, ExtremeFiniteMergedPartialsDoubleCount)
{
test_extreme_finite_merged_partials<double>();
}

TEST_F(GroupbyMergeM2ExtremeTest, NanMeanFirstPartialInt64Count)
{
test_nan_mean_first_partial<int64_t>();
}

TEST_F(GroupbyMergeM2ExtremeTest, NanMeanFirstPartialDoubleCount)
{
test_nan_mean_first_partial<double>();
}

TEST_F(GroupbyMergeM2ExtremeTest, InfMeanFirstPartialInt64Count)
{
test_inf_mean_first_partial<int64_t>();
}

TEST_F(GroupbyMergeM2ExtremeTest, InfMeanFirstPartialDoubleCount)
{
test_inf_mean_first_partial<double>();
}

TEST_F(GroupbyMergeM2ExtremeTest, NanMeanMergedWithFiniteInt64Count)
{
test_nan_mean_merged_with_finite<int64_t>();
}

TEST_F(GroupbyMergeM2ExtremeTest, NanMeanMergedWithFiniteDoubleCount)
{
test_nan_mean_merged_with_finite<double>();
}

TYPED_TEST(GroupbyMergeM2TypedTest, SimpleInput)
{
using T = TypeParam;
Expand Down
Loading