Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 83 additions & 2 deletions java/src/main/java/ai/rapids/cudf/Aggregation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*
*/
Expand Down Expand Up @@ -61,7 +61,9 @@ enum Kind {
MERGE_TDIGEST(33), // This can take a delta argument for accuracy level
HISTOGRAM(34),
MERGE_HISTOGRAM(35),
BITWISE_AGG(36);
BITWISE_AGG(36),
ORDERING_FOR_MIN_BY(37),
VALUE_FOR_MIN_BY(38);

final int nativeId;

Expand Down Expand Up @@ -1047,12 +1049,91 @@ static BitXorAggregation bitXor() {
return new BitXorAggregation();
}

static Aggregation orderingForMinBy(long multiInputId) {
return new OrderingForMinByAgg(multiInputId);
}

static Aggregation valueForMinBy(long multiInputId) {
return new ValueForMinByAgg(multiInputId);
}

/**
* Multi-input aggregation for ordering column in min_by operation.
*/
private static final class OrderingForMinByAgg extends Aggregation {
private final long multiInputId;

OrderingForMinByAgg(long multiInputId) {
super(Kind.ORDERING_FOR_MIN_BY);
this.multiInputId = multiInputId;
}

@Override
long createNativeInstance() {
return Aggregation.createMultiInputAgg(kind.nativeId, multiInputId);
}

@Override
public int hashCode() {
return 31 * kind.hashCode() + Long.hashCode(multiInputId);
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof OrderingForMinByAgg) {
OrderingForMinByAgg o = (OrderingForMinByAgg) other;
return o.multiInputId == this.multiInputId;
}
return false;
}
}

/**
* Multi-input aggregation for value column in min_by operation.
*/
private static final class ValueForMinByAgg extends Aggregation {
private final long multiInputId;

ValueForMinByAgg(long multiInputId) {
super(Kind.VALUE_FOR_MIN_BY);
this.multiInputId = multiInputId;
}

@Override
long createNativeInstance() {
return Aggregation.createMultiInputAgg(kind.nativeId, multiInputId);
}

@Override
public int hashCode() {
return 31 * kind.hashCode() + Long.hashCode(multiInputId);
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
} else if (other instanceof ValueForMinByAgg) {
ValueForMinByAgg o = (ValueForMinByAgg) other;
return o.multiInputId == this.multiInputId;
}
return false;
}
}

/**
* Create one of the aggregations that only needs a kind, no other parameters. This does not
* work for all types and for code safety reasons each kind is added separately.
*/
private static native long createNoParamAgg(int kind);

/**
* Create a multi-input aggregation with a correlation ID.
*/
private static native long createMultiInputAgg(int kind, long multiInputId);

/**
* Create an nth aggregation.
*/
Expand Down
22 changes: 21 additions & 1 deletion java/src/main/java/ai/rapids/cudf/GroupByAggregation.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*
*
* SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*
*/
Expand Down Expand Up @@ -148,6 +148,26 @@ public static GroupByAggregation standardDeviation() {
return new GroupByAggregation(Aggregation.standardDeviation());
}

/**
* Create an ordering column aggregation for min_by operation.
* This must be paired with a valueForMinBy aggregation using the same multiInputId.
* @param multiInputId Correlation ID to pair with the corresponding value aggregation
* @return GroupByAggregation for the ordering column in min_by
*/
public static GroupByAggregation orderingForMinBy(long multiInputId) {
return new GroupByAggregation(Aggregation.orderingForMinBy(multiInputId));
}

/**
* Create a value column aggregation for min_by operation.
* This must be paired with an orderingForMinBy aggregation using the same multiInputId.
* @param multiInputId Correlation ID to pair with the corresponding ordering aggregation
* @return GroupByAggregation for the value column in min_by
*/
public static GroupByAggregation valueForMinBy(long multiInputId) {
return new GroupByAggregation(Aggregation.valueForMinBy(multiInputId));
}

/**
* Standard deviation aggregation.
* @param ddof delta degrees of freedom. The divisor used in calculation of std is
Expand Down
30 changes: 30 additions & 0 deletions java/src/main/java/ai/rapids/cudf/MultiInputIds.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*
*/

package ai.rapids.cudf;

import java.util.concurrent.atomic.AtomicLong;

/**
* Utility class for generating unique correlation IDs for multi-input aggregations.
* These IDs are used to correlate multiple role-tagged aggregation instances that
* belong to the same logical multi-input operation (e.g., min_by, max_by).
*/
public final class MultiInputIds {
private static final AtomicLong COUNTER = new AtomicLong();

/**
* Generate the next unique multi-input correlation ID.
* @return a unique long value to correlate multi-input aggregation roles
*/
public static long next() {
return COUNTER.incrementAndGet();
}

private MultiInputIds() {
}
}
20 changes: 19 additions & 1 deletion java/src/main/native/src/AggregationJni.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2020-2025, NVIDIA CORPORATION.
* SPDX-FileCopyrightText: Copyright (c) 2020-2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include "MultiInputAggregation.hpp"
#include "cudf_jni_apis.hpp"

#include <cudf/aggregation.hpp>
Expand Down Expand Up @@ -301,6 +302,23 @@ JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createMergeSetsAgg(JNIEn
JNI_CATCH(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createMultiInputAgg(JNIEnv* env,
jclass class_object,
jint kind,
jlong multi_input_id)
{
JNI_TRY
{
cudf::jni::auto_set_device(env);
JNI_ARG_CHECK(
env, cudf::jni::is_multi_input_role(kind), "invalid multi-input aggregation kind", 0);
auto agg = std::make_unique<cudf::jni::multi_input_aggregation>(
static_cast<cudf::jni::multi_input_role>(kind), multi_input_id);
return reinterpret_cast<jlong>(agg.release());
}
JNI_CATCH(env, 0);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Aggregation_createHostUDFAgg(JNIEnv* env,
jclass class_object,
jlong udf_native_handle)
Expand Down
69 changes: 69 additions & 0 deletions java/src/main/native/src/MultiInputAggregation.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#pragma once

#include <cudf/aggregation.hpp>

#include <cstdint>
#include <functional>
#include <memory>

namespace cudf::jni {

/**
* @brief JNI-private roles used to pair columns for multi-input aggregations.
*
* These values must match Aggregation.Kind native ids in Java.
*/
enum class multi_input_role : int32_t {
ORDERING_FOR_MIN_BY = 37,
VALUE_FOR_MIN_BY = 38,
};

inline bool is_multi_input_role(int32_t kind)
{
return kind == static_cast<int32_t>(multi_input_role::ORDERING_FOR_MIN_BY) ||
kind == static_cast<int32_t>(multi_input_role::VALUE_FOR_MIN_BY);
}

/**
* @brief JNI holder for role-tagged multi-input groupby aggregations.
*
* This aggregation is never dispatched to libcudf directly. TableJni.cpp consumes
* it, pairs it with siblings sharing the same id, and creates the actual libcudf
* aggregation requests.
*/
class multi_input_aggregation final : public cudf::groupby_aggregation {
public:
multi_input_role const role;
int64_t const multi_input_id;

multi_input_aggregation(multi_input_role role, int64_t multi_input_id)
: cudf::aggregation{cudf::aggregation::Kind::ARGMIN}, role{role}, multi_input_id{multi_input_id}
{
}

[[nodiscard]] std::unique_ptr<cudf::aggregation> clone() const override
{
return std::make_unique<multi_input_aggregation>(*this);
}

[[nodiscard]] bool is_equal(cudf::aggregation const& other) const override
{
auto const* other_multi_input = dynamic_cast<multi_input_aggregation const*>(&other);
return other_multi_input != nullptr && role == other_multi_input->role &&
multi_input_id == other_multi_input->multi_input_id;
}

[[nodiscard]] size_t do_hash() const override
{
auto const role_hash = std::hash<int32_t>{}(static_cast<int32_t>(role));
auto const id_hash = std::hash<int64_t>{}(multi_input_id);
return role_hash ^ (id_hash + 0x9e3779b97f4a7c15ULL + (role_hash << 6) + (role_hash >> 2));
}
};

} // namespace cudf::jni
Loading
Loading