Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from cudf_polars.experimental.explain import SerializablePlan
from cudf_polars.experimental.rapidsmpf.frontend.core import StreamingEngine
from cudf_polars.experimental.rapidsmpf.frontend.options import StreamingOptions

POLARS_VALIDATION_OPTIONS = {
"check_row_order": True,
"check_column_order": True,
Expand Down Expand Up @@ -544,7 +545,7 @@ def from_args(cls, args: argparse.Namespace) -> RunConfig:
duckdb_temp_dir=args.duckdb_temp_dir,
)

def serialize(self, engine: pl.GPUEngine | None) -> dict:
def serialize(self, engine: StreamingEngine | None) -> dict:
"""Serialize the run config to a dictionary."""
opts = self.streaming_options
result: dict[str, Any] = {
Expand Down Expand Up @@ -583,7 +584,21 @@ def serialize(self, engine: pl.GPUEngine | None) -> dict:
}
if engine is not None:
config_options = ConfigOptions.from_polars_engine(engine)
result["config_options"] = dataclasses.asdict(config_options)
# Drop non-serializable contexts.
config_options = dataclasses.replace(
config_options,
executor=dataclasses.replace(
config_options.executor,
spmd_context=None,
ray_context=None,
dask_context=None,
),
)
rapidsmpf_options = engine.rapidsmpf_options.get_strings()
result["config_options"] = {
"config_options": dataclasses.asdict(config_options),
"rapidsmpf_options": rapidsmpf_options,
}
return result

def summarize(self) -> None:
Expand Down Expand Up @@ -1060,6 +1075,7 @@ def _finalize_benchmark_run(
run_config: RunConfig,
validation_failures: list[int],
query_failures: list[tuple[int, int]],
engine: StreamingEngine | None,
) -> None:
"""Summarize, serialize, and exit after a benchmark run."""
if args.summarize:
Expand All @@ -1074,7 +1090,7 @@ def _finalize_benchmark_run(
)
else:
print("✅ All validated queries passed.")
args.output.write(json.dumps(run_config.serialize(engine=None)))
args.output.write(json.dumps(run_config.serialize(engine=engine)))
args.output.write("\n")
sys.exit(1 if (query_failures or validation_failures) else 0)

Expand Down Expand Up @@ -1133,7 +1149,9 @@ def _allgather_result(df: pl.DataFrame) -> pl.DataFrame:
run_config = _consolidate_logs(
run_config, engine=engine, gather_client_logs=False
)
_finalize_benchmark_run(args, run_config, validation_failures, query_failures)
_finalize_benchmark_run(
args, run_config, validation_failures, query_failures, engine=engine
)


def run_polars_ray(
Expand Down Expand Up @@ -1180,7 +1198,9 @@ def run_polars_ray(
run_config = dataclasses.replace(run_config, records=dict(records), plans=plans)
run_config = _consolidate_logs(run_config, engine=engine)

_finalize_benchmark_run(args, run_config, validation_failures, query_failures)
_finalize_benchmark_run(
args, run_config, validation_failures, query_failures, engine=engine
)


def run_polars_dask(
Expand Down Expand Up @@ -1240,7 +1260,9 @@ def run_polars_dask(
finally:
if dask_client is not None:
dask_client.close()
_finalize_benchmark_run(args, run_config, validation_failures, query_failures)
_finalize_benchmark_run(
args, run_config, validation_failures, query_failures, engine=engine
)


def setup_logging(query_id: int, iteration: int) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from collections.abc import Callable, MutableMapping
from concurrent.futures import ThreadPoolExecutor

import rapidsmpf.config
from rapidsmpf.communicator.communicator import Communicator
from rapidsmpf.memory.buffer_resource import BufferResource
from rapidsmpf.streaming.core.context import Context
Expand Down Expand Up @@ -150,6 +151,8 @@ class StreamingEngine(pl.GPUEngine):
when :meth:`shutdown` is called. If ``None``, an empty stack is created.
"""

rapidsmpf_options: rapidsmpf.config.Options

def __init__(
self,
*,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,9 +602,8 @@ def __init__(
"memory_resource_config", None
)

rapidsmpf_options_as_bytes = resolve_rapidsmpf_options(
rapidsmpf_options
).serialize()
self.rapidsmpf_options = resolve_rapidsmpf_options(rapidsmpf_options)
rapidsmpf_options_as_bytes = self.rapidsmpf_options.serialize()

# Unique identifier for this cluster instance; namespaces the per-worker
# attribute so multiple DaskEngine contexts can coexist on the same workers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,8 @@ def __init__(
"memory_resource_config", None
)

rapidsmpf_options_as_bytes = resolve_rapidsmpf_options(
rapidsmpf_options
).serialize()
self.rapidsmpf_options = resolve_rapidsmpf_options(rapidsmpf_options)
rapidsmpf_options_as_bytes = self.rapidsmpf_options.serialize()

exit_stack = contextlib.ExitStack()
if not ray.is_initialized():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,8 @@ def __init__(
)
bind_to_gpu(hw_binding)

rapidsmpf_options = resolve_rapidsmpf_options(rapidsmpf_options)
self.rapidsmpf_options = resolve_rapidsmpf_options(rapidsmpf_options)

mr_config: MemoryResourceConfig | None = engine_options.get(
"memory_resource_config", None
)
Expand All @@ -361,12 +362,12 @@ def __init__(
comm = bootstrap.create_ucxx_comm(
progress_thread=ProgressThread(),
type=bootstrap.BackendType.AUTO,
options=rapidsmpf_options,
options=self.rapidsmpf_options,
)
else:
comm = single_communicator(
progress_thread=ProgressThread(),
options=rapidsmpf_options,
options=self.rapidsmpf_options,
)
# else: caller-provided comm; the caller retains ownership

Expand All @@ -384,7 +385,7 @@ def __init__(
# exit-stack holding a stale reference. ``_cleanup_ctx`` is
# registered instead — it shuts down whatever ``self._ctx`` is
# at engine-shutdown time (i.e. the latest reset's Context).
ctx = Context.from_options(comm.logger, mr, rapidsmpf_options)
ctx = Context.from_options(comm.logger, mr, self.rapidsmpf_options)
exit_stack.callback(self._cleanup_ctx)
self._comm: Communicator | None = comm
self._ctx: Context | None = ctx
Expand Down
Loading