Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions java/src/main/java/ai/rapids/cudf/DistinctHashJoin.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

package ai.rapids.cudf;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* This class represents a reusable hash table built from distinct join keys from the right-side
* table for a join operation. The resulting handle can be reused across a series of left probe
* tables when the right-side join keys are guaranteed to be distinct.
*/
public class DistinctHashJoin implements AutoCloseable {
static {
NativeDepsLoader.loadNativeDeps();
}

private static final Logger log = LoggerFactory.getLogger(DistinctHashJoin.class);

private static class DistinctHashJoinCleaner extends MemoryCleaner.Cleaner {
private Table buildKeys;
private long nativeHandle;

DistinctHashJoinCleaner(Table buildKeys, long nativeHandle) {
this.buildKeys = buildKeys;
this.nativeHandle = nativeHandle;
addRef();
}

@Override
protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) {
long origAddress = nativeHandle;
boolean neededCleanup = nativeHandle != 0;
if (neededCleanup) {
try {
destroy(nativeHandle);
buildKeys.close();
buildKeys = null;
} finally {
nativeHandle = 0;
}
if (logErrorIfNotClean) {
log.error("A DISTINCT HASH TABLE WAS LEAKED (ID: " + id + " " +
Long.toHexString(origAddress));
}
}
return neededCleanup;
}

@Override
public boolean isClean() {
return nativeHandle == 0;
}
}

private final DistinctHashJoinCleaner cleaner;
private final boolean compareNulls;
private boolean isClosed = false;

/**
* Construct a reusable distinct hash table from the join key columns from the right-side table.
* The build key rows must be distinct.
*
* @param buildKeys table view containing the join keys for the right-side join table
* @param compareNulls true if null key values should match otherwise false
*/
public DistinctHashJoin(Table buildKeys, boolean compareNulls) {
this.compareNulls = compareNulls;
Table buildTable = new Table(buildKeys.getColumns());
try {
long handle = create(buildTable.getNativeView(), compareNulls);
this.cleaner = new DistinctHashJoinCleaner(buildTable, handle);
MemoryCleaner.register(this, cleaner);
} catch (Throwable t) {
try {
buildTable.close();
} catch (Throwable t2) {
t.addSuppressed(t2);
}
throw t;
}
}

@Override
public synchronized void close() {
cleaner.delRef();
if (isClosed) {
cleaner.logRefCountDebug("double free " + this);
throw new IllegalStateException("Close called too many times " + this);
}
cleaner.clean(false);
isClosed = true;
}

/** Get the number of join key columns for the table used to generate the hash table. */
public long getNumberOfColumns() {
return cleaner.buildKeys.getNumberOfColumns();
}

/** Returns true if the hash table was built to match on nulls otherwise false. */
public boolean getCompareNulls() {
return compareNulls;
}

long getNativeView() {
return cleaner.nativeHandle;
}

private static native long create(long tableView, boolean nullEqual);
private static native void destroy(long handle);
}
4 changes: 4 additions & 0 deletions java/src/main/java/ai/rapids/cudf/MemoryCleaner.java
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ static void register(HashJoin hashJoin, Cleaner cleaner) {
all.put(cleaner.id, new CleanerWeakReference(hashJoin, cleaner, collected, true));
}

static void register(DistinctHashJoin hashJoin, Cleaner cleaner) {
all.put(cleaner.id, new CleanerWeakReference(hashJoin, cleaner, collected, true));
}

static void register(KeyRemapping keyRemapping, Cleaner cleaner) {
all.put(cleaner.id, new CleanerWeakReference(keyRemapping, cleaner, collected, true));
}
Expand Down
54 changes: 54 additions & 0 deletions java/src/main/java/ai/rapids/cudf/Table.java
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,9 @@ private static native long[] leftJoinGatherMaps(long leftKeys, long rightKeys,
private static native long[] leftDistinctJoinGatherMap(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long[] leftDistinctHashJoinGatherMap(long leftTable,
long rightDistinctHashJoin) throws CudfException;

private static native long leftJoinRowCount(long leftTable, long rightHashJoin) throws CudfException;

private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin) throws CudfException;
Expand All @@ -661,6 +664,9 @@ private static native long[] innerJoinGatherMaps(long leftKeys, long rightKeys,
private static native long[] innerDistinctJoinGatherMaps(long leftKeys, long rightKeys,
boolean compareNullsEqual) throws CudfException;

private static native long[] innerDistinctHashJoinGatherMaps(long table,
long distinctHashJoin) throws CudfException;

private static native long innerJoinRowCount(long table, long hashJoin) throws CudfException;

private static native long[] innerHashJoinGatherMaps(long table, long hashJoin) throws CudfException;
Expand Down Expand Up @@ -2958,6 +2964,31 @@ public GatherMap leftDistinctJoinGatherMap(Table rightKeys, boolean compareNulls
return buildSingleJoinGatherMap(gatherMapData);
}

/**
* Computes a gather map that can be used to manifest the result of a left equi-join between
* two tables where the right table is guaranteed to not contain any duplicated join keys.
* The left table can be used as-is to produce the left table columns resulting from the join,
* i.e.: left table ordering is preserved in the join result, so no gather map is required for
* the left table. The resulting gather map can be applied to the right table to produce the
* right table columns resulting from the join. It is assumed this table instance holds the
* key columns from the left table, and the {@link DistinctHashJoin} argument has been
* constructed from the key columns from the right table.
*
* It is the responsibility of the caller to close the resulting gather map instance.
*
* @param rightHash hash table built from distinct join key columns from the right table
* @return right table gather map
*/
public GatherMap leftDistinctJoinGatherMap(DistinctHashJoin rightHash) {
if (getNumberOfColumns() != rightHash.getNumberOfColumns()) {
throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
long[] gatherMapData =
leftDistinctHashJoinGatherMap(getNativeView(), rightHash.getNativeView());
return buildSingleJoinGatherMap(gatherMapData);
}

/**
* Computes the number of rows resulting from a left equi-join between two tables.
* It is assumed this table instance holds the key columns from the left table, and the
Expand Down Expand Up @@ -3226,6 +3257,29 @@ public GatherMap[] innerDistinctJoinGatherMaps(Table rightKeys, boolean compareN
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes the gather maps that can be used to manifest the result of an inner equi-join between
* two tables where the right table is guaranteed to not contain any duplicated join keys. It is
* assumed this table instance holds the key columns from the left table, and the
* {@link DistinctHashJoin} argument has been constructed from the key columns from the right
* table. Two {@link GatherMap} instances will be returned that can be used to gather the left
* and right tables, respectively, to produce the result of the inner join.
*
* It is the responsibility of the caller to close the resulting gather map instances.
*
* @param rightHash hash table built from distinct join key columns from the right table
* @return left and right table gather maps
*/
public GatherMap[] innerJoinGatherMaps(DistinctHashJoin rightHash) {
if (getNumberOfColumns() != rightHash.getNumberOfColumns()) {
throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() +
"rightKeys: " + rightHash.getNumberOfColumns());
}
long[] gatherMapData =
innerDistinctHashJoinGatherMaps(getNativeView(), rightHash.getNativeView());
return buildJoinGatherMaps(gatherMapData);
}

/**
* Computes the number of rows resulting from an inner equi-join between two tables.
* @param otherHash hash table built from join key columns from the other table
Expand Down
1 change: 1 addition & 0 deletions java/src/main/native/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ add_library(
src/ContiguousTableJni.cpp
src/DataSourceHelperJni.cpp
src/DeletionVectorJni.cpp
src/DistinctHashJoinJni.cpp
src/HashJoinJni.cpp
src/HostMemoryBufferNativeUtilsJni.cpp
src/KeyRemappingJni.cpp
Expand Down
46 changes: 46 additions & 0 deletions java/src/main/native/src/DistinctHashJoinJni.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/

#include "cudf_jni_apis.hpp"

#include <cudf/join/distinct_hash_join.hpp>
#include <cudf/table/table_view.hpp>

extern "C" {

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_DistinctHashJoin_create(
JNIEnv* env, jclass, jlong j_build_keys, jboolean j_nulls_equal)
{
JNI_NULL_CHECK(env, j_build_keys, "build keys table is null", 0);

JNI_TRY
{
cudf::jni::auto_set_device(env);

auto const build_keys = reinterpret_cast<cudf::table_view const*>(j_build_keys);
auto const nulls_equal =
j_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL;

auto handle = std::make_unique<cudf::distinct_hash_join>(*build_keys, nulls_equal);
return cudf::jni::release_as_jlong(handle);
}
JNI_CATCH(env, 0);
}

JNIEXPORT void JNICALL Java_ai_rapids_cudf_DistinctHashJoin_destroy(
JNIEnv* env, jclass, jlong j_handle)
{
JNI_NULL_CHECK(env, j_handle, "distinct hash join handle is null", );

JNI_TRY
{
cudf::jni::auto_set_device(env);
auto handle = reinterpret_cast<cudf::distinct_hash_join*>(j_handle);
delete handle;
}
JNI_CATCH(env, );
}

} // extern "C"
30 changes: 30 additions & 0 deletions java/src/main/native/src/TableJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3012,6 +3012,21 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftDistinctJoinGatherMap
});
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftDistinctHashJoinGatherMap(
JNIEnv* env, jclass, jlong j_left_table, jlong j_right_hash_join)
{
JNI_NULL_CHECK(env, j_left_table, "left table is null", NULL);
JNI_NULL_CHECK(env, j_right_hash_join, "right distinct hash join is null", NULL);
JNI_TRY
{
cudf::jni::auto_set_device(env);
auto left_table = reinterpret_cast<cudf::table_view const*>(j_left_table);
auto hash_join = reinterpret_cast<cudf::distinct_hash_join const*>(j_right_hash_join);
return cudf::jni::gather_map_to_java(env, hash_join->left_join(*left_table));
}
JNI_CATCH(env, NULL);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_leftJoinRowCount(JNIEnv* env,
jclass,
jlong j_left_table,
Expand Down Expand Up @@ -3226,6 +3241,21 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerDistinctJoinGatherMa
});
}

JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerDistinctHashJoinGatherMaps(
JNIEnv* env, jclass, jlong j_left_table, jlong j_right_hash_join)
{
JNI_NULL_CHECK(env, j_left_table, "left table is null", NULL);
JNI_NULL_CHECK(env, j_right_hash_join, "right distinct hash join is null", NULL);
JNI_TRY
{
cudf::jni::auto_set_device(env);
auto left_table = reinterpret_cast<cudf::table_view const*>(j_left_table);
auto hash_join = reinterpret_cast<cudf::distinct_hash_join const*>(j_right_hash_join);
return cudf::jni::gather_maps_to_java(env, hash_join->inner_join(*left_table));
}
JNI_CATCH(env, NULL);
}

JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_innerJoinRowCount(JNIEnv* env,
jclass,
jlong j_left_table,
Expand Down
Loading
Loading