Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,12 @@ private static CascadesContext newContext(Optional<CascadesContext> parent, Opti
StatementContext statementContext, Plan initPlan, CTEContext cteContext,
PhysicalProperties requireProperties, boolean isLeadingDisableJoinReorder,
CTEContext recursiveCteContext) {
return new CascadesContext(parent, subtree, statementContext, initPlan, null,
CascadesContext cascadesContext = new CascadesContext(parent, subtree, statementContext, initPlan, null,
cteContext, requireProperties, isLeadingDisableJoinReorder, recursiveCteContext);
if (parent.isPresent() && parent.get().getOuterScope().isPresent()) {
cascadesContext.setOuterScope(parent.get().getOuterScope().get());
}
return cascadesContext;
}

public CascadesContext getRoot() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,8 @@ public class Rewriter extends AbstractBatchJobExecutor {
),
// query rewrite support window, so add this rule here
custom(RuleType.AGG_SCALAR_SUBQUERY_TO_WINDOW_FUNCTION, AggScalarSubQueryToWindowFunction::new),
custom(RuleType.PULL_UP_CTE_ANCHOR, PullUpCteAnchor::new),
custom(RuleType.CTE_INLINE, CTEInline::new),
bottomUp(
new EliminateUselessPlanUnderApply(),
// CorrelateApplyToUnCorrelateApply and ApplyToJoin
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Alias;
Expand All @@ -33,6 +34,8 @@
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
Expand Down Expand Up @@ -234,8 +237,16 @@ private AnalyzedResult analyzeSubquery(SubqueryExpr expr) {
getScope().getSlots(), getScope().getAsteriskSlots());
subqueryContext.setOuterScope(subqueryScope);
subqueryContext.newAnalyzer().analyze();
return new AnalyzedResult((LogicalPlan) subqueryContext.getRewritePlan(),
Comment thread
starocean999 marked this conversation as resolved.
StatementContext statementContext = cascadesContext.getStatementContext();
subqueryContext.getRewritePlan().collect(LogicalCTEAnchor.class::isInstance).stream()
.forEach(cteAnchor -> statementContext.addToMustLineCTEs(((LogicalCTEAnchor) cteAnchor).getCteId()));
AnalyzedResult analyzedResult = new AnalyzedResult((LogicalPlan) subqueryContext.getRewritePlan(),
subqueryScope.getCorrelatedSlots());
if (analyzedResult.hasCorrelatedSlotsUnderCteProducer()) {
throw new AnalysisException(
"Unsupported correlated subquery in cte " + analyzedResult.getLogicalPlan());
}
return analyzedResult;
}

public Scope getScope() {
Expand Down Expand Up @@ -273,6 +284,12 @@ public boolean hasCorrelatedSlotsUnderAgg() {
ImmutableSet.copyOf(correlatedSlots), LogicalAggregate.class);
}

public boolean hasCorrelatedSlotsUnderCteProducer() {
return correlatedSlots.isEmpty() ? false
: hasCorrelatedSlotsUnderNode(logicalPlan,
ImmutableSet.copyOf(correlatedSlots), LogicalCTEProducer.class);
}

private static <T> boolean hasCorrelatedSlotsUnderNode(Plan rootPlan,
ImmutableSet<Slot> slots, Class<T> clazz) {
ArrayDeque<Plan> planQueue = new ArrayDeque<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.trees.copier.DeepCopierContext;
import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier;
Expand Down Expand Up @@ -86,6 +87,10 @@ public Plan visitLogicalCTEAnchor(LogicalCTEAnchor<? extends Plan, ? extends Pla
return false;
});
if (mustInlineCTEs.contains(cteAnchor.getCteId())) {
if (containsNondeterministicFunction((LogicalCTEProducer<?>) cteAnchor.left())) {
throw new AnalysisException(
"Inline CTE required; failed due to containing nondeterministic functions.");
}
// should inline
Plan root = cteAnchor.right().accept(this, (LogicalCTEProducer<?>) cteAnchor.left());
// process child
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
import org.apache.doris.nereids.rules.rewrite.UnCorrelatedApplyFilter;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nullable;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
Expand All @@ -45,6 +48,7 @@
import org.junit.jupiter.api.Test;

import java.util.List;
import java.util.Set;

public class AnalyzeCTETest extends TestWithFeService implements MemoPatternMatchSupported {

Expand Down Expand Up @@ -87,6 +91,23 @@ public class AnalyzeCTETest extends TestWithFeService implements MemoPatternMatc

private final String cteWithDiffRelationId = "with s as (select * from supplier) select * from s as s1, s as s2";

private final List<String> correlatedSubqueryWithCteSqls = ImmutableList.of(
"SELECT * FROM supplier outer_s WHERE EXISTS ("
+ "WITH picked AS (SELECT s_suppkey, s_region FROM supplier) "
+ "SELECT 1 FROM picked WHERE picked.s_region = outer_s.s_region)",
"SELECT * FROM supplier outer_s WHERE outer_s.s_suppkey IN ("
+ "WITH picked AS (SELECT s_suppkey, s_region FROM supplier) "
+ "SELECT picked.s_suppkey FROM picked WHERE picked.s_region = outer_s.s_region)"
);

private final String scalarSubqueryWithCteSql = "SELECT * FROM supplier outer_s WHERE outer_s.s_suppkey = ("
+ "WITH picked AS (SELECT s_suppkey FROM supplier WHERE s_nation = 'PERU') "
+ "SELECT min(picked.s_suppkey) FROM picked)";

private final String correlatedSlotUnderCteProducerSql = "SELECT * FROM supplier outer_s WHERE EXISTS ("
+ "WITH picked AS (SELECT s_suppkey FROM supplier WHERE s_region = outer_s.s_region) "
+ "SELECT 1 FROM picked)";

private final List<String> testSqls = ImmutableList.of(
multiCte, cteWithColumnAlias, cteConsumerInSubQuery, cteConsumerJoin, cteReferToAnotherOne, cteJoinSelf,
cteNested, cteInTheMiddle, cteWithDiffRelationId
Expand Down Expand Up @@ -174,6 +195,40 @@ public void testCTEWithAlias() {
);
}

@Test
public void testCorrelatedSubqueryWithCte() throws Exception {
for (String sql : correlatedSubqueryWithCteSqls) {
StatementScopeIdGenerator.clear();
Plan plan = PlanChecker.from(connectContext)
.analyze(sql)
.getPlan();
Set<Plan> applyNodes = plan.collect(LogicalApply.class::isInstance);

Assertions.assertEquals(1, applyNodes.size(), sql);
LogicalApply<?, ?> apply = (LogicalApply<?, ?>) applyNodes.iterator().next();
Assertions.assertTrue(apply.isCorrelated(), sql);
Assertions.assertTrue(apply.child(1).anyMatch(LogicalCTEAnchor.class::isInstance), sql);
}
}

@Test
public void testScalarSubqueryWithCte() throws Exception {
Plan plan = PlanChecker.from(connectContext)
.analyze(scalarSubqueryWithCteSql)
.getPlan();
Set<Plan> cteAnchors = plan.collect(LogicalCTEAnchor.class::isInstance);

Assertions.assertFalse(cteAnchors.isEmpty());
}

@Test
public void testCorrelatedSlotUnderCteProducerInSubquery() {
AnalysisException exception = Assertions.assertThrows(AnalysisException.class,
() -> PlanChecker.from(connectContext).analyze(correlatedSlotUnderCteProducerSql),
"Not throw expected exception.");
Assertions.assertTrue(exception.getMessage().contains("Unsupported correlated subquery in cte"));
}

@Test
public void testCTEWithAnExistedTableOrViewName() {
PlanChecker.from(connectContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.nereids.NereidsPlanner;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.parser.NereidsParser;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.plans.commands.ExplainCommand;
Expand All @@ -29,6 +30,7 @@
import org.apache.doris.qe.OriginStatement;
import org.apache.doris.utframe.TestWithFeService;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

public class CTEInlineTest extends TestWithFeService implements MemoPatternMatchSupported {
Expand Down Expand Up @@ -81,4 +83,40 @@ public void recCteInline() {
).when(cte -> cte.getCteName().equals("yy"))
);
}

@Test
public void recCteInlineRequireDeterministicProducer() {
String sql = new StringBuilder().append("with recursive t1 as (\n").append(" select\n")
.append(" 1 as c1,\n").append(" 1 as c2\n").append("),\n").append("t2 as (\n")
.append(" select\n").append(" 2 as c1,\n").append(" 2 as c2\n").append("),\n")
.append("t3 as (\n").append(" select\n").append(" 3 as c1,\n").append(" 3 as c2\n")
.append("),\n").append("xx as (\n").append(" select\n").append(" c1,\n")
.append(" c2\n").append(" from\n").append(" t1\n").append(" union\n")
.append(" select\n").append(" t2.c1,\n").append(" t2.c2\n").append(" from\n")
.append(" t2,\n").append(" xx\n").append(" where\n")
.append(" t2.c1 = xx.c1\n").append(" and rand() > 0\n")
.append("),\n").append("yy as (\n").append(" select\n").append(" c1,\n")
.append(" c2\n").append(" from\n").append(" t3\n").append(" union\n")
.append(" select\n").append(" t3.c1,\n").append(" t3.c2\n").append(" from\n")
.append(" t3,\n").append(" yy,\n").append(" xx\n").append(" where\n")
.append(" t3.c1 = yy.c1\n").append(" and t3.c2 = xx.c1\n").append(")\n")
.append("select\n").append(" *\n").append("from\n").append(" yy y1,\n")
.append(" yy y2;")
.toString();
LogicalPlan unboundPlan = new NereidsParser().parseSingle(sql);
StatementContext statementContext = new StatementContext(connectContext,
new OriginStatement(sql, 0));
NereidsPlanner planner = new NereidsPlanner(statementContext);
boolean originalNotEvalNondeterministicFunction = connectContext.notEvalNondeterministicFunction();
connectContext.setNotEvalNondeterministicFunction(true);
try {
AnalysisException exception = Assertions.assertThrows(AnalysisException.class,
() -> planner.planWithLock(unboundPlan, PhysicalProperties.ANY,
ExplainCommand.ExplainLevel.REWRITTEN_PLAN));
Assertions.assertTrue(exception.getMessage()
.contains("Inline CTE required; failed due to containing nondeterministic functions."));
} finally {
connectContext.setNotEvalNondeterministicFunction(originalNotEvalNondeterministicFunction);
}
}
}
10 changes: 10 additions & 0 deletions regression-test/data/query_p0/subquery/test_subquery_in_cte.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select_exists --
east 1

-- !select_scalar --
east 1

-- !select_in --
east 1

Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// Regression test for duplicate RelationId bug in simple CASE WHEN with subqueries.
// The bug caused "groupExpression already exists in memo" error because the simple
// case value (a subquery) was duplicated into multiple EqualTo nodes at parse time.

suite("test_subquery_in_cte") {

multi_sql """
DROP TABLE IF EXISTS test_subquery_in_cte_users;
DROP TABLE IF EXISTS test_subquery_in_cte_orders;
CREATE TABLE test_subquery_in_cte_users (
user_id INT,
region VARCHAR(16)
)
DUPLICATE KEY(user_id)
DISTRIBUTED BY HASH(user_id) BUCKETS 1
PROPERTIES("replication_num"="1");
CREATE TABLE test_subquery_in_cte_orders (
order_id INT,
user_id INT,
qty INT
)
DUPLICATE KEY(order_id, user_id)
DISTRIBUTED BY HASH(order_id) BUCKETS 1
PROPERTIES("replication_num"="1");
INSERT INTO test_subquery_in_cte_users VALUES (1, 'east'), (2, 'west');
INSERT INTO test_subquery_in_cte_orders VALUES (10, 1, 2), (11, 2, 1);
"""
qt_select_exists """SELECT
u.region,
COUNT(*) AS user_cnt
FROM test_subquery_in_cte_users u
WHERE EXISTS (
WITH picked AS (
SELECT
o.user_id AS uid
FROM test_subquery_in_cte_orders o
WHERE o.qty >= 2
)
SELECT 1
FROM picked p
WHERE p.uid = u.user_id
)
GROUP BY u.region
ORDER BY 1, 2;
"""

qt_select_scalar """SELECT
u.region,
COUNT(*) AS user_cnt
FROM test_subquery_in_cte_users u
WHERE user_id = (
WITH picked AS (
SELECT
o.user_id AS uid
FROM test_subquery_in_cte_orders o
WHERE o.qty >= 2
)
SELECT uid
FROM picked p
)
GROUP BY u.region
ORDER BY 1, 2;
"""

qt_select_in """SELECT
u.region,
COUNT(*) AS user_cnt
FROM test_subquery_in_cte_users u
WHERE user_id in (
WITH picked AS (
SELECT
o.user_id AS uid
FROM test_subquery_in_cte_orders o
WHERE o.qty >= 2
)
SELECT uid
FROM picked p
)
GROUP BY u.region
ORDER BY 1, 2;
"""
}
Loading