diff --git a/python/cudf_polars/cudf_polars/experimental/benchmarks/asserts.py b/python/cudf_polars/cudf_polars/experimental/benchmarks/asserts.py index e8c80e480cd..34eda03f20e 100644 --- a/python/cudf_polars/cudf_polars/experimental/benchmarks/asserts.py +++ b/python/cudf_polars/cudf_polars/experimental/benchmarks/asserts.py @@ -250,15 +250,16 @@ def assert_tpch_result_equal( details={"error": str(e)}, ) from e - # We know that each dataframe is sorted on `sort_by` according to itself. - # Now we have some freedom to reorder the rows. We'll use this freedom to avoid - # any kind of sorting on floating-point columns, which introduces all sorts of - # fuzziness we don't want to deal with. non_float_columns = [ col for col in left.columns if left.schema[col] not in (pl.Float32, pl.Float64) ] + float_columns = [ + col for col in left.columns if left.schema[col] in (pl.Float32, pl.Float64) + ] + # if non-float sort leaves ambiguity, use float as tie-breaker + all_sort_columns = non_float_columns + float_columns left_sorted = left.sort(by=non_float_columns, nulls_last=nulls_last) right_sorted = right.sort(by=non_float_columns, nulls_last=nulls_last) @@ -269,10 +270,17 @@ def assert_tpch_result_equal( right_sorted, **polars_kwargs, # type: ignore[arg-type] ) - except AssertionError as e: - raise ValidationError( - message="Result mismatch", details={"error": str(e)} - ) from e + except AssertionError: + try: + polars.testing.assert_frame_equal( + left.sort(by=all_sort_columns, nulls_last=nulls_last), + right.sort(by=all_sort_columns, nulls_last=nulls_last), + **polars_kwargs, # type: ignore[arg-type] + ) + except AssertionError as e2: + raise ValidationError( + message="Result mismatch", details={"error": str(e2)} + ) from e2 else: # Handle the .sort_by(...).head(n) case; First, split the data into two parts @@ -329,11 +337,19 @@ def assert_tpch_result_equal( expected_first.sort(by=non_float_columns, nulls_last=nulls_last), **polars_kwargs, # type: ignore[arg-type] ) - except AssertionError as e: - raise ValidationError( - message="Result mismatch in non-ties part", - details={"error": str(e)}, - ) from e + except AssertionError: + # Non-float sort left ambiguous ties; retry with float columns as secondary key + try: + polars.testing.assert_frame_equal( + result_first.sort(by=all_sort_columns, nulls_last=nulls_last), + expected_first.sort(by=all_sort_columns, nulls_last=nulls_last), + **polars_kwargs, # type: ignore[arg-type] + ) + except AssertionError as e2: + raise ValidationError( + message="Result mismatch in non-ties part", + details={"error": str(e2)}, + ) from e2 # We already know that the lengths match (we've validated that the # *total* lengths match and the non-ties lengths match, so this rump @@ -352,11 +368,23 @@ def assert_tpch_result_equal( ), **polars_kwargs, # type: ignore[arg-type] ) - except AssertionError as e: - raise ValidationError( - message="Result mismatch in ties part", - details={"error": str(e)}, - ) from e + except AssertionError: + # Non-float sort left ambiguous ties; retry with float columns + try: + polars.testing.assert_frame_equal( + result_ties.sort( + all_sort_columns, nulls_last=nulls_last + ).select(by), + expected_ties.sort( + all_sort_columns, nulls_last=nulls_last + ).select(by), + **polars_kwargs, # type: ignore[arg-type] + ) + except AssertionError as e2: + raise ValidationError( + message="Result mismatch in ties part", + details={"error": str(e2)}, + ) from e2 else: # no sort_by, just a straight comparison.