Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
5861da9
Support streaming over expression
Matt711 Apr 14, 2026
66be8e9
Add fast path for scalar aggs case, support preserving the full oreder
Matt711 Apr 16, 2026
3b15601
add more tests, protect against OOMs when preserving the full order
Matt711 Apr 17, 2026
6b91f1b
check style
Matt711 Apr 17, 2026
ac73ede
add pre-shuffle test, and clarifying comments
Matt711 Apr 17, 2026
0c8e3de
add multi-rank test
Matt711 Apr 17, 2026
696bbc9
style
Matt711 Apr 17, 2026
fdc40c9
Merge branch 'main' into fea/polars/streaming-over
Matt711 Apr 17, 2026
e477439
reject non decomposable aggregations
Matt711 Apr 17, 2026
4a6eb6d
fix xfail condition for mg test
Matt711 Apr 17, 2026
0c48e1d
Merge branch 'main' into fea/polars/streaming-over
Matt711 Apr 17, 2026
0d6d6b1
Merge branch 'main' into fea/polars/streaming-over
Matt711 Apr 17, 2026
01d8777
Merge branch 'main' into fea/polars/streaming-over
Matt711 Apr 17, 2026
f1cf13c
Add Over IR node
Matt711 Apr 21, 2026
4ac5e65
decompose select with over expressions with mixed partition by keys
Matt711 Apr 21, 2026
c02eea9
move IR to type checking block
Matt711 Apr 21, 2026
edc704e
add Over IR node
Matt711 Apr 21, 2026
5d28fb9
merge conflict
Matt711 Apr 22, 2026
4fc4d25
merge conflict
Matt711 May 1, 2026
e1442f2
Merge branch 'main' of https://github.com/rapidsai/cudf into fea/pola…
Matt711 May 2, 2026
d632dd4
add _decompose_grouped_window_node
rjzamora May 4, 2026
a7a914a
Merge remote-tracking branch 'upstream/main' into streaming-over-rick
rjzamora May 4, 2026
6bee1b8
cull unnecessary code
rjzamora May 4, 2026
4ca3e3c
address review
Matt711 May 4, 2026
d690b1e
heavy revisions
rjzamora May 4, 2026
9dab468
Merge remote-tracking branch 'upstream/main' into fea/polars/streamin…
rjzamora May 4, 2026
f8db130
Merge branch 'streaming-over-rick' into fea/polars/streaming-over
rjzamora May 4, 2026
3893253
Merge branch 'main' into fea/polars/streaming-over
rjzamora May 4, 2026
ca1e808
fix non-scalar over() by absorbing passthrough cols into Over node
Matt711 May 4, 2026
ec93d63
fix non-scalar over() with duplicated=True input
Matt711 May 5, 2026
dc28164
Fix assertion failures in assert_tpch_result_equal due to float sort …
Matt711 May 5, 2026
c270bb1
remove comment
Matt711 May 5, 2026
f275788
update decompose condition for dynamic planning
Matt711 May 5, 2026
6051ac5
Fix scalar-over path: update _DECOMPOSABLE_AGG_NAMES, assert no-presh…
Matt711 May 5, 2026
b272f6c
Fix over_actor: remove stale key assertion, fix shuffle modulus to us…
Matt711 May 5, 2026
4d1b482
Add Over.do_evaluate and remove eval_ir indirection in over_actor
Matt711 May 5, 2026
85f77fd
Extract _allgather_and_broadcast and _shuffle_and_reassemble from ove…
Matt711 May 5, 2026
4c4a33f
Fix shuffle modulus: AllGather total size/count and compute from targ…
Matt711 May 5, 2026
563bac7
restoring _fuse_over_nodes
Matt711 May 6, 2026
dad1a79
merge conflict
Matt711 May 6, 2026
30e8826
update fixture name
Matt711 May 6, 2026
d86d0d8
More Polars plan optimizations for TPC-DS
Matt711 May 6, 2026
cb94dd4
Merge branch 'main' into imp/pdsds/more-pds-optimizations
Matt711 May 6, 2026
f9fdd46
add a second (reverse) shuffle for mg correctness
Matt711 May 6, 2026
c6a7b10
estimate the modulus
Matt711 May 6, 2026
3051291
remove xfail marker from test
Matt711 May 6, 2026
bd076c3
simplify skip condition in spmd over test
Matt711 May 6, 2026
12620c8
set allow_subset=True
Matt711 May 6, 2026
5f4dd3e
remove input_ir arg
Matt711 May 6, 2026
09a4586
simplify conditional statement
Matt711 May 6, 2026
ccaee76
TODO tracer API
Matt711 May 6, 2026
bc584d7
rename split* instead of boundaries
Matt711 May 6, 2026
a723dc8
use int32
Matt711 May 6, 2026
902a4cd
clean up docstring OriginStamps docstring
Matt711 May 6, 2026
338757a
Merge branch 'main' into fea/polars/streaming-over
Matt711 May 6, 2026
0a94101
fix upstream polars tests
Matt711 May 7, 2026
370fc21
Merge branch 'main' into fea/polars/streaming-over
Matt711 May 7, 2026
5b25eea
Merge branch 'main' into fea/polars/streaming-over
Matt711 May 7, 2026
7c81adf
address driveby nits
Matt711 May 7, 2026
68a7983
drop id() as dict key
Matt711 May 7, 2026
7975ff4
simplify _evaluate_window_with_stamps
Matt711 May 7, 2026
a6f397c
address more nits
Matt711 May 7, 2026
ea39bd7
docstrings & use names_to_indices
Matt711 May 7, 2026
658e4bf
simplify scalar-Over IR + cleanup
Matt711 May 7, 2026
d3b6562
colleect ir rewrite for scalar path ahead of time, comment improvemen…
Matt711 May 7, 2026
7b828be
add module doc string overviewing the algorithm, few smaller clean ups
Matt711 May 7, 2026
397d50f
more clean ups
Matt711 May 7, 2026
946582a
Merge branch 'main' into fea/polars/streaming-over
Matt711 May 7, 2026
f7687f1
oh yeah, dont use to_arrow
Matt711 May 7, 2026
beee7d9
merge conflict
Matt711 May 11, 2026
2b8b357
Merge branch 'main' of https://github.com/rapidsai/cudf into fea/pola…
Matt711 May 11, 2026
fc9f7bb
no allgather unecessarily
Matt711 May 11, 2026
a965b27
Merge PR #22191
quasiben May 11, 2026
64fe9f3
Merge PR #22378
quasiben May 11, 2026
ac4c9d1
Merge PR #22395
quasiben May 11, 2026
b1e57de
Merge upstream main into pds-ds-all
quasiben May 12, 2026
3a23609
update NormalizedPartitioning.from_keys call site
Matt711 May 12, 2026
c671992
Revert "remove comment"
TomAugspurger May 12, 2026
3c7fda9
Revert "Fix assertion failures in assert_tpch_result_equal due to flo…
TomAugspurger May 12, 2026
65827dc
Update floating-point handling
TomAugspurger May 12, 2026
9c66981
Merge branch 'main' into bug/pdsds/q64-validation
TomAugspurger May 12, 2026
6f29e67
merge conflict
Matt711 May 12, 2026
de156af
Merge PR #22191 updates
quasiben May 12, 2026
4f69818
Merge PR #22378 updates
quasiben May 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 49 additions & 15 deletions python/cudf_polars/cudf_polars/experimental/benchmarks/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,15 +247,33 @@ def assert_tpch_result_equal(

# We know that each dataframe is sorted on `sort_by` according to itself.
# Now we have some freedom to reorder the rows. We'll use this freedom to avoid
# any kind of sorting on floating-point columns, which introduces all sorts of
# fuzziness we don't want to deal with.
# any kind of fuzziness from sorting on floating-point columns.
#
# As long as we sort by the non-float columns *first*, we'll avoid any
# false positives / false negatives from comparing two tables that have the
# same values but happen to be in a different order. Sorting by floating-point
# columns *last* ensures that records that are close to each other appear in
# (roughly) the same order, such that polar's approximate equality checks
# will allow them to be considered equal (or not, if the aren't actually close).
non_float_columns = [
col
for col in left.columns
if left.schema[col] not in (pl.Float32, pl.Float64)
]
left_sorted = left.sort(by=non_float_columns, nulls_last=nulls_last)
right_sorted = right.sort(by=non_float_columns, nulls_last=nulls_last)
float_columns = [
col for col in left.columns if left.schema[col] in (pl.Float32, pl.Float64)
]
grouped_sort_columns = [*non_float_columns, *float_columns]

def sort_for_comparison(df: pl.DataFrame) -> pl.DataFrame:
return (
df.sort(by=grouped_sort_columns, nulls_last=nulls_last)
if grouped_sort_columns
else df
)

left_sorted = sort_for_comparison(left)
right_sorted = sort_for_comparison(right)

if limit is None or left.is_empty():
try:
Expand Down Expand Up @@ -320,8 +338,8 @@ def assert_tpch_result_equal(

try:
polars.testing.assert_frame_equal(
result_first.sort(by=non_float_columns, nulls_last=nulls_last),
expected_first.sort(by=non_float_columns, nulls_last=nulls_last),
sort_for_comparison(result_first),
sort_for_comparison(expected_first),
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e:
Expand All @@ -339,12 +357,8 @@ def assert_tpch_result_equal(

try:
polars.testing.assert_frame_equal(
result_ties.sort(non_float_columns, nulls_last=nulls_last).select(
by
),
expected_ties.sort(non_float_columns, nulls_last=nulls_last).select(
by
),
sort_for_comparison(result_ties).select(by),
sort_for_comparison(expected_ties).select(by),
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e:
Expand All @@ -354,11 +368,31 @@ def assert_tpch_result_equal(
) from e

else:
# no sort_by, just a straight comparison.
non_float_columns = [
col
for col in left.columns
if left.schema[col] not in (pl.Float32, pl.Float64)
]
float_columns = [
col for col in left.columns if left.schema[col] in (pl.Float32, pl.Float64)
]
grouped_sort_columns = [*non_float_columns, *float_columns]
left_sorted = (
left.sort(by=grouped_sort_columns, nulls_last=nulls_last)
if grouped_sort_columns
else left
)
right_sorted = (
right.sort(by=grouped_sort_columns, nulls_last=nulls_last)
if grouped_sort_columns
else right
)

# no sort_by, compare after grouped sort to ignore nondeterministic row order.
try:
polars.testing.assert_frame_equal(
left,
right,
left_sorted,
right_sorted,
**polars_kwargs, # type: ignore[arg-type]
Comment on lines 370 to 396
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

check_row_order is neutralized when sort_by is empty

This branch always canonicalizes row order before comparison, so check_row_order=True no longer enforces original order and can hide ordering regressions.

Suggested fix
 else:
-    non_float_columns = [
-        col
-        for col in left.columns
-        if left.schema[col] not in (pl.Float32, pl.Float64)
-    ]
-    float_columns = [
-        col for col in left.columns if left.schema[col] in (pl.Float32, pl.Float64)
-    ]
-    grouped_sort_columns = [*non_float_columns, *float_columns]
-    left_sorted = (
-        left.sort(by=grouped_sort_columns, nulls_last=nulls_last)
-        if grouped_sort_columns
-        else left
-    )
-    right_sorted = (
-        right.sort(by=grouped_sort_columns, nulls_last=nulls_last)
-        if grouped_sort_columns
-        else right
-    )
+    if check_row_order:
+        left_sorted = left
+        right_sorted = right
+    else:
+        non_float_columns = [
+            col
+            for col in left.columns
+            if left.schema[col] not in (pl.Float32, pl.Float64)
+        ]
+        float_columns = [
+            col for col in left.columns if left.schema[col] in (pl.Float32, pl.Float64)
+        ]
+        grouped_sort_columns = [*non_float_columns, *float_columns]
+        left_sorted = (
+            left.sort(by=grouped_sort_columns, nulls_last=nulls_last)
+            if grouped_sort_columns
+            else left
+        )
+        right_sorted = (
+            right.sort(by=grouped_sort_columns, nulls_last=nulls_last)
+            if grouped_sort_columns
+            else right
+        )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@python/cudf_polars/cudf_polars/experimental/benchmarks/asserts.py` around
lines 370 - 396, The current else-branch always reorders rows via
grouped_sort_columns before comparing, which neutralizes check_row_order; update
the logic around left_sorted/right_sorted so that when check_row_order is True
you do NOT sort (keep left/right as-is) and only perform the grouped_sort when
check_row_order is False (preserving existing behavior when sort_by is empty);
adjust the variables used in the call to polars.testing.assert_frame_equal
(left_sorted/right_sorted) accordingly so the comparison respects the
check_row_order flag.

)
except AssertionError as e:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,65 +335,66 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
item = get_data(run_config.dataset_path, "item", run_config.suffix)
date_dim = get_data(run_config.dataset_path, "date_dim", run_config.suffix)

cross_items = build_cross_items(
store_sales, catalog_sales, web_sales, item, date_dim, year=year
)
average_sales = build_average_sales(
store_sales, catalog_sales, web_sales, date_dim, year=year
)

# week_dates is ≤7 rows (one calendar week), computed once as a 1-partition frame.
# Push the week filter into each channel before the UNION via a semi-join so that
# ~99% of rows (everything outside the target week) are dropped before the
# expensive cross_items join and groupby.
target_week = (
date_dim.filter(
(pl.col("d_year") == year + 1)
& (pl.col("d_moy") == 12)
& (pl.col("d_dom") == day)
)
.select("d_week_seq")
.unique()
)
week_dates = date_dim.join(target_week, on="d_week_seq").select("d_date_sk")

all_sales = pl.concat(
[
store_sales.select(
store_sales.join(
week_dates, left_on="ss_sold_date_sk", right_on="d_date_sk", how="semi"
).select(
[
pl.lit("store").alias("channel"),
pl.col("ss_item_sk").alias("item_sk"),
pl.col("ss_quantity").alias("quantity"),
pl.col("ss_list_price").alias("list_price"),
pl.col("ss_sold_date_sk").alias("date_sk"),
]
),
catalog_sales.select(
catalog_sales.join(
week_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi"
).select(
[
pl.lit("catalog").alias("channel"),
pl.col("cs_item_sk").alias("item_sk"),
pl.col("cs_quantity").alias("quantity"),
pl.col("cs_list_price").alias("list_price"),
pl.col("cs_sold_date_sk").alias("date_sk"),
]
),
web_sales.select(
web_sales.join(
week_dates, left_on="ws_sold_date_sk", right_on="d_date_sk", how="semi"
).select(
[
pl.lit("web").alias("channel"),
pl.col("ws_item_sk").alias("item_sk"),
pl.col("ws_quantity").alias("quantity"),
pl.col("ws_list_price").alias("list_price"),
pl.col("ws_sold_date_sk").alias("date_sk"),
]
),
]
)

cross_items = build_cross_items(
store_sales, catalog_sales, web_sales, item, date_dim, year=year
)
average_sales = build_average_sales(
store_sales, catalog_sales, web_sales, date_dim, year=year
)

# d_week_seq target is the same for all 3 channels; compute it once.
target_week = (
date_dim.filter(
(pl.col("d_year") == year + 1)
& (pl.col("d_moy") == 12)
& (pl.col("d_dom") == day)
)
.select("d_week_seq")
.unique()
)
week_dates = date_dim.join(target_week, on="d_week_seq").select("d_date_sk")

# Build y: all 3 channels in a single pipeline.
# cross_items and average_sales each appear once — no CSE needed.
# After group_by the frame is tiny, so the cross join with the 1-row
# average_sales frame is negligible even if Polars fuses it into an IEJoin.
y = (
all_sales.join(cross_items, left_on="item_sk", right_on="ss_item_sk")
.join(item, left_on="item_sk", right_on="i_item_sk")
.join(week_dates, left_on="date_sk", right_on="d_date_sk")
.group_by(["channel", "i_brand_id", "i_class_id", "i_category_id"])
.agg(
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,38 +110,103 @@ def polars_impl(run_config: RunConfig) -> QueryResult:
sort_by = {"i_item_id": False, "i_item_desc": False, "s_state": False}
limit = 100

store_sales_base = (
q1 = f"{year}Q1"
q1_q3 = [f"{year}Q1", f"{year}Q2", f"{year}Q3"]

# Pre-filter date_dim to only qualifying d_date_sk values.
d1_dates = date_dim.filter(pl.col("d_quarter_name") == q1).select("d_date_sk")
d_q3_dates = date_dim.filter(pl.col("d_quarter_name").is_in(q1_q3)).select(
"d_date_sk"
)

# store_returns has [6] partitions — at the broadcast limit. Filter it to Q1-Q3 dates
# first, then use the (customer, item) pairs it contains to pre-filter both store_sales
# and catalog_sales before those larger tables enter the expensive shuffle joins.
store_returns_filtered = store_returns.join(
d_q3_dates, left_on="sr_returned_date_sk", right_on="d_date_sk", how="semi"
).select(["sr_customer_sk", "sr_item_sk", "sr_ticket_number", "sr_return_quantity"])

# (customer, item) pairs present in any qualifying store return; stays at [6] partitions
# so broadcast is free. Polars will CACHE this shared subplan.
sr_customer_item = store_returns_filtered.select(["sr_customer_sk", "sr_item_sk"])

store_sales_filtered = (
store_sales.join(
date_dim, left_on="ss_sold_date_sk", right_on="d_date_sk", suffix="_d1"
d1_dates, left_on="ss_sold_date_sk", right_on="d_date_sk", how="semi"
)
.join(
sr_customer_item,
left_on=["ss_customer_sk", "ss_item_sk"],
right_on=["sr_customer_sk", "sr_item_sk"],
how="semi",
)
.select(
[
"ss_customer_sk",
"ss_item_sk",
"ss_store_sk",
"ss_ticket_number",
"ss_quantity",
]
)
.join(
item.select(["i_item_sk", "i_item_id", "i_item_desc"]),
left_on="ss_item_sk",
right_on="i_item_sk",
)
.join(
store.select(["s_store_sk", "s_state"]),
left_on="ss_store_sk",
right_on="s_store_sk",
)
.select(
[
"ss_customer_sk",
"ss_item_sk",
"ss_ticket_number",
"ss_quantity",
"i_item_id",
"i_item_desc",
"s_state",
]
)
.join(item, left_on="ss_item_sk", right_on="i_item_sk")
.join(store, left_on="ss_store_sk", right_on="s_store_sk")
.filter(pl.col("d_quarter_name") == f"{year}Q1")
)

store_returns_base = store_returns.join(
date_dim, left_on="sr_returned_date_sk", right_on="d_date_sk", suffix="_d2"
).filter(pl.col("d_quarter_name").is_in([f"{year}Q1", f"{year}Q2", f"{year}Q3"]))

catalog_sales_base = catalog_sales.join(
date_dim, left_on="cs_sold_date_sk", right_on="d_date_sk", suffix="_d3"
).filter(pl.col("d_quarter_name").is_in([f"{year}Q1", f"{year}Q2", f"{year}Q3"]))
catalog_sales_filtered = (
catalog_sales.join(
d_q3_dates, left_on="cs_sold_date_sk", right_on="d_date_sk", how="semi"
)
.join(
sr_customer_item,
left_on=["cs_bill_customer_sk", "cs_item_sk"],
right_on=["sr_customer_sk", "sr_item_sk"],
how="semi",
)
.select(["cs_bill_customer_sk", "cs_item_sk", "cs_quantity"])
)

return QueryResult(
frame=(
store_sales_base.join(
store_returns_base,
store_sales_filtered.join(
store_returns_filtered,
left_on=["ss_customer_sk", "ss_item_sk", "ss_ticket_number"],
right_on=["sr_customer_sk", "sr_item_sk", "sr_ticket_number"],
how="inner",
suffix="_sr",
)
.select(
[
"ss_customer_sk",
"ss_item_sk",
"ss_quantity",
"sr_return_quantity",
"i_item_id",
"i_item_desc",
"s_state",
]
)
.join(
catalog_sales_base,
catalog_sales_filtered,
left_on=["ss_customer_sk", "ss_item_sk"],
right_on=["cs_bill_customer_sk", "cs_item_sk"],
how="inner",
suffix="_cs",
)
.group_by(["i_item_id", "i_item_desc", "s_state"])
.agg(
Expand Down
Loading
Loading