-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[CALCITE-7620] Result of FILTER clause in window functions is incorrect #5040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2504,6 +2504,11 @@ private RexNode convertOver(Blackboard bb, SqlNode node) { | |
| SqlCall call = (SqlCall) node; | ||
| bb.getValidator().deriveType(bb.scope, call); | ||
| SqlCall aggCall = call.operand(0); | ||
| @Nullable SqlNode filter = null; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I rarely see the need for @nullable in code.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to satisfy the |
||
| if (aggCall.getKind() == SqlKind.FILTER) { | ||
| filter = aggCall.operand(1); | ||
| aggCall = aggCall.operand(0); | ||
| } | ||
| boolean ignoreNulls = false; | ||
| switch (aggCall.getKind()) { | ||
| case IGNORE_NULLS: | ||
|
|
@@ -2515,6 +2520,22 @@ private RexNode convertOver(Blackboard bb, SqlNode node) { | |
| default: | ||
| break; | ||
| } | ||
| if (filter != null) { | ||
| final SqlOperator op = aggCall.getOperator(); | ||
| if (op instanceof SqlAggFunction | ||
| && !((SqlAggFunction) op).requiresOver()) { | ||
| // FILTER on a windowed aggregate can be implemented by wrapping the | ||
| // aggregate arguments in CASE expressions, because true aggregates | ||
| // ignore NULL inputs. This does not work for window value functions | ||
| // (FIRST_VALUE, LAST_VALUE, NTH_VALUE, LEAD, LAG, etc.) which do not | ||
| // ignore NULL inputs. | ||
| aggCall = applyFilterToAggArgs(aggCall, filter); | ||
| bb.getValidator().deriveType(bb.scope, aggCall); | ||
| } else { | ||
| throw new UnsupportedOperationException( | ||
| "FILTER clause is not supported for window function " + op.getName()); | ||
| } | ||
| } | ||
|
|
||
| SqlNode windowOrRef = call.operand(1); | ||
| final SqlWindow window = | ||
|
|
@@ -2609,6 +2630,47 @@ private RexNode convertOver(Blackboard bb, SqlNode node) { | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Applies a FILTER clause to the arguments of an aggregate call by wrapping | ||
| * each argument in a CASE expression. For example, | ||
| * {@code SUM(sal) FILTER (WHERE comm IS NOT NULL)} becomes | ||
| * {@code SUM(CASE WHEN comm IS NOT NULL THEN sal END)}. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm unsure if the semantics of functions like FIRST_VALUE, LAST_VALUE, NTH_VALUE, LEAD, and LAG are correct after the rewrite.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your concern is valid. The rewrite is only semantically correct for true aggregate functions like SUM, COUNT, AVG, MIN, and MAX, because those functions ignore NULL inputs. However filter does not hold for window value functions such as FIRST_VALUE, LAST_VALUE, NTH_VALUE, LEAD, and LAG. such as sql in postgresql would error out: Therefore, based on the tests conducted so far, I believe there are no issues.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a related test?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
| * | ||
| * <p>This transformation preserves the semantics of the FILTER clause for | ||
| * windowed aggregates: rows that do not satisfy the filter contribute NULL | ||
| * and are ignored by the aggregate function. | ||
| */ | ||
| private static SqlCall applyFilterToAggArgs(SqlCall aggCall, SqlNode filter) { | ||
| final SqlOperator op = aggCall.getOperator(); | ||
| final List<SqlNode> operands = aggCall.getOperandList(); | ||
| final SqlParserPos pos = aggCall.getParserPosition(); | ||
| final SqlLiteral quantifier = aggCall.getFunctionQuantifier(); | ||
| final List<SqlNode> newOperands = new ArrayList<>(operands.size()); | ||
| if (op == SqlStdOperatorTable.COUNT | ||
| && operands.size() == 1 | ||
| && operands.get(0) instanceof SqlIdentifier | ||
| && ((SqlIdentifier) operands.get(0)).isStar()) { | ||
| // COUNT(*) FILTER (WHERE x) => COUNT(CASE WHEN x THEN 0 END) | ||
| newOperands.add( | ||
| new SqlCase(pos, null, SqlNodeList.of(filter), | ||
| SqlNodeList.of(SqlLiteral.createExactNumeric("0", pos)), | ||
| SqlLiteral.createNull(pos))); | ||
| } else { | ||
| for (SqlNode operand : operands) { | ||
| if (operand instanceof SqlIdentifier | ||
| && ((SqlIdentifier) operand).isStar()) { | ||
| newOperands.add(operand); | ||
| } else { | ||
| newOperands.add( | ||
| new SqlCase(pos, null, SqlNodeList.of(filter), | ||
| SqlNodeList.of(operand), | ||
| SqlLiteral.createNull(pos))); | ||
| } | ||
| } | ||
| } | ||
| return op.createCall(quantifier, pos, newOperands); | ||
| } | ||
|
|
||
| protected void convertFrom( | ||
| Blackboard bb, | ||
| @Nullable SqlNode from) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2430,11 +2430,12 @@ EnumerableCalc(expr#0..5=[{inputs}], expr#6=[RAND()], expr#7=[CAST($t6):INTEGER | |
| EnumerableSort(sort0=[$2], dir0=[ASC]) | ||
| EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], SAL=[$t5], DEPTNO=[$t7]) | ||
| EnumerableTableScan(table=[[scott, EMP]]) | ||
| EnumerableSort(sort0=[$1], dir0=[ASC]) | ||
| EnumerableCalc(expr#0..1=[{inputs}], expr#2=[false], expr#3=[1], expr#4=[<=($t1, $t3)], cs=[$t2], DEPTNO=[$t0], rn=[$t1], $condition=[$t4]) | ||
| EnumerableWindow(window#0=[window(partition {0} rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])], constants=[[false]]) | ||
| EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0]) | ||
| EnumerableTableScan(table=[[scott, DEPT]]) | ||
| EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[<=($t2, $t3)], proj#0..2=[{exprs}], $condition=[$t4]) | ||
| EnumerableSort(sort0=[$1], dir0=[ASC]) | ||
| EnumerableCalc(expr#0..1=[{inputs}], expr#2=[false], cs=[$t2], DEPTNO=[$t0], rn=[$t1]) | ||
| EnumerableWindow(window#0=[window(partition {0} rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])], constants=[[false]]) | ||
| EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0]) | ||
| EnumerableTableScan(table=[[scott, DEPT]]) | ||
| !plan | ||
| !} | ||
|
|
||
|
|
@@ -2540,11 +2541,12 @@ EnumerableCalc(expr#0..5=[{inputs}], expr#6=[NOT($t3)], expr#7=[IS NOT NULL($t3) | |
| EnumerableSort(sort0=[$2], dir0=[ASC]) | ||
| EnumerableCalc(expr#0..7=[{inputs}], EMPNO=[$t0], SAL=[$t5], DEPTNO=[$t7]) | ||
| EnumerableTableScan(table=[[scott, EMP]]) | ||
| EnumerableSort(sort0=[$1], dir0=[ASC]) | ||
| EnumerableCalc(expr#0..1=[{inputs}], expr#2=[false], expr#3=[1], expr#4=[<=($t1, $t3)], cs=[$t2], DEPTNO=[$t0], rn=[$t1], $condition=[$t4]) | ||
| EnumerableWindow(window#0=[window(partition {0} rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])], constants=[[false]]) | ||
| EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0]) | ||
| EnumerableTableScan(table=[[scott, DEPT]]) | ||
| EnumerableCalc(expr#0..2=[{inputs}], expr#3=[1], expr#4=[<=($t2, $t3)], proj#0..2=[{exprs}], $condition=[$t4]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious about the motivation behind this update?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fix altered the sorting metadata derivation for the window operator, resulting in a legitimate change to the execution plan structure of two existing !plan tests, but the computation results remain unchanged. Specific changes: Previously, the collation derivation for After the fix, Both are logically equivalent; only the relative positions of |
||
| EnumerableSort(sort0=[$1], dir0=[ASC]) | ||
| EnumerableCalc(expr#0..1=[{inputs}], expr#2=[false], cs=[$t2], DEPTNO=[$t0], rn=[$t1]) | ||
| EnumerableWindow(window#0=[window(partition {0} rows between UNBOUNDED PRECEDING and CURRENT ROW aggs [ROW_NUMBER()])], constants=[[false]]) | ||
| EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0]) | ||
| EnumerableTableScan(table=[[scott, DEPT]]) | ||
| !plan | ||
| !} | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason for modifying
RelMdCollation.windowis that the original window sorting derivation was too optimistic, which would cause the optimizer to mistakenly believe that the window output retained the input order, thus mistakenly deleting the top-level Sort.The original implementation had the following problem:
Previously,
RelMdCollation.windowdirectly returnedmq.collations(input), meaning "the window operator will preserve the order of the input rows as is." However, the actual implementation ofEnumerableWindowfirst groups the rows by thePARTITION BYkey usingSortedMultiMap, and then sorts them within each group by thewindow ORDER BYkey. Therefore, the input order is not preserved; the global output order isPARTITION BY keys + ORDER BY keys, not simply the input order.This caused the top-level
Sortto be incorrectly optimized away.When
order by empnois written in the SQL, if the window also happens to be sorted byempno, the optimizer will mistakenly assume that the window output is globally ordered, thus deleting the top-levelEnumerableSort. The resulting output is grouped bydeptno, not sorted byempno.