diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py b/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py index 98532df2a87..c35177782e3 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/rolling.py @@ -615,13 +615,17 @@ def _grouped_window_scan_setup( local = self._sorted_grouper(by_cols_for_scan) return order_index, by_cols_for_scan, local + # TODO: this is an ordered left-join that drops the join keys. + # Rename it and replace the manual scatter+gather with the Join IR's + # _reorder_maps helper (lifted to a shared utility) so the streaming + # and in-memory over paths share the same primitive. def _broadcast_agg_results( self, by_tbl: plc.Table, group_keys_tbl: plc.Table, value_tbls: list[plc.Table], - names: list[str], - dtypes: list[DataType], + names: Sequence[str], + dtypes: Sequence[DataType], stream: Stream, ) -> list[Column]: # We do a left-join between the input keys to group-keys diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 47afd9b7e58..8f68b259954 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -38,6 +38,8 @@ ) if TYPE_CHECKING: + from collections.abc import Generator + from polars import GPUEngine from cudf_polars.typing import NodeTraverser @@ -90,6 +92,7 @@ def __init__(self, visitor: NodeTraverser, engine: GPUEngine): self.errors: list[Exception] = [] self._cache_nodes: dict[int, ir.Cache] = {} self._expr_context: ExecutionContext = ExecutionContext.FRAME + self._internal_name_gen: Generator[str, None, None] | None = None def translate_ir(self, *, n: int | None = None) -> ir.IR: """ @@ -247,6 +250,24 @@ def __exit__(self, *args: Any) -> None: self.translator._expr_context = self._prev +class set_internal_name_gen(AbstractContextManager[None]): + """Share one internal-name generator across sibling expression translations.""" + + __slots__ = ("_prev", "schema", "translator") + + def __init__(self, translator: Translator, schema: Schema) -> None: + self.translator = translator + self.schema = schema + self._prev: Generator[str, None, None] | None = None + + def __enter__(self) -> None: + self._prev = self.translator._internal_name_gen + self.translator._internal_name_gen = unique_names(self.schema) + + def __exit__(self, *args: Any) -> None: + self.translator._internal_name_gen = self._prev + + @singledispatch def _translate_ir(node: Any, translator: Translator, schema: Schema) -> ir.IR: raise NotImplementedError( @@ -362,9 +383,11 @@ def _( def _(node: plrs._ir_nodes.Select, translator: Translator, schema: Schema) -> ir.IR: with set_node(translator.visitor, node.input): inp = translator.translate_ir(n=None) - exprs = [ - translate_named_expr(translator, n=e, schema=inp.schema) for e in node.expr - ] + with set_internal_name_gen(translator, inp.schema): + exprs = [ + translate_named_expr(translator, n=e, schema=inp.schema) + for e in node.expr + ] return ir.Select(schema, exprs, node.should_broadcast, inp) @@ -478,9 +501,11 @@ def _(node: plrs._ir_nodes.Join, translator: Translator, schema: Schema) -> ir.I def _(node: plrs._ir_nodes.HStack, translator: Translator, schema: Schema) -> ir.IR: with set_node(translator.visitor, node.input): inp = translator.translate_ir(n=None) - exprs = [ - translate_named_expr(translator, n=e, schema=inp.schema) for e in node.exprs - ] + with set_internal_name_gen(translator, inp.schema): + exprs = [ + translate_named_expr(translator, n=e, schema=inp.schema) + for e in node.exprs + ] return ir.HStack(schema, exprs, node.should_broadcast, inp) @@ -830,7 +855,7 @@ def _( # pl.col("a").rolling(...) with set_expr_context(translator, ExecutionContext.ROLLING): agg = translator.translate_expr(n=node.function, schema=schema) - name_generator = unique_names(schema) + name_generator = translator._internal_name_gen or unique_names(schema) aggs, named_post_agg = decompose_single_agg( expr.NamedExpr(next(name_generator), agg), name_generator, @@ -875,7 +900,7 @@ def _( # not exposed until polars 1.39. with set_expr_context(translator, ExecutionContext.WINDOW): agg = translator.translate_expr(n=node.function, schema=schema) - name_gen = unique_names(schema) + name_gen = translator._internal_name_gen or unique_names(schema) aggs, post = decompose_single_agg( expr.NamedExpr(next(name_gen), agg), name_gen, diff --git a/python/cudf_polars/cudf_polars/dsl/utils/naming.py b/python/cudf_polars/cudf_polars/dsl/utils/naming.py index 65eedbb1495..018477f9177 100644 --- a/python/cudf_polars/cudf_polars/dsl/utils/naming.py +++ b/python/cudf_polars/cudf_polars/dsl/utils/naming.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: Apache-2.0 """Name generation utilities.""" @@ -7,11 +7,15 @@ from typing import TYPE_CHECKING +from cudf_polars.dsl.expr import NamedExpr + if TYPE_CHECKING: from collections.abc import Generator, Iterable + from cudf_polars.typing import Schema + -__all__ = ["unique_names"] +__all__ = ["names_to_indices", "unique_names"] def unique_names(names: Iterable[str]) -> Generator[str, None, None]: @@ -32,3 +36,28 @@ def unique_names(names: Iterable[str]) -> Generator[str, None, None]: while True: yield f"{prefix}{i}" i += 1 + + +def names_to_indices( + names: tuple[str | NamedExpr, ...], schema: Schema +) -> tuple[int, ...]: + """ + Return column indices for the given names in schema order. + + Accepts either column names (str) or NamedExpr, so it can be used with + e.g. ir.left_on, ir.right_on as well as plain name tuples. + + Parameters + ---------- + names + The names to get indices for. + schema + The schema to get indices from. + + Returns + ------- + The column indices for each name in schema order. + """ + keys = list(schema.keys()) + str_names = [n.name if isinstance(n, NamedExpr) else n for n in names] + return tuple(keys.index(n) for n in str_names) diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q14.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q14.py index 57946e3044d..d353c19102c 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q14.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q14.py @@ -335,65 +335,66 @@ def polars_impl(run_config: RunConfig) -> QueryResult: item = get_data(run_config.dataset_path, "item", run_config.suffix) date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix) + cross_items = build_cross_items( + store_sales, catalog_sales, web_sales, item, date_dim, year=year + ) + average_sales = build_average_sales( + store_sales, catalog_sales, web_sales, date_dim, year=year + ) + + # week_dates is ≤7 rows (one calendar week), computed once as a 1-partition frame. + # Push the week filter into each channel before the UNION via a semi-join so that + # ~99% of rows (everything outside the target week) are dropped before the + # expensive cross_items join and groupby. + target_week = ( + date_dim.filter( + (pl.col("d_year") == year + 1) + & (pl.col("d_moy") == 12) + & (pl.col("d_dom") == day) + ) + .select("d_week_seq") + .unique() + ) + week_dates = date_dim.join(target_week, on="d_week_seq").select("d_date_sk") + all_sales = pl.concat( [ - store_sales.select( + store_sales.join( + week_dates, left_on="ss_sold_date_sk", right_on="d_date_sk", how="semi" + ).select( [ pl.lit("store").alias("channel"), pl.col("ss_item_sk").alias("item_sk"), pl.col("ss_quantity").alias("quantity"), pl.col("ss_list_price").alias("list_price"), - pl.col("ss_sold_date_sk").alias("date_sk"), ] ), - catalog_sales.select( + catalog_sales.join( + week_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi" + ).select( [ pl.lit("catalog").alias("channel"), pl.col("cs_item_sk").alias("item_sk"), pl.col("cs_quantity").alias("quantity"), pl.col("cs_list_price").alias("list_price"), - pl.col("cs_sold_date_sk").alias("date_sk"), ] ), - web_sales.select( + web_sales.join( + week_dates, left_on="ws_sold_date_sk", right_on="d_date_sk", how="semi" + ).select( [ pl.lit("web").alias("channel"), pl.col("ws_item_sk").alias("item_sk"), pl.col("ws_quantity").alias("quantity"), pl.col("ws_list_price").alias("list_price"), - pl.col("ws_sold_date_sk").alias("date_sk"), ] ), ] ) - cross_items = build_cross_items( - store_sales, catalog_sales, web_sales, item, date_dim, year=year - ) - average_sales = build_average_sales( - store_sales, catalog_sales, web_sales, date_dim, year=year - ) - - # d_week_seq target is the same for all 3 channels; compute it once. - target_week = ( - date_dim.filter( - (pl.col("d_year") == year + 1) - & (pl.col("d_moy") == 12) - & (pl.col("d_dom") == day) - ) - .select("d_week_seq") - .unique() - ) - week_dates = date_dim.join(target_week, on="d_week_seq").select("d_date_sk") - - # Build y: all 3 channels in a single pipeline. - # cross_items and average_sales each appear once — no CSE needed. - # After group_by the frame is tiny, so the cross join with the 1-row - # average_sales frame is negligible even if Polars fuses it into an IEJoin. y = ( all_sales.join(cross_items, left_on="item_sk", right_on="ss_item_sk") .join(item, left_on="item_sk", right_on="i_item_sk") - .join(week_dates, left_on="date_sk", right_on="d_date_sk") .group_by(["channel", "i_brand_id", "i_class_id", "i_category_id"]) .agg( [ diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q17.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q17.py index 997a1546d47..b26e23984e8 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q17.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q17.py @@ -110,38 +110,103 @@ def polars_impl(run_config: RunConfig) -> QueryResult: sort_by = {"i_item_id": False, "i_item_desc": False, "s_state": False} limit = 100 - store_sales_base = ( + q1 = f"{year}Q1" + q1_q3 = [f"{year}Q1", f"{year}Q2", f"{year}Q3"] + + # Pre-filter date_dim to only qualifying d_date_sk values. + d1_dates = date_dim.filter(pl.col("d_quarter_name") == q1).select("d_date_sk") + d_q3_dates = date_dim.filter(pl.col("d_quarter_name").is_in(q1_q3)).select( + "d_date_sk" + ) + + # store_returns has [6] partitions — at the broadcast limit. Filter it to Q1-Q3 dates + # first, then use the (customer, item) pairs it contains to pre-filter both store_sales + # and catalog_sales before those larger tables enter the expensive shuffle joins. + store_returns_filtered = store_returns.join( + d_q3_dates, left_on="sr_returned_date_sk", right_on="d_date_sk", how="semi" + ).select(["sr_customer_sk", "sr_item_sk", "sr_ticket_number", "sr_return_quantity"]) + + # (customer, item) pairs present in any qualifying store return; stays at [6] partitions + # so broadcast is free. Polars will CACHE this shared subplan. + sr_customer_item = store_returns_filtered.select(["sr_customer_sk", "sr_item_sk"]) + + store_sales_filtered = ( store_sales.join( - date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk", suffix="_d1" + d1_dates, left_on="ss_sold_date_sk", right_on="d_date_sk", how="semi" + ) + .join( + sr_customer_item, + left_on=["ss_customer_sk", "ss_item_sk"], + right_on=["sr_customer_sk", "sr_item_sk"], + how="semi", + ) + .select( + [ + "ss_customer_sk", + "ss_item_sk", + "ss_store_sk", + "ss_ticket_number", + "ss_quantity", + ] + ) + .join( + item.select(["i_item_sk", "i_item_id", "i_item_desc"]), + left_on="ss_item_sk", + right_on="i_item_sk", + ) + .join( + store.select(["s_store_sk", "s_state"]), + left_on="ss_store_sk", + right_on="s_store_sk", + ) + .select( + [ + "ss_customer_sk", + "ss_item_sk", + "ss_ticket_number", + "ss_quantity", + "i_item_id", + "i_item_desc", + "s_state", + ] ) - .join(item, left_on="ss_item_sk", right_on="i_item_sk") - .join(store, left_on="ss_store_sk", right_on="s_store_sk") - .filter(pl.col("d_quarter_name") == f"{year}Q1") ) - store_returns_base = store_returns.join( - date_dim, left_on="sr_returned_date_sk", right_on="d_date_sk", suffix="_d2" - ).filter(pl.col("d_quarter_name").is_in([f"{year}Q1", f"{year}Q2", f"{year}Q3"])) - - catalog_sales_base = catalog_sales.join( - date_dim, left_on="cs_sold_date_sk", right_on="d_date_sk", suffix="_d3" - ).filter(pl.col("d_quarter_name").is_in([f"{year}Q1", f"{year}Q2", f"{year}Q3"])) + catalog_sales_filtered = ( + catalog_sales.join( + d_q3_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi" + ) + .join( + sr_customer_item, + left_on=["cs_bill_customer_sk", "cs_item_sk"], + right_on=["sr_customer_sk", "sr_item_sk"], + how="semi", + ) + .select(["cs_bill_customer_sk", "cs_item_sk", "cs_quantity"]) + ) return QueryResult( frame=( - store_sales_base.join( - store_returns_base, + store_sales_filtered.join( + store_returns_filtered, left_on=["ss_customer_sk", "ss_item_sk", "ss_ticket_number"], right_on=["sr_customer_sk", "sr_item_sk", "sr_ticket_number"], - how="inner", - suffix="_sr", + ) + .select( + [ + "ss_customer_sk", + "ss_item_sk", + "ss_quantity", + "sr_return_quantity", + "i_item_id", + "i_item_desc", + "s_state", + ] ) .join( - catalog_sales_base, + catalog_sales_filtered, left_on=["ss_customer_sk", "ss_item_sk"], right_on=["cs_bill_customer_sk", "cs_item_sk"], - how="inner", - suffix="_cs", ) .group_by(["i_item_id", "i_item_desc", "s_state"]) .agg( diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q18.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q18.py index 9c2b9f227ef..5ef6017150e 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q18.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q18.py @@ -121,10 +121,7 @@ def polars_impl(run_config: RunConfig) -> QueryResult: catalog_sales = get_data( run_config.dataset_path, "catalog_sales", run_config.suffix ) - customer_demographics_1 = get_data( - run_config.dataset_path, "customer_demographics", run_config.suffix - ) - customer_demographics_2 = get_data( + customer_demographics = get_data( run_config.dataset_path, "customer_demographics", run_config.suffix ) customer = get_data(run_config.dataset_path, "customer", run_config.suffix) @@ -134,30 +131,52 @@ def polars_impl(run_config: RunConfig) -> QueryResult: date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix) item = get_data(run_config.dataset_path, "item", run_config.suffix) + # Pre-filter each dimension table before joining against catalog_sales [45 partitions]. + # d_year not in GROUP BY — semi-join keeps only the date key in the pipeline. + filtered_dates = date_dim.filter(pl.col("d_year") == year).select("d_date_sk") + filtered_cd1 = customer_demographics.filter( + (pl.col("cd_gender") == gen) & (pl.col("cd_education_status") == es) + ).select(["cd_demo_sk", "cd_dep_count"]) + filtered_customer = customer.filter(pl.col("c_birth_month").is_in(month)).select( + ["c_customer_sk", "c_current_cdemo_sk", "c_current_addr_sk", "c_birth_year"] + ) + filtered_addr = customer_address.filter(pl.col("ca_state").is_in(state)).select( + ["ca_address_sk", "ca_county", "ca_state", "ca_country"] + ) + base_query = ( - catalog_sales.join(date_dim, left_on="cs_sold_date_sk", right_on="d_date_sk") - .join(item, left_on="cs_item_sk", right_on="i_item_sk") + catalog_sales.select( + [ + "cs_sold_date_sk", + "cs_item_sk", + "cs_bill_cdemo_sk", + "cs_bill_customer_sk", + "cs_quantity", + "cs_list_price", + "cs_coupon_amt", + "cs_sales_price", + "cs_net_profit", + ] + ) .join( - customer_demographics_1, - left_on="cs_bill_cdemo_sk", - right_on="cd_demo_sk", - suffix="_cd1", + filtered_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi" + ) + .join( + item.select(["i_item_sk", "i_item_id"]), + left_on="cs_item_sk", + right_on="i_item_sk", ) - .join(customer, left_on="cs_bill_customer_sk", right_on="c_customer_sk") + .join(filtered_cd1, left_on="cs_bill_cdemo_sk", right_on="cd_demo_sk") .join( - customer_demographics_2, + filtered_customer, left_on="cs_bill_customer_sk", right_on="c_customer_sk" + ) + .join( + customer_demographics.select("cd_demo_sk"), left_on="c_current_cdemo_sk", right_on="cd_demo_sk", - suffix="_cd2", - ) - .join(customer_address, left_on="c_current_addr_sk", right_on="ca_address_sk") - .filter( - (pl.col("cd_gender") == gen) - & (pl.col("cd_education_status") == es) - & pl.col("c_birth_month").is_in(month) - & (pl.col("d_year") == year) - & pl.col("ca_state").is_in(state) + how="semi", ) + .join(filtered_addr, left_on="c_current_addr_sk", right_on="ca_address_sk") ) agg_exprs = [ diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q2.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q2.py index 998cccb6017..edf285e0d9a 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q2.py @@ -136,10 +136,6 @@ def polars_impl(run_config: RunConfig) -> QueryResult: ), ] ) - # Step 2: Create wswscs CTE equivalent (aggregate by week and day of week) - # First join with date_dim to get day names - wscs_with_dates = wscs.join(date_dim, left_on="sold_date_sk", right_on="d_date_sk") - # Create separate aggregations for each day to better control null handling days = ( "Sunday", "Monday", @@ -158,35 +154,26 @@ def polars_impl(run_config: RunConfig) -> QueryResult: "fri_sales", "sat_sales", ) - # Start with all week sequences - all_weeks = wscs_with_dates.select("d_week_seq").unique() - wswscs = all_weeks - + # Pre-filter date_dim to 4 years ([year-1, year, year+1, year+2]) to capture + # boundary weeks that span year transitions (e.g. from Dec 28 to Jan 3). Filtering to + # only [year, year+1] incorrectly excludes Dec days whose d_week_seq also + # appears in year's date_dim, producing null sales for those boundary weeks. + date_dim_prefilter = date_dim.filter( + pl.col("d_year").is_in([year - 1, year, year + 1, year + 2]) + ).select(["d_date_sk", "d_week_seq", "d_day_name"]) wswscs = ( - wscs_with_dates.with_columns( + wscs.join(date_dim_prefilter, left_on="sold_date_sk", right_on="d_date_sk") + .group_by("d_week_seq") + .agg( [ pl.when(pl.col("d_day_name") == day) .then(pl.col("sales_price")) .otherwise(None) + .sum() .alias(name) for day, name in zip(days, day_cols, strict=True) ] ) - .group_by("d_week_seq") - .agg( - *(pl.col(name).sum().alias(name) for name in day_cols), - *(pl.col(name).count().alias(f"{name}_count") for name in day_cols), - ) - .with_columns( - [ - pl.when(pl.col(f"{name}_count") > 0) - .then(pl.col(name)) - .otherwise(None) - .alias(name) - for name in day_cols - ] - ) - .select(["d_week_seq", *day_cols]) ) # Step 3: Create year data (y subquery equivalent) diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q23.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q23.py index debe1b512b3..dfdcef33744 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q23.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q23.py @@ -105,10 +105,16 @@ def polars_impl(run_config: RunConfig) -> QueryResult: ) web_sales = get_data(run_config.dataset_path, "web_sales", run_config.suffix) + # Pre-filter date_dim to the 4-year window so the inner join with store_sales + # naturally excludes out-of-window records before the expensive item join. + year_set = [year, year + 1, year + 2, year + 3] + date_dim_years = date_dim.filter(pl.col("d_year").is_in(year_set)) + frequent_ss_items = ( - store_sales.join(date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk") + store_sales.join( + date_dim_years, left_on="ss_sold_date_sk", right_on="d_date_sk" + ) .join(item, left_on="ss_item_sk", right_on="i_item_sk") - .filter(pl.col("d_year").is_in([year, year + 1, year + 2, year + 3])) .with_columns(pl.col("i_item_desc").str.slice(0, 30).alias("itemdesc")) .group_by(["itemdesc", "ss_item_sk", "d_date"]) .agg(pl.len().alias("cnt")) @@ -121,8 +127,11 @@ def polars_impl(run_config: RunConfig) -> QueryResult: # only valid because we know that the TPC-DS includes a foreign key here, so all # customers in store_sales _must_ be entries that exist somewhere in customer. store_sales.filter(pl.col("ss_customer_sk").is_not_null()) - .join(date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk") - .filter(pl.col("d_year").is_in([year, year + 1, year + 2, year + 3])) + .join( + date_dim_years.select("d_date_sk"), + left_on="ss_sold_date_sk", + right_on="d_date_sk", + ) .group_by("ss_customer_sk") .agg((pl.col("ss_quantity") * pl.col("ss_sales_price")).sum().alias("csales")) ) @@ -146,13 +155,12 @@ def polars_impl(run_config: RunConfig) -> QueryResult: (pl.col("d_year") == year) & (pl.col("d_moy") == month) ).select("d_date_sk") + # Join order: most selective filters first (date_target ~1.2%, frequent_ss_items, + # best_customers semi ~5%), then customer last — it's a non-filtering name lookup + # that only adds c_last_name/c_first_name, so running it on the already-reduced + # row set avoids the full catalog_sales/web_sales scan width. catalog_part = ( - catalog_sales.join( - customer.select(["c_customer_sk", "c_last_name", "c_first_name"]), - left_on="cs_bill_customer_sk", - right_on="c_customer_sk", - ) - .join(date_target, left_on="cs_sold_date_sk", right_on="d_date_sk") + catalog_sales.join(date_target, left_on="cs_sold_date_sk", right_on="d_date_sk") .join(frequent_ss_items, left_on="cs_item_sk", right_on="ss_item_sk") .join( best_customers, @@ -160,17 +168,17 @@ def polars_impl(run_config: RunConfig) -> QueryResult: right_on="ss_customer_sk", how="semi", ) + .join( + customer.select(["c_customer_sk", "c_last_name", "c_first_name"]), + left_on="cs_bill_customer_sk", + right_on="c_customer_sk", + ) .group_by(["c_last_name", "c_first_name"]) .agg((pl.col("cs_quantity") * pl.col("cs_list_price")).sum().alias("sales")) ) web_part = ( - web_sales.join( - customer.select(["c_customer_sk", "c_last_name", "c_first_name"]), - left_on="ws_bill_customer_sk", - right_on="c_customer_sk", - ) - .join(date_target, left_on="ws_sold_date_sk", right_on="d_date_sk") + web_sales.join(date_target, left_on="ws_sold_date_sk", right_on="d_date_sk") .join(frequent_ss_items, left_on="ws_item_sk", right_on="ss_item_sk") .join( best_customers, @@ -178,6 +186,11 @@ def polars_impl(run_config: RunConfig) -> QueryResult: right_on="ss_customer_sk", how="semi", ) + .join( + customer.select(["c_customer_sk", "c_last_name", "c_first_name"]), + left_on="ws_bill_customer_sk", + right_on="c_customer_sk", + ) .group_by(["c_last_name", "c_first_name"]) .agg((pl.col("ws_quantity") * pl.col("ws_list_price")).sum().alias("sales")) ) diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q25.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q25.py index 8006585a170..d431f5507b8 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q25.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q25.py @@ -96,17 +96,6 @@ def polars_impl(run_config: RunConfig) -> QueryResult: store = get_data(run_config.dataset_path, "store", run_config.suffix) item = get_data(run_config.dataset_path, "item", run_config.suffix) - d1, d2, d3 = [ - date_dim.clone().select( - [ - pl.col("d_date_sk").alias(f"{p}_date_sk"), - pl.col("d_moy").alias(f"{p}_moy"), - pl.col("d_year").alias(f"{p}_year"), - ] - ) - for p in ("d1", "d2", "d3") - ] - sort_by = { "i_item_id": False, "i_item_desc": False, @@ -114,31 +103,103 @@ def polars_impl(run_config: RunConfig) -> QueryResult: "s_store_name": False, } limit = 100 + + # d1: only April of the target year — very selective (~1/60 of date_dim rows). + # d2/d3: from April to October of the target year: same condition, one pre-filtered frame. + d1_dates = date_dim.filter( + (pl.col("d_moy") == 4) & (pl.col("d_year") == year) + ).select("d_date_sk") + d2_d3_dates = date_dim.filter( + pl.col("d_moy").is_between(4, 10) & (pl.col("d_year") == year) + ).select("d_date_sk") + + # store_returns [6] ≤ broadcast limit: filter to qualifying return dates first, + # then extract (customer, item) pairs to pre-filter ss and cs before shuffle joins. + store_returns_filtered = store_returns.join( + d2_d3_dates, left_on="sr_returned_date_sk", right_on="d_date_sk", how="semi" + ).select(["sr_customer_sk", "sr_item_sk", "sr_ticket_number", "sr_net_loss"]) + sr_customer_item = store_returns_filtered.select(["sr_customer_sk", "sr_item_sk"]) + + store_sales_filtered = ( + store_sales.join( + d1_dates, left_on="ss_sold_date_sk", right_on="d_date_sk", how="semi" + ) + .join( + sr_customer_item, + left_on=["ss_customer_sk", "ss_item_sk"], + right_on=["sr_customer_sk", "sr_item_sk"], + how="semi", + ) + .select( + [ + "ss_customer_sk", + "ss_item_sk", + "ss_store_sk", + "ss_ticket_number", + "ss_net_profit", + ] + ) + .join( + item.select(["i_item_sk", "i_item_id", "i_item_desc"]), + left_on="ss_item_sk", + right_on="i_item_sk", + ) + .join( + store.select(["s_store_sk", "s_store_id", "s_store_name"]), + left_on="ss_store_sk", + right_on="s_store_sk", + ) + .select( + [ + "ss_customer_sk", + "ss_item_sk", + "ss_ticket_number", + "ss_net_profit", + "i_item_id", + "i_item_desc", + "s_store_id", + "s_store_name", + ] + ) + ) + + catalog_sales_filtered = ( + catalog_sales.join( + d2_d3_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi" + ) + .join( + sr_customer_item, + left_on=["cs_bill_customer_sk", "cs_item_sk"], + right_on=["sr_customer_sk", "sr_item_sk"], + how="semi", + ) + .select(["cs_bill_customer_sk", "cs_item_sk", "cs_net_profit"]) + ) + return QueryResult( frame=( - store_sales.join(d1, left_on="ss_sold_date_sk", right_on="d1_date_sk") - .join(item, left_on="ss_item_sk", right_on="i_item_sk") - .join(store, left_on="ss_store_sk", right_on="s_store_sk") - .join( - store_returns, + store_sales_filtered.join( + store_returns_filtered, left_on=["ss_customer_sk", "ss_item_sk", "ss_ticket_number"], right_on=["sr_customer_sk", "sr_item_sk", "sr_ticket_number"], ) - .join(d2, left_on="sr_returned_date_sk", right_on="d2_date_sk") + .select( + [ + "ss_customer_sk", + "ss_item_sk", + "ss_net_profit", + "sr_net_loss", + "i_item_id", + "i_item_desc", + "s_store_id", + "s_store_name", + ] + ) .join( - catalog_sales, + catalog_sales_filtered, left_on=["ss_customer_sk", "ss_item_sk"], right_on=["cs_bill_customer_sk", "cs_item_sk"], ) - .join(d3, left_on="cs_sold_date_sk", right_on="d3_date_sk") - .filter( - (pl.col("d1_moy") == 4) - & (pl.col("d1_year") == year) - & (pl.col("d2_moy").is_between(4, 10)) - & (pl.col("d2_year") == year) - & (pl.col("d3_moy").is_between(4, 10)) - & (pl.col("d3_year") == year) - ) .group_by(["i_item_id", "i_item_desc", "s_store_id", "s_store_name"]) .agg( [ diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q29.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q29.py index a39954745bd..6bdb7593929 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q29.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q29.py @@ -97,23 +97,6 @@ def polars_impl(run_config: RunConfig) -> QueryResult: store = get_data(run_config.dataset_path, "store", run_config.suffix) item = get_data(run_config.dataset_path, "item", run_config.suffix) - d1, d2 = [ - date_dim.clone().select( - [ - pl.col("d_date_sk").alias(f"{p}_date_sk"), - pl.col("d_moy").alias(f"{p}_moy"), - pl.col("d_year").alias(f"{p}_year"), - ] - ) - for p in ("d1", "d2") - ] - d3 = date_dim.clone().select( - [ - pl.col("d_date_sk").alias("d3_date_sk"), - pl.col("d_year").alias("d3_year"), - ] - ) - sort_by = { "i_item_id": False, "i_item_desc": False, @@ -121,30 +104,107 @@ def polars_impl(run_config: RunConfig) -> QueryResult: "s_store_name": False, } limit = 100 + + # d1: one specific month of the target year — most selective filter. + # d2: 4-month window of the target year. + # d3: 3-year window — less selective but still worth pushing before the cs shuffle join. + d1_dates = date_dim.filter( + (pl.col("d_moy") == month) & (pl.col("d_year") == year) + ).select("d_date_sk") + d2_dates = date_dim.filter( + pl.col("d_moy").is_between(month, month + 3) & (pl.col("d_year") == year) + ).select("d_date_sk") + d3_dates = date_dim.filter( + pl.col("d_year").is_in([year, year + 1, year + 2]) + ).select("d_date_sk") + + # store_returns [6] ≤ broadcast limit: apply d2 date filter, then use + # (customer, item) pairs to pre-filter ss and cs before shuffle joins. + store_returns_filtered = store_returns.join( + d2_dates, left_on="sr_returned_date_sk", right_on="d_date_sk", how="semi" + ).select(["sr_customer_sk", "sr_item_sk", "sr_ticket_number", "sr_return_quantity"]) + sr_customer_item = store_returns_filtered.select(["sr_customer_sk", "sr_item_sk"]) + + store_sales_filtered = ( + store_sales.join( + d1_dates, left_on="ss_sold_date_sk", right_on="d_date_sk", how="semi" + ) + .join( + sr_customer_item, + left_on=["ss_customer_sk", "ss_item_sk"], + right_on=["sr_customer_sk", "sr_item_sk"], + how="semi", + ) + .select( + [ + "ss_customer_sk", + "ss_item_sk", + "ss_store_sk", + "ss_ticket_number", + "ss_quantity", + ] + ) + .join( + item.select(["i_item_sk", "i_item_id", "i_item_desc"]), + left_on="ss_item_sk", + right_on="i_item_sk", + ) + .join( + store.select(["s_store_sk", "s_store_id", "s_store_name"]), + left_on="ss_store_sk", + right_on="s_store_sk", + ) + .select( + [ + "ss_customer_sk", + "ss_item_sk", + "ss_ticket_number", + "ss_quantity", + "i_item_id", + "i_item_desc", + "s_store_id", + "s_store_name", + ] + ) + ) + + catalog_sales_filtered = ( + catalog_sales.join( + d3_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi" + ) + .join( + sr_customer_item, + left_on=["cs_bill_customer_sk", "cs_item_sk"], + right_on=["sr_customer_sk", "sr_item_sk"], + how="semi", + ) + .select(["cs_bill_customer_sk", "cs_item_sk", "cs_quantity"]) + ) + return QueryResult( frame=( - store_sales.join(d1, left_on="ss_sold_date_sk", right_on="d1_date_sk") - .join(item, left_on="ss_item_sk", right_on="i_item_sk") - .join(store, left_on="ss_store_sk", right_on="s_store_sk") - .join( - store_returns, + store_sales_filtered.join( + store_returns_filtered, left_on=["ss_customer_sk", "ss_item_sk", "ss_ticket_number"], right_on=["sr_customer_sk", "sr_item_sk", "sr_ticket_number"], ) - .join(d2, left_on="sr_returned_date_sk", right_on="d2_date_sk") + .select( + [ + "ss_customer_sk", + "ss_item_sk", + "ss_quantity", + "sr_return_quantity", + "i_item_id", + "i_item_desc", + "s_store_id", + "s_store_name", + ] + ) .join( - catalog_sales, + catalog_sales_filtered, left_on=["ss_customer_sk", "ss_item_sk"], right_on=["cs_bill_customer_sk", "cs_item_sk"], ) - .join(d3, left_on="cs_sold_date_sk", right_on="d3_date_sk") - .filter( - (pl.col("d1_moy") == month) - & (pl.col("d1_year") == year) - & (pl.col("d2_moy").is_between(month, month + 3)) - & (pl.col("d2_year") == year) - & (pl.col("d3_year").is_in([year, year + 1, year + 2])) - ) .group_by(["i_item_id", "i_item_desc", "s_store_id", "s_store_name"]) .agg( [ diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q43.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q43.py index b0c1023d655..cc885a05fe8 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q43.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q43.py @@ -91,10 +91,20 @@ def polars_impl(run_config: RunConfig) -> QueryResult: year = params["year"] gmt = params["gmt"] - # Load tables date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix) store_sales = get_data(run_config.dataset_path, "store_sales", run_config.suffix) store = get_data(run_config.dataset_path, "store", run_config.suffix) + + # Pre-filter lookup tables before joining against store_sales [58 partitions]. + # d_year not needed after filter; d_day_name drives the conditional agg columns. + filtered_dates = date_dim.filter(pl.col("d_year") == year).select( + ["d_date_sk", "d_day_name"] + ) + # s_gmt_offset not needed after filter; keep group-by output columns. + filtered_store = store.filter(pl.col("s_gmt_offset") == gmt).select( + ["s_store_sk", "s_store_name", "s_store_id"] + ) + sort_by = { "s_store_name": False, "s_store_id": False, @@ -107,15 +117,13 @@ def polars_impl(run_config: RunConfig) -> QueryResult: "sat_sales": False, } limit = 100 - # Main query with joins and conditional aggregations return QueryResult( frame=( - store_sales.join(date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk") - .join(store, left_on="ss_store_sk", right_on="s_store_sk") - .filter((pl.col("s_gmt_offset") == gmt) & (pl.col("d_year") == year)) + store_sales.select(["ss_sold_date_sk", "ss_store_sk", "ss_sales_price"]) + .join(filtered_dates, left_on="ss_sold_date_sk", right_on="d_date_sk") + .join(filtered_store, left_on="ss_store_sk", right_on="s_store_sk") .with_columns( [ - # Pre-compute conditional sales amounts for each day pl.when(pl.col("d_day_name") == "Sunday") .then(pl.col("ss_sales_price")) .otherwise(0) diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q44.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q44.py index 204c9c90a76..bdff35af0f5 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q44.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q44.py @@ -90,87 +90,57 @@ def polars_impl(run_config: RunConfig) -> QueryResult: store_sk = params["store_sk"] - # Load tables store_sales = get_data(run_config.dataset_path, "store_sales", run_config.suffix) item = get_data(run_config.dataset_path, "item", run_config.suffix) - # Step 1: Calculate benchmark (average profit for store with null demographics) + # Benchmark: global mean profit for the store with null demographics — single row. + # Use a constant-key equi-join instead of how="cross" so the streaming executor + # treats it as a broadcast join (1 row ≤ broadcast_join_limit) rather than a + # ConditionalJoin that falls back from multi-partition mode. benchmark = ( store_sales.filter( - (pl.col("ss_store_sk") == store_sk) & (pl.col("ss_cdemo_sk").is_null()) + (pl.col("ss_store_sk") == store_sk) & pl.col("ss_cdemo_sk").is_null() ) - .group_by("ss_store_sk") - .agg( - [ - pl.col("ss_net_profit").mean().alias("profit_mean"), - pl.col("ss_net_profit").count().alias("profit_count"), - ] - ) - .with_columns( - [ - pl.when(pl.col("profit_count") > 0) - .then(pl.col("profit_mean")) - .otherwise(None) - .alias("benchmark_profit") - ] - ) - .select("benchmark_profit") + .select(pl.col("ss_net_profit").mean().alias("benchmark_profit")) + .with_columns(pl.lit(1, dtype=pl.Int32).alias("_key")) ) - # Step 2: Calculate item-level average profits for store + # Item-level average profits, broadcast-joined with the 1-row benchmark. item_profits = ( store_sales.filter(pl.col("ss_store_sk") == store_sk) .group_by("ss_item_sk") - .agg( - [ - pl.col("ss_net_profit").mean().alias("profit_mean"), - pl.col("ss_net_profit").count().alias("profit_count"), - ] - ) - .with_columns( - [ - pl.when(pl.col("profit_count") > 0) - .then(pl.col("profit_mean")) - .otherwise(None) - .alias("avg(ss_net_profit)") - ] - ) - .drop(["profit_mean", "profit_count"]) - .join(benchmark, how="cross") - .filter(pl.col("avg(ss_net_profit)") > (0.9 * pl.col("benchmark_profit"))) + .agg(pl.col("ss_net_profit").mean().alias("avg_profit")) + .with_columns(pl.lit(1, dtype=pl.Int32).alias("_key")) + .join(benchmark, on="_key") + .filter(pl.col("avg_profit") > 0.9 * pl.col("benchmark_profit")) + .select(["ss_item_sk", "avg_profit"]) ) - # Step 3: Create ascending ranking (worst to best) ascending_rank = ( item_profits.with_columns( - [pl.col("avg(ss_net_profit)").rank(method="ordinal").alias("rnk")] + pl.col("avg_profit").rank(method="ordinal").alias("rnk") ) .filter(pl.col("rnk") < 11) .select(["ss_item_sk", "rnk"]) ) - # Step 4: Create descending ranking (best to worst) descending_rank = ( item_profits.with_columns( - [ - pl.col("avg(ss_net_profit)") - .rank(method="ordinal", descending=True) - .alias("rnk") - ] + pl.col("avg_profit").rank(method="ordinal", descending=True).alias("rnk") ) .filter(pl.col("rnk") < 11) .select(["ss_item_sk", "rnk"]) ) + item_cols = item.select(["i_item_sk", "i_product_name"]) sort_by = {"rnk": False} limit = 100 - # Step 5: Join rankings and get product names return QueryResult( frame=( ascending_rank.join(descending_rank, on="rnk", how="inner", suffix="_desc") - .join(item, left_on="ss_item_sk", right_on="i_item_sk", how="inner") + .join(item_cols, left_on="ss_item_sk", right_on="i_item_sk", how="inner") .join( - item, + item_cols, left_on="ss_item_sk_desc", right_on="i_item_sk", how="inner", diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q52.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q52.py index 2a8f74151b3..b391e49cda9 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q52.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q52.py @@ -69,15 +69,21 @@ def polars_impl(run_config: RunConfig) -> QueryResult: sort_by = {"d_year": False, "ext_price": True, "brand_id": False} limit = 100 + + # Pre-filter both lookup tables before joining against store_sales [87 partitions]. + # date_dim keeps d_year because it appears in the GROUP BY. + filtered_dates = date_dim.filter( + (pl.col("d_moy") == month) & (pl.col("d_year") == year) + ).select(["d_date_sk", "d_year"]) + filtered_item = item.filter(pl.col("i_manager_id") == manager_id).select( + ["i_item_sk", "i_brand", "i_brand_id"] + ) + return QueryResult( frame=( - store_sales.join(date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk") - .join(item, left_on="ss_item_sk", right_on="i_item_sk") - .filter( - (pl.col("i_manager_id") == manager_id) - & (pl.col("d_moy") == month) - & (pl.col("d_year") == year) - ) + store_sales.select(["ss_sold_date_sk", "ss_item_sk", "ss_ext_sales_price"]) + .join(filtered_dates, left_on="ss_sold_date_sk", right_on="d_date_sk") + .join(filtered_item, left_on="ss_item_sk", right_on="i_item_sk") .group_by(["d_year", "i_brand", "i_brand_id"]) .agg(pl.col("ss_ext_sales_price").sum().alias("ext_price")) .select( diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q53.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q53.py index 88a9e1992ee..dcc897de935 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q53.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q53.py @@ -101,25 +101,37 @@ def polars_impl(run_config: RunConfig) -> QueryResult: date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix) store = get_data(run_config.dataset_path, "store", run_config.suffix) month_seq_list = list(range(dms, dms + 12)) + + # Pre-filter lookup tables before joining against store_sales [87 partitions]. + # date_dim: keep d_qoy because it appears in the GROUP BY. + filtered_dates = date_dim.filter( + pl.col("d_month_seq").is_in(month_seq_list) + ).select(["d_date_sk", "d_qoy"]) + # item: apply both OR'd rule groups up front; only i_manufact_id needed after. + filtered_item = item.filter( + ( + pl.col("i_category").is_in(categories1) + & pl.col("i_class").is_in(classes1) + & pl.col("i_brand").is_in(brands1) + ) + | ( + pl.col("i_category").is_in(categories2) + & pl.col("i_class").is_in(classes2) + & pl.col("i_brand").is_in(brands2) + ) + ).select(["i_item_sk", "i_manufact_id"]) + grouped_data = ( - store_sales.join(item, left_on="ss_item_sk", right_on="i_item_sk") - .join(date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk") - .join(store, left_on="ss_store_sk", right_on="s_store_sk") - .filter(pl.col("d_month_seq").is_in(month_seq_list)) - .filter( - # First rule group - ( - (pl.col("i_category").is_in(categories1)) - & (pl.col("i_class").is_in(classes1)) - & (pl.col("i_brand").is_in(brands1)) - ) - | - # Second rule group - ( - (pl.col("i_category").is_in(categories2)) - & (pl.col("i_class").is_in(classes2)) - & (pl.col("i_brand").is_in(brands2)) - ) + store_sales.select( + ["ss_sold_date_sk", "ss_item_sk", "ss_store_sk", "ss_sales_price"] + ) + .join(filtered_item, left_on="ss_item_sk", right_on="i_item_sk") + .join(filtered_dates, left_on="ss_sold_date_sk", right_on="d_date_sk") + .join( + store.select("s_store_sk"), + left_on="ss_store_sk", + right_on="s_store_sk", + how="semi", ) .group_by(["i_manufact_id", "d_qoy"]) .agg([pl.col("ss_sales_price").sum().alias("sum_sales_raw")]) diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q55.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q55.py index e6cfbfef9e6..eb3d1eb019e 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q55.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q55.py @@ -65,15 +65,25 @@ def polars_impl(run_config: RunConfig) -> QueryResult: item = get_data(run_config.dataset_path, "item", run_config.suffix) sort_by = {"ext_price": True, "brand_id": False} limit = 100 + + # d_year not in GROUP BY so date filter can be a semi-join (no date columns needed). + filtered_dates = date_dim.filter( + (pl.col("d_moy") == month) & (pl.col("d_year") == year) + ).select("d_date_sk") + filtered_item = item.filter(pl.col("i_manager_id") == manager_id).select( + ["i_item_sk", "i_brand", "i_brand_id"] + ) + return QueryResult( frame=( - store_sales.join(date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk") - .join(item, left_on="ss_item_sk", right_on="i_item_sk") - .filter( - (pl.col("i_manager_id") == manager_id) - & (pl.col("d_moy") == month) - & (pl.col("d_year") == year) + store_sales.select(["ss_sold_date_sk", "ss_item_sk", "ss_ext_sales_price"]) + .join( + filtered_dates, + left_on="ss_sold_date_sk", + right_on="d_date_sk", + how="semi", ) + .join(filtered_item, left_on="ss_item_sk", right_on="i_item_sk") .group_by(["i_brand", "i_brand_id"]) .agg(pl.col("ss_ext_sales_price").sum().alias("ext_price")) .select( diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q63.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q63.py index 3e9b06cb553..5bcd47a8ddc 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q63.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q63.py @@ -89,42 +89,54 @@ def polars_impl(run_config: RunConfig) -> QueryResult: date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix) store = get_data(run_config.dataset_path, "store", run_config.suffix) - inner_query = ( - store_sales.join(item, left_on="ss_item_sk", right_on="i_item_sk") - .join(date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk") - .join(store, left_on="ss_store_sk", right_on="s_store_sk") - .filter( - pl.col("d_month_seq").is_in([dms + i for i in range(12)]) - & ( - ( - pl.col("i_category").is_in(["Books", "Children", "Electronics"]) - & pl.col("i_class").is_in( - ["personal", "portable", "reference", "self-help"] - ) - & pl.col("i_brand").is_in( - [ - "scholaramalgamalg #14", - "scholaramalgamalg #7", - "exportiunivamalg #9", - "scholaramalgamalg #9", - ] - ) - ) - | ( - pl.col("i_category").is_in(["Women", "Music", "Men"]) - & pl.col("i_class").is_in( - ["accessories", "classical", "fragrances", "pants"] - ) - & pl.col("i_brand").is_in( - [ - "amalgimporto #1", - "edu packscholar #1", - "exportiimporto #1", - "importoamalg #1", - ] - ) - ) + # Pre-filter both lookup tables before joining against store_sales [58 partitions]. + # item: apply both OR'd rule groups up front; only i_manager_id needed after. + # date_dim: keep d_moy because it appears in the GROUP BY. + filtered_item = item.filter( + ( + pl.col("i_category").is_in(["Books", "Children", "Electronics"]) + & pl.col("i_class").is_in( + ["personal", "portable", "reference", "self-help"] + ) + & pl.col("i_brand").is_in( + [ + "scholaramalgamalg #14", + "scholaramalgamalg #7", + "exportiunivamalg #9", + "scholaramalgamalg #9", + ] + ) + ) + | ( + pl.col("i_category").is_in(["Women", "Music", "Men"]) + & pl.col("i_class").is_in( + ["accessories", "classical", "fragrances", "pants"] ) + & pl.col("i_brand").is_in( + [ + "amalgimporto #1", + "edu packscholar #1", + "exportiimporto #1", + "importoamalg #1", + ] + ) + ) + ).select(["i_item_sk", "i_manager_id"]) + filtered_dates = date_dim.filter( + pl.col("d_month_seq").is_in([dms + i for i in range(12)]) + ).select(["d_date_sk", "d_moy"]) + + inner_query = ( + store_sales.select( + ["ss_sold_date_sk", "ss_item_sk", "ss_store_sk", "ss_sales_price"] + ) + .join(filtered_item, left_on="ss_item_sk", right_on="i_item_sk") + .join(filtered_dates, left_on="ss_sold_date_sk", right_on="d_date_sk") + .join( + store.select("s_store_sk"), + left_on="ss_store_sk", + right_on="s_store_sk", + how="semi", ) .group_by(["i_manager_id", "d_moy"]) .agg([pl.col("ss_sales_price").sum().alias("sum_sales")]) diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q67.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q67.py index 0c6859ca55a..855c98a1b60 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q67.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q67.py @@ -214,11 +214,35 @@ def polars_impl(run_config: RunConfig) -> QueryResult: store = get_data(run_config.dataset_path, "store", run_config.suffix) item = get_data(run_config.dataset_path, "item", run_config.suffix) + # Pre-filter date_dim to the 12-month window before joining against store_sales [58]. + # d_month_seq not needed after filter; keep group-by output columns. + filtered_dates = date_dim.filter( + pl.col("d_month_seq").is_between(dms, dms + 11) + ).select(["d_date_sk", "d_year", "d_qoy", "d_moy"]) + base_data = ( - store_sales.join(date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk") - .join(store, left_on="ss_store_sk", right_on="s_store_sk") - .join(item, left_on="ss_item_sk", right_on="i_item_sk") - .filter(pl.col("d_month_seq").is_between(dms, dms + 11)) + store_sales.select( + [ + "ss_sold_date_sk", + "ss_item_sk", + "ss_store_sk", + "ss_sales_price", + "ss_quantity", + ] + ) + .join(filtered_dates, left_on="ss_sold_date_sk", right_on="d_date_sk") + .join( + store.select(["s_store_sk", "s_store_id"]), + left_on="ss_store_sk", + right_on="s_store_sk", + ) + .join( + item.select( + ["i_item_sk", "i_category", "i_class", "i_brand", "i_product_name"] + ), + left_on="ss_item_sk", + right_on="i_item_sk", + ) .with_columns( (pl.col("ss_sales_price") * pl.col("ss_quantity")) .fill_null(0) diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q76.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q76.py index e05d9b28f65..8bcaf67e9b8 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q76.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q76.py @@ -106,10 +106,16 @@ def polars_impl(run_config: RunConfig) -> QueryResult: ) item = get_data(run_config.dataset_path, "item", run_config.suffix) date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix) + + # Project lookup tables to only the columns needed in each component. + date_cols = date_dim.select(["d_date_sk", "d_year", "d_qoy"]) + item_cols = item.select(["i_item_sk", "i_category"]) + store_component = ( store_sales.filter(pl.col(nullcol_ss).is_null()) - .join(date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk") - .join(item, left_on="ss_item_sk", right_on="i_item_sk") + .select(["ss_sold_date_sk", "ss_item_sk", "ss_ext_sales_price"]) + .join(date_cols, left_on="ss_sold_date_sk", right_on="d_date_sk") + .join(item_cols, left_on="ss_item_sk", right_on="i_item_sk") .select( [ pl.lit("store").alias("channel"), @@ -123,8 +129,9 @@ def polars_impl(run_config: RunConfig) -> QueryResult: ) web_component = ( web_sales.filter(pl.col(nullcol_ws).is_null()) - .join(date_dim, left_on="ws_sold_date_sk", right_on="d_date_sk") - .join(item, left_on="ws_item_sk", right_on="i_item_sk") + .select(["ws_sold_date_sk", "ws_item_sk", "ws_ext_sales_price"]) + .join(date_cols, left_on="ws_sold_date_sk", right_on="d_date_sk") + .join(item_cols, left_on="ws_item_sk", right_on="i_item_sk") .select( [ pl.lit("web").alias("channel"), @@ -138,8 +145,9 @@ def polars_impl(run_config: RunConfig) -> QueryResult: ) catalog_component = ( catalog_sales.filter(pl.col(nullcol_cs).is_null()) - .join(date_dim, left_on="cs_sold_date_sk", right_on="d_date_sk") - .join(item, left_on="cs_item_sk", right_on="i_item_sk") + .select(["cs_sold_date_sk", "cs_item_sk", "cs_ext_sales_price"]) + .join(date_cols, left_on="cs_sold_date_sk", right_on="d_date_sk") + .join(item_cols, left_on="cs_item_sk", right_on="i_item_sk") .select( [ pl.lit("catalog").alias("channel"), @@ -166,10 +174,7 @@ def polars_impl(run_config: RunConfig) -> QueryResult: .agg( [ pl.len().cast(pl.Int64).alias("sales_cnt"), - pl.when(pl.col("ext_sales_price").count() > 0) - .then(pl.col("ext_sales_price").sum()) - .otherwise(None) - .alias("sales_amt"), + pl.col("ext_sales_price").sum().alias("sales_amt"), ] ) .select( diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q88.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q88.py index a086bb609cf..304e9404117 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q88.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q88.py @@ -142,81 +142,76 @@ def polars_impl(run_config: RunConfig) -> QueryResult: ) time_dim = get_data(run_config.dataset_path, "time_dim", run_config.suffix) store = get_data(run_config.dataset_path, "store", run_config.suffix) - hd_filter = ( + + # Pre-filter each small table before joining against store_sales [58 partitions]. + filtered_hdemo = household_demographics.filter( ((pl.col("hd_dep_count") == hd1) & (pl.col("hd_vehicle_count") <= hd1 + 2)) | ((pl.col("hd_dep_count") == hd2) & (pl.col("hd_vehicle_count") <= hd2 + 2)) | ((pl.col("hd_dep_count") == hd3) & (pl.col("hd_vehicle_count") <= hd3 + 2)) + ).select("hd_demo_sk") + filtered_store = store.filter(pl.col("s_store_name") == s_store_name).select( + "s_store_sk" ) - base_query = ( - store_sales.join( - time_dim, left_on="ss_sold_time_sk", right_on="t_time_sk", how="inner" - ) - .join( - household_demographics, - left_on="ss_hdemo_sk", - right_on="hd_demo_sk", - how="inner", - ) - .join(store, left_on="ss_store_sk", right_on="s_store_sk", how="inner") - .filter( - hd_filter - & ( - pl.col("s_store_name").is_not_null() - & (pl.col("s_store_name") == s_store_name) - ) + # Restrict time_dim to the union of all 8 slot conditions; every surviving row maps + # to exactly one bucket, so the downstream pl.when chain is exhaustive. + filtered_time = time_dim.filter( + ((pl.col("t_hour") == 8) & (pl.col("t_minute") >= 30)) + | pl.col("t_hour").is_in([9, 10, 11]) + | ((pl.col("t_hour") == 12) & (pl.col("t_minute") < 30)) + ).select(["t_time_sk", "t_hour", "t_minute"]) + + bucket_names = [ + "h8_30_to_9", + "h9_to_9_30", + "h9_30_to_10", + "h10_to_10_30", + "h10_30_to_11", + "h11_to_11_30", + "h11_30_to_12", + "h12_to_12_30", + ] + + # Collapse the 58-partition store_sales pipeline to an 8-row bucket-count table first. + # The 8 conditional sums in the final select then operate on [1] partition, so even if + # the streaming executor creates separate sub-plans for each sum, each reads only the + # tiny CACHE'd group_by output rather than re-scanning store_sales. + counts_lf = ( + store_sales.select(["ss_sold_time_sk", "ss_hdemo_sk", "ss_store_sk"]) + .join(filtered_time, left_on="ss_sold_time_sk", right_on="t_time_sk") + .join(filtered_hdemo, left_on="ss_hdemo_sk", right_on="hd_demo_sk", how="semi") + .join(filtered_store, left_on="ss_store_sk", right_on="s_store_sk", how="semi") + .select( + pl.when((pl.col("t_hour") == 8) & (pl.col("t_minute") >= 30)) + .then(pl.lit(0)) + .when((pl.col("t_hour") == 9) & (pl.col("t_minute") < 30)) + .then(pl.lit(1)) + .when((pl.col("t_hour") == 9) & (pl.col("t_minute") >= 30)) + .then(pl.lit(2)) + .when((pl.col("t_hour") == 10) & (pl.col("t_minute") < 30)) + .then(pl.lit(3)) + .when((pl.col("t_hour") == 10) & (pl.col("t_minute") >= 30)) + .then(pl.lit(4)) + .when((pl.col("t_hour") == 11) & (pl.col("t_minute") < 30)) + .then(pl.lit(5)) + .when((pl.col("t_hour") == 11) & (pl.col("t_minute") >= 30)) + .then(pl.lit(6)) + .when((pl.col("t_hour") == 12) & (pl.col("t_minute") < 30)) + .then(pl.lit(7)) + .alias("bucket") ) + .group_by("bucket") + .agg(pl.len().cast(pl.Int64).alias("cnt")) ) + return QueryResult( - frame=base_query.select( + frame=counts_lf.select( [ - pl.when((pl.col("t_hour") == 8) & (pl.col("t_minute") >= 30)) - .then(1) - .otherwise(0) - .sum() - .cast(pl.Int64) - .alias("h8_30_to_9"), - pl.when((pl.col("t_hour") == 9) & (pl.col("t_minute") < 30)) - .then(1) - .otherwise(0) - .sum() - .cast(pl.Int64) - .alias("h9_to_9_30"), - pl.when((pl.col("t_hour") == 9) & (pl.col("t_minute") >= 30)) - .then(1) - .otherwise(0) - .sum() - .cast(pl.Int64) - .alias("h9_30_to_10"), - pl.when((pl.col("t_hour") == 10) & (pl.col("t_minute") < 30)) - .then(1) - .otherwise(0) - .sum() - .cast(pl.Int64) - .alias("h10_to_10_30"), - pl.when((pl.col("t_hour") == 10) & (pl.col("t_minute") >= 30)) - .then(1) - .otherwise(0) - .sum() - .cast(pl.Int64) - .alias("h10_30_to_11"), - pl.when((pl.col("t_hour") == 11) & (pl.col("t_minute") < 30)) - .then(1) - .otherwise(0) - .sum() - .cast(pl.Int64) - .alias("h11_to_11_30"), - pl.when((pl.col("t_hour") == 11) & (pl.col("t_minute") >= 30)) - .then(1) - .otherwise(0) - .sum() - .cast(pl.Int64) - .alias("h11_30_to_12"), - pl.when((pl.col("t_hour") == 12) & (pl.col("t_minute") < 30)) - .then(1) - .otherwise(0) + pl.when(pl.col("bucket") == i) + .then(pl.col("cnt")) + .otherwise(pl.lit(0).cast(pl.Int64)) .sum() - .cast(pl.Int64) - .alias("h12_to_12_30"), + .alias(name) + for i, name in enumerate(bucket_names) ] ), sort_by=[], diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q9.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q9.py index d42218179ea..142c94c3477 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q9.py @@ -102,57 +102,56 @@ def polars_impl(run_config: RunConfig) -> QueryResult: aggcelse = params["aggcelse"] rc = params["rc"] - # Load required tables store_sales = get_data(run_config.dataset_path, "store_sales", run_config.suffix) reason = get_data(run_config.dataset_path, "reason", run_config.suffix) - # Define bucket configurations: (min_qty, max_qty, count_threshold) - buckets = [ - (1, 20, rc[0]), - (21, 40, rc[1]), - (41, 60, rc[2]), - (61, 80, rc[3]), - (81, 100, rc[4]), - ] - - bucket_expressions = [] - for i, (min_qty, max_qty, _) in enumerate(buckets, 1): - condition = pl.col("ss_quantity").is_between(min_qty, max_qty, closed="both") - bucket_expressions.extend( - [ - condition.sum().alias(f"count_{i}"), - pl.when(condition) - .then(pl.col(aggcthen)) - .otherwise(None) - .mean() - .alias(f"avg_then_{i}"), - pl.when(condition) - .then(pl.col(aggcelse)) - .otherwise(None) - .mean() - .alias(f"avg_else_{i}"), - ] + thresholds = pl.LazyFrame({"bucket": [1, 2, 3, 4, 5], "threshold": list(rc)}) + + # Single scan: the 5 ss_quantity ranges are non-overlapping, so a group_by + # computes all counts and averages in one pass over store_sales. + stats = ( + store_sales.with_columns( + pl.when(pl.col("ss_quantity").is_between(1, 20)) + .then(pl.lit(1)) + .when(pl.col("ss_quantity").is_between(21, 40)) + .then(pl.lit(2)) + .when(pl.col("ss_quantity").is_between(41, 60)) + .then(pl.lit(3)) + .when(pl.col("ss_quantity").is_between(61, 80)) + .then(pl.lit(4)) + .when(pl.col("ss_quantity").is_between(81, 100)) + .then(pl.lit(5)) + .alias("bucket") ) - - combined_stats = store_sales.select(bucket_expressions) - - # Select appropriate value per bucket based on count threshold - bucket_values = [] - for i, (_min_qty, _max_qty, threshold) in enumerate(buckets, 1): - bucket = ( - pl.when(pl.col(f"count_{i}") > threshold) - .then(pl.col(f"avg_then_{i}")) - .otherwise(pl.col(f"avg_else_{i}")) - .alias(f"bucket{i}") + .filter(pl.col("bucket").is_not_null()) + .group_by("bucket") + .agg( + pl.len().alias("count"), + pl.col(aggcthen).mean().alias("avg_then"), + pl.col(aggcelse).mean().alias("avg_else"), + ) + .join(thresholds, on="bucket") + .select( + pl.col("bucket"), + pl.when(pl.col("count") > pl.col("threshold")) + .then(pl.col("avg_then")) + .otherwise(pl.col("avg_else")) + .alias("value"), ) - bucket_values.append(bucket) + .sort("bucket") + ) + + # Pivot 5 rows → 1 row with 5 named columns (operates on 5 rows, trivially fast) + wide = stats.select( + pl.col("value").filter(pl.col("bucket") == i).first().alias(f"bucket{i}") + for i in range(1, 6) + ) - # Create result DataFrame with one row (using reason table as in SQL) return QueryResult( frame=( reason.filter(pl.col("r_reason_sk") == 1) - .join(combined_stats, how="cross") - .select(bucket_values) + .join(wide, how="cross") + .select([f"bucket{i}" for i in range(1, 6)]) .limit(1) ), sort_by=[], diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q98.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q98.py index a70bf71269b..24442425911 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q98.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/pdsds_queries/q98.py @@ -76,17 +76,34 @@ def polars_impl(run_config: RunConfig) -> QueryResult: end_date_py = start_date_py + timedelta(days=30) start_date = pl.date(start_date_py.year, start_date_py.month, start_date_py.day) end_date = pl.date(end_date_py.year, end_date_py.month, end_date_py.day) + + # Pre-filter item to matching categories before joining against store_sales [58 partitions]. + filtered_item = item.filter( + pl.col("i_category").is_in(params["categories"]) + ).select( + [ + "i_item_sk", + "i_item_id", + "i_item_desc", + "i_current_price", + "i_class", + "i_category", + ] + ) + # Pre-filter date_dim to the 30-day window; d_date not needed after filter — semi-join. + filtered_dates = date_dim.filter( + pl.col("d_date").is_between(start_date, end_date, closed="both") + ).select("d_date_sk") + return QueryResult( frame=( - store_sales.join( - item, left_on="ss_item_sk", right_on="i_item_sk", how="inner" - ) + store_sales.select(["ss_sold_date_sk", "ss_item_sk", "ss_ext_sales_price"]) + .join(filtered_item, left_on="ss_item_sk", right_on="i_item_sk") .join( - date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk", how="inner" - ) - .filter( - pl.col("i_category").is_in(params["categories"]) - & pl.col("d_date").is_between(start_date, end_date, closed="both") + filtered_dates, + left_on="ss_sold_date_sk", + right_on="d_date_sk", + how="semi", ) .group_by( [ diff --git a/python/cudf_polars/cudf_polars/experimental/expressions.py b/python/cudf_polars/cudf_polars/experimental/expressions.py index f578e3b0e07..ef94da02dea 100644 --- a/python/cudf_polars/cudf_polars/experimental/expressions.py +++ b/python/cudf_polars/cudf_polars/experimental/expressions.py @@ -44,6 +44,7 @@ from cudf_polars.dsl.expressions.base import Col, ExecutionContext, NamedExpr from cudf_polars.dsl.expressions.binaryop import BinOp from cudf_polars.dsl.expressions.literal import Literal +from cudf_polars.dsl.expressions.rolling import GroupedWindow from cudf_polars.dsl.expressions.ternary import Ternary from cudf_polars.dsl.expressions.unary import Cast, Len, UnaryFunction from cudf_polars.dsl.ir import Distinct, Empty, HConcat, Select @@ -51,6 +52,7 @@ CachingVisitor, ) from cudf_polars.experimental.base import PartitionInfo +from cudf_polars.experimental.over import _decompose_grouped_window_node from cudf_polars.experimental.repartition import Repartition from cudf_polars.experimental.utils import _dynamic_planning_on @@ -462,6 +464,10 @@ def _decompose_expr_node( ) (expr,) = columns return expr, input_ir, partition_info + elif isinstance(expr, GroupedWindow) and _dynamic_planning_on(config_options): + return _decompose_grouped_window_node( + expr, input_ir, partition_info, config_options, names=names + ) else: # This is an un-supported expression - raise. raise NotImplementedError( diff --git a/python/cudf_polars/cudf_polars/experimental/over.py b/python/cudf_polars/cudf_polars/experimental/over.py new file mode 100644 index 00000000000..abc90d35810 --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/over.py @@ -0,0 +1,337 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +"""Over IR node for streaming window expressions.""" + +from __future__ import annotations + +import itertools +from collections import defaultdict +from typing import TYPE_CHECKING, ClassVar, cast + +from cudf_polars.dsl.expr import Agg, Col, Len, NamedExpr +from cudf_polars.dsl.ir import IR, GroupBy, Select +from cudf_polars.dsl.utils.naming import names_to_indices, unique_names +from cudf_polars.experimental.groupby import combine, decompose + +if TYPE_CHECKING: + from collections.abc import Generator, MutableMapping + + from cudf_polars.containers import DataFrame + from cudf_polars.dsl.expr import GroupedWindow + from cudf_polars.dsl.expressions.base import Expr + from cudf_polars.dsl.ir import IRExecutionContext + from cudf_polars.experimental.base import PartitionInfo + from cudf_polars.typing import Schema + from cudf_polars.utils.config import ConfigOptions + + +# Aggregations whose partial results can be combined. +_DECOMPOSABLE_AGG_NAMES: frozenset[str] = frozenset( + ("sum", "count", "mean", "min", "max", "std", "var") +) + + +def _build_over_groupby_irs( + gw_nodes: tuple[GroupedWindow, ...], + child_ir: IR, +) -> tuple[GroupBy, GroupBy, Select]: + """ + Build piecewise, reduction, and selection GroupBy IRs. + + Parameters + ---------- + gw_nodes + Top-level GroupedWindow nodes sharing the same partition-by keys; + all must be scalar (Agg/Len only in named_aggs). + child_ir + Input IR feeding the Over node; defines the schema seen by the + per-chunk piecewise GroupBy. + + Returns + ------- + piecewise_ir + GroupBy IR that computes partial aggregates per chunk. + reduction_ir + GroupBy IR that reduces partial aggregates to a single result. + agg_select_ir + Select IR applied on top of the reduction. Carries any post- + aggregation expressions (e.g. division for mean); for fully + pass-through aggregations it is a Select of plain ``Col`` refs + of the same shape as the reduction output. + """ + gw = gw_nodes[0] + by_exprs = cast("list[Col]", list(gw.children[: gw.by_count])) + key_named_exprs = [NamedExpr(e.name, e) for e in by_exprs] + key_schema = {e.name: child_ir.schema[e.name] for e in by_exprs} + + all_scalar_named: list[NamedExpr] = [] + seen: set[str] = set() + for gw_node in gw_nodes: + reductions, unary_ops = gw_node._split_named_expr() + assert not any(unary_ops.values()), "unary window ops not allowed here" + for ne in reductions: + if ne.name in seen: + continue + all_scalar_named.append(ne) + seen.add(ne.name) + + name_gen = unique_names(child_ir.schema.keys()) + decompositions = [ + decompose(ne.name, ne.value, names=name_gen) for ne in all_scalar_named + ] + selection_exprs, piecewise_exprs, reduction_exprs, need_preshuffle = combine( + *decompositions + ) + assert not need_preshuffle, ( + "Scalar AllGather path does not support aggregations requiring pre-shuffle" + ) + + pwise_schema = dict(key_schema) | { + ne.name: ne.value.dtype for ne in piecewise_exprs + } + piecewise_ir = GroupBy( + pwise_schema, + key_named_exprs, + piecewise_exprs, + False, # noqa: FBT003 + None, + child_ir, + ) + + reduction_key_exprs = [ + NamedExpr(ne.name, Col(pwise_schema[ne.name], ne.name)) + for ne in key_named_exprs + ] + reduction_schema = { + ne.name: ne.value.dtype + for ne in itertools.chain(reduction_key_exprs, reduction_exprs) + } + reduction_ir = GroupBy( + reduction_schema, + reduction_key_exprs, + reduction_exprs, + False, # noqa: FBT003 + None, + piecewise_ir, + ) + + select_key_exprs = [ + NamedExpr(ne.name, Col(reduction_schema[ne.name], ne.name)) + for ne in key_named_exprs + ] + select_schema = { + ne.name: ne.value.dtype + for ne in itertools.chain(select_key_exprs, selection_exprs) + } + agg_select_ir = Select( + select_schema, + [*select_key_exprs, *selection_exprs], + False, # noqa: FBT003 + reduction_ir, + ) + + return piecewise_ir, reduction_ir, agg_select_ir + + +class Over(IR): + """Window over() IR node for the streaming runtime.""" + + __slots__ = ("exprs", "is_scalar", "key_indices") + _non_child: ClassVar[tuple[str, ...]] = ( + "schema", + "key_indices", + "is_scalar", + "exprs", + ) + _n_non_child_args: ClassVar[int] = 1 + key_indices: tuple[int, ...] + is_scalar: bool + exprs: tuple[NamedExpr, ...] + + def __init__( + self, + schema: Schema, + key_indices: tuple[int, ...], + is_scalar: bool, # noqa: FBT001 + exprs: tuple[NamedExpr, ...], + input_ir: IR, + ): + assert len(key_indices) > 0, "Over node requires at least one partition-by key" + self.schema = schema + self.key_indices = key_indices + self.is_scalar = is_scalar + self.exprs = exprs + self._non_child_args = (exprs,) + self.children = (input_ir,) + + @classmethod + def do_evaluate( + cls, + exprs: tuple[NamedExpr, ...], + df: DataFrame, + *, + context: IRExecutionContext, + ) -> DataFrame: + """Evaluate window expressions against df.""" + # At evaluation time Over is just a Select with should_broadcast=True; + # the window-specific work lives in the GroupedWindow expressions. + return Select.do_evaluate(exprs, True, df, context=context) # noqa: FBT003 + + +def _is_scalar_grouped_window(expr: GroupedWindow) -> bool: + """Return True if this GroupedWindow can use the scalar broadcast path.""" + reductions, unary_ops = expr._split_named_expr() + if any(unary_ops.values()): + return False + if not all(isinstance(c, Col) for c in expr.children[: expr.by_count]): + return False + return all( + isinstance(ne.value, Len) + or (isinstance(ne.value, Agg) and ne.value.name in _DECOMPOSABLE_AGG_NAMES) + for ne in reductions + ) + + +def _extract_over_shuffle_indices( + expr: GroupedWindow, child_schema: Schema +) -> tuple[int, ...] | None: + """ + Return partition-by column indices in ``child_schema``, or None. + + Returns None when any partition-by expression is not a plain column + reference (the multi-partition path only supports Col keys today). + """ + by_children = expr.children[: expr.by_count] + if not all(isinstance(c, Col) for c in by_children): + return None + return names_to_indices( + tuple(cast("Col", c).name for c in by_children), child_schema + ) + + +def _decompose_grouped_window_node( + expr: GroupedWindow, + input_ir: IR, + partition_info: MutableMapping[IR, PartitionInfo], + config_options: ConfigOptions, + *, + names: Generator[str, None, None], +) -> tuple[Expr, IR, MutableMapping[IR, PartitionInfo]]: + """ + Build an Over IR node wrapping a single GroupedWindow expression. + + Every GroupedWindow becomes its own Over here; co-keyed Overs are + fused together later by select fusion so the actor evaluates all + window expressions in one pass. + + Returns + ------- + Expr + A ``Col`` referencing the Over node's output column, suitable + for substitution into the enclosing expression. + IR + The new ``Over`` IR node. + MutableMapping[IR, PartitionInfo] + ``partition_info`` augmented with an entry for the new node. + """ + indices = _extract_over_shuffle_indices(expr, input_ir.schema) + if indices is None: + # TODO: support non-Col partition-by keys on the multi-partition + # paths. Today the hash shuffle layer rejects expression keys, and + # the scalar-aggregation broadcast path builds its piecewise + # groupby from Col by-children directly. Supporting expression + # keys would require lowering them to columns in the input first. + raise NotImplementedError( + "GroupedWindow with non-Col partition-by keys " + "is not supported for multiple partitions." + ) + is_scalar = _is_scalar_grouped_window(expr) + col_name = next(names) + over_node = Over( + {col_name: expr.dtype}, + indices, + is_scalar, + (NamedExpr(col_name, expr),), + input_ir, + ) + partition_info[over_node] = partition_info[input_ir] + return Col(expr.dtype, col_name), over_node, partition_info + + +def _fuse_over_nodes( + selections: list[Select], + partition_info: MutableMapping[IR, PartitionInfo], +) -> tuple[list[Select], MutableMapping[IR, PartitionInfo]]: + """ + Fuse per-expression Over nodes that share the same grouping key. + + Selects sharing the Over's input IR are absorbed into the merged Over + so the actor produces the full output schema in one shuffle pass. The + grouping key is ``(key_indices, is_scalar, input_ir)``. + + Returns + ------- + list[Select] + The rewritten selections: one merged ``Select`` per Over group, + followed by any selections that were neither part of an Over + group nor absorbed into one. + MutableMapping[IR, PartitionInfo] + ``partition_info`` augmented with entries for the merged Over + nodes and merged Select nodes introduced by the rewrite. + """ + over_groups: defaultdict[ + tuple[tuple[int, ...], bool, IR], list[tuple[Select, Over]] + ] = defaultdict(list) + passthrough: list[Select] = [] + + for sel in selections: + child = sel.children[0] + if isinstance(child, Over): + input_ir = child.children[0] + over_groups[(child.key_indices, child.is_scalar, input_ir)].append( + (sel, child) + ) + else: + passthrough.append(sel) + + if not over_groups: + return selections, partition_info + + result: list[Select] = [] + for (key_indices, is_scalar, input_ir), group in over_groups.items(): + pi = partition_info[group[0][1]] + + absorbed: list[Select] = [] + remaining: list[Select] = [] + for s in passthrough: + (absorbed if s.children[0] == input_ir else remaining).append(s) + passthrough = remaining + + over_exprs = tuple( + itertools.chain( + *(s.exprs for s in absorbed), + *(over.exprs for _, over in group), + ) + ) + merged_over = Over( + {ne.name: ne.value.dtype for ne in over_exprs}, + key_indices, + is_scalar, + over_exprs, + input_ir, + ) + partition_info[merged_over] = pi + this_group = {*absorbed, *(sel for sel, _ in group)} + outer_exprs = tuple( + itertools.chain.from_iterable( + s.exprs for s in selections if s in this_group + ) + ) + outer_schema = {ne.name: ne.value.dtype for ne in outer_exprs} + + merged_sel = Select(outer_schema, outer_exprs, True, merged_over) # noqa: FBT003 + partition_info[merged_sel] = pi + result.append(merged_sel) + + result.extend(passthrough) + return result, partition_info diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/__init__.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/__init__.py index 2ea9af50c1b..7eedbf12bd4 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/__init__.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/__init__.py @@ -13,6 +13,7 @@ import cudf_polars.experimental.rapidsmpf.groupby import cudf_polars.experimental.rapidsmpf.io import cudf_polars.experimental.rapidsmpf.join +import cudf_polars.experimental.rapidsmpf.over import cudf_polars.experimental.rapidsmpf.repartition import cudf_polars.experimental.rapidsmpf.union # noqa: F401 diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/common.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/common.py index 0cb26c3689f..fa5cb995ac5 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/common.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/common.py @@ -14,6 +14,7 @@ from cudf_polars.dsl.traversal import traversal from cudf_polars.experimental.io import StreamingSink from cudf_polars.experimental.join import Join +from cudf_polars.experimental.over import Over from cudf_polars.experimental.repartition import Repartition from cudf_polars.experimental.shuffle import Shuffle @@ -101,6 +102,7 @@ def __init__( Sort, GroupBy, Distinct, + Over, ) self.collective_nodes: list[IR] = [ @@ -150,6 +152,16 @@ def __enter__(self) -> dict[IR, list[int]]: _get_new_collective_id(), _get_new_collective_id(), ] + elif isinstance(node, Over) and not node.is_scalar: + # Non-scalar Over needs 2 IDs: one for the size AllGather + + # forward shuffle (the AllGather completes before the forward + # shuffle starts, so they can share), and a separate ID for + # the return shuffle (which overlaps with the forward shuffle + # during extract+insert). + self.collective_id_map[node] = [ + _get_new_collective_id(), + _get_new_collective_id(), + ] else: self.collective_id_map[node] = [_get_new_collective_id()] diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py index 20faf0a2b03..04cac3c0454 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py @@ -4,7 +4,6 @@ from __future__ import annotations -from collections import deque from typing import TYPE_CHECKING from rapidsmpf.shuffler import PartitionAssignment @@ -20,7 +19,7 @@ from cudf_polars.containers import DataFrame, DataType from cudf_polars.dsl.expr import Col, NamedExpr from cudf_polars.dsl.ir import Empty, Sort -from cudf_polars.dsl.utils.naming import unique_names +from cudf_polars.dsl.utils.naming import names_to_indices, unique_names from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager from cudf_polars.experimental.rapidsmpf.collectives.shuffle import ShuffleManager from cudf_polars.experimental.rapidsmpf.dispatch import generate_ir_sub_network @@ -30,6 +29,7 @@ ) from cudf_polars.experimental.rapidsmpf.utils import ( ChannelManager, + ChunkStore, allgather_reduce, chunk_to_frame, concat_batch, @@ -37,7 +37,6 @@ evaluate_batch, evaluate_chunk, gather_in_task_group, - names_to_indices, process_children, recv_metadata, replay_buffered_channel, @@ -53,8 +52,6 @@ from cudf_polars.utils.cuda_stream import get_joined_cuda_stream if TYPE_CHECKING: - from collections.abc import Generator - from rapidsmpf.communicator.communicator import Communicator from rapidsmpf.streaming.core.channel import Channel from rapidsmpf.streaming.core.context import Context @@ -66,23 +63,6 @@ from cudf_polars.utils.config import StreamingExecutor -class ChunkStore: - """Ordered spillable buffer for TableChunk messages.""" - - def __init__(self, ctx: Context) -> None: - self._mids: deque[int] = deque() - self._store = ctx.spillable_messages() - - def insert(self, msg: Message) -> None: - """Insert a message into the store.""" - self._mids.append(self._store.insert(msg)) - - def __iter__(self) -> Generator[Message, None, None]: - """Yield messages in insertion order, draining the store.""" - while self._mids: - yield self._store.extract(mid=self._mids.popleft()) - - async def _simple_top_or_bottom_k( context: Context, comm: Communicator, diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/core.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/core.py index 539ff10e7b6..8ec9c9faec5 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/core.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/core.py @@ -19,6 +19,7 @@ Union, ) from cudf_polars.dsl.traversal import CachingVisitor, traversal +from cudf_polars.experimental.over import Over from cudf_polars.experimental.rapidsmpf.dispatch import FanoutInfo from cudf_polars.experimental.rapidsmpf.nodes import ( generate_ir_sub_network_wrapper, @@ -169,12 +170,11 @@ def _mark_children_unbounded(node: IR) -> None: for node in traversal([ir]): if node in unbounded: _mark_children_unbounded(node) - elif isinstance(node, Union): - # Union processes children sequentially, so all children - # with multiple consumers need unbounded fanout - _mark_children_unbounded(node) - elif isinstance(node, Join): - # This may be a broadcast join + elif isinstance(node, (Union, Join, Over)): + # Union processes children sequentially; Join may broadcast one + # side; Over buffers (or samples-then-replays) its input before + # producing output. In every case the input source needs + # unbounded fanout so other consumers don't block it. _mark_children_unbounded(node) elif len(node.children) > 1: # Check if this node is doing any broadcasting. diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/groupby.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/groupby.py index 7eaaf3ecacd..fe1a2f7fb61 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/groupby.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/groupby.py @@ -15,7 +15,6 @@ from rapidsmpf.streaming.cudf.channel_metadata import ( ChannelMetadata, HashScheme, - Partitioning, ) from rapidsmpf.streaming.cudf.table_chunk import TableChunk @@ -26,7 +25,6 @@ from cudf_polars.dsl.ir import IR, Distinct, GroupBy, Select from cudf_polars.dsl.utils.naming import unique_names from cudf_polars.experimental.groupby import combine, decompose -from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager from cudf_polars.experimental.rapidsmpf.collectives.shuffle import ShuffleManager from cudf_polars.experimental.rapidsmpf.dispatch import ( generate_ir_sub_network, @@ -34,6 +32,8 @@ from cudf_polars.experimental.rapidsmpf.utils import ( ChannelManager, NormalizedPartitioning, + _make_hash_shuffle_metadata, + allgather_and_reduce, allgather_reduce, chunkwise_evaluate, empty_table_chunk, @@ -308,26 +308,13 @@ async def _tree_reduce( await send_metadata(ch_out, context, metadata_out) if need_allgather: - allgather = AllGatherManager(context, comm, collective_id) - with allgather.inserting() as inserter: - inserter.insert( - 0, - _enforce_schema( - aggregated, decomposed.reduction_ir.schema, context.br() - ), - ) - - stream = ir_context.get_cuda_stream() - aggregated = await evaluate_chunk( + aggregated = await allgather_and_reduce( context, - TableChunk.from_pylibcudf_table( - await allgather.extract_concatenated(stream), - stream, - exclusive_view=True, - br=context.br(), - ), + comm, + collective_id, + _enforce_schema(aggregated, decomposed.reduction_ir.schema, context.br()), decomposed.reduction_ir, - ir_context=ir_context, + ir_context, ) if decomposed.select_ir is not None: @@ -401,31 +388,9 @@ async def _shuffle_reduce( options = Options(get_environment_variables()) shuffle_comm = single_comm(options, comm.progress_thread) shuffle_context = Context(shuffle_comm.logger, context.br(), options) - shuf_nranks = shuffle_comm.nranks - shuf_rank = shuffle_comm.rank - modulus = max(shuf_nranks, modulus) - - if shuf_nranks == 1: - inter_rank_scheme = ( - None - if metadata_in.partitioning is None - else metadata_in.partitioning.inter_rank - ) - local_scheme = HashScheme( - column_indices=decomposed.output_indices, modulus=modulus - ) - local_output_count = modulus - else: - inter_rank_scheme = HashScheme( - column_indices=decomposed.output_indices, modulus=modulus - ) - local_scheme = "inherit" - local_output_count = (modulus - shuf_rank + shuf_nranks - 1) // shuf_nranks - - metadata_out = ChannelMetadata( - local_count=local_output_count, - partitioning=Partitioning(inter_rank_scheme, local_scheme), - duplicated=metadata_in.duplicated, + modulus = max(shuffle_comm.nranks, modulus) + metadata_out = _make_hash_shuffle_metadata( + shuffle_comm, decomposed.output_indices, modulus, metadata_in ) await send_metadata(ch_out, context, metadata_out) diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py index 7798f1faa5d..7d2282ba032 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py @@ -5,7 +5,7 @@ from __future__ import annotations import asyncio -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Literal from rapidsmpf.memory.memory_reservation import opaque_memory_usage @@ -30,6 +30,7 @@ from cudf_polars.containers import DataFrame from cudf_polars.dsl.ir import IR, Join +from cudf_polars.dsl.utils.naming import names_to_indices from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager from cudf_polars.experimental.rapidsmpf.collectives.shuffle import _global_shuffle from cudf_polars.experimental.rapidsmpf.dispatch import ( @@ -39,13 +40,14 @@ from cudf_polars.experimental.rapidsmpf.utils import ( ChannelManager, NormalizedPartitioning, + TableSizeStats, _is_already_partitioned, + _sample_chunks, allgather_reduce, chunk_to_frame, empty_table_chunk, gather_in_task_group, maybe_remap_partitioning, - names_to_indices, process_children, recv_metadata, replay_buffered_channel, @@ -76,20 +78,6 @@ MAX_BROADCAST_ROWS = CUDF_ROW_LIMIT // 2 -@dataclass(frozen=True) -class JoinSideStats: - """Sampled chunks and aggregate size/row stats for one side of a join.""" - - chunks: dict[int, TableChunk] = field(default_factory=dict) - """The sampled chunks, keyed by sequence number.""" - total_size: int = 0 - """The total estimated size of the child table.""" - total_rows: int = 0 - """The total estimated number of rows in the child table.""" - total_chunks: int = 0 - """The total estimated number of chunks in the child table.""" - - @dataclass(frozen=True) class JoinStrategy: """Summary of sampling and strategy selection for a dynamic join.""" @@ -855,10 +843,10 @@ def _num_indices(partitioning: NormalizedPartitioning) -> int: async def _aggregate_estimates( context: Context, comm: Communicator, - left_sample: JoinSideStats, - right_sample: JoinSideStats, + left_sample: TableSizeStats, + right_sample: TableSizeStats, collective_ids: list[int], -) -> tuple[JoinSideStats, JoinSideStats]: +) -> tuple[TableSizeStats, TableSizeStats]: """Aggregate table-size and row estimates across ranks.""" # AllGather size, row, and chunk count estimates across ranks ( @@ -880,13 +868,13 @@ async def _aggregate_estimates( right_sample.total_chunks, ) - new_left_sample = JoinSideStats( + new_left_sample = TableSizeStats( chunks=left_sample.chunks, total_size=left_total, total_rows=left_total_rows, total_chunks=left_total_chunks, ) - new_right_sample = JoinSideStats( + new_right_sample = TableSizeStats( chunks=right_sample.chunks, total_size=right_total, total_rows=right_total_rows, @@ -904,8 +892,8 @@ async def _choose_strategy_from_samples( right_partitioning: NormalizedPartitioning, executor: StreamingExecutor, *, - left_sample: JoinSideStats, - right_sample: JoinSideStats, + left_sample: TableSizeStats, + right_sample: TableSizeStats, chunkwise: bool, tracer: ActorTracer | None, ) -> JoinStrategy: @@ -1043,59 +1031,6 @@ def _modulus(partitioning: NormalizedPartitioning) -> int | None: return max(large, min_shuffle_modulus) -async def _sample_chunks( - context: Context, - ch: Channel[TableChunk], - max_sample_chunks: int, - max_sample_bytes: int, - local_count: int, -) -> JoinSideStats: - """ - Sample chunks from a channel. - - Parameters - ---------- - context - The context. - ch - The channel to sample from. - max_sample_chunks - The maximum number of chunks to sample. - max_sample_bytes - The maximum number of bytes to sample. - local_count - The number of local chunks. - - Returns - ------- - The sampled chunks. - """ - sampled_chunks: dict[int, TableChunk] = {} - total_size = 0 - total_rows = 0 - for _ in range(max_sample_chunks): - msg = await ch.recv(context) - if msg is None: - break - chunk = TableChunk.from_message(msg, br=context.br()).make_available_and_spill( - context.br(), allow_overbooking=True - ) - sampled_chunks[msg.sequence_number] = chunk - total_size += chunk.data_alloc_size() - total_rows += chunk.shape[0] - if total_size >= max_sample_bytes: - break - if sampled_chunks: - total_size = int((total_size / len(sampled_chunks)) * local_count) - total_rows = int((total_rows / len(sampled_chunks)) * local_count) - return JoinSideStats( - chunks=sampled_chunks, - total_size=total_size, - total_rows=total_rows, - total_chunks=local_count, - ) - - async def _choose_strategy( context: Context, comm: Communicator, @@ -1108,7 +1043,7 @@ async def _choose_strategy( collective_ids: list[int], *, tracer: ActorTracer | None, -) -> tuple[JoinSideStats, JoinSideStats, JoinStrategy]: +) -> tuple[TableSizeStats, TableSizeStats, JoinStrategy]: """Sample both sides, aggregate estimates, and choose broadcast vs shuffle.""" nranks = comm.nranks left_partitioning = NormalizedPartitioning.from_keys( @@ -1125,8 +1060,8 @@ async def _choose_strategy( if left_partitioning.is_aligned_with(right_partitioning, context.br()): # We can use a chunkwise join chunkwise = True - left_sample = JoinSideStats(total_chunks=left_metadata.local_count) - right_sample = JoinSideStats(total_chunks=right_metadata.local_count) + left_sample = TableSizeStats(total_chunks=left_metadata.local_count) + right_sample = TableSizeStats(total_chunks=right_metadata.local_count) else: # Need to shuffle or broadcast - Use sampled data to choose a strategy chunkwise = False diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py new file mode 100644 index 00000000000..3df1a6c04c3 --- /dev/null +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/over.py @@ -0,0 +1,820 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. +# SPDX-License-Identifier: Apache-2.0 +""" +Window ``over()`` actor for the RapidsMPF streaming runtime. + +Implements the ``group_to_rows`` ``WindowMapping`` only: each input row +receives the value computed for its group. Other mappings (``explode``, +``join``) are not supported. + +The actor picks one of three strategies at runtime based on the incoming +channel metadata and the shape of the windowed expressions. + +Chunkwise (already partitioned) + If the channel is already hash-partitioned on the over-keys (or any + prefix of them), every group is fully contained within one rank's + chunks. The window expression is correct on each chunk in isolation + and no cross-rank coordination is needed. + +Scalar broadcast (decomposable aggregations) + When every aggregation is decomposable, partial aggregates can be + combined associatively across ranks. Each rank computes per-chunk + partials, an AllGather collects them, a single reduction yields the + global aggregate per group, and each input chunk has those results + joined back onto its rows by the partition keys. Order is preserved + naturally: input chunks are buffered in receive order and emitted in + the same order after the global aggregate is known. + +Forward + return shuffle (non-decomposable aggregations) + For functions that need every row in a group visible at once, a hash + shuffle on the partition keys co-locates each group on one rank for + evaluation. After evaluation, a second shuffle routes each row back + to the rank that originally received it (output channels are + rank-local, so only the originating rank can emit), and the rows are + reassembled in input order using stamps that travel with the data + through both shuffles. +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, ClassVar, cast + +from rapidsmpf.memory.memory_reservation import opaque_memory_usage +from rapidsmpf.shuffler import PartitionAssignment +from rapidsmpf.streaming.core.actor import define_actor +from rapidsmpf.streaming.core.message import Message +from rapidsmpf.streaming.cudf.channel_metadata import ChannelMetadata +from rapidsmpf.streaming.cudf.table_chunk import ( + TableChunk, + make_table_chunks_available_or_wait, +) + +import polars as pl + +import pylibcudf as plc + +from cudf_polars.containers import Column, DataFrame, DataType +from cudf_polars.dsl.expr import GroupedWindow +from cudf_polars.dsl.expressions.base import ExecutionContext +from cudf_polars.dsl.utils.naming import unique_names +from cudf_polars.dsl.utils.reshape import broadcast +from cudf_polars.experimental.over import Over, _build_over_groupby_irs +from cudf_polars.experimental.rapidsmpf.collectives.shuffle import ( + LocalRepartitioner, + ShuffleManager, +) +from cudf_polars.experimental.rapidsmpf.dispatch import generate_ir_sub_network +from cudf_polars.experimental.rapidsmpf.utils import ( + ChannelManager, + ChunkStore, + NormalizedPartitioning, + _evaluate_chunk_sync, + _sample_chunks, + allgather_and_reduce, + allgather_reduce, + chunk_to_frame, + chunkwise_evaluate, + empty_table_chunk, + evaluate_batch, + evaluate_chunk, + gather_in_task_group, + maybe_remap_partitioning, + process_children, + recv_metadata, + replay_buffered_channel, + send_metadata, + shutdown_on_error, +) + +if TYPE_CHECKING: + from rapidsmpf.communicator.communicator import Communicator + from rapidsmpf.memory.buffer_resource import BufferResource + from rapidsmpf.streaming.core.channel import Channel + from rapidsmpf.streaming.core.context import Context + + from rmm.pylibrmm.stream import Stream + + from cudf_polars.dsl.expr import Col + from cudf_polars.dsl.ir import IR, GroupBy, IRExecutionContext, Select + from cudf_polars.experimental.rapidsmpf.dispatch import SubNetGenerator + from cudf_polars.experimental.rapidsmpf.utils import TableSizeStats + + +@dataclass(frozen=True) +class _ScalarOverPlan: + """Pre-computed IR rewrites for the scalar Over path.""" + + gw_nodes: tuple[GroupedWindow, ...] + key_names: tuple[str, ...] + piecewise_ir: GroupBy + reduction_ir: GroupBy + agg_select_ir: Select + + +def _build_scalar_over_plan(ir: Over) -> _ScalarOverPlan: + """Pre-compute the IR rewrites needed by the scalar Over path.""" + gw_nodes = tuple(ne.value for ne in ir.exprs if isinstance(ne.value, GroupedWindow)) + # Lowering rejects non-Col partition-by keys, so every by-child here is a Col. + by_children = gw_nodes[0].children[: gw_nodes[0].by_count] + key_names = tuple(cast("Col", c).name for c in by_children) + piecewise_ir, reduction_ir, agg_select_ir = _build_over_groupby_irs( + gw_nodes, ir.children[0] + ) + return _ScalarOverPlan( + gw_nodes=gw_nodes, + key_names=key_names, + piecewise_ir=piecewise_ir, + reduction_ir=reduction_ir, + agg_select_ir=agg_select_ir, + ) + + +def _broadcast_gw_sync( + gw: GroupedWindow, + chunk_df: DataFrame, + global_agg_df: DataFrame, + key_names: tuple[str, ...], + stream: Stream, +) -> Any: + """Broadcast the global aggregate for one GroupedWindow back to row positions.""" + by_exprs = gw.children[: gw.by_count] + by_cols = broadcast( + *(b.evaluate(chunk_df) for b in by_exprs), + target_length=chunk_df.num_rows, + stream=stream, + ) + by_tbl = plc.Table([c.obj for c in by_cols]) + group_keys_tbl = global_agg_df.select(key_names).table + + out_names, out_dtypes = zip( + *((ne.name, ne.value.dtype) for ne in gw.named_aggs), strict=True + ) + value_tbls = [ + plc.Table([global_agg_df.column_map[ne.name].obj]) for ne in gw.named_aggs + ] + + broadcasted_cols = gw._broadcast_agg_results( + by_tbl, group_keys_tbl, value_tbls, out_names, out_dtypes, stream + ) + temp_df = DataFrame(broadcasted_cols, stream=stream) + return gw.post.value.evaluate(temp_df, context=ExecutionContext.FRAME) + + +def _evaluate_ir_broadcast_sync( + chunk: TableChunk, + ir: Over, + global_agg_df: DataFrame, + key_names: tuple[str, ...], + gw_nodes: tuple[GroupedWindow, ...], + ir_context: IRExecutionContext, + br: BufferResource, +) -> TableChunk: + """Map the per-group aggregate onto a chunk's rows to produce its Over output.""" + chunk_df = chunk_to_frame(chunk, ir.children[0]) + # global_agg_df and chunk_df may live on different streams (the former from + # the upstream allgather/reduction on ir_context's stream, the latter from + # the input message). Join them so the broadcast kernels read global_agg_df + # safely. + with ir_context.stream_ordered_after(chunk_df, global_agg_df) as stream: + chunk_df = DataFrame(chunk_df.columns, stream=stream) + global_agg_df = DataFrame(global_agg_df.columns, stream=stream) + + gw_results = { + gw: _broadcast_gw_sync(gw, chunk_df, global_agg_df, key_names, stream) + for gw in gw_nodes + } + + result_cols = [] + for ne in ir.exprs: + if isinstance(ne.value, GroupedWindow): + col = gw_results[ne.value].rename(ne.name) + else: + col = ne.evaluate(chunk_df, context=ExecutionContext.FRAME) + result_cols.append(col) + + return TableChunk.from_pylibcudf_table( + plc.Table([c.obj for c in result_cols]), + stream, + exclusive_view=True, + br=br, + ) + + +async def _evaluate_broadcast_chunk( + context: Context, + chunk: TableChunk, + ir: Over, + global_agg_df: DataFrame, + key_names: tuple[str, ...], + gw_nodes: tuple[GroupedWindow, ...], + ir_context: IRExecutionContext, +) -> TableChunk: + """Unspill the chunk and map the per-group aggregate onto its rows.""" + chunk, extra = await make_table_chunks_available_or_wait( + context, + chunk, + reserve_extra=chunk.data_alloc_size(), + net_memory_delta=0, + ) + with opaque_memory_usage(extra): + return await asyncio.to_thread( + _evaluate_ir_broadcast_sync, + chunk, + ir, + global_agg_df, + key_names, + gw_nodes, + ir_context, + context.br(), + ) + + +@dataclass(frozen=True) +class OriginStamps: + """ + Stamp column names that ride both shuffles for output reassembly. + + Parameters + ---------- + chunk_index + Column name for a dense rank-local 0..N-1 counter identifying which + input chunk a row came from. + position + Column name for the row's position within its input chunk. + rank + Column name for the originating rank. + """ + + chunk_index: str + position: str + rank: str + + dtype: ClassVar[DataType] = DataType(pl.Int32()) + + @property + def names(self) -> tuple[str, str, str]: + """Stamp column names, in the order they are appended to the table.""" + return (self.chunk_index, self.position, self.rank) + + +def _origin_stamps_for(ir: Over) -> OriginStamps: + """Pick three stamp column names that do not collide with the schema.""" + names = unique_names((*ir.children[0].schema.keys(), *ir.schema.keys())) + return OriginStamps(next(names), next(names), next(names)) + + +def _append_origin_stamps( + chunk: TableChunk, + chunk_index: int, + origin_rank: int, + stream: Stream, + br: Any, +) -> TableChunk: + """Append (chunk_index, position, rank) stamp columns to *chunk*.""" + table = chunk.table_view() + n_rows = table.num_rows() + int32 = plc.types.DataType(plc.TypeId.INT32) + chunk_index_col = plc.Column.from_scalar( + plc.Scalar.from_py(chunk_index, int32, stream=stream), n_rows, stream=stream + ) + rank_col = plc.Column.from_scalar( + plc.Scalar.from_py(origin_rank, int32, stream=stream), n_rows, stream=stream + ) + position_col = plc.filling.sequence( + n_rows, + plc.Scalar.from_py(0, int32, stream=stream), + plc.Scalar.from_py(1, int32, stream=stream), + stream=stream, + ) + return TableChunk.from_pylibcudf_table( + plc.Table([*table.columns(), chunk_index_col, position_col, rank_col]), + stream, + exclusive_view=False, + br=br, + ) + + +def _evaluate_window_with_stamps( + chunk: TableChunk, + ir: Over, + ir_context: IRExecutionContext, + stamps: OriginStamps, +) -> DataFrame: + """Evaluate *ir* on the un-stamped portion of *chunk*; reattach stamps after.""" + child_schema = ir.children[0].schema + stream = ir_context.get_cuda_stream() + columns = chunk.table_view().columns() + n_child = len(child_schema) + + input_df = DataFrame.from_table( + plc.Table(columns[:n_child]), + list(child_schema.keys()), + list(child_schema.values()), + stream, + ) + result = ir.do_evaluate(ir.exprs, input_df, context=ir_context) + stamp_cols = [ + Column(col, dtype=stamps.dtype, name=name) + for col, name in zip(columns[n_child:], stamps.names, strict=True) + ] + return result.with_columns(stamp_cols, stream=stream) + + +def _partition_by_origin_rank( + result: DataFrame, + num_ranks: int, + br: Any, +) -> tuple[TableChunk | None, list[int]]: + """ + Rearrange rows so partition i contains rows whose origin rank is i. + + Returns a chunk with the rank stamp dropped and the per-rank split + indices for direct insertion into the return shuffle. + """ + if result.table.num_rows() == 0: + return None, [] + + stream = result.stream + columns = result.table.columns() + rank_column = columns[-1] + payload = plc.Table(columns[:-1]) + + rearranged, offsets = plc.partitioning.partition( + payload, rank_column, num_ranks, stream=stream + ) + return ( + TableChunk.from_pylibcudf_table(rearranged, stream, exclusive_view=True, br=br), + list(offsets[1:-1]), + ) + + +async def _allgather_and_broadcast( + context: Context, + comm: Communicator, + ir: Over, + ir_context: IRExecutionContext, + ch_in: Channel[TableChunk], + ch_out: Channel[TableChunk], + metadata_in: ChannelMetadata, + tracer: Any, + collective_id: int, + plan: _ScalarOverPlan, +) -> None: + """Compute partial aggregates per chunk, AllGather globally, then broadcast to each chunk.""" + piecewise_ir = plan.piecewise_ir + reduction_ir = plan.reduction_ir + agg_select_ir = plan.agg_select_ir + + buffer = ChunkStore(context) + partial_aggs: list[TableChunk] = [] + + while (msg := await ch_in.recv(context)) is not None: + chunk = TableChunk.from_message(msg, br=context.br()) + chunk, extra = await make_table_chunks_available_or_wait( + context, + chunk, + reserve_extra=chunk.data_alloc_size(), + net_memory_delta=0, + ) + with opaque_memory_usage(extra): + partial = await asyncio.to_thread( + _evaluate_chunk_sync, + chunk, + piecewise_ir, + ir_context, + context.br(), + ) + partial_aggs.append(partial) + buffer.insert(Message(msg.sequence_number, chunk)) + + if partial_aggs: + local_agg = await evaluate_batch( + partial_aggs, context, reduction_ir, ir_context=ir_context + ) + else: + local_agg = empty_table_chunk( + reduction_ir, context, ir_context.get_cuda_stream() + ) + + # AllGather the locally-reduced partials (pre post-aggregation) so a + # single global reduction combines them; the post-aggregation step + # runs once after. + if comm.nranks > 1 and not metadata_in.duplicated: + global_agg = await allgather_and_reduce( + context, comm, collective_id, local_agg, reduction_ir, ir_context + ) + else: + global_agg = local_agg + + global_agg = await evaluate_chunk( + context, global_agg, agg_select_ir, ir_context=ir_context + ) + global_agg_df = chunk_to_frame(global_agg, agg_select_ir) + + metadata_out = ChannelMetadata( + local_count=metadata_in.local_count, + partitioning=maybe_remap_partitioning(ir, metadata_in.partitioning), + duplicated=metadata_in.duplicated, + ) + await send_metadata(ch_out, context, metadata_out) + + for msg in buffer: + result = await _evaluate_broadcast_chunk( + context, + TableChunk.from_message(msg, br=context.br()), + ir, + global_agg_df, + plan.key_names, + plan.gw_nodes, + ir_context, + ) + if tracer is not None: + tracer.add_chunk(table=result.table_view()) + await ch_out.send(context, Message(msg.sequence_number, result)) + + await ch_out.drain(context) + + +async def _choose_modulus( + context: Context, + comm: Communicator, + ch_in: Channel[TableChunk], + metadata_in: ChannelMetadata, + collective_id: int, + target_partition_size: int, + sample_chunk_count: int, +) -> tuple[TableSizeStats, int]: + """ + Sample input, AllGather size estimates, and derive the forward-shuffle modulus. + + Returns the sample (whose chunks must be replayed back to the consumer) + and the chosen number of forward-shuffle partitions. + """ + sample = await _sample_chunks( + context, + ch_in, + sample_chunk_count, + target_partition_size, + metadata_in.local_count, + ) + if comm.nranks > 1 and not metadata_in.duplicated: + total_bytes, total_count = await allgather_reduce( + context, comm, collective_id, sample.total_size, sample.total_chunks + ) + else: + total_bytes, total_count = sample.total_size, sample.total_chunks + modulus = min( + max(comm.nranks, total_bytes // max(1, target_partition_size)), + max(1, total_count), + ) + return sample, modulus + + +async def _distribute_by_group( + context: Context, + comm: Communicator, + forward_shuffle: ShuffleManager, + ch_in: Channel[TableChunk], + key_indices: tuple[int, ...], + ir_context: IRExecutionContext, + skip_insert: bool, # noqa: FBT001 +) -> list[int]: + """Stream chunks from *ch_in* into the forward shuffle with origin stamps.""" + # We already have the upstream metadata; signal we don't need the replay + # channel's copy. + await ch_in.shutdown_metadata(context) + + sequence_numbers: list[int] = [] + chunk_index = 0 + async with forward_shuffle.inserting() as inserter: + while (msg := await ch_in.recv(context)) is not None: + chunk = TableChunk.from_message( + msg, br=context.br() + ).make_available_and_spill(context.br(), allow_overbooking=True) + sequence_numbers.append(msg.sequence_number) + if not skip_insert: + # TODO: For duplicated input only rank 0 inserts here, and + # every row is stamped with origin_rank=0, so the return + # shuffle routes all output back to rank 0 and ranks + # 1..nranks-1 sit idle on emit. Slice the duplicated input + # across ranks (e.g. stripe by row index) and stamp each + # slice with its target origin rank to distribute emit work. + stamped = await asyncio.to_thread( + _append_origin_stamps, + chunk, + chunk_index, + comm.rank, + ir_context.get_cuda_stream(), + context.br(), + ) + inserter.insert_hash(stamped, key_indices) + chunk_index += 1 + return sequence_numbers + + +async def _evaluate_and_route_to_origin( + context: Context, + ir: Over, + ir_context: IRExecutionContext, + forward_shuffle: ShuffleManager, + return_shuffle: ShuffleManager, + num_ranks: int, + stamps: OriginStamps, +) -> None: + """Window-evaluate each local forward partition, then ship rows back to their origin.""" + async with return_shuffle.inserting() as inserter: + for partition_id in forward_shuffle.local_partitions(): + stream = ir_context.get_cuda_stream() + extracted = forward_shuffle.extract_chunk(partition_id, stream) + if extracted.num_rows() == 0: + continue + partition = TableChunk.from_pylibcudf_table( + extracted, stream, exclusive_view=True, br=context.br() + ) + evaluated = await asyncio.to_thread( + _evaluate_window_with_stamps, partition, ir, ir_context, stamps + ) + routed, splits = await asyncio.to_thread( + _partition_by_origin_rank, evaluated, num_ranks, context.br() + ) + if routed is not None: + inserter.insert_split(routed, splits) + + +async def _reassemble_input_chunks( + context: Context, + ch_out: Channel[TableChunk], + ir_context: IRExecutionContext, + return_shuffle: ShuffleManager, + sequence_numbers: list[int], + ir: Over, + tracer: Any, +) -> None: + """Emit one output chunk per input chunk, in original order.""" + n_chunks = len(sequence_numbers) + if n_chunks == 0: + return + + n_exprs = len(ir.exprs) + chunk_index_column = n_exprs + + # TODO: thread ir_context through repartition_by_index so each + # PackedData piece moves on its own pool stream rather than sharing one. + local = LocalRepartitioner(return_shuffle, local_count=n_chunks) + await local.repartition_by_index( + partition_col=chunk_index_column, stream=ir_context.get_cuda_stream() + ) + + for chunk_index, sequence_number in zip( + local.local_partitions(), sequence_numbers, strict=True + ): + # Distinct stream per chunk so downstream work on different + # chunks can overlap on the GPU. + stream = ir_context.get_cuda_stream() + tbl = local.extract_chunk(chunk_index, stream) + if tbl.num_rows() == 0: + chunk = empty_table_chunk(ir, context, stream) + else: + sorted_tbl = plc.sorting.stable_sort_by_key( + tbl, + plc.Table([tbl.columns()[n_exprs]]), + [plc.types.Order.ASCENDING], + [plc.types.NullOrder.AFTER], + stream=stream, + ) + chunk = TableChunk.from_pylibcudf_table( + plc.Table(sorted_tbl.columns()[:n_exprs]), + stream, + exclusive_view=True, + br=context.br(), + ) + if tracer is not None: + tracer.add_chunk(table=chunk.table_view()) + await ch_out.send(context, Message(sequence_number, chunk)) + + +async def _shuffle_and_reassemble( + context: Context, + comm: Communicator, + ir: Over, + ir_context: IRExecutionContext, + ch_in: Channel[TableChunk], + ch_out: Channel[TableChunk], + metadata_in: ChannelMetadata, + tracer: Any, + size_collective_id: int, + forward_shuffle_collective_id: int, + return_shuffle_collective_id: int, + target_partition_size: int, + sample_chunk_count: int, +) -> None: + """Hash-shuffle by partition keys, evaluate, then route rows back to their origin rank.""" + stamps = _origin_stamps_for(ir) + + metadata_out = ChannelMetadata( + local_count=metadata_in.local_count, + partitioning=maybe_remap_partitioning(ir, metadata_in.partitioning), + duplicated=False, + ) + await send_metadata(ch_out, context, metadata_out) + + skip_insert = metadata_in.duplicated and comm.rank != 0 + + sample, forward_modulus = await _choose_modulus( + context, + comm, + ch_in, + metadata_in, + size_collective_id, + target_partition_size, + sample_chunk_count, + ) + + forward_shuffle = ShuffleManager( + context, comm, forward_modulus, forward_shuffle_collective_id + ) + return_shuffle = ShuffleManager( + context, + comm, + comm.nranks, + return_shuffle_collective_id, + partition_assignment=PartitionAssignment.CONTIGUOUS, + ) + + ch_replay = context.create_channel() + sequence_numbers, _ = await gather_in_task_group( + _distribute_by_group( + context, + comm, + forward_shuffle, + ch_replay, + ir.key_indices, + ir_context, + skip_insert, + ), + replay_buffered_channel( + context, ch_replay, ch_in, sample.chunks, metadata_in, trace_ir=ir + ), + ) + + await _evaluate_and_route_to_origin( + context, + ir, + ir_context, + forward_shuffle, + return_shuffle, + comm.nranks, + stamps, + ) + await _reassemble_input_chunks( + context, ch_out, ir_context, return_shuffle, sequence_numbers, ir, tracer + ) + + await ch_out.drain(context) + + +@define_actor() +async def over_actor( + context: Context, + comm: Communicator, + ir: Over, + ir_context: IRExecutionContext, + ch_out: Channel[TableChunk], + ch_in: Channel[TableChunk], + collective_ids: list[int], + target_partition_size: int, + sample_chunk_count: int, + scalar_plan: _ScalarOverPlan | None, +) -> None: + """ + Streaming actor for window ``over()`` expressions. + + Parameters + ---------- + context + The rapidsmpf context. + comm + The communicator. + ir + The Over IR node. + ir_context + The IR execution context. + ch_out + The output channel. + ch_in + The input channel. + collective_ids + Collective IDs reserved for this operation. Scalar Over nodes receive + one ID (AllGather); non-scalar nodes receive two (one shared by the + size AllGather and forward Shuffle, plus a separate one for the + return Shuffle which overlaps with the forward extract). + target_partition_size + Target output partition size in bytes, used to compute the shuffle + modulus for the non-scalar path. + sample_chunk_count + Maximum number of input chunks to sample when estimating the shuffle + modulus on the non-scalar path. + scalar_plan + Pre-computed IR rewrites for the scalar Over path, built at planning + time. ``None`` for non-scalar Over nodes. + """ + async with shutdown_on_error( + context, ch_in, ch_out, trace_ir=ir, ir_context=ir_context + ) as tracer: + metadata_in = await recv_metadata(ch_in, context) + + partitioning = NormalizedPartitioning.from_keys( + metadata_in.partitioning, + comm.nranks, + keys=ir.key_indices, + allow_subset=True, + ) + if partitioning: + metadata_out = ChannelMetadata( + local_count=metadata_in.local_count, + partitioning=maybe_remap_partitioning(ir, metadata_in.partitioning), + duplicated=metadata_in.duplicated, + ) + await chunkwise_evaluate( + context, + ir, + ir_context, + ch_out, + ch_in, + metadata_out, + tracer=tracer, + ) + return + + if ir.is_scalar: + assert scalar_plan is not None + await _allgather_and_broadcast( + context, + comm, + ir, + ir_context, + ch_in, + ch_out, + metadata_in, + tracer, + collective_ids[0], + scalar_plan, + ) + else: + await _shuffle_and_reassemble( + context, + comm, + ir, + ir_context, + ch_in, + ch_out, + metadata_in, + tracer, + # collective_ids[0] is reused for the size AllGather and the + # forward shuffle (sequential, no overlap); collective_ids[1] + # is the return shuffle, which overlaps with forward extract. + size_collective_id=collective_ids[0], + forward_shuffle_collective_id=collective_ids[0], + return_shuffle_collective_id=collective_ids[1], + target_partition_size=target_partition_size, + sample_chunk_count=sample_chunk_count, + ) + + +@generate_ir_sub_network.register(Over) +def _( + ir: Over, rec: SubNetGenerator +) -> tuple[dict[IR, list[Any]], dict[IR, ChannelManager]]: + executor = rec.state["config_options"].executor + actors, channels = process_children(ir, rec) + channels[ir] = ChannelManager(rec.state["context"]) + collective_ids = list(rec.state["collective_id_map"].get(ir, [])) + if not ir.is_scalar and executor.dynamic_planning is None: + raise ValueError( + "Non-scalar over() requires dynamic planning to size the forward " + "shuffle. Enable it via StreamingExecutor(dynamic_planning=...) " + "or the --dynamic-planning CLI flag." + ) + sample_chunk_count = ( + executor.dynamic_planning.sample_chunk_count + if executor.dynamic_planning is not None + else 0 + ) + scalar_plan = _build_scalar_over_plan(ir) if ir.is_scalar else None + actors[ir] = [ + over_actor( + rec.state["context"], + rec.state["comm"], + ir, + rec.state["ir_context"], + channels[ir].reserve_input_slot(), + channels[ir.children[0]].reserve_output_slot(), + collective_ids, + executor.target_partition_size, + sample_chunk_count, + scalar_plan, + ) + ] + return actors, channels diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/tracing.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/tracing.py index ec24bfab3b7..c9f39cab068 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/tracing.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/tracing.py @@ -71,6 +71,11 @@ def add_chunk(self, *, table: plc.Table | None = None) -> None: table The table to record. """ + # TODO: replace this API with one that takes a TableChunk directly, + # using TableChunk.shape[0] for the row count, and consider providing a + # helper that logs and sends a chunk in one call so the + # ``tracer.add_chunk(...) + ch_out.send(...)`` pattern doesn't have to + # be duplicated across every actor. if table is not None: # pragma: no cover; Covered by rapidsmpf tests self.row_count = (self.row_count or 0) + table.num_rows() self.chunk_count += 1 diff --git a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py index 6724b75e0ec..f5859d5040f 100644 --- a/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py +++ b/python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py @@ -10,8 +10,9 @@ import operator import struct import time +from collections import deque from contextlib import asynccontextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import reduce from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast @@ -39,11 +40,20 @@ from cudf_polars.dsl.expr import Col, NamedExpr from cudf_polars.dsl.ir import Cache, Filter, GroupBy, HStack, Join, Projection, Select from cudf_polars.dsl.tracing import Scope +from cudf_polars.dsl.utils.naming import names_to_indices +from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager from cudf_polars.experimental.utils import _concat from cudf_polars.utils.dtypes import make_empty_column if TYPE_CHECKING: - from collections.abc import AsyncIterator, Callable, Coroutine, Iterator, Sequence + from collections.abc import ( + AsyncIterator, + Callable, + Coroutine, + Generator, + Iterator, + Sequence, + ) from rapidsmpf.communicator.communicator import Communicator from rapidsmpf.memory.buffer_resource import BufferResource @@ -63,6 +73,23 @@ PartitioningScheme: TypeAlias = InterRankScheme | Literal["inherit"] +class ChunkStore: + """Ordered spillable buffer for TableChunk messages.""" + + def __init__(self, ctx: Context) -> None: + self._mids: deque[int] = deque() + self._store = ctx.spillable_messages() + + def insert(self, msg: Message) -> None: + """Insert a message into the store.""" + self._mids.append(self._store.insert(msg)) + + def __iter__(self) -> Generator[Message, None, None]: + """Yield messages in insertion order, draining the store.""" + while self._mids: + yield self._store.extract(mid=self._mids.popleft()) + + @contextlib.contextmanager def set_memory_resource(mr: rmm.mr.DeviceMemoryResource) -> Iterator[None]: """ @@ -354,6 +381,54 @@ async def recv_metadata(ch: Channel[TableChunk], ctx: Context) -> ChannelMetadat return ChannelMetadata.from_message(msg) +def _make_hash_shuffle_metadata( + comm: Communicator, + key_indices: tuple[int, ...], + modulus: int, + metadata_in: ChannelMetadata, +) -> ChannelMetadata: + """ + Build output ChannelMetadata for a hash shuffle by key_indices. + + Parameters + ---------- + comm + The communicator. + key_indices + Column indices to hash-partition on. + modulus + Number of output partitions (must be >= comm.nranks). + metadata_in + Input channel metadata (used for duplicated flag and, on a + single-rank run, to preserve the existing inter-rank scheme). + + Returns + ------- + ChannelMetadata + Ready to pass to send_metadata. + """ + nranks = comm.nranks + if nranks == 1: + inter_rank_scheme = ( + None + if metadata_in.partitioning is None + else metadata_in.partitioning.inter_rank + ) + local_scheme: HashScheme | str = HashScheme( + column_indices=key_indices, modulus=modulus + ) + local_output_count = modulus + else: + inter_rank_scheme = HashScheme(column_indices=key_indices, modulus=modulus) + local_scheme = "inherit" + local_output_count = (modulus - comm.rank + nranks - 1) // nranks + return ChannelMetadata( + local_count=local_output_count, + partitioning=Partitioning(inter_rank_scheme, local_scheme), + duplicated=metadata_in.duplicated, + ) + + def _evaluate_chunk_sync( chunk: TableChunk, ir: IR, @@ -434,6 +509,49 @@ async def evaluate_chunk( return chunk +async def allgather_and_reduce( + context: Context, + comm: Communicator, + collective_id: int, + local_chunk: TableChunk, + reduce_ir: IR, + ir_context: IRExecutionContext, +) -> TableChunk: + """ + AllGather ``local_chunk`` across ranks and apply ``reduce_ir`` to the result. + + Parameters + ---------- + context + The rapidsmpf streaming context. + comm + The communicator. + collective_id + Collective operation ID for the AllGather. + local_chunk + The locally-reduced chunk this rank contributes. + reduce_ir + IR node applied to the concatenated AllGather output. + ir_context + The IR execution context. + + Returns + ------- + The chunk produced by evaluating ``reduce_ir`` on the gathered result. + """ + allgather = AllGatherManager(context, comm, collective_id) + with allgather.inserting() as inserter: + inserter.insert(0, local_chunk) + stream = ir_context.get_cuda_stream() + concat_chunk = TableChunk.from_pylibcudf_table( + await allgather.extract_concatenated(stream), + stream, + exclusive_view=True, + br=context.br(), + ) + return await evaluate_chunk(context, concat_chunk, reduce_ir, ir_context=ir_context) + + async def concat_batch( batch: list[TableChunk], context: Context, @@ -610,29 +728,71 @@ def indices_to_names(indices: tuple[int, ...], schema: Schema) -> tuple[str, ... return tuple(keys[i] for i in indices) -def names_to_indices( - names: tuple[str | NamedExpr, ...], schema: Schema -) -> tuple[int, ...]: - """ - Return column indices for the given names in schema order. +@dataclass(frozen=True) +class TableSizeStats: + """Sampled chunks and aggregate size/row stats for a table channel.""" + + chunks: dict[int, TableChunk] = field(default_factory=dict) + """The sampled chunks, keyed by sequence number.""" + total_size: int = 0 + """The total estimated size of the table in bytes.""" + total_rows: int = 0 + """The total estimated number of rows in the table.""" + total_chunks: int = 0 + """The total estimated number of chunks in the table.""" - Accepts either column names (str) or NamedExpr, so it can be used with - e.g. ir.left_on, ir.right_on as well as plain name tuples. + +async def _sample_chunks( + context: Context, + ch: Channel[TableChunk], + max_sample_chunks: int, + max_sample_bytes: int, + local_count: int, +) -> TableSizeStats: + """ + Sample chunks from a channel and extrapolate to a per-rank size estimate. Parameters ---------- - names - The names to get indices for. - schema - The schema to get indices from. + context + The context. + ch + The channel to sample from. + max_sample_chunks + The maximum number of chunks to sample. + max_sample_bytes + The maximum number of bytes to sample. + local_count + The expected number of local chunks (used for extrapolation). Returns ------- - The column indices for each name in schema order. + Sampled chunks and the extrapolated total size/rows for this rank. """ - keys = list(schema.keys()) - str_names = [n.name if isinstance(n, NamedExpr) else n for n in names] - return tuple(keys.index(n) for n in str_names) + sampled_chunks: dict[int, TableChunk] = {} + total_size = 0 + total_rows = 0 + for _ in range(max_sample_chunks): + msg = await ch.recv(context) + if msg is None: + break + chunk = TableChunk.from_message(msg, br=context.br()).make_available_and_spill( + context.br(), allow_overbooking=True + ) + sampled_chunks[msg.sequence_number] = chunk + total_size += chunk.data_alloc_size() + total_rows += chunk.shape[0] + if total_size >= max_sample_bytes: + break + if sampled_chunks: + total_size = int((total_size / len(sampled_chunks)) * local_count) + total_rows = int((total_rows / len(sampled_chunks)) * local_count) + return TableSizeStats( + chunks=sampled_chunks, + total_size=total_size, + total_rows=total_rows, + total_chunks=local_count, + ) async def replay_buffered_channel( diff --git a/python/cudf_polars/cudf_polars/experimental/select.py b/python/cudf_polars/cudf_polars/experimental/select.py index 606741d587d..87c34db1b1c 100644 --- a/python/cudf_polars/cudf_polars/experimental/select.py +++ b/python/cudf_polars/cudf_polars/experimental/select.py @@ -20,6 +20,7 @@ decompose_expr_graph, make_expr_decomposer, ) +from cudf_polars.experimental.over import _fuse_over_nodes from cudf_polars.experimental.repartition import Repartition from cudf_polars.experimental.utils import ( _contains_unsupported_fill_strategy, @@ -166,6 +167,7 @@ def decompose_select( # Concatenate partial selections new_ir: Select | HConcat + selections, partition_info = _fuse_over_nodes(selections, partition_info) selections, partition_info = _fuse_simple_reductions( selections, partition_info, diff --git a/python/cudf_polars/tests/experimental/test_rolling.py b/python/cudf_polars/tests/experimental/test_rolling.py index bc1bbbc40eb..63695e65700 100644 --- a/python/cudf_polars/tests/experimental/test_rolling.py +++ b/python/cudf_polars/tests/experimental/test_rolling.py @@ -49,6 +49,213 @@ def test_rolling_datetime(request, engine): assert_gpu_result_equal(q, engine=engine) +@pytest.mark.parametrize( + "expr", + [ + pl.col("x").sum().over("g"), + pl.len().over("g"), + pl.col("x").sum().over("g", "g2"), + pl.col("x").sum().over("g_null"), + pl.col("x").sum().over("g", order_by="s"), + pl.col("x").rank(method="dense", descending=True).over("g"), + pl.col("x").rank(method="min").over("g", "g2"), + pl.col("x").cum_sum().over("g", order_by="s"), + pl.when((pl.col("x") % 2) == 0) + .then(None) + .otherwise(pl.col("x")) + .fill_null(strategy="forward") + .over("g", order_by="s"), + ], + ids=[ + "single_key_sum", + "len_over", + "multi_key", + "null_keys", + "order_by", + "rank_dense", + "rank_min_multi_key", + "cum_sum_order_by", + "fill_null_forward", + ], +) +def test_over_select(engine, expr): + df = pl.LazyFrame( + { + "g": [1, 1, 2, 2, 2, 1], + "x": [1, 2, 3, 4, 5, 6], + "g2": ["a", "b", "a", "b", "a", "b"], + "g_null": [1, None, 1, None, 2, 1], + "s": [6, 5, 4, 3, 2, 1], + } + ) + assert_gpu_result_equal(df.select(expr), engine=engine, check_row_order=True) + + +def test_over_with_columns(engine): + df = pl.LazyFrame( + { + "g": [1, 1, 2, 2, 2, 1], + "x": [1, 2, 3, 4, 5, 6], + } + ) + assert_gpu_result_equal( + df.with_columns(pl.col("x").sum().over("g")), + engine=engine, + check_row_order=True, + ) + + +def test_over_colliding_internal_agg_names(engine): + df = pl.LazyFrame( + { + "category": ["A", "A", "B", "B", "C"], + "value": [20, 30, 15, 40, 35], + } + ) + q = df.select( + pl.col("category"), + pl.col("value"), + pl.col("value").sum().over("category").alias("cat_sum"), + pl.col("value").mean().over("category").alias("cat_avg"), + ).sort("category", "value") + assert_gpu_result_equal(q, engine=engine, check_row_order=True) + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("x").sum().over(pl.col("g") % 2), + pl.col("x").sum().over("g", pl.col("x") % 2), + ], + ids=["noncol_key", "mixed_col_and_expr_key"], +) +@pytest.mark.parametrize( + "engine", + [{"executor_options": {"max_rows_per_partition": 2}}], + indirect=True, +) +def test_over_noncol_key_fallback(request, engine, expr) -> None: + # Non-Col and mixed Col/expr partition-by keys are not yet supported for + # multi-partition streaming and should fall back to single-partition. + if not isinstance(engine, SPMDEngine): + # On Dask/Ray the fallback warning fires on worker processes and is + # invisible to ``pytest.warns``. + request.applymarker( + pytest.mark.xfail( + reason="https://github.com/rapidsai/cudf/issues/22405", + strict=False, + ) + ) + df = pl.LazyFrame( + { + "g": [1, 1, 2, 2, 2, 1], + "x": [1, 2, 3, 4, 5, 6], + } + ) + with pytest.warns(UserWarning, match=r"not supported for multiple partitions"): + assert_gpu_result_equal(df.select(expr), engine=engine) + + +@pytest.mark.parametrize( + "engine", + [{"executor_options": {"max_rows_per_partition": 2}}], + indirect=True, +) +def test_over_mixed_keys(engine) -> None: + # Multiple over expressions with different partition-by keys are decomposed + # into separate Over nodes (one per key group) and combined with HConcat. + df = pl.LazyFrame( + { + "g": [1, 1, 2, 2, 2, 1], + "g2": ["a", "b", "a", "b", "a", "b"], + "x": [1, 2, 3, 4, 5, 6], + } + ) + q = df.select( + pl.col("x").sum().over("g").alias("s_g"), + pl.col("x").sum().over("g2").alias("s_g2"), + ) + assert_gpu_result_equal(q, engine=engine, check_row_order=False) + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("x").sum().over("g"), + pl.len().over("g"), + pl.col("x").rank(method="dense").over("g"), + pl.col("x").cum_sum().over("g", order_by="s"), + ], + ids=["scalar_sum", "scalar_len", "nonscalar_rank", "nonscalar_cum_sum"], +) +@pytest.mark.parametrize( + "engine", + [{"executor_options": {"max_rows_per_partition": 1}}], + indirect=True, +) +def test_over_many_partitions(engine, expr) -> None: + # max_rows_per_partition=1 forces one chunk per row, exercising the AllGather + # (scalar broadcast) and sort-and-split (non-scalar) paths across many partitions. + df = pl.LazyFrame( + { + "g": [1, 1, 2, 2, 2, 1], + "x": [1, 2, 3, 4, 5, 6], + "s": [6, 5, 4, 3, 2, 1], + } + ) + assert_gpu_result_equal(df.select(expr), engine=engine, check_row_order=True) + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("x").sum().over("g"), + pl.col("x").rank(method="dense").over("g"), + ], + ids=["scalar_sum", "nonscalar_rank"], +) +@pytest.mark.parametrize( + "engine", + [{"executor_options": {"max_rows_per_partition": 2}}], + indirect=True, +) +def test_over_empty_input(engine, expr) -> None: + df = pl.LazyFrame( + { + "g": pl.Series([], dtype=pl.Int64), + "x": pl.Series([], dtype=pl.Int64), + } + ) + assert_gpu_result_equal(df.select(expr), engine=engine, check_row_order=True) + + +@pytest.mark.parametrize( + "expr", + [ + pl.col("x").sum().over("g"), + pl.col("x").rank(method="dense").over("g"), + ], + ids=["scalar_sum", "nonscalar_rank"], +) +@pytest.mark.parametrize( + "engine", + [{"executor_options": {"max_rows_per_partition": 3, "broadcast_join_limit": -1}}], + indirect=True, +) +def test_over_already_partitioned(engine, expr) -> None: + # broadcast_join_limit=-1 makes the broadcast threshold negative, disabling + # broadcast joins entirely. Therefore, we should already be shuffled on "g" + # after then join. The over("g") actor should detect this and evaluate chunkwise + # without a second shuffle. + left = pl.LazyFrame({"g": [1, 1, 2, 2, 2, 1], "x": [1, 2, 3, 4, 5, 6]}) + right = pl.LazyFrame({"g": [1, 2], "y": [10, 20]}) + assert_gpu_result_equal( + left.join(right, on="g").with_columns(expr), + engine=engine, + check_row_order=False, + ) + + def test_over_in_filter_unsupported(request, streaming_engine_factory) -> None: engine = streaming_engine_factory( StreamingOptions(max_rows_per_partition=1, fallback_mode="warn"), diff --git a/python/cudf_polars/tests/experimental/test_spmd.py b/python/cudf_polars/tests/experimental/test_spmd.py index fabaeedbc78..a61268cf387 100644 --- a/python/cudf_polars/tests/experimental/test_spmd.py +++ b/python/cudf_polars/tests/experimental/test_spmd.py @@ -393,3 +393,146 @@ def test_reset_rejects_construction_time_engine_options( ) with pytest.raises(ValueError, match="memory_resource_config"): engine._reset(engine_options={"memory_resource_config": None}) + + +# Group keys probed with num_partitions=2, nranks=2, ROUND_ROBIN: +# _SAME_RANK_KEYS[r] hashes to partition r: data stays on its origin rank. +# _CROSS_RANK_KEYS[r] hashes to partition 1-r: data is fully shuffled away. +# num_partitions=2 = max(nranks=2, local_count=1). local_count=1 requires +# max_rows_per_partition >= the number of rows per rank (3 here), so we use 4. +_SAME_RANK_KEYS = [ + 0, + 3, +] # g=0 hashes to partition 0 (rank 0); g=3 hashes to partition 1 (rank 1) +_CROSS_RANK_KEYS = [ + 3, + 0, +] # g=3 hashes to partition 1 (rank 1); g=0 hashes to partition 0 (rank 0) + + +@pytest.mark.parametrize( + "expr,is_scalar", + [ + (pl.col("x").sum().over("g").alias("result"), True), + (pl.col("x").rank(method="dense").over("g").alias("result"), False), + ], + ids=["scalar_sum", "nonscalar_rank"], +) +@pytest.mark.parametrize( + "cross_rank", + [False, True], + ids=["same_rank", "cross_rank"], +) +def test_over_multirank( + request: pytest.FixtureRequest, + comm: Communicator, + expr: pl.Expr, + is_scalar: bool, # noqa: FBT001 + cross_rank: bool, # noqa: FBT001 +) -> None: + """over() correctness in multi-rank SPMD mode, same-rank and cross-rank cases. + + same_rank: group keys hash to the origin rank's own partition (happy path). + cross_rank: group keys hash to the other rank's partition, exercising the + bug where row_idx spaces are rank-local so Phase 2 fills the wrong + accumulated slots and each rank receives the other rank's data. + + max_rows_per_partition=4 keeps all 3 rows in one chunk (local_count=1), + so num_partitions=max(nranks=2, 1)=2, matching the probed key assignments. + """ + with SPMDEngine( + comm=comm, + executor_options={"max_rows_per_partition": 4, "dynamic_planning": {}}, + ) as engine: + rank = engine.rank + nranks = engine.nranks + if nranks != 2: + request.applymarker( + pytest.mark.skip( + reason="key assignments are probed for exactly 2 ranks" + ) + ) + keys = _CROSS_RANK_KEYS if cross_rank else _SAME_RANK_KEYS + g = keys[rank] + xs = [rank * 3 + 1, rank * 3 + 2, rank * 3 + 3] + lf = pl.LazyFrame({"g": [g, g, g], "x": xs}) + local_result = lf.select(pl.col("g"), pl.col("x"), expr).collect(engine=engine) + + # Each rank must get back its OWN rows (not another rank's). + assert local_result["g"].unique().to_list() == [g], ( + f"rank {rank}: expected only group {g} in output, " + f"got {local_result['g'].unique().to_list()}" + ) + + with reserve_op_id() as op_id: + global_result = allgather_polars_dataframe( + engine=engine, local_df=local_result, op_id=op_id + ) + + assert global_result.shape == (3 * nranks, 3) + for r in range(nranks): + grp_g = keys[r] + grp = global_result.filter(pl.col("g") == grp_g).sort("x") + assert grp.shape == (3, 3), f"rank {r} group has wrong row count" + expected_xs = [r * 3 + 1, r * 3 + 2, r * 3 + 3] + assert grp["x"].to_list() == expected_xs + if is_scalar: + assert grp["result"].to_list() == [sum(expected_xs)] * 3 + else: + assert grp["result"].to_list() == [1, 2, 3] + + +def test_over_nonscalar_duplicated_input( + request: pytest.FixtureRequest, + comm: Communicator, +) -> None: + """Non-scalar over() on duplicated=True input produces correct row count and values. + + group_by() AllGathers its result onto all ranks (duplicated=True). The + non-scalar over() path must output duplicated=False and only insert rows on + rank 0, otherwise all ranks insert the same rows (N-fold overcounting) and + the downstream Repartition skips AllGather. + + max_rows_per_partition=10 keeps all 3 rows in one chunk (local_count=1), + so modulus=max(nranks=2, 1)=2, matching the _SAME_RANK_KEYS assignments. + """ + with SPMDEngine( + comm=comm, + executor_options={"max_rows_per_partition": 10, "dynamic_planning": {}}, + ) as engine: + rank = engine.rank + nranks = engine.nranks + if nranks != 2: + request.applymarker( + pytest.mark.skip( + reason="key assignments are probed for exactly 2 ranks" + ) + ) + + coarse_g = _SAME_RANK_KEYS[rank] + fine_gs = [rank * 3 + 1, rank * 3 + 2, rank * 3 + 3] + xs = [rank * 30 + 10, rank * 30 + 20, rank * 30 + 30] + lf = pl.LazyFrame({"fine_g": fine_gs, "coarse_g": [coarse_g] * 3, "x": xs}) + local_result = ( + lf.group_by("fine_g", "coarse_g") + .agg(pl.col("x").first()) + .with_columns( + pl.col("x").rank(method="dense").over("coarse_g").alias("rank_x") + ) + .collect(engine=engine) + ) + + with reserve_op_id() as op_id: + global_result = allgather_polars_dataframe( + engine=engine, local_df=local_result, op_id=op_id + ) + + assert global_result.shape == (3 * nranks, 4) + for r in range(nranks): + cg = _SAME_RANK_KEYS[r] + grp = global_result.filter(pl.col("coarse_g") == cg).sort("x") + assert grp.shape == (3, 4), f"coarse_g={cg}: wrong row count" + assert grp["rank_x"].to_list() == [1, 2, 3], ( + f"coarse_g={cg}: expected dense ranks [1, 2, 3] " + f"but got {grp['rank_x'].to_list()}" + )