diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index e46d2729899cdf..27995afa656b51 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -222,8 +222,12 @@ private static CascadesContext newContext(Optional parent, Opti StatementContext statementContext, Plan initPlan, CTEContext cteContext, PhysicalProperties requireProperties, boolean isLeadingDisableJoinReorder, CTEContext recursiveCteContext) { - return new CascadesContext(parent, subtree, statementContext, initPlan, null, + CascadesContext cascadesContext = new CascadesContext(parent, subtree, statementContext, initPlan, null, cteContext, requireProperties, isLeadingDisableJoinReorder, recursiveCteContext); + if (parent.isPresent() && parent.get().getOuterScope().isPresent()) { + cascadesContext.setOuterScope(parent.get().getOuterScope().get()); + } + return cascadesContext; } public CascadesContext getRoot() { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index d078f4c3fe02fb..0dd75a645ddaf2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -466,6 +466,8 @@ public class Rewriter extends AbstractBatchJobExecutor { ), // query rewrite support window, so add this rule here custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION, AggScalarSubQueryToWindowFunction::new), + custom(RuleType.PULL_UP_CTE_ANCHOR, PullUpCteAnchor::new), + custom(RuleType.CTE_INLINE, CTEInline::new), bottomUp( new EliminateUselessPlanUnderApply(), // CorrelateApplyToUnCorrelateApply and ApplyToJoin diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubExprAnalyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubExprAnalyzer.java index 3db749ab5ed785..bf0ea28c4e89f3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubExprAnalyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/SubExprAnalyzer.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.rules.analysis; import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.analyzer.Scope; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Alias; @@ -33,6 +34,8 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; @@ -234,8 +237,16 @@ private AnalyzedResult analyzeSubquery(SubqueryExpr expr) { getScope().getSlots(), getScope().getAsteriskSlots()); subqueryContext.setOuterScope(subqueryScope); subqueryContext.newAnalyzer().analyze(); - return new AnalyzedResult((LogicalPlan) subqueryContext.getRewritePlan(), + StatementContext statementContext = cascadesContext.getStatementContext(); + subqueryContext.getRewritePlan().collect(LogicalCTEAnchor.class::isInstance).stream() + .forEach(cteAnchor -> statementContext.addToMustLineCTEs(((LogicalCTEAnchor) cteAnchor).getCteId())); + AnalyzedResult analyzedResult = new AnalyzedResult((LogicalPlan) subqueryContext.getRewritePlan(), subqueryScope.getCorrelatedSlots()); + if (analyzedResult.hasCorrelatedSlotsUnderCteProducer()) { + throw new AnalysisException( + "Unsupported correlated subquery in cte " + analyzedResult.getLogicalPlan()); + } + return analyzedResult; } public Scope getScope() { @@ -273,6 +284,12 @@ public boolean hasCorrelatedSlotsUnderAgg() { ImmutableSet.copyOf(correlatedSlots), LogicalAggregate.class); } + public boolean hasCorrelatedSlotsUnderCteProducer() { + return correlatedSlots.isEmpty() ? false + : hasCorrelatedSlotsUnderNode(logicalPlan, + ImmutableSet.copyOf(correlatedSlots), LogicalCTEProducer.class); + } + private static boolean hasCorrelatedSlotsUnderNode(Plan rootPlan, ImmutableSet slots, Class clazz) { ArrayDeque planQueue = new ArrayDeque<>(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CTEInline.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CTEInline.java index 9983c1062da58a..a15b9f02fc1202 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CTEInline.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CTEInline.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.trees.copier.DeepCopierContext; import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier; @@ -86,6 +87,10 @@ public Plan visitLogicalCTEAnchor(LogicalCTEAnchor) cteAnchor.left())) { + throw new AnalysisException( + "Inline CTE required; failed due to containing nondeterministic functions."); + } // should inline Plan root = cteAnchor.right().accept(this, (LogicalCTEProducer) cteAnchor.left()); // process child diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java index 95db4890049dd9..d24f61f1c46cef 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/AnalyzeCTETest.java @@ -31,7 +31,10 @@ import org.apache.doris.nereids.rules.rewrite.UnCorrelatedApplyFilter; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.commands.ExplainCommand; +import org.apache.doris.nereids.trees.plans.logical.LogicalApply; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan; import org.apache.doris.nereids.util.MemoPatternMatchSupported; @@ -45,6 +48,7 @@ import org.junit.jupiter.api.Test; import java.util.List; +import java.util.Set; public class AnalyzeCTETest extends TestWithFeService implements MemoPatternMatchSupported { @@ -87,6 +91,23 @@ public class AnalyzeCTETest extends TestWithFeService implements MemoPatternMatc private final String cteWithDiffRelationId = "with s as (select * from supplier) select * from s as s1, s as s2"; + private final List correlatedSubqueryWithCteSqls = ImmutableList.of( + "SELECT * FROM supplier outer_s WHERE EXISTS (" + + "WITH picked AS (SELECT s_suppkey, s_region FROM supplier) " + + "SELECT 1 FROM picked WHERE picked.s_region = outer_s.s_region)", + "SELECT * FROM supplier outer_s WHERE outer_s.s_suppkey IN (" + + "WITH picked AS (SELECT s_suppkey, s_region FROM supplier) " + + "SELECT picked.s_suppkey FROM picked WHERE picked.s_region = outer_s.s_region)" + ); + + private final String scalarSubqueryWithCteSql = "SELECT * FROM supplier outer_s WHERE outer_s.s_suppkey = (" + + "WITH picked AS (SELECT s_suppkey FROM supplier WHERE s_nation = 'PERU') " + + "SELECT min(picked.s_suppkey) FROM picked)"; + + private final String correlatedSlotUnderCteProducerSql = "SELECT * FROM supplier outer_s WHERE EXISTS (" + + "WITH picked AS (SELECT s_suppkey FROM supplier WHERE s_region = outer_s.s_region) " + + "SELECT 1 FROM picked)"; + private final List testSqls = ImmutableList.of( multiCte, cteWithColumnAlias, cteConsumerInSubQuery, cteConsumerJoin, cteReferToAnotherOne, cteJoinSelf, cteNested, cteInTheMiddle, cteWithDiffRelationId @@ -174,6 +195,40 @@ public void testCTEWithAlias() { ); } + @Test + public void testCorrelatedSubqueryWithCte() throws Exception { + for (String sql : correlatedSubqueryWithCteSqls) { + StatementScopeIdGenerator.clear(); + Plan plan = PlanChecker.from(connectContext) + .analyze(sql) + .getPlan(); + Set applyNodes = plan.collect(LogicalApply.class::isInstance); + + Assertions.assertEquals(1, applyNodes.size(), sql); + LogicalApply apply = (LogicalApply) applyNodes.iterator().next(); + Assertions.assertTrue(apply.isCorrelated(), sql); + Assertions.assertTrue(apply.child(1).anyMatch(LogicalCTEAnchor.class::isInstance), sql); + } + } + + @Test + public void testScalarSubqueryWithCte() throws Exception { + Plan plan = PlanChecker.from(connectContext) + .analyze(scalarSubqueryWithCteSql) + .getPlan(); + Set cteAnchors = plan.collect(LogicalCTEAnchor.class::isInstance); + + Assertions.assertFalse(cteAnchors.isEmpty()); + } + + @Test + public void testCorrelatedSlotUnderCteProducerInSubquery() { + AnalysisException exception = Assertions.assertThrows(AnalysisException.class, + () -> PlanChecker.from(connectContext).analyze(correlatedSlotUnderCteProducerSql), + "Not throw expected exception."); + Assertions.assertTrue(exception.getMessage().contains("Unsupported correlated subquery in cte")); + } + @Test public void testCTEWithAnExistedTableOrViewName() { PlanChecker.from(connectContext) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CTEInlineTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CTEInlineTest.java index e9767ef524bd75..61ceb31ad04fd7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CTEInlineTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CTEInlineTest.java @@ -19,6 +19,7 @@ import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.trees.plans.commands.ExplainCommand; @@ -29,6 +30,7 @@ import org.apache.doris.qe.OriginStatement; import org.apache.doris.utframe.TestWithFeService; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; public class CTEInlineTest extends TestWithFeService implements MemoPatternMatchSupported { @@ -81,4 +83,40 @@ public void recCteInline() { ).when(cte -> cte.getCteName().equals("yy")) ); } + + @Test + public void recCteInlineRequireDeterministicProducer() { + String sql = new StringBuilder().append("with recursive t1 as (\n").append(" select\n") + .append(" 1 as c1,\n").append(" 1 as c2\n").append("),\n").append("t2 as (\n") + .append(" select\n").append(" 2 as c1,\n").append(" 2 as c2\n").append("),\n") + .append("t3 as (\n").append(" select\n").append(" 3 as c1,\n").append(" 3 as c2\n") + .append("),\n").append("xx as (\n").append(" select\n").append(" c1,\n") + .append(" c2\n").append(" from\n").append(" t1\n").append(" union\n") + .append(" select\n").append(" t2.c1,\n").append(" t2.c2\n").append(" from\n") + .append(" t2,\n").append(" xx\n").append(" where\n") + .append(" t2.c1 = xx.c1\n").append(" and rand() > 0\n") + .append("),\n").append("yy as (\n").append(" select\n").append(" c1,\n") + .append(" c2\n").append(" from\n").append(" t3\n").append(" union\n") + .append(" select\n").append(" t3.c1,\n").append(" t3.c2\n").append(" from\n") + .append(" t3,\n").append(" yy,\n").append(" xx\n").append(" where\n") + .append(" t3.c1 = yy.c1\n").append(" and t3.c2 = xx.c1\n").append(")\n") + .append("select\n").append(" *\n").append("from\n").append(" yy y1,\n") + .append(" yy y2;") + .toString(); + LogicalPlan unboundPlan = new NereidsParser().parseSingle(sql); + StatementContext statementContext = new StatementContext(connectContext, + new OriginStatement(sql, 0)); + NereidsPlanner planner = new NereidsPlanner(statementContext); + boolean originalNotEvalNondeterministicFunction = connectContext.notEvalNondeterministicFunction(); + connectContext.setNotEvalNondeterministicFunction(true); + try { + AnalysisException exception = Assertions.assertThrows(AnalysisException.class, + () -> planner.planWithLock(unboundPlan, PhysicalProperties.ANY, + ExplainCommand.ExplainLevel.REWRITTEN_PLAN)); + Assertions.assertTrue(exception.getMessage() + .contains("Inline CTE required; failed due to containing nondeterministic functions.")); + } finally { + connectContext.setNotEvalNondeterministicFunction(originalNotEvalNondeterministicFunction); + } + } } diff --git a/regression-test/data/query_p0/subquery/test_subquery_in_cte.out b/regression-test/data/query_p0/subquery/test_subquery_in_cte.out new file mode 100644 index 00000000000000..b74978d3d6f52b --- /dev/null +++ b/regression-test/data/query_p0/subquery/test_subquery_in_cte.out @@ -0,0 +1,10 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select_exists -- +east 1 + +-- !select_scalar -- +east 1 + +-- !select_in -- +east 1 + diff --git a/regression-test/suites/query_p0/subquery/test_subquery_in_cte.groovy b/regression-test/suites/query_p0/subquery/test_subquery_in_cte.groovy new file mode 100644 index 00000000000000..08a8ee0a7dc8e6 --- /dev/null +++ b/regression-test/suites/query_p0/subquery/test_subquery_in_cte.groovy @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Regression test for duplicate RelationId bug in simple CASE WHEN with subqueries. +// The bug caused "groupExpression already exists in memo" error because the simple +// case value (a subquery) was duplicated into multiple EqualTo nodes at parse time. + +suite("test_subquery_in_cte") { + + multi_sql """ + DROP TABLE IF EXISTS test_subquery_in_cte_users; + DROP TABLE IF EXISTS test_subquery_in_cte_orders; + CREATE TABLE test_subquery_in_cte_users ( + user_id INT, + region VARCHAR(16) + ) + DUPLICATE KEY(user_id) + DISTRIBUTED BY HASH(user_id) BUCKETS 1 + PROPERTIES("replication_num"="1"); + CREATE TABLE test_subquery_in_cte_orders ( + order_id INT, + user_id INT, + qty INT + ) + DUPLICATE KEY(order_id, user_id) + DISTRIBUTED BY HASH(order_id) BUCKETS 1 + PROPERTIES("replication_num"="1"); + INSERT INTO test_subquery_in_cte_users VALUES (1, 'east'), (2, 'west'); + INSERT INTO test_subquery_in_cte_orders VALUES (10, 1, 2), (11, 2, 1); + """ + qt_select_exists """SELECT + u.region, + COUNT(*) AS user_cnt + FROM test_subquery_in_cte_users u + WHERE EXISTS ( + WITH picked AS ( + SELECT + o.user_id AS uid + FROM test_subquery_in_cte_orders o + WHERE o.qty >= 2 + ) + SELECT 1 + FROM picked p + WHERE p.uid = u.user_id + ) + GROUP BY u.region + ORDER BY 1, 2; + """ + + qt_select_scalar """SELECT + u.region, + COUNT(*) AS user_cnt + FROM test_subquery_in_cte_users u + WHERE user_id = ( + WITH picked AS ( + SELECT + o.user_id AS uid + FROM test_subquery_in_cte_orders o + WHERE o.qty >= 2 + ) + SELECT uid + FROM picked p + ) + GROUP BY u.region + ORDER BY 1, 2; + """ + + qt_select_in """SELECT + u.region, + COUNT(*) AS user_cnt + FROM test_subquery_in_cte_users u + WHERE user_id in ( + WITH picked AS ( + SELECT + o.user_id AS uid + FROM test_subquery_in_cte_orders o + WHERE o.qty >= 2 + ) + SELECT uid + FROM picked p + ) + GROUP BY u.region + ORDER BY 1, 2; + """ +}