Skip to content
59 changes: 37 additions & 22 deletions python/cudf_polars/cudf_polars/experimental/benchmarks/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,31 @@ def assert_tpch_result_equal(
right = right.with_columns(*float_casts)
left = left.with_columns(*float_casts)

non_float_columns = [
col for col in left.columns if left.schema[col] not in (pl.Float32, pl.Float64)
]
float_columns = [
col for col in left.columns if left.schema[col] in (pl.Float32, pl.Float64)
]
grouped_sort_columns = [*non_float_columns, *float_columns]

def sort_for_comparison(df: pl.DataFrame) -> pl.DataFrame:
# We know that each dataframe is sorted on `sort_by` according to itself.
# Now we have some freedom to reorder the rows. We'll use this freedom to avoid
# any kind of fuzziness from sorting on floating-point columns.
#
# As long as we sort by the non-float columns *first*, we'll avoid any
# false positives / false negatives from comparing two tables that have the
# same values but happen to be in a different order. Sorting by floating-point
# columns *last* ensures that records that are close to each other appear in
# (roughly) the same order, such that polar's approximate equality checks
# will allow them to be considered equal (or not, if the aren't actually close).
return (
df.sort(by=grouped_sort_columns, nulls_last=nulls_last)
if grouped_sort_columns
else df
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if sort_by:
by, descending = list(zip(*sort_by, strict=True))

Expand Down Expand Up @@ -245,17 +270,8 @@ def assert_tpch_result_equal(
details={"error": str(e)},
) from e

# We know that each dataframe is sorted on `sort_by` according to itself.
# Now we have some freedom to reorder the rows. We'll use this freedom to avoid
# any kind of sorting on floating-point columns, which introduces all sorts of
# fuzziness we don't want to deal with.
non_float_columns = [
col
for col in left.columns
if left.schema[col] not in (pl.Float32, pl.Float64)
]
left_sorted = left.sort(by=non_float_columns, nulls_last=nulls_last)
right_sorted = right.sort(by=non_float_columns, nulls_last=nulls_last)
left_sorted = sort_for_comparison(left)
right_sorted = sort_for_comparison(right)

if limit is None or left.is_empty():
try:
Expand Down Expand Up @@ -320,8 +336,8 @@ def assert_tpch_result_equal(

try:
polars.testing.assert_frame_equal(
result_first.sort(by=non_float_columns, nulls_last=nulls_last),
expected_first.sort(by=non_float_columns, nulls_last=nulls_last),
sort_for_comparison(result_first),
sort_for_comparison(expected_first),
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e:
Expand All @@ -339,12 +355,8 @@ def assert_tpch_result_equal(

try:
polars.testing.assert_frame_equal(
result_ties.sort(non_float_columns, nulls_last=nulls_last).select(
by
),
expected_ties.sort(non_float_columns, nulls_last=nulls_last).select(
by
),
sort_for_comparison(result_ties).select(by),
sort_for_comparison(expected_ties).select(by),
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e:
Expand All @@ -354,11 +366,14 @@ def assert_tpch_result_equal(
) from e

else:
# no sort_by, just a straight comparison.
left_sorted = sort_for_comparison(left)
right_sorted = sort_for_comparison(right)

# no sort_by, compare after grouped sort to ignore nondeterministic row order.
try:
polars.testing.assert_frame_equal(
left,
right,
left_sorted,
right_sorted,
Comment thread
TomAugspurger marked this conversation as resolved.
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e:
Expand Down
32 changes: 32 additions & 0 deletions python/cudf_polars/tests/testing/test_asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,3 +474,35 @@ def test_assert_tpch_result_equal_sort_keys_raises_not_sorted() -> None:
sort_keys=sort_keys,
nulls_last=True,
)


@pytest.mark.parametrize("sort_by", [[("a", True)], []])
@pytest.mark.parametrize("drop_columns", [[], ["b"], ["a", "b"]])
def test_assert_tpch_result_equal_grouped_float_sort(
sort_by: list[tuple[str, bool]], drop_columns: list[str]
) -> None:
# https://github.com/rapidsai/cudf/issues/22129
# Same non-float values with float values reordered inside each non-float group.
left = pl.DataFrame({"a": [1, 1, 1], "b": [2, 2, 2], "c": [1.0, 2.0, 3.0]})
right = pl.DataFrame({"a": [1, 1, 1], "b": [2, 2, 2], "c": [1.0, 2.999, 2.0]})

if drop_columns:
left = left.drop(drop_columns)
right = right.drop(drop_columns)
if "a" in drop_columns:
sort_by = []

assert_tpch_result_equal(
left, right, sort_by=sort_by, abs_tol=0.01, check_exact=False
)

# But this table is different, since row 3.0 - 2.9 > abs_tol.
right_different = pl.DataFrame(
{"a": [1, 1, 1], "b": [2, 2, 2], "c": [1.0, 2.90, 2.0]}
)
if drop_columns:
right_different = right_different.drop(drop_columns)
with pytest.raises(ValidationError, match="Result mismatch"):
assert_tpch_result_equal(
left, right_different, sort_by=sort_by, abs_tol=0.01, check_exact=False
)
Loading