diff --git a/java/src/main/java/ai/rapids/cudf/Aggregation.java b/java/src/main/java/ai/rapids/cudf/Aggregation.java index 1babf17798f..1cf02945c1e 100644 --- a/java/src/main/java/ai/rapids/cudf/Aggregation.java +++ b/java/src/main/java/ai/rapids/cudf/Aggregation.java @@ -1,6 +1,6 @@ /* * - * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 * */ @@ -61,7 +61,9 @@ enum Kind { MERGE_TDIGEST(33), // This can take a delta argument for accuracy level HISTOGRAM(34), MERGE_HISTOGRAM(35), - BITWISE_AGG(36); + BITWISE_AGG(36), + ORDERING_FOR_MIN_BY(37), + VALUE_FOR_MIN_BY(38); final int nativeId; @@ -1047,12 +1049,91 @@ static BitXorAggregation bitXor() { return new BitXorAggregation(); } + static Aggregation orderingForMinBy(long multiInputId) { + return new OrderingForMinByAgg(multiInputId); + } + + static Aggregation valueForMinBy(long multiInputId) { + return new ValueForMinByAgg(multiInputId); + } + + /** + * Multi-input aggregation for ordering column in min_by operation. + */ + private static final class OrderingForMinByAgg extends Aggregation { + private final long multiInputId; + + OrderingForMinByAgg(long multiInputId) { + super(Kind.ORDERING_FOR_MIN_BY); + this.multiInputId = multiInputId; + } + + @Override + long createNativeInstance() { + return Aggregation.createMultiInputAgg(kind.nativeId, multiInputId); + } + + @Override + public int hashCode() { + return 31 * kind.hashCode() + Long.hashCode(multiInputId); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other instanceof OrderingForMinByAgg) { + OrderingForMinByAgg o = (OrderingForMinByAgg) other; + return o.multiInputId == this.multiInputId; + } + return false; + } + } + + /** + * Multi-input aggregation for value column in min_by operation. + */ + private static final class ValueForMinByAgg extends Aggregation { + private final long multiInputId; + + ValueForMinByAgg(long multiInputId) { + super(Kind.VALUE_FOR_MIN_BY); + this.multiInputId = multiInputId; + } + + @Override + long createNativeInstance() { + return Aggregation.createMultiInputAgg(kind.nativeId, multiInputId); + } + + @Override + public int hashCode() { + return 31 * kind.hashCode() + Long.hashCode(multiInputId); + } + + @Override + public boolean equals(Object other) { + if (this == other) { + return true; + } else if (other instanceof ValueForMinByAgg) { + ValueForMinByAgg o = (ValueForMinByAgg) other; + return o.multiInputId == this.multiInputId; + } + return false; + } + } + /** * Create one of the aggregations that only needs a kind, no other parameters. This does not * work for all types and for code safety reasons each kind is added separately. */ private static native long createNoParamAgg(int kind); + /** + * Create a multi-input aggregation with a correlation ID. + */ + private static native long createMultiInputAgg(int kind, long multiInputId); + /** * Create an nth aggregation. */ diff --git a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java index e616b82d7fe..400cff0739b 100644 --- a/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java +++ b/java/src/main/java/ai/rapids/cudf/GroupByAggregation.java @@ -1,6 +1,6 @@ /* * - * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 * */ @@ -148,6 +148,26 @@ public static GroupByAggregation standardDeviation() { return new GroupByAggregation(Aggregation.standardDeviation()); } + /** + * Create an ordering column aggregation for min_by operation. + * This must be paired with a valueForMinBy aggregation using the same multiInputId. + * @param multiInputId Correlation ID to pair with the corresponding value aggregation + * @return GroupByAggregation for the ordering column in min_by + */ + public static GroupByAggregation orderingForMinBy(long multiInputId) { + return new GroupByAggregation(Aggregation.orderingForMinBy(multiInputId)); + } + + /** + * Create a value column aggregation for min_by operation. + * This must be paired with an orderingForMinBy aggregation using the same multiInputId. + * @param multiInputId Correlation ID to pair with the corresponding ordering aggregation + * @return GroupByAggregation for the value column in min_by + */ + public static GroupByAggregation valueForMinBy(long multiInputId) { + return new GroupByAggregation(Aggregation.valueForMinBy(multiInputId)); + } + /** * Standard deviation aggregation. * @param ddof delta degrees of freedom. The divisor used in calculation of std is diff --git a/java/src/main/java/ai/rapids/cudf/MultiInputIds.java b/java/src/main/java/ai/rapids/cudf/MultiInputIds.java new file mode 100644 index 00000000000..dfefd27cd5a --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/MultiInputIds.java @@ -0,0 +1,30 @@ +/* + * + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package ai.rapids.cudf; + +import java.util.concurrent.atomic.AtomicLong; + +/** + * Utility class for generating unique correlation IDs for multi-input aggregations. + * These IDs are used to correlate multiple role-tagged aggregation instances that + * belong to the same logical multi-input operation (e.g., min_by, max_by). + */ +public final class MultiInputIds { + private static final AtomicLong COUNTER = new AtomicLong(); + + /** + * Generate the next unique multi-input correlation ID. + * @return a unique long value to correlate multi-input aggregation roles + */ + public static long next() { + return COUNTER.incrementAndGet(); + } + + private MultiInputIds() { + } +} diff --git a/java/src/main/native/src/AggregationJni.cpp b/java/src/main/native/src/AggregationJni.cpp index 24bed3a36eb..3d8931e1d42 100644 --- a/java/src/main/native/src/AggregationJni.cpp +++ b/java/src/main/native/src/AggregationJni.cpp @@ -1,8 +1,9 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ +#include "MultiInputAggregation.hpp" #include "cudf_jni_apis.hpp" #include @@ -301,6 +302,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createMergeSetsAgg(JNIEn JNI_CATCH(env, 0); } +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createMultiInputAgg(JNIEnv* env, + jclass class_object, + jint kind, + jlong multi_input_id) +{ + JNI_TRY + { + cudf::jni::auto_set_device(env); + JNI_ARG_CHECK( + env, cudf::jni::is_multi_input_role(kind), "invalid multi-input aggregation kind", 0); + auto agg = std::make_unique( + static_cast(kind), multi_input_id); + return reinterpret_cast(agg.release()); + } + JNI_CATCH(env, 0); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createHostUDFAgg(JNIEnv* env, jclass class_object, jlong udf_native_handle) diff --git a/java/src/main/native/src/MultiInputAggregation.hpp b/java/src/main/native/src/MultiInputAggregation.hpp new file mode 100644 index 00000000000..da5758f24ad --- /dev/null +++ b/java/src/main/native/src/MultiInputAggregation.hpp @@ -0,0 +1,69 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +#include +#include +#include + +namespace cudf::jni { + +/** + * @brief JNI-private roles used to pair columns for multi-input aggregations. + * + * These values must match Aggregation.Kind native ids in Java. + */ +enum class multi_input_role : int32_t { + ORDERING_FOR_MIN_BY = 37, + VALUE_FOR_MIN_BY = 38, +}; + +inline bool is_multi_input_role(int32_t kind) +{ + return kind == static_cast(multi_input_role::ORDERING_FOR_MIN_BY) || + kind == static_cast(multi_input_role::VALUE_FOR_MIN_BY); +} + +/** + * @brief JNI holder for role-tagged multi-input groupby aggregations. + * + * This aggregation is never dispatched to libcudf directly. TableJni.cpp consumes + * it, pairs it with siblings sharing the same id, and creates the actual libcudf + * aggregation requests. + */ +class multi_input_aggregation final : public cudf::groupby_aggregation { + public: + multi_input_role const role; + int64_t const multi_input_id; + + multi_input_aggregation(multi_input_role role, int64_t multi_input_id) + : cudf::aggregation{cudf::aggregation::Kind::ARGMIN}, role{role}, multi_input_id{multi_input_id} + { + } + + [[nodiscard]] std::unique_ptr clone() const override + { + return std::make_unique(*this); + } + + [[nodiscard]] bool is_equal(cudf::aggregation const& other) const override + { + auto const* other_multi_input = dynamic_cast(&other); + return other_multi_input != nullptr && role == other_multi_input->role && + multi_input_id == other_multi_input->multi_input_id; + } + + [[nodiscard]] size_t do_hash() const override + { + auto const role_hash = std::hash{}(static_cast(role)); + auto const id_hash = std::hash{}(multi_input_id); + return role_hash ^ (id_hash + 0x9e3779b97f4a7c15ULL + (role_hash << 6) + (role_hash >> 2)); + } +}; + +} // namespace cudf::jni diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 65484bbb508..5bf2558fc20 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -3,6 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ +#include "MultiInputAggregation.hpp" #include "csv_chunked_writer.hpp" #include "cudf_jni_apis.hpp" #include "dtype_utils.hpp" @@ -55,6 +56,8 @@ #include #include +#include +#include namespace cudf { namespace jni { @@ -1091,6 +1094,8 @@ void set_nullable(ArrowSchema* schema) } // namespace cudf using cudf::jni::convert_table_for_return; +using cudf::jni::multi_input_aggregation; +using cudf::jni::multi_input_role; using cudf::jni::ptr_as_jlong; using cudf::jni::release_as_jlong; @@ -3896,38 +3901,106 @@ Java_ai_rapids_cudf_Table_groupByAggregate(JNIEnv* env, // as we go. std::vector requests; + struct min_by_group { + cudf::column_view ordering_col; + cudf::column_view value_col; + int ordering_slot = -1; + int value_slot = -1; + int request_index = -1; + }; + + std::unordered_map min_by_groups; + std::vector> regular_result_slots; + int previous_index = -1; for (int i = 0; i < n_values.size(); i++) { - cudf::groupby::aggregation_request req; - int col_index = n_values[i]; - cudf::groupby_aggregation* agg = dynamic_cast(n_agg_instances[i]); JNI_ARG_CHECK( env, agg != nullptr, "aggregation is not an instance of groupby_aggregation", nullptr); + + if (auto* multi_agg = dynamic_cast(agg)) { + auto& group = min_by_groups[multi_agg->multi_input_id]; + if (multi_agg->role == multi_input_role::ORDERING_FOR_MIN_BY) { + JNI_ARG_CHECK(env, + group.ordering_slot == -1, + "multiple ordering columns for min_by multi-input id", + nullptr); + group.ordering_col = n_input_table->column(n_values[i]); + group.ordering_slot = i; + } else if (multi_agg->role == multi_input_role::VALUE_FOR_MIN_BY) { + JNI_ARG_CHECK(env, + group.value_slot == -1, + "multiple value columns for min_by multi-input id", + nullptr); + group.value_col = n_input_table->column(n_values[i]); + group.value_slot = i; + } else { + JNI_ARG_CHECK(env, false, "unsupported multi-input aggregation role", nullptr); + } + continue; + } + + int col_index = n_values[i]; std::unique_ptr cloned( dynamic_cast(agg->clone().release())); - if (col_index == previous_index) { - requests.back().aggregations.push_back(std::move(cloned)); - } else { + if (requests.empty() || col_index != previous_index) { + cudf::groupby::aggregation_request req; req.values = n_input_table->column(col_index); - req.aggregations.push_back(std::move(cloned)); requests.push_back(std::move(req)); + previous_index = col_index; } - previous_index = col_index; + + int const request_index = static_cast(requests.size() - 1); + int const aggregation_index = static_cast(requests.back().aggregations.size()); + requests.back().aggregations.push_back(std::move(cloned)); + regular_result_slots.emplace_back(i, request_index, aggregation_index); + } + + for (auto& group_entry : min_by_groups) { + auto& group = group_entry.second; + JNI_ARG_CHECK(env, + group.ordering_slot != -1, + "missing ordering column for min_by multi-input id", + nullptr); + JNI_ARG_CHECK( + env, group.value_slot != -1, "missing value column for min_by multi-input id", nullptr); + + cudf::groupby::aggregation_request req; + req.values = group.ordering_col; + req.aggregations.push_back(cudf::make_argmin_aggregation()); + group.request_index = static_cast(requests.size()); + requests.push_back(std::move(req)); } std::pair, std::vector> result = grouper.aggregate(requests); - std::vector> result_columns; - int agg_result_size = result.second.size(); - for (int agg_result_index = 0; agg_result_index < agg_result_size; agg_result_index++) { - int col_agg_size = result.second[agg_result_index].results.size(); - for (int col_agg_index = 0; col_agg_index < col_agg_size; col_agg_index++) { - result_columns.push_back(std::move(result.second[agg_result_index].results[col_agg_index])); - } + std::vector> result_columns(n_values.size()); + + for (auto const& [slot, request_index, aggregation_index] : regular_result_slots) { + result_columns[slot] = std::move(result.second[request_index].results[aggregation_index]); + } + + for (auto const& group_entry : min_by_groups) { + auto const& group = group_entry.second; + auto const indices = result.second[group.request_index].results.front()->view(); + cudf::column_view indices_no_nulls(cudf::data_type{cudf::type_to_id()}, + indices.size(), + static_cast(indices.data()), + nullptr, + 0); + + auto gathered = cudf::gather(cudf::table_view{{group.value_col, group.ordering_col}}, + indices_no_nulls, + indices.nullable() ? cudf::out_of_bounds_policy::NULLIFY + : cudf::out_of_bounds_policy::DONT_CHECK, + cudf::negative_index_policy::NOT_ALLOWED); + auto cols = gathered->release(); + result_columns[group.value_slot] = std::move(cols[0]); + result_columns[group.ordering_slot] = std::move(cols[1]); } + return convert_table_for_return(env, result.first, std::move(result_columns)); } JNI_CATCH(env, NULL); diff --git a/java/src/test/java/ai/rapids/cudf/MultiInputAggregationTest.java b/java/src/test/java/ai/rapids/cudf/MultiInputAggregationTest.java new file mode 100644 index 00000000000..fc85a2bcd34 --- /dev/null +++ b/java/src/test/java/ai/rapids/cudf/MultiInputAggregationTest.java @@ -0,0 +1,176 @@ +/* + * + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + * + */ + +package ai.rapids.cudf; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Test class for multi-input aggregations (min_by, max_by, etc.) + */ +public class MultiInputAggregationTest { + + /** + * Test min_by aggregation using role-tagged columns instead of a struct wrapper. + */ + @Test + public void testMinBy() { + try (ColumnVector keys = ColumnVector.fromInts(1, 1, 2, 2, 3); + ColumnVector ordering = ColumnVector.fromInts(5, 3, 8, 6, 9); + ColumnVector values = ColumnVector.fromStrings("a", "b", "c", "d", "e"); + Table input = new Table(keys, ordering, values)) { + long minById = MultiInputIds.next(); + + try (Table result = input.groupBy(0).aggregate( + GroupByAggregation.orderingForMinBy(minById).onColumn(1), + GroupByAggregation.valueForMinBy(minById).onColumn(2)); + Table sorted = result.orderBy(OrderByArg.asc(0)); + ColumnVector expectedKeys = ColumnVector.fromInts(1, 2, 3); + ColumnVector expectedOrdering = ColumnVector.fromInts(3, 6, 9); + ColumnVector expectedValues = ColumnVector.fromStrings("b", "d", "e")) { + assertEquals(3, sorted.getNumberOfColumns()); + assertEquals(3, sorted.getRowCount()); + AssertUtils.assertColumnsAreEqual(expectedKeys, sorted.getColumn(0)); + AssertUtils.assertColumnsAreEqual(expectedOrdering, sorted.getColumn(1)); + AssertUtils.assertColumnsAreEqual(expectedValues, sorted.getColumn(2)); + } + } + } + + /** + * Test min_by with unsorted keys and value role before ordering role. + */ + @Test + public void testMinByWithUnsortedKeysAndValueFirst() { + try (ColumnVector keys = ColumnVector.fromInts(2, 1, 2, 1, 3, 2); + ColumnVector ordering = ColumnVector.fromInts(8, 5, 6, 3, 9, 4); + ColumnVector values = ColumnVector.fromStrings("c", "a", "d", "b", "e", "f"); + Table input = new Table(keys, ordering, values)) { + long minById = MultiInputIds.next(); + + try (Table result = input.groupBy(0).aggregate( + GroupByAggregation.valueForMinBy(minById).onColumn(2), + GroupByAggregation.orderingForMinBy(minById).onColumn(1)); + Table sorted = result.orderBy(OrderByArg.asc(0)); + ColumnVector expectedKeys = ColumnVector.fromInts(1, 2, 3); + ColumnVector expectedValues = ColumnVector.fromStrings("b", "f", "e"); + ColumnVector expectedOrdering = ColumnVector.fromInts(3, 4, 9)) { + assertEquals(3, sorted.getNumberOfColumns()); + assertEquals(3, sorted.getRowCount()); + AssertUtils.assertColumnsAreEqual(expectedKeys, sorted.getColumn(0)); + AssertUtils.assertColumnsAreEqual(expectedValues, sorted.getColumn(1)); + AssertUtils.assertColumnsAreEqual(expectedOrdering, sorted.getColumn(2)); + } + } + } + + /** + * Test that multiple min_by operations with different correlation IDs work in one groupby. + */ + @Test + public void testMultipleMinByOperations() { + try (ColumnVector keys = ColumnVector.fromInts(1, 1, 2, 2); + ColumnVector ordering1 = ColumnVector.fromInts(5, 3, 8, 6); + ColumnVector values1 = ColumnVector.fromStrings("a", "b", "c", "d"); + ColumnVector ordering2 = ColumnVector.fromInts(10, 20, 15, 25); + ColumnVector values2 = ColumnVector.fromStrings("w", "x", "y", "z"); + Table input = new Table(keys, ordering1, values1, ordering2, values2)) { + long minById1 = MultiInputIds.next(); + long minById2 = MultiInputIds.next(); + + try (Table result = input.groupBy(0).aggregate( + GroupByAggregation.orderingForMinBy(minById1).onColumn(1), + GroupByAggregation.valueForMinBy(minById1).onColumn(2), + GroupByAggregation.orderingForMinBy(minById2).onColumn(3), + GroupByAggregation.valueForMinBy(minById2).onColumn(4)); + Table sorted = result.orderBy(OrderByArg.asc(0)); + ColumnVector expectedKeys = ColumnVector.fromInts(1, 2); + ColumnVector expectedOrdering1 = ColumnVector.fromInts(3, 6); + ColumnVector expectedValues1 = ColumnVector.fromStrings("b", "d"); + ColumnVector expectedOrdering2 = ColumnVector.fromInts(10, 15); + ColumnVector expectedValues2 = ColumnVector.fromStrings("w", "y")) { + assertEquals(5, sorted.getNumberOfColumns()); + assertEquals(2, sorted.getRowCount()); + AssertUtils.assertColumnsAreEqual(expectedKeys, sorted.getColumn(0)); + AssertUtils.assertColumnsAreEqual(expectedOrdering1, sorted.getColumn(1)); + AssertUtils.assertColumnsAreEqual(expectedValues1, sorted.getColumn(2)); + AssertUtils.assertColumnsAreEqual(expectedOrdering2, sorted.getColumn(3)); + AssertUtils.assertColumnsAreEqual(expectedValues2, sorted.getColumn(4)); + } + } + } + + /** + * Test result placement when multi-input aggregations are mixed with regular aggregations. + */ + @Test + public void testMinByWithRegularAggregations() { + try (ColumnVector keys = ColumnVector.fromInts(1, 1, 2, 2); + ColumnVector ordering = ColumnVector.fromInts(5, 3, 8, 6); + ColumnVector values = ColumnVector.fromStrings("a", "b", "c", "d"); + ColumnVector amounts = ColumnVector.fromInts(10, 20, 30, 40); + Table input = new Table(keys, ordering, values, amounts)) { + long minById = MultiInputIds.next(); + + try (Table result = input.groupBy(0).aggregate( + GroupByAggregation.sum().onColumn(3), + GroupByAggregation.orderingForMinBy(minById).onColumn(1), + GroupByAggregation.valueForMinBy(minById).onColumn(2), + GroupByAggregation.max().onColumn(3)); + Table sorted = result.orderBy(OrderByArg.asc(0)); + ColumnVector expectedKeys = ColumnVector.fromInts(1, 2); + ColumnVector expectedSums = ColumnVector.fromLongs(30, 70); + ColumnVector expectedOrdering = ColumnVector.fromInts(3, 6); + ColumnVector expectedValues = ColumnVector.fromStrings("b", "d"); + ColumnVector expectedMax = ColumnVector.fromInts(20, 40)) { + assertEquals(5, sorted.getNumberOfColumns()); + assertEquals(2, sorted.getRowCount()); + AssertUtils.assertColumnsAreEqual(expectedKeys, sorted.getColumn(0)); + AssertUtils.assertColumnsAreEqual(expectedSums, sorted.getColumn(1)); + AssertUtils.assertColumnsAreEqual(expectedOrdering, sorted.getColumn(2)); + AssertUtils.assertColumnsAreEqual(expectedValues, sorted.getColumn(3)); + AssertUtils.assertColumnsAreEqual(expectedMax, sorted.getColumn(4)); + } + } + } + + /** + * Test that correlation IDs are unique across multiple calls. + */ + @Test + public void testMultiInputIdsUniqueness() { + long id1 = MultiInputIds.next(); + long id2 = MultiInputIds.next(); + long id3 = MultiInputIds.next(); + + assertNotEquals(id1, id2); + assertNotEquals(id2, id3); + assertNotEquals(id1, id3); + assertTrue(id2 > id1); + assertTrue(id3 > id2); + } + + /** + * Test that role-tagged aggregations compare by both role and correlation ID. + */ + @Test + public void testMinByAggregationEquality() { + long id1 = MultiInputIds.next(); + long id2 = MultiInputIds.next(); + + assertEquals(GroupByAggregation.valueForMinBy(id1), GroupByAggregation.valueForMinBy(id1)); + assertEquals(GroupByAggregation.orderingForMinBy(id1), + GroupByAggregation.orderingForMinBy(id1)); + assertNotEquals(GroupByAggregation.valueForMinBy(id1), GroupByAggregation.valueForMinBy(id2)); + assertNotEquals(GroupByAggregation.orderingForMinBy(id1), + GroupByAggregation.orderingForMinBy(id2)); + assertNotEquals(GroupByAggregation.valueForMinBy(id1), + GroupByAggregation.orderingForMinBy(id1)); + } +}