diff --git a/java/src/main/java/ai/rapids/cudf/DistinctHashJoin.java b/java/src/main/java/ai/rapids/cudf/DistinctHashJoin.java new file mode 100644 index 00000000000..05812404237 --- /dev/null +++ b/java/src/main/java/ai/rapids/cudf/DistinctHashJoin.java @@ -0,0 +1,114 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +package ai.rapids.cudf; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * This class represents a reusable hash table built from distinct join keys from the right-side + * table for a join operation. The resulting handle can be reused across a series of left probe + * tables when the right-side join keys are guaranteed to be distinct. + */ +public class DistinctHashJoin implements AutoCloseable { + static { + NativeDepsLoader.loadNativeDeps(); + } + + private static final Logger log = LoggerFactory.getLogger(DistinctHashJoin.class); + + private static class DistinctHashJoinCleaner extends MemoryCleaner.Cleaner { + private Table buildKeys; + private long nativeHandle; + + DistinctHashJoinCleaner(Table buildKeys, long nativeHandle) { + this.buildKeys = buildKeys; + this.nativeHandle = nativeHandle; + addRef(); + } + + @Override + protected synchronized boolean cleanImpl(boolean logErrorIfNotClean) { + long origAddress = nativeHandle; + boolean neededCleanup = nativeHandle != 0; + if (neededCleanup) { + try { + destroy(nativeHandle); + buildKeys.close(); + buildKeys = null; + } finally { + nativeHandle = 0; + } + if (logErrorIfNotClean) { + log.error("A DISTINCT HASH TABLE WAS LEAKED (ID: " + id + " " + + Long.toHexString(origAddress)); + } + } + return neededCleanup; + } + + @Override + public boolean isClean() { + return nativeHandle == 0; + } + } + + private final DistinctHashJoinCleaner cleaner; + private final boolean compareNulls; + private boolean isClosed = false; + + /** + * Construct a reusable distinct hash table from the join key columns from the right-side table. + * The build key rows must be distinct. + * + * @param buildKeys table view containing the join keys for the right-side join table + * @param compareNulls true if null key values should match otherwise false + */ + public DistinctHashJoin(Table buildKeys, boolean compareNulls) { + this.compareNulls = compareNulls; + Table buildTable = new Table(buildKeys.getColumns()); + try { + long handle = create(buildTable.getNativeView(), compareNulls); + this.cleaner = new DistinctHashJoinCleaner(buildTable, handle); + MemoryCleaner.register(this, cleaner); + } catch (Throwable t) { + try { + buildTable.close(); + } catch (Throwable t2) { + t.addSuppressed(t2); + } + throw t; + } + } + + @Override + public synchronized void close() { + cleaner.delRef(); + if (isClosed) { + cleaner.logRefCountDebug("double free " + this); + throw new IllegalStateException("Close called too many times " + this); + } + cleaner.clean(false); + isClosed = true; + } + + /** Get the number of join key columns for the table used to generate the hash table. */ + public long getNumberOfColumns() { + return cleaner.buildKeys.getNumberOfColumns(); + } + + /** Returns true if the hash table was built to match on nulls otherwise false. */ + public boolean getCompareNulls() { + return compareNulls; + } + + long getNativeView() { + return cleaner.nativeHandle; + } + + private static native long create(long tableView, boolean nullEqual); + private static native void destroy(long handle); +} diff --git a/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java b/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java index a38336f893c..81046a12516 100644 --- a/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java +++ b/java/src/main/java/ai/rapids/cudf/MemoryCleaner.java @@ -374,6 +374,10 @@ static void register(HashJoin hashJoin, Cleaner cleaner) { all.put(cleaner.id, new CleanerWeakReference(hashJoin, cleaner, collected, true)); } + static void register(DistinctHashJoin hashJoin, Cleaner cleaner) { + all.put(cleaner.id, new CleanerWeakReference(hashJoin, cleaner, collected, true)); + } + static void register(KeyRemapping keyRemapping, Cleaner cleaner) { all.put(cleaner.id, new CleanerWeakReference(keyRemapping, cleaner, collected, true)); } diff --git a/java/src/main/java/ai/rapids/cudf/Table.java b/java/src/main/java/ai/rapids/cudf/Table.java index a6c4264202d..2d1212928ab 100644 --- a/java/src/main/java/ai/rapids/cudf/Table.java +++ b/java/src/main/java/ai/rapids/cudf/Table.java @@ -648,6 +648,9 @@ private static native long[] leftJoinGatherMaps(long leftKeys, long rightKeys, private static native long[] leftDistinctJoinGatherMap(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long[] leftDistinctHashJoinGatherMap(long leftTable, + long rightDistinctHashJoin) throws CudfException; + private static native long leftJoinRowCount(long leftTable, long rightHashJoin) throws CudfException; private static native long[] leftHashJoinGatherMaps(long leftTable, long rightHashJoin) throws CudfException; @@ -661,6 +664,9 @@ private static native long[] innerJoinGatherMaps(long leftKeys, long rightKeys, private static native long[] innerDistinctJoinGatherMaps(long leftKeys, long rightKeys, boolean compareNullsEqual) throws CudfException; + private static native long[] innerDistinctHashJoinGatherMaps(long table, + long distinctHashJoin) throws CudfException; + private static native long innerJoinRowCount(long table, long hashJoin) throws CudfException; private static native long[] innerHashJoinGatherMaps(long table, long hashJoin) throws CudfException; @@ -2958,6 +2964,31 @@ public GatherMap leftDistinctJoinGatherMap(Table rightKeys, boolean compareNulls return buildSingleJoinGatherMap(gatherMapData); } + /** + * Computes a gather map that can be used to manifest the result of a left equi-join between + * two tables where the right table is guaranteed to not contain any duplicated join keys. + * The left table can be used as-is to produce the left table columns resulting from the join, + * i.e.: left table ordering is preserved in the join result, so no gather map is required for + * the left table. The resulting gather map can be applied to the right table to produce the + * right table columns resulting from the join. It is assumed this table instance holds the + * key columns from the left table, and the {@link DistinctHashJoin} argument has been + * constructed from the key columns from the right table. + * + * It is the responsibility of the caller to close the resulting gather map instance. + * + * @param rightHash hash table built from distinct join key columns from the right table + * @return right table gather map + */ + public GatherMap leftDistinctJoinGatherMap(DistinctHashJoin rightHash) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + long[] gatherMapData = + leftDistinctHashJoinGatherMap(getNativeView(), rightHash.getNativeView()); + return buildSingleJoinGatherMap(gatherMapData); + } + /** * Computes the number of rows resulting from a left equi-join between two tables. * It is assumed this table instance holds the key columns from the left table, and the @@ -3226,6 +3257,29 @@ public GatherMap[] innerDistinctJoinGatherMaps(Table rightKeys, boolean compareN return buildJoinGatherMaps(gatherMapData); } + /** + * Computes the gather maps that can be used to manifest the result of an inner equi-join between + * two tables where the right table is guaranteed to not contain any duplicated join keys. It is + * assumed this table instance holds the key columns from the left table, and the + * {@link DistinctHashJoin} argument has been constructed from the key columns from the right + * table. Two {@link GatherMap} instances will be returned that can be used to gather the left + * and right tables, respectively, to produce the result of the inner join. + * + * It is the responsibility of the caller to close the resulting gather map instances. + * + * @param rightHash hash table built from distinct join key columns from the right table + * @return left and right table gather maps + */ + public GatherMap[] innerJoinGatherMaps(DistinctHashJoin rightHash) { + if (getNumberOfColumns() != rightHash.getNumberOfColumns()) { + throw new IllegalArgumentException("Column count mismatch, this: " + getNumberOfColumns() + + "rightKeys: " + rightHash.getNumberOfColumns()); + } + long[] gatherMapData = + innerDistinctHashJoinGatherMaps(getNativeView(), rightHash.getNativeView()); + return buildJoinGatherMaps(gatherMapData); + } + /** * Computes the number of rows resulting from an inner equi-join between two tables. * @param otherHash hash table built from join key columns from the other table diff --git a/java/src/main/native/CMakeLists.txt b/java/src/main/native/CMakeLists.txt index 1e7df3802b9..1d5a3130038 100644 --- a/java/src/main/native/CMakeLists.txt +++ b/java/src/main/native/CMakeLists.txt @@ -164,6 +164,7 @@ add_library( src/ContiguousTableJni.cpp src/DataSourceHelperJni.cpp src/DeletionVectorJni.cpp + src/DistinctHashJoinJni.cpp src/HashJoinJni.cpp src/HostMemoryBufferNativeUtilsJni.cpp src/KeyRemappingJni.cpp diff --git a/java/src/main/native/src/DistinctHashJoinJni.cpp b/java/src/main/native/src/DistinctHashJoinJni.cpp new file mode 100644 index 00000000000..8fff57b1c6f --- /dev/null +++ b/java/src/main/native/src/DistinctHashJoinJni.cpp @@ -0,0 +1,46 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "cudf_jni_apis.hpp" + +#include +#include + +extern "C" { + +JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_DistinctHashJoin_create( + JNIEnv* env, jclass, jlong j_build_keys, jboolean j_nulls_equal) +{ + JNI_NULL_CHECK(env, j_build_keys, "build keys table is null", 0); + + JNI_TRY + { + cudf::jni::auto_set_device(env); + + auto const build_keys = reinterpret_cast(j_build_keys); + auto const nulls_equal = + j_nulls_equal ? cudf::null_equality::EQUAL : cudf::null_equality::UNEQUAL; + + auto handle = std::make_unique(*build_keys, nulls_equal); + return cudf::jni::release_as_jlong(handle); + } + JNI_CATCH(env, 0); +} + +JNIEXPORT void JNICALL Java_ai_rapids_cudf_DistinctHashJoin_destroy( + JNIEnv* env, jclass, jlong j_handle) +{ + JNI_NULL_CHECK(env, j_handle, "distinct hash join handle is null", ); + + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto handle = reinterpret_cast(j_handle); + delete handle; + } + JNI_CATCH(env, ); +} + +} // extern "C" diff --git a/java/src/main/native/src/TableJni.cpp b/java/src/main/native/src/TableJni.cpp index 65484bbb508..102bdbc3573 100644 --- a/java/src/main/native/src/TableJni.cpp +++ b/java/src/main/native/src/TableJni.cpp @@ -3012,6 +3012,21 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftDistinctJoinGatherMap }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_leftDistinctHashJoinGatherMap( + JNIEnv* env, jclass, jlong j_left_table, jlong j_right_hash_join) +{ + JNI_NULL_CHECK(env, j_left_table, "left table is null", NULL); + JNI_NULL_CHECK(env, j_right_hash_join, "right distinct hash join is null", NULL); + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto hash_join = reinterpret_cast(j_right_hash_join); + return cudf::jni::gather_map_to_java(env, hash_join->left_join(*left_table)); + } + JNI_CATCH(env, NULL); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_leftJoinRowCount(JNIEnv* env, jclass, jlong j_left_table, @@ -3226,6 +3241,21 @@ JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerDistinctJoinGatherMa }); } +JNIEXPORT jlongArray JNICALL Java_ai_rapids_cudf_Table_innerDistinctHashJoinGatherMaps( + JNIEnv* env, jclass, jlong j_left_table, jlong j_right_hash_join) +{ + JNI_NULL_CHECK(env, j_left_table, "left table is null", NULL); + JNI_NULL_CHECK(env, j_right_hash_join, "right distinct hash join is null", NULL); + JNI_TRY + { + cudf::jni::auto_set_device(env); + auto left_table = reinterpret_cast(j_left_table); + auto hash_join = reinterpret_cast(j_right_hash_join); + return cudf::jni::gather_maps_to_java(env, hash_join->inner_join(*left_table)); + } + JNI_CATCH(env, NULL); +} + JNIEXPORT jlong JNICALL Java_ai_rapids_cudf_Table_innerJoinRowCount(JNIEnv* env, jclass, jlong j_left_table, diff --git a/java/src/test/java/ai/rapids/cudf/DistinctHashJoinTest.java b/java/src/test/java/ai/rapids/cudf/DistinctHashJoinTest.java new file mode 100644 index 00000000000..4f5c9b68aff --- /dev/null +++ b/java/src/test/java/ai/rapids/cudf/DistinctHashJoinTest.java @@ -0,0 +1,116 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +package ai.rapids.cudf; + +import org.junit.jupiter.api.Test; + +import java.util.AbstractMap; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class DistinctHashJoinTest { + private static final int GATHER_MAP_SENTINEL = Integer.MIN_VALUE; + + @SafeVarargs + private static Set> pairSet(Map.Entry... pairs) { + return new HashSet<>(Arrays.asList(pairs)); + } + + private static Map.Entry pair(Integer left, Integer right) { + return new AbstractMap.SimpleEntry<>(left, right); + } + + private static Set> gatherMapPairToSet( + HostColumnVector leftMap, HostColumnVector rightMap) { + Set> result = new HashSet<>(); + for (int i = 0; i < leftMap.getRowCount(); i++) { + Integer leftVal = leftMap.getInt(i); + Integer rightVal = rightMap.getInt(i); + result.add(pair(leftVal == GATHER_MAP_SENTINEL ? null : leftVal, + rightVal == GATHER_MAP_SENTINEL ? null : rightVal)); + } + return result; + } + + private static List gatherMapToList(HostColumnVector gatherMap) { + Integer[] result = new Integer[(int) gatherMap.getRowCount()]; + for (int i = 0; i < gatherMap.getRowCount(); i++) { + int index = gatherMap.getInt(i); + result[i] = index == GATHER_MAP_SENTINEL ? null : index; + } + return Arrays.asList(result); + } + + @Test + void testInnerJoinGatherMapsCanBeReusedAcrossProbeTables() { + try (ColumnVector buildKeys = ColumnVector.fromInts(0, 1, 2, 3); + Table buildTable = new Table(buildKeys); + DistinctHashJoin hashJoin = new DistinctHashJoin(buildTable, true); + ColumnVector probe1Keys = ColumnVector.fromInts(1, 2, 4); + Table probe1Table = new Table(probe1Keys); + ColumnVector probe2Keys = ColumnVector.fromInts(3, 0, 5); + Table probe2Table = new Table(probe2Keys)) { + assertGatherPairs(probe1Table.innerJoinGatherMaps(hashJoin), + pairSet(pair(0, 1), pair(1, 2))); + assertGatherPairs(probe2Table.innerJoinGatherMaps(hashJoin), + pairSet(pair(0, 3), pair(1, 0))); + } + } + + @Test + void testLeftJoinGatherMapCanBeReusedAcrossProbeTables() { + try (ColumnVector buildKeys = ColumnVector.fromInts(0, 1, 2, 3); + Table buildTable = new Table(buildKeys); + DistinctHashJoin hashJoin = new DistinctHashJoin(buildTable, true); + ColumnVector probe1Keys = ColumnVector.fromInts(1, 4, 0); + Table probe1Table = new Table(probe1Keys); + ColumnVector probe2Keys = ColumnVector.fromBoxedInts(null, 2, 8); + Table probe2Table = new Table(probe2Keys)) { + assertGatherIndices(probe1Table.leftDistinctJoinGatherMap(hashJoin), + Arrays.asList(1, null, 0)); + assertGatherIndices(probe2Table.leftDistinctJoinGatherMap(hashJoin), + Arrays.asList(null, 2, null)); + } + } + + @Test + void testInnerJoinRespectsNullEquality() { + try (ColumnVector buildKeys = ColumnVector.fromBoxedInts(null, 1, 2); + Table buildTable = new Table(buildKeys); + DistinctHashJoin hashJoin = new DistinctHashJoin(buildTable, false); + ColumnVector probeKeys = ColumnVector.fromBoxedInts(null, 2); + Table probeTable = new Table(probeKeys)) { + assertGatherPairs(probeTable.innerJoinGatherMaps(hashJoin), pairSet(pair(1, 2))); + } + } + + private static void assertGatherPairs(GatherMap[] gatherMaps, + Set> expected) { + try (GatherMap leftMap = gatherMaps[0]; + GatherMap rightMap = gatherMaps[1]; + HostColumnVector leftHost = + leftMap.toColumnView(0, (int) leftMap.getRowCount()).copyToHost(); + HostColumnVector rightHost = + rightMap.toColumnView(0, (int) rightMap.getRowCount()).copyToHost()) { + assertEquals(expected.size(), leftMap.getRowCount()); + assertEquals(expected.size(), rightMap.getRowCount()); + assertEquals(expected, gatherMapPairToSet(leftHost, rightHost)); + } + } + + private static void assertGatherIndices(GatherMap gatherMap, List expected) { + try (GatherMap map = gatherMap; + HostColumnVector host = map.toColumnView(0, (int) map.getRowCount()).copyToHost()) { + assertEquals(expected.size(), map.getRowCount()); + assertEquals(expected, gatherMapToList(host)); + } + } +}