Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ object GpuMetric extends Logging {
val BUILD_DATA_SIZE = "buildDataSize"
val BUILD_TIME = "buildTime"
val STREAM_TIME = "streamTime"
val BUILD_SIDE_CACHE_BUILDS = "buildSideCacheBuilds"
val BUILD_SIDE_CACHE_HITS = "buildSideCacheHits"
val NUM_TASKS_FALL_BACKED = "numTasksFallBacked"
val NUM_TASKS_REPARTITIONED = "numTasksRepartitioned"
val NUM_TASKS_SKIPPED_AGG = "numTasksSkippedAgg"
Expand Down Expand Up @@ -173,6 +175,8 @@ object GpuMetric extends Logging {
val DESCRIPTION_BUILD_DATA_SIZE = "build side size"
val DESCRIPTION_BUILD_TIME = "build time"
val DESCRIPTION_STREAM_TIME = "stream time"
val DESCRIPTION_BUILD_SIDE_CACHE_BUILDS = "cached build side builds"
val DESCRIPTION_BUILD_SIDE_CACHE_HITS = "cached build side hits"
val DESCRIPTION_NUM_TASKS_FALL_BACKED = "number of sort fallback tasks"
val DESCRIPTION_NUM_TASKS_REPARTITIONED = "number of tasks repartitioned for agg"
val DESCRIPTION_NUM_TASKS_SKIPPED_AGG = "number of tasks skipped aggregation"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ case class GpuShuffledHashJoinExec(
}
// doJoin will close singleBatch
doJoin(singleBatch, maybeBufferedStreamIter, joinOptions,
numOutputRows, numOutputBatches, opTime, joinTime)
numOutputRows, numOutputBatches, opTime, joinTime, enableBuildSideReuse = false)
case Right(builtBatchIter) =>
// For big joins, when the build data can not fit into a single batch.
val sizeBuildIter = builtBatchIter.map { cb =>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
* Copyright (c) 2025-2026, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -393,6 +393,9 @@ object NvtxRegistry {
val BUILD_JOIN_TABLE: NvtxId = NvtxId("build join table", NvtxColor.GREEN,
"Building hash table for join operation")

val BROADCAST_HASH_TABLE_BUILD: NvtxId = NvtxId("broadcast hash table build",
NvtxColor.GREEN, "Building cuDF hash table for broadcast hash join")

// Window operations
val WINDOW: NvtxId = NvtxId("window", NvtxColor.CYAN,
"Computing window function results")
Expand Down Expand Up @@ -780,6 +783,7 @@ object NvtxRegistry {
register(EXISTENCE_JOIN_SCATTER_MAP)
register(EXISTENCE_JOIN_BATCH)
register(BUILD_JOIN_TABLE)
register(BROADCAST_HASH_TABLE_BUILD)
register(WINDOW)
register(RUNNING_WINDOW)
register(DOUBLE_BATCHED_WINDOW_PRE)
Expand Down
13 changes: 13 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,17 @@ val GPU_COREDUMP_PIPE_PATTERN = conf("spark.rapids.gpu.coreDump.pipePattern")
.checkValues(JoinBuildSideSelection.values.map(_.toString))
.createWithDefault(JoinBuildSideSelection.AUTO.toString)

val BROADCAST_HASH_TABLE_REUSE =
conf("spark.rapids.sql.join.broadcastHashTable.reuse")
.doc("Enable reuse of the broadcast-side hash table for broadcast hash joins. " +
"When enabled, the hash table is built once per broadcast and shared across all " +
"stream batches within a task and across all tasks that consume the same broadcast " +
"on an executor. Reuse pins the physical build side to the broadcast side for the " +
"lifetime of each cached join, overriding the dynamic build-side selection " +
s"heuristic configured by ${JOIN_BUILD_SIDE.key}.")
.booleanConf
.createWithDefault(false)

val LOG_JOIN_CARDINALITY = conf("spark.rapids.sql.join.logCardinality")
.doc("Enable logging of join cardinality statistics to help diagnose performance issues. " +
"When enabled, logs task context, key data types, join condition, row counts, and " +
Expand Down Expand Up @@ -3280,6 +3291,8 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val joinGathererSizeEstimateThreshold: Double = get(JOIN_GATHERER_SIZE_ESTIMATE_THRESHOLD)

lazy val broadcastHashTableReuse: Boolean = get(BROADCAST_HASH_TABLE_REUSE)

/**
* Get join options based on the current configuration.
* @param targetSize the target batch size in bytes to use for the join
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,19 @@ trait SpillableHandle extends StoreHandle with Logging {

}

/**
* Contract for handles tracked by the device spill store (see [[SpillableHandle]]).
*/
trait DeviceStoreHandle extends SpillableHandle {
def releaseSpilled(): Unit
}

/**
* Spillable handles that can be materialized on the device.
* @tparam T an auto closeable subclass. `dev` tracks an instance of this object,
* on the device.
*/
trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle {
trait DeviceSpillableHandle[T <: AutoCloseable] extends DeviceStoreHandle {
private[spill] var dev: Option[T]

private[spill] override def spillable: Boolean = synchronized {
Expand All @@ -305,7 +312,7 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle {
* free a device buffer that the worker thread isn't done with).
* See https://github.com/NVIDIA/spark-rapids/issues/8610 for more info.
*/
def releaseSpilled(): Unit = {
override def releaseSpilled(): Unit = {
releaseDeviceResource()
}

Expand All @@ -319,6 +326,192 @@ trait DeviceSpillableHandle[T <: AutoCloseable] extends SpillableHandle {
}
}

object SharedRecomputableDeviceHandle {
/**
* A scoped lease that pins a shared recomputable device object.
* Holding a lease prevents the object from being closed.
*
* @param handle the handle that the lease is on
* @param resource the object that the lease is on
*/
final class Lease[T <: AutoCloseable] private[spill] (
handle: SharedRecomputableDeviceHandle[T],
val resource: T) extends AutoCloseable {
private[this] var closed = false

override def close(): Unit = synchronized {
if (closed) {
throw new IllegalStateException("Close called too many times on recomputable handle lease")
}
closed = true
handle.releasePin()
}
}

def apply[T <: AutoCloseable](
approxSizeInBytes: Long,
initialValue: T)(
rebuild: => T): SharedRecomputableDeviceHandle[T] = {
val handle = new SharedRecomputableDeviceHandle(approxSizeInBytes, initialValue, () => rebuild)
SpillFramework.stores.deviceStore.track(handle)
handle
}
}

/**
* Handle for device-only object that is shared amongst threads and is cheap to recompute.
*
* When this handle is selected for spilling, it does not copy anything to host or disk. Instead
* it marks the current device object as evicted and returns `approxSizeInBytes` so the spill
* framework accounts for the freed device memory. The actual close of the evicted object is
* deferred to `releaseSpilled` after device synchronization.
*
* The protected device object may not expose cuDF-style reference counts (e.g. cuDF hash tables).
* Instead we maintain a pin count on the object, and callers must pin the object through `acquire`.
* The object is spillable only when the pin count is 0.
*
* The `rebuild` function is used to recreate the device object after it has been evicted. It is
* called by the first thread that calls `acquire` after a successful spill.
*/
class SharedRecomputableDeviceHandle[T <: AutoCloseable] private[spill] (
override val approxSizeInBytes: Long,
initialValue: T,
rebuild: () => T) extends DeviceStoreHandle with Logging {
import SharedRecomputableDeviceHandle.Lease

private[spill] var dev: Option[T] = Some(initialValue)
private[this] var pendingRelease: Seq[T] = Seq.empty
private[this] var pinCount: Int = 0
private[this] var rebuilding: Boolean = false

private[spill] override def spillable: Boolean = synchronized {
super.spillable && dev.isDefined && pinCount == 0
}

/**
* Acquire a lease on the shared recomputable device object:
* - If `dev` is defined, increment the pin count and return a lease on it.
* - If `dev` is missing and another thread is rebuilding, wait for build completion.
* - If `dev` is missing and no other thread is rebuilding, set `rebuilding=true` (exclusively)
* and rebuild the object. After rebuild, set `dev`, notify waiters, and track the handle.
*/
def acquire(): Lease[T] = {
var materialized: Option[T] = None
var shouldBuild = false
while (materialized.isEmpty) {
shouldBuild = synchronized {
if (closed) {
throw new IllegalStateException("attempting to materialize a closed handle")
} else if (dev.isDefined) {
pinCount += 1
materialized = dev
false
} else if (rebuilding) {
wait()
false
} else {
rebuilding = true
true
}
}

if (shouldBuild) {
var rebuilt: Option[T] = None
try {
rebuilt = Some(rebuild())
var shouldTrack = false
synchronized {
rebuilding = false
if (closed) {
notifyAll()
throw new IllegalStateException("attempting to materialize a closed handle")
}
dev = rebuilt
pinCount += 1
materialized = rebuilt
shouldTrack = true
notifyAll()
}
if (shouldTrack) {
SpillFramework.stores.deviceStore.track(this)
}
} catch {
case t: Throwable =>
// Rebuild failed; release any rebuilt object, clear rebuilding,
// and notify waiters before rethrowing
rebuilt.foreach(_.close())
synchronized {
rebuilding = false
notifyAll()
}
throw t
}
}
}
new Lease(this, materialized.get)
}

private[spill] def releasePin(): Unit = synchronized {
if (pinCount <= 0) {
throw new IllegalStateException("releasePin called without a matching acquire")
}
pinCount -= 1
}

override def spill(): Long = {
var evicted: Option[T] = None
val thisThreadSpills = synchronized {
if (!closed && dev.isDefined && pinCount == 0 && !spilling) {
spilling = true
evicted = dev
dev = None
true
} else {
false
}
}
if (thisThreadSpills) {
SpillFramework.removeFromDeviceStore(this)
var shouldClose = false
executeSpill {
synchronized {
pendingRelease = pendingRelease ++ evicted.toSeq
spilling = false
shouldClose = closed
}
0L
}
if (shouldClose) {
doClose()
}
approxSizeInBytes
} else {
0L
}
}

override def releaseSpilled(): Unit = {
val toClose = synchronized {
val release = pendingRelease
pendingRelease = Seq.empty
release
}
toClose.safeClose()
}

override def doClose(): Unit = {
SpillFramework.removeFromDeviceStore(this)
val toClose = synchronized {
val current = dev
val release = pendingRelease
dev = None
pendingRelease = Seq.empty
current.toSeq ++ release.toSeq
}
toClose.safeClose()
}
}

/**
* Spillable handles that can be materialized on the host.
* @tparam T an auto closeable subclass. `host` tracks an instance of this object,
Expand Down Expand Up @@ -1739,7 +1932,7 @@ class SpillableHostStore(val maxSize: Option[Long] = None)
override protected def spillNvtxRange: NvtxId = NvtxRegistry.DISK_SPILL
}

class SpillableDeviceStore extends SpillableStore[DeviceSpillableHandle[_]] {
class SpillableDeviceStore extends SpillableStore[DeviceStoreHandle] {
override protected def spillNvtxRange: NvtxId = NvtxRegistry.DEVICE_SPILL

override def postSpill(plan: SpillPlan): Unit = {
Expand Down Expand Up @@ -2152,7 +2345,7 @@ object SpillFramework extends Logging {
// if the stores have already shut down, we don't want to create them here
// so we use `storesInternal` directly in these remove functions.

private[spill] def removeFromDeviceStore(handle: DeviceSpillableHandle[_]): Unit = {
private[spill] def removeFromDeviceStore(handle: DeviceStoreHandle): Unit = {
synchronized {
Option(storesInternal).map(_.deviceStore)
}.foreach(_.remove(handle))
Expand Down
Loading
Loading