diff --git a/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_median_where_string_filter/out.sql b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_median_where_string_filter/out.sql new file mode 100644 index 000000000000..4fe90af3cabc --- /dev/null +++ b/ibis/backends/bigquery/tests/unit/snapshots/test_compiler/test_approx_median_where_string_filter/out.sql @@ -0,0 +1,3 @@ +SELECT + APPROX_QUANTILES(IF(`t0`.`string_col` > 'a', `t0`.`float_col`, NULL), 2 IGNORE NULLS)[1] AS `ApproxMedian_float_col_Greater_string_col_'a'` +FROM `functional_alltypes` AS `t0` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/unit/test_compiler.py b/ibis/backends/bigquery/tests/unit/test_compiler.py index 7f7a3ec43875..4eb1d27e743b 100644 --- a/ibis/backends/bigquery/tests/unit/test_compiler.py +++ b/ibis/backends/bigquery/tests/unit/test_compiler.py @@ -558,6 +558,12 @@ def test_approx(alltypes, agg, where, snapshot): snapshot.assert_match(to_sql(expr), "out.sql") +def test_approx_median_where_string_filter(alltypes, snapshot): + t = alltypes + expr = t.float_col.approx_median(where=t.string_col > "a") + snapshot.assert_match(to_sql(expr), "out.sql") + + @pytest.mark.parametrize("funcname", ["bit_and", "bit_or", "bit_xor"]) @pytest.mark.parametrize( "where", diff --git a/ibis/backends/sql/compilers/bigquery/__init__.py b/ibis/backends/sql/compilers/bigquery/__init__.py index 37375a36d031..d83da12e3ed1 100644 --- a/ibis/backends/sql/compilers/bigquery/__init__.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -360,8 +360,47 @@ def visit_GeoSimplify(self, op, *, arg, tolerance, preserve_collapsed): ) return self.f.st_simplify(arg, tolerance) + def _visit_approx_quantile_helper(self, op, *, arg, where): + # BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return + # `resolution + 1` quantiles array. To handle this, we compute the + # resolution ourselves then restructure the output array as needed. + # To avoid excessive resolution we arbitrarily cap it at 100,000 - + # since these are approximate quantiles anyway this seems fine. + + quantiles = util.promote_list(op.quantile.value) + fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles] + resolution = min(math.lcm(*(den for _, den in fracs)), 100_000) + indices = [(num * resolution) // den for num, den in fracs] + + if where is not None: + arg = self.if_(where, arg, NULL) + + if not op.arg.dtype.is_floating(): + arg = self.cast(arg, dt.float64) + + array = self.f.approx_quantiles( + arg, sge.IgnoreNulls(this=sge.convert(resolution)) + ) + if isinstance(op, (ops.ApproxQuantile, ops.ApproxMedian)): + return array[indices[0]] + + if indices == list(range(resolution + 1)): + return array + else: + return sge.Array(expressions=[array[i] for i in indices]) + + def visit_ApproxQuantile(self, op, *, arg, quantile, where): + if not isinstance(op.quantile, ops.Literal): + raise com.UnsupportedOperationError( + "quantile must be a literal in BigQuery" + ) + return self._visit_approx_quantile_helper(op, arg=arg, where=where) + def visit_ApproxMedian(self, op, *, arg, where): - return self.agg.approx_quantiles(arg, 2, where=where)[self.f.offset(1)] + new_op = ops.ApproxQuantile(arg=op.arg, quantile=0.5, where=op.where) + return self._visit_approx_quantile_helper(new_op, arg=arg, where=where) + + visit_ApproxMultiQuantile = visit_ApproxQuantile def visit_Pi(self, op): return self.f.acos(-1) @@ -397,41 +436,6 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by): return sge.GroupConcat(this=arg, separator=sep) - def visit_ApproxQuantile(self, op, *, arg, quantile, where): - if not isinstance(op.quantile, ops.Literal): - raise com.UnsupportedOperationError( - "quantile must be a literal in BigQuery" - ) - - # BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return - # `resolution + 1` quantiles array. To handle this, we compute the - # resolution ourselves then restructure the output array as needed. - # To avoid excessive resolution we arbitrarily cap it at 100,000 - - # since these are approximate quantiles anyway this seems fine. - quantiles = util.promote_list(op.quantile.value) - fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles] - resolution = min(math.lcm(*(den for _, den in fracs)), 100_000) - indices = [(num * resolution) // den for num, den in fracs] - - if where is not None: - arg = self.if_(where, arg, NULL) - - if not op.arg.dtype.is_floating(): - arg = self.cast(arg, dt.float64) - - array = self.f.approx_quantiles( - arg, sge.IgnoreNulls(this=sge.convert(resolution)) - ) - if isinstance(op, ops.ApproxQuantile): - return array[indices[0]] - - if indices == list(range(resolution + 1)): - return array - else: - return sge.Array(expressions=[array[i] for i in indices]) - - visit_ApproxMultiQuantile = visit_ApproxQuantile - def visit_FloorDivide(self, op, *, left, right): return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index dab220909ab1..ca8c5609a898 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -1183,6 +1183,7 @@ def test_corr_cov( ) +@pytest.mark.parametrize("filtered", [False, True]) @pytest.mark.notimpl( ["mysql", "singlestoredb", "sqlite", "mssql", "druid"], raises=com.OperationNotDefinedError, @@ -1194,8 +1195,9 @@ def test_corr_cov( # Ref: https://materialize.com/docs/transform-data/patterns/percentiles/ raises=com.OperationNotDefinedError, ) -def test_approx_median(alltypes): - expr = alltypes.double_col.approx_median() +def test_approx_median(alltypes, filtered): + where = alltypes.int_col <= 100 if filtered else None + expr = alltypes.double_col.approx_median(where=where) result = expr.execute() assert isinstance(result, float)