Skip to content

Commit 3368f33

Browse files
committed
fix(bigquery): ApproxMedian with where clause generates invalid call to ApproxQuantile
1 parent d603320 commit 3368f33

4 files changed

Lines changed: 53 additions & 38 deletions

File tree

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
SELECT
2+
APPROX_QUANTILES(IF(`t0`.`string_col` > 'a', `t0`.`float_col`, NULL), 2 IGNORE NULLS)[1] AS `ApproxMedian_float_col_Greater_string_col_'a'`
3+
FROM `functional_alltypes` AS `t0`

ibis/backends/bigquery/tests/unit/test_compiler.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,6 +558,12 @@ def test_approx(alltypes, agg, where, snapshot):
558558
snapshot.assert_match(to_sql(expr), "out.sql")
559559

560560

561+
def test_approx_median_where_string_filter(alltypes, snapshot):
562+
t = alltypes
563+
expr = t.float_col.approx_median(where=t.string_col > "a")
564+
snapshot.assert_match(to_sql(expr), "out.sql")
565+
566+
561567
@pytest.mark.parametrize("funcname", ["bit_and", "bit_or", "bit_xor"])
562568
@pytest.mark.parametrize(
563569
"where",

ibis/backends/sql/compilers/bigquery/__init__.py

Lines changed: 40 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,47 @@ def visit_GeoSimplify(self, op, *, arg, tolerance, preserve_collapsed):
360360
)
361361
return self.f.st_simplify(arg, tolerance)
362362

363+
def _visit_approx_quantile_helper(self, op, *, arg, where):
364+
# BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return
365+
# `resolution + 1` quantiles array. To handle this, we compute the
366+
# resolution ourselves then restructure the output array as needed.
367+
# To avoid excessive resolution we arbitrarily cap it at 100,000 -
368+
# since these are approximate quantiles anyway this seems fine.
369+
370+
quantiles = util.promote_list(op.quantile.value)
371+
fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles]
372+
resolution = min(math.lcm(*(den for _, den in fracs)), 100_000)
373+
indices = [(num * resolution) // den for num, den in fracs]
374+
375+
if where is not None:
376+
arg = self.if_(where, arg, NULL)
377+
378+
if not op.arg.dtype.is_floating():
379+
arg = self.cast(arg, dt.float64)
380+
381+
array = self.f.approx_quantiles(
382+
arg, sge.IgnoreNulls(this=sge.convert(resolution))
383+
)
384+
if isinstance(op, (ops.ApproxQuantile, ops.ApproxMedian)):
385+
return array[indices[0]]
386+
387+
if indices == list(range(resolution + 1)):
388+
return array
389+
else:
390+
return sge.Array(expressions=[array[i] for i in indices])
391+
392+
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
393+
if not isinstance(op.quantile, ops.Literal):
394+
raise com.UnsupportedOperationError(
395+
"quantile must be a literal in BigQuery"
396+
)
397+
return self._visit_approx_quantile_helper(op, arg=arg, where=where)
398+
363399
def visit_ApproxMedian(self, op, *, arg, where):
364-
return self.agg.approx_quantiles(arg, 2, where=where)[self.f.offset(1)]
400+
new_op = ops.ApproxQuantile(arg=op.arg, quantile=0.5, where=op.where)
401+
return self._visit_approx_quantile_helper(new_op, arg=arg, where=where)
402+
403+
visit_ApproxMultiQuantile = visit_ApproxQuantile
365404

366405
def visit_Pi(self, op):
367406
return self.f.acos(-1)
@@ -397,41 +436,6 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):
397436

398437
return sge.GroupConcat(this=arg, separator=sep)
399438

400-
def visit_ApproxQuantile(self, op, *, arg, quantile, where):
401-
if not isinstance(op.quantile, ops.Literal):
402-
raise com.UnsupportedOperationError(
403-
"quantile must be a literal in BigQuery"
404-
)
405-
406-
# BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return
407-
# `resolution + 1` quantiles array. To handle this, we compute the
408-
# resolution ourselves then restructure the output array as needed.
409-
# To avoid excessive resolution we arbitrarily cap it at 100,000 -
410-
# since these are approximate quantiles anyway this seems fine.
411-
quantiles = util.promote_list(op.quantile.value)
412-
fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles]
413-
resolution = min(math.lcm(*(den for _, den in fracs)), 100_000)
414-
indices = [(num * resolution) // den for num, den in fracs]
415-
416-
if where is not None:
417-
arg = self.if_(where, arg, NULL)
418-
419-
if not op.arg.dtype.is_floating():
420-
arg = self.cast(arg, dt.float64)
421-
422-
array = self.f.approx_quantiles(
423-
arg, sge.IgnoreNulls(this=sge.convert(resolution))
424-
)
425-
if isinstance(op, ops.ApproxQuantile):
426-
return array[indices[0]]
427-
428-
if indices == list(range(resolution + 1)):
429-
return array
430-
else:
431-
return sge.Array(expressions=[array[i] for i in indices])
432-
433-
visit_ApproxMultiQuantile = visit_ApproxQuantile
434-
435439
def visit_FloorDivide(self, op, *, left, right):
436440
return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype)
437441

ibis/backends/tests/test_aggregation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,6 +1183,7 @@ def test_corr_cov(
11831183
)
11841184

11851185

1186+
@pytest.mark.parametrize("filtered", [False, True])
11861187
@pytest.mark.notimpl(
11871188
["mysql", "singlestoredb", "sqlite", "mssql", "druid"],
11881189
raises=com.OperationNotDefinedError,
@@ -1194,8 +1195,9 @@ def test_corr_cov(
11941195
# Ref: https://materialize.com/docs/transform-data/patterns/percentiles/
11951196
raises=com.OperationNotDefinedError,
11961197
)
1197-
def test_approx_median(alltypes):
1198-
expr = alltypes.double_col.approx_median()
1198+
def test_approx_median(alltypes, filtered):
1199+
where = alltypes.int_col <= 100 if filtered else None
1200+
expr = alltypes.double_col.approx_median(where=where)
11991201
result = expr.execute()
12001202
assert isinstance(result, float)
12011203

0 commit comments

Comments
 (0)