Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/cudf_polars/cudf_polars/experimental/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _(
if (
Path(ir.path).exists()
and executor_options.sink_to_directory
and executor_options.cluster == Cluster.SINGLE
and executor_options.cluster == Cluster.DEFAULT_SINGLETON
):
# This lowering-time check can't be performed with the spmd / ray / dask
# clusters, which lower on each worker independently. There's a race condition
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES.
# SPDX-License-Identifier: Apache-2.0

"""RapidsMPF streaming-engine support."""

from __future__ import annotations

# Side-effect imports: each module registers
# ``@generate_ir_sub_network.register(...)`` handlers at import time so the
# dispatch table is populated before any query is evaluated.
import cudf_polars.experimental.rapidsmpf.collectives.shuffle
import cudf_polars.experimental.rapidsmpf.collectives.sort
import cudf_polars.experimental.rapidsmpf.groupby
import cudf_polars.experimental.rapidsmpf.io
import cudf_polars.experimental.rapidsmpf.join
import cudf_polars.experimental.rapidsmpf.repartition
import cudf_polars.experimental.rapidsmpf.union # noqa: F401

__all__: list[str] = []
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ async def _simple_top_or_bottom_k(
ir_context=ir_context,
)
)
chunk: TableChunk = await evaluate_batch(chunks, context, ir, ir_context=ir_context)
chunk: TableChunk
if chunks:
chunk = await evaluate_batch(chunks, context, ir, ir_context=ir_context)
else:
# This rank received no input partitions. Produce an empty chunk
# with the IR's output schema so the AllGather below still has
# something to insert (and other ranks don't deadlock waiting).
chunk = empty_table_chunk(ir, context, ir_context.get_cuda_stream())
Comment thread
wence- marked this conversation as resolved.
chunks.clear()

if comm.nranks > 1 and not metadata_in.duplicated:
Expand Down
270 changes: 28 additions & 242 deletions python/cudf_polars/cudf_polars/experimental/rapidsmpf/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,70 +4,41 @@

from __future__ import annotations

import contextlib
import dataclasses
import uuid
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from typing import TYPE_CHECKING, Any

from rapidsmpf.communicator.single import (
new_communicator as single_process_communicator,
)
from rapidsmpf.config import Options, get_environment_variables
from rapidsmpf.memory.buffer import MemoryType
from rapidsmpf.memory.buffer_resource import BufferResource, LimitAvailableMemory
from rapidsmpf.memory.pinned_memory_resource import PinnedMemoryResource
from rapidsmpf.progress_thread import ProgressThread
from rapidsmpf.rmm_resource_adaptor import RmmResourceAdaptor
from rapidsmpf.streaming.core.actor import (
run_actor_network,
)
from rapidsmpf.streaming.core.context import Context
from rapidsmpf.streaming.core.leaf_actor import pull_from_channel
from rapidsmpf.streaming.cudf.table_chunk import TableChunk

import pylibcudf as plc
import rmm

import cudf_polars.experimental.rapidsmpf.collectives.shuffle
import cudf_polars.experimental.rapidsmpf.collectives.sort
import cudf_polars.experimental.rapidsmpf.groupby
import cudf_polars.experimental.rapidsmpf.io
import cudf_polars.experimental.rapidsmpf.join
import cudf_polars.experimental.rapidsmpf.repartition
import cudf_polars.experimental.rapidsmpf.union
from cudf_polars.containers import DataFrame

import cudf_polars.dsl.tracing
from cudf_polars.dsl.ir import (
DataFrameScan,
IRExecutionContext,
Join,
Scan,
Union,
)
from cudf_polars.dsl.traversal import CachingVisitor, traversal
from cudf_polars.experimental.parallel import lower_ir_graph
from cudf_polars.experimental.rapidsmpf.collectives import ReserveOpIDs
from cudf_polars.experimental.rapidsmpf.dispatch import FanoutInfo
from cudf_polars.experimental.rapidsmpf.nodes import (
generate_ir_sub_network_wrapper,
metadata_drain_node,
)
from cudf_polars.experimental.rapidsmpf.tracing import log_query_plan
from cudf_polars.experimental.rapidsmpf.utils import empty_table_chunk
from cudf_polars.experimental.statistics import collect_statistics
from cudf_polars.utils.config import CUDAStreamPoolConfig
from cudf_polars.utils.config import SPMDContext

if TYPE_CHECKING:
from collections.abc import MutableMapping

from rapidsmpf.communicator.communicator import Communicator
from rapidsmpf.streaming.core.channel import Channel
from rapidsmpf.streaming.core.context import Context
from rapidsmpf.streaming.core.leaf_actor import DeferredMessages
from rapidsmpf.streaming.cudf.channel_metadata import ChannelMetadata
from rapidsmpf.streaming.cudf.table_chunk import TableChunk

import polars as pl

from cudf_polars.dsl.ir import IR
from cudf_polars.dsl.ir import IR, IRExecutionContext
from cudf_polars.experimental.base import PartitionInfo, StatsCollector
from cudf_polars.experimental.parallel import ConfigOptions
from cudf_polars.experimental.rapidsmpf.dispatch import (
Expand Down Expand Up @@ -99,13 +70,32 @@ def evaluate_logical_plan(
-------
The output DataFrame and metadata collector.
"""
query_id = uuid.uuid4()
# For default_singleton, inject the process-wide DefaultSingletonEngine instance
# into config_options before treating it as a regular SPMDEngine.
if config_options.executor.cluster == "default_singleton":
from cudf_polars.experimental.rapidsmpf.frontend.default_singleton_engine import (
DefaultSingletonEngine,
)
Comment thread
madsbk marked this conversation as resolved.

engine = DefaultSingletonEngine.get_or_create()
config_options = dataclasses.replace(
config_options,
executor=dataclasses.replace(
config_options.executor,
spmd_context=SPMDContext(
comm=engine.comm,
context=engine.context,
py_executor=engine.py_executor,
),
),
)

query_id = uuid.uuid4()
with cudf_polars.dsl.tracing.bound_contextvars(
cudf_polars_query_id=str(query_id),
):
match config_options.executor.cluster:
case "spmd":
case "spmd" | "default_singleton":
from cudf_polars.experimental.rapidsmpf.frontend.spmd import (
evaluate_pipeline_spmd_mode,
)
Expand Down Expand Up @@ -138,216 +128,12 @@ def evaluate_logical_plan(
collect_metadata=collect_metadata,
query_id=query_id,
)
case "single":
# Single-process execution: lower and run locally.
stats = collect_statistics(ir, config_options)
ir, partition_info = lower_ir_graph(ir, config_options, stats)
with ReserveOpIDs(ir, config_options) as collective_id_map:
log_query_plan(ir, config_options)
result, metadata_collector = evaluate_pipeline(
ir,
partition_info,
config_options,
stats,
collective_id_map,
single_process_communicator(Options(), ProgressThread()),
collect_metadata=collect_metadata,
query_id=query_id,
)
case other:
raise ValueError(f"Unknown cluster mode: {other}")

return result, metadata_collector


def evaluate_pipeline(
ir: IR,
partition_info: MutableMapping[IR, PartitionInfo],
config_options: ConfigOptions[StreamingExecutor],
stats: StatsCollector,
collective_id_map: dict[IR, list[int]],
comm: Communicator,
rmpf_context: Context | None = None,
*,
collect_metadata: bool = False,
query_id: uuid.UUID,
) -> tuple[pl.DataFrame, list[ChannelMetadata] | None]:
"""
Build and evaluate a RapidsMPF streaming pipeline.

Parameters
----------
ir
The IR node.
partition_info
The partition information.
config_options
The configuration options.
stats
The statistics collector.
collective_id_map
The mapping of IR nodes to lists of collective IDs.
comm
The communicator describing the participating processes.
rmpf_context
The RapidsMPF context.
collect_metadata
Whether to collect runtime metadata.
query_id
A unique identifier for the query.

Returns
-------
The output DataFrame and metadata collector.
"""
_original_mr: Any = None
use_stream_pool = False
if rmpf_context is not None:
# Using "distributed" mode.
# Always use the RapidsMPF stream pool for now.
br = rmpf_context.br()
use_stream_pool = True
rmpf_context_manager = contextlib.nullcontext(rmpf_context)
else:
# Using "single" mode.
# Create a new local RapidsMPF context.
_original_mr = rmm.mr.get_current_device_resource()
mr = RmmResourceAdaptor(_original_mr)
rmm.mr.set_current_device_resource(mr)
memory_available: MutableMapping[MemoryType, LimitAvailableMemory] | None = None
single_spill_device = config_options.executor.client_device_threshold
if single_spill_device > 0.0 and single_spill_device < 1.0:
total_memory = rmm.mr.available_device_memory()[1]
memory_available = {
MemoryType.DEVICE: LimitAvailableMemory(
mr, limit=int(total_memory * single_spill_device)
)
}

options = Options(
{
# By default, set the number of streaming threads to the max
# number of IO threads. The user may override this with an
# environment variable (i.e. RAPIDSMPF_NUM_STREAMING_THREADS)
"num_streaming_threads": str(
max(config_options.executor.max_io_threads, 1)
)
}
| get_environment_variables()
)
pinned_mr = (
PinnedMemoryResource.make_if_available()
if config_options.executor.spill_to_pinned_memory
else None
)
stream_pool = (
config_options.cuda_stream_policy.build()
if isinstance(config_options.cuda_stream_policy, CUDAStreamPoolConfig)
else None
)
use_stream_pool = stream_pool is not None
br = BufferResource(
mr,
pinned_mr=pinned_mr,
memory_available=memory_available,
stream_pool=stream_pool,
)
rmpf_context_manager = Context(comm.logger, br, options)

with rmpf_context_manager as rmpf_context:
# Create the IR execution context
if use_stream_pool:
ir_context = IRExecutionContext(
get_cuda_stream=rmpf_context.get_stream_from_pool, query_id=query_id
)
else:
ir_context = IRExecutionContext(query_id=query_id)

# Generate network nodes
assert rmpf_context is not None, "RapidsMPF context must defined."
metadata_collector: list[ChannelMetadata] | None = (
[] if collect_metadata else None
)
nodes, output = generate_network(
rmpf_context,
comm,
ir,
partition_info,
config_options,
stats,
ir_context=ir_context,
collective_id_map=collective_id_map,
metadata_collector=metadata_collector,
)

try:
# Run the network
with ThreadPoolExecutor(
max_workers=config_options.executor.num_py_executors,
thread_name_prefix="cpse",
) as executor:
run_actor_network(actors=nodes, py_executor=executor)

# Extract/return the concatenated result.
# Keep chunks alive until after concatenation to prevent
# use-after-free with stream-ordered allocations
messages = output.release()
chunks = [
TableChunk.from_message(msg, br=br).make_available_and_spill(
br, allow_overbooking=True
)
for msg in messages
]
dfs: list[DataFrame] = []
if chunks:
col_names = list(ir.schema.keys())
col_dtypes = list(ir.schema.values())
dfs = [
DataFrame.from_table(
chunk.table_view(), col_names, col_dtypes, chunk.stream
)
for chunk in chunks
]
if len(dfs) == 1:
df = dfs[0]
else:
with ir_context.stream_ordered_after(*dfs) as stream:
df = DataFrame.from_table(
plc.concatenate.concatenate(
[d.table for d in dfs], stream=stream
),
col_names,
col_dtypes,
stream,
)
else:
# No chunks received - create an empty DataFrame with correct schema
stream = ir_context.get_cuda_stream()
chunk = empty_table_chunk(ir, rmpf_context, stream)
df = DataFrame.from_table(
chunk.table_view(),
list(ir.schema.keys()),
list(ir.schema.values()),
stream,
)

result = df.to_polars()

# Now we need to drop *all* GPU data. This ensures that no cudaFreeAsync runs
# before the Context, which ultimately contains the rmm MR, goes out of scope.
del messages, chunks, dfs, df
finally:
# Ensure these are dropped even if a node raises
# an exception in run_actor_network
del nodes, output

# Restore the initial RMM memory resource
if _original_mr is not None:
rmm.mr.set_current_device_resource(_original_mr)

return result, metadata_collector


def determine_fanout_nodes(
ir: IR,
partition_info: MutableMapping[IR, PartitionInfo],
Expand Down
Loading
Loading