Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 46 additions & 18 deletions python/cudf_polars/cudf_polars/experimental/benchmarks/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,15 +250,16 @@ def assert_tpch_result_equal(
details={"error": str(e)},
) from e

# We know that each dataframe is sorted on `sort_by` according to itself.
# Now we have some freedom to reorder the rows. We'll use this freedom to avoid
# any kind of sorting on floating-point columns, which introduces all sorts of
# fuzziness we don't want to deal with.
non_float_columns = [
col
for col in left.columns
if left.schema[col] not in (pl.Float32, pl.Float64)
]
float_columns = [
col for col in left.columns if left.schema[col] in (pl.Float32, pl.Float64)
]
# if non-float sort leaves ambiguity, use float as tie-breaker
all_sort_columns = non_float_columns + float_columns
left_sorted = left.sort(by=non_float_columns, nulls_last=nulls_last)
right_sorted = right.sort(by=non_float_columns, nulls_last=nulls_last)

Expand All @@ -269,10 +270,17 @@ def assert_tpch_result_equal(
right_sorted,
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e:
raise ValidationError(
message="Result mismatch", details={"error": str(e)}
) from e
except AssertionError:
try:
polars.testing.assert_frame_equal(
left.sort(by=all_sort_columns, nulls_last=nulls_last),
right.sort(by=all_sort_columns, nulls_last=nulls_last),
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e2:
raise ValidationError(
message="Result mismatch", details={"error": str(e2)}
) from e2

else:
# Handle the .sort_by(...).head(n) case; First, split the data into two parts
Expand Down Expand Up @@ -329,11 +337,19 @@ def assert_tpch_result_equal(
expected_first.sort(by=non_float_columns, nulls_last=nulls_last),
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e:
raise ValidationError(
message="Result mismatch in non-ties part",
details={"error": str(e)},
) from e
except AssertionError:
# Non-float sort left ambiguous ties; retry with float columns as secondary key
try:
polars.testing.assert_frame_equal(
result_first.sort(by=all_sort_columns, nulls_last=nulls_last),
expected_first.sort(by=all_sort_columns, nulls_last=nulls_last),
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e2:
raise ValidationError(
message="Result mismatch in non-ties part",
details={"error": str(e2)},
) from e2

# We already know that the lengths match (we've validated that the
# *total* lengths match and the non-ties lengths match, so this rump
Expand All @@ -352,11 +368,23 @@ def assert_tpch_result_equal(
),
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e:
raise ValidationError(
message="Result mismatch in ties part",
details={"error": str(e)},
) from e
except AssertionError:
# Non-float sort left ambiguous ties; retry with float columns
try:
polars.testing.assert_frame_equal(
result_ties.sort(
all_sort_columns, nulls_last=nulls_last
).select(by),
expected_ties.sort(
all_sort_columns, nulls_last=nulls_last
).select(by),
**polars_kwargs, # type: ignore[arg-type]
)
except AssertionError as e2:
raise ValidationError(
message="Result mismatch in ties part",
details={"error": str(e2)},
) from e2

else:
# no sort_by, just a straight comparison.
Expand Down
Loading