From 0e21e692154ad52daa3927338551167718e1d859 Mon Sep 17 00:00:00 2001 From: Stamatis Zampetakis Date: Mon, 15 Jun 2026 11:15:27 +0200 Subject: [PATCH] [CALCITE-7604] Add rule to pull up GROUP BY above JOIN --- .../rel/rules/AggregateJoinTransposeRule.java | 2 +- .../apache/calcite/rel/rules/CoreRules.java | 5 + .../rel/rules/JoinAggregateTransposeRule.java | 210 ++++++++++++++++++ .../test/JoinAggregateTransposeRuleTest.java | 168 ++++++++++++++ .../test/JoinAggregateTransposeRuleTest.xml | 202 +++++++++++++++++ .../test/resources/sql/join-agg-transpose.iq | 78 +++++++ 6 files changed, 664 insertions(+), 1 deletion(-) create mode 100644 core/src/main/java/org/apache/calcite/rel/rules/JoinAggregateTransposeRule.java create mode 100644 core/src/test/java/org/apache/calcite/test/JoinAggregateTransposeRuleTest.java create mode 100644 core/src/test/resources/org/apache/calcite/test/JoinAggregateTransposeRuleTest.xml create mode 100644 core/src/test/resources/sql/join-agg-transpose.iq diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java index 934b7633a9f7..c56f61ba1388 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java @@ -129,7 +129,7 @@ public AggregateJoinTransposeRule(Class aggregateClass, allowFunctions); } - private static boolean isAggregateSupported(Aggregate aggregate, + static boolean isAggregateSupported(Aggregate aggregate, boolean allowFunctions) { if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) { return false; diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java index 873d04016f6f..bc21c41dd366 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java @@ -654,6 +654,11 @@ private CoreRules() {} public static final JoinExtractFilterRule JOIN_EXTRACT_FILTER = JoinExtractFilterRule.Config.DEFAULT.toRule(); + /** Rule that pulls an {@link Aggregate} from the left input of a + * {@link Join} to above the join (group-by pull up). */ + public static final JoinAggregateTransposeRule JOIN_AGGREGATE_TRANSPOSE = + JoinAggregateTransposeRule.Config.DEFAULT.toRule(); + /** Rule that matches a {@link LogicalJoin} whose inputs are * {@link LogicalProject}s, and pulls the project expressions up. */ public static final JoinProjectTransposeRule JOIN_PROJECT_BOTH_TRANSPOSE = diff --git a/core/src/main/java/org/apache/calcite/rel/rules/JoinAggregateTransposeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/JoinAggregateTransposeRule.java new file mode 100644 index 000000000000..99cf46141d7c --- /dev/null +++ b/core/src/main/java/org/apache/calcite/rel/rules/JoinAggregateTransposeRule.java @@ -0,0 +1,210 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.rel.rules; + +import org.apache.calcite.plan.RelOptRuleCall; +import org.apache.calcite.plan.RelRule; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.core.Aggregate; +import org.apache.calcite.rel.core.Join; +import org.apache.calcite.rel.core.JoinInfo; +import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexUtil; +import org.apache.calcite.tools.RelBuilder; +import org.apache.calcite.util.ImmutableBitSet; +import org.apache.calcite.util.mapping.MappingType; +import org.apache.calcite.util.mapping.Mappings; + +import org.immutables.value.Value; + +import java.util.ArrayList; +import java.util.List; + +import static org.apache.calcite.rel.rules.AggregateJoinTransposeRule.isAggregateSupported; + +/** + * Planner rule that pulls an + * {@link org.apache.calcite.rel.core.Aggregate} + * from below a {@link org.apache.calcite.rel.core.Join} to above it. + * + *

Before + *


+ * SELECT s.sales
+ * FROM (SELECT ss_sold_date_sk, SUM(ss_sales_price) AS sales
+ *       FROM store_sales
+ *       GROUP BY ss_sold_date_sk) s
+ * JOIN date_dim d
+ *   ON s.ss_sold_date_sk = d.d_date_sk
+ * WHERE d.d_year = 2000
+ * 
+ * + *

After + *


+ * SELECT SUM(ss_sales_price) AS sales
+ * FROM store_sales s
+ * JOIN date_dim d
+ *   ON s.ss_sold_date_sk = d.d_date_sk
+ * WHERE d.d_year = 2000
+ * GROUP BY s.ss_sold_date_sk
+ * 
+ * + *

This rule implements the simplest form of group-by pull up transformation + * described in the following papers: + * + *

+ * + *

The papers contain additional variants ("lazy" aggregation) not currently + * implemented. + * + * @see CoreRules#JOIN_AGGREGATE_TRANSPOSE + */ +@Value.Enclosing +public class JoinAggregateTransposeRule + extends RelRule + implements TransformationRule { + + protected JoinAggregateTransposeRule(Config config) { + super(config); + } + + @Override public final boolean matches(RelOptRuleCall call) { + final Join join = call.rel(0); + final Aggregate left = call.rel(1); + final RelNode right = call.rel(2); + final JoinInfo info = join.analyzeCondition(); + final RelMetadataQuery mq = call.getMetadataQuery(); + + // Only handle INNER equijoins with simple aggregates for now. + // Join keys on the agg side must reference only group-by columns + // (ensures row elimination removes whole groups, not partial) + ImmutableBitSet groupOutput = ImmutableBitSet.range(left.getGroupCount()); + return join.getJoinType() == JoinRelType.INNER + && info.isEqui() + // We could potentially relax the check for the supported functions + // in this rule. I opted to keep things more constrained for now + // in case we decide to extend this rule for lazy aggregation. + && isAggregateSupported(left, true) + && groupOutput.contains(info.leftSet()) + // The right side must be unique on its join keys (no row duplication) + && Boolean.TRUE.equals(mq.areColumnsUnique(right, info.rightSet())); + } + + @Override public void onMatch(RelOptRuleCall call) { + final Join join = call.rel(0); + final Aggregate left = call.rel(1); + final RelNode aggInput = left.getInput(); + final RelNode right = join.getRight(); + + // Build the transformation + final int rawFieldCount = aggInput.getRowType().getFieldCount(); + final int leftFields = left.getRowType().getFieldCount(); + final int rightFields = right.getRowType().getFieldCount(); + final List groupList = left.getGroupSet().toList(); + + // Remap join condition: replace references to left output columns + // with references to raw aggInput columns in the new join layout. + // Old join: [agg output (leftFields) | other (rightFields)] + // New join: [aggInput (rawFieldCount) | other (rightFields)] + final int oldJoinWidth = join.getRowType().getFieldCount(); + final int newJoinWidth = rawFieldCount + rightFields; + + final Mappings.TargetMapping condMapping = + Mappings.create(MappingType.FUNCTION, oldJoinWidth, newJoinWidth); + // Agg output positions 0..groupCount-1 -> raw aggInput column positions + for (int i = 0; i < groupList.size(); i++) { + condMapping.set(i, groupList.get(i)); + } + // Other-side columns shift: from leftFields+j to rawFieldCount+j + for (int j = 0; j < rightFields; j++) { + condMapping.set(leftFields + j, rawFieldCount + j); + } + final RexNode newCondition = RexUtil.apply(condMapping, join.getCondition()); + + // Build new join + final RelBuilder relBuilder = call.builder(); + relBuilder.push(aggInput).push(right); + relBuilder.join(JoinRelType.INNER, newCondition); + + // Build new left above the join. + // New group-by set: original group columns (at their raw positions in + // aggInput) plus all other-side columns (to preserve them). + final ImmutableBitSet.Builder newGroupSetBuilder = ImmutableBitSet.builder(); + for (int col : groupList) { + newGroupSetBuilder.set(col); + } + for (int j = 0; j < rightFields; j++) { + newGroupSetBuilder.set(rawFieldCount + j); + } + final ImmutableBitSet newGroupSet = newGroupSetBuilder.build(); + + relBuilder.aggregate(relBuilder.groupKey(newGroupSet), left.getAggCallList()); + + // Add project to restore original join output column order. + // Original output: [group(left_cols), agg_calls, right_cols] + // New output: [group(left_cols, right_cols), agg_calls] + + // Create a mapping between the input (source) and the output (target) + // columns of the new aggregate. For example: + // + // Aggregate: Aggregate(group=[{7, 9, 10}]) + // Mapping: { 7 -> 0, 9 -> 1, 10 -> 2 } + final Mappings.TargetMapping newGroupMap = Mappings.target(newGroupSet.toList(), newJoinWidth); + + final List projects = new ArrayList<>(); + // Group-by columns of original left + for (int col : groupList) { + int pos = newGroupMap.getTarget(col); + projects.add(relBuilder.field(pos)); + } + // Aggregate call results + int aggCallBase = newGroupSet.cardinality(); + for (int k = 0; k < left.getAggCallList().size(); k++) { + projects.add(relBuilder.field(aggCallBase + k)); + } + // Right-side columns + for (int j = 0; j < rightFields; j++) { + int pos = newGroupMap.getTarget(rawFieldCount + j); + projects.add(relBuilder.field(pos)); + } + + relBuilder.project(projects, join.getRowType().getFieldNames()); + + call.transformTo(relBuilder.build()); + } + + /** Rule configuration. */ + @Value.Immutable + public interface Config extends RelRule.Config { + Config DEFAULT = ImmutableJoinAggregateTransposeRule.Config.of() + .withOperandSupplier(b0 -> + b0.operand(Join.class).inputs( + b1 -> b1.operand(Aggregate.class).anyInputs(), + b2 -> b2.operand(RelNode.class).anyInputs())) + .withDescription("JoinAggregateTransposeRule"); + + @Override default JoinAggregateTransposeRule toRule() { + return new JoinAggregateTransposeRule(this); + } + } +} diff --git a/core/src/test/java/org/apache/calcite/test/JoinAggregateTransposeRuleTest.java b/core/src/test/java/org/apache/calcite/test/JoinAggregateTransposeRuleTest.java new file mode 100644 index 000000000000..97f784522a61 --- /dev/null +++ b/core/src/test/java/org/apache/calcite/test/JoinAggregateTransposeRuleTest.java @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to you under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.calcite.test; + +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.plan.hep.HepProgram; +import org.apache.calcite.prepare.CalciteCatalogReader; +import org.apache.calcite.rel.rules.CoreRules; +import org.apache.calcite.schema.SchemaPlus; +import org.apache.calcite.tools.Frameworks; + +import com.google.common.collect.ImmutableList; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for {@link org.apache.calcite.rel.rules.JoinAggregateTransposeRule}. + * + *

Relevant tickets: + *

+ */ +class JoinAggregateTransposeRuleTest { + + private static RelOptFixture fixture() { + // Use SCOTT schema to keep unit and end-to-end (Quidem) tests aligned. + SchemaPlus rootSchema = Frameworks.createRootSchema(true); + CalciteAssert.addSchema(rootSchema, CalciteAssert.SchemaSpec.SCOTT); + CalciteConnectionConfig config = + CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false"); + return RelOptFixture.DEFAULT + .withCatalogReaderFactory( + (typeFactory, caseSensitive) -> + new CalciteCatalogReader( + CalciteSchema.from(rootSchema), + ImmutableList.of("SCOTT"), + typeFactory, + config)) + .withDiffRepos(DiffRepository.lookup(JoinAggregateTransposeRuleTest.class)); + } + + private static RelOptFixture sql(String sql) { + return fixture().sql(sql); + } + + /** + * Tests that the rule can pull the group by from the left side of the join + * in the trivial case where there are no aggregate functions. + */ + @Test void testPullGroupByWithoutAggregateFunctions() { + final String sql = "select g.deptno\n" + + "from (select deptno from emp group by deptno) g\n" + + "join dept d on g.deptno = d.deptno"; + sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).check(); + } + + /** + * Tests that the rule can pull the group by from the left side of the join + * when that is a simple aggregation. + */ + @Test void testPullGroupByFromLeftWithSimpleAggregation() { + final String sql = "select g.deptno, g.total_sal, d.dname\n" + + "from (select deptno, sum(sal) as total_sal\n" + + " from emp group by deptno) g\n" + + "join dept d on g.deptno = d.deptno"; + sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).check(); + } + + /** + * Tests that the rule can pull the group by from the left side of the join + * when that is a simple aggregation with multiple aggregate functions. + */ + @Test void testPullGroupByFromLeftWithSimpleAggregationMultipleFunctions() { + final String sql = "select g.deptno, g.total_sal, g.low_sal, g.high_sal, d.dname\n" + + "from (select deptno, sum(sal) as total_sal, min(sal) as low_sal, max(sal) as high_sal\n" + + " from emp group by deptno) g\n" + + "join dept d on g.deptno = d.deptno"; + sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).check(); + } + + /** + * Tests that the rule can pull the group by from the right side of the join + * by exploiting the join commutativity. Demonstrates that there is no + * need for implementing separate rule/logic for pulling the group by from the + * right side of the join. + */ + @Test void testPullGroupByFromRightWithSimpleAggregation() { + final String sql = "select g.deptno, g.total_sal, d.dname\n" + + "from dept d\n" + + "join (select deptno, sum(sal) as total_sal\n" + + " from emp group by deptno) g\n" + + " on g.deptno = d.deptno"; + HepProgram program = HepProgram.builder() + // Without a limit here the commute rule would keep flipping the + // join indefinitely causing stack overflow. + .addMatchLimit(1) + .addRuleInstance(CoreRules.JOIN_COMMUTE) + .addMatchLimit(HepProgram.MATCH_UNTIL_FIXPOINT) + .addRuleInstance(CoreRules.JOIN_AGGREGATE_TRANSPOSE) + .build(); + sql(sql) + .withProgram(program) + .check(); + } + + /** + * Tests that the rule can pull the group by from the left side of the join + * even when there is a filtering on the right side. + */ + @Test void testPullGroupByFromLeftWithFilterOnRight() { + final String sql = "select g.deptno, g.total_sal, d.dname\n" + + "from (select deptno, sum(sal) as total_sal\n" + + " from emp group by deptno) g\n" + + "join dept d on g.deptno = d.deptno\n" + + "where d.dname = 'RESEARCH'"; + sql(sql) + .withPreRule(CoreRules.FILTER_INTO_JOIN) + .withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE) + .check(); + } + + /** + * Tests that the rules not apply when the join is not an equijoin. + */ + @Test void testNoPullGroupByAboveNonEquiJoin() { + final String sql = "select g.deptno\n" + + "from (select deptno from emp group by deptno) g\n" + + "join dept d on g.deptno > d.deptno"; + sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).checkUnchanged(); + } + + /** + * Tests that the rule does not apply when the right join keys are not unique. + * The job column is not unique thus the rule bails out. + */ + @Test void testNoPullGroupByWhenRightJoinKeysNotUnique() { + final String sql = "select g.job, g.cnt\n" + + "from (select job, count(*) as cnt\n" + + " from emp group by job) g\n" + + "join emp e on g.job = e.job"; + sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).checkUnchanged(); + } + + + @AfterAll static void checkActualAndReferenceFiles() { + fixture().diffRepos.checkActualAndReferenceFiles(); + } +} diff --git a/core/src/test/resources/org/apache/calcite/test/JoinAggregateTransposeRuleTest.xml b/core/src/test/resources/org/apache/calcite/test/JoinAggregateTransposeRuleTest.xml new file mode 100644 index 000000000000..8c7c7c59640e --- /dev/null +++ b/core/src/test/resources/org/apache/calcite/test/JoinAggregateTransposeRuleTest.xml @@ -0,0 +1,202 @@ + + + + + + d.deptno]]> + + + ($0, $1)], joinType=[inner]) + LogicalAggregate(group=[{0}]) + LogicalProject(DEPTNO=[$7]) + LogicalTableScan(table=[[scott, EMP]]) + LogicalTableScan(table=[[scott, DEPT]]) +]]> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/core/src/test/resources/sql/join-agg-transpose.iq b/core/src/test/resources/sql/join-agg-transpose.iq new file mode 100644 index 000000000000..3cccb32da2a9 --- /dev/null +++ b/core/src/test/resources/sql/join-agg-transpose.iq @@ -0,0 +1,78 @@ +# join-agg-transpose.iq - [CALCITE-7604] Add rule to pull up GROUP BY above join +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to you under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +!use post +!set outputformat mysql +!use scott + +!set hep-rules " ++FILTER_INTO_JOIN, ++JOIN_AGGREGATE_TRANSPOSE" + +# Tests the rule can pull the group by from the left side of the join +# in the trivial case where there are no aggregate functions + +select g.deptno +from (select deptno from emp group by deptno) g +join dept d on g.deptno = d.deptno; ++--------+ +| DEPTNO | ++--------+ +| 10 | +| 20 | +| 30 | ++--------+ +(3 rows) + +!ok + +EnumerableCalc(expr#0..1=[{inputs}], DEPTNO=[$t0]) + EnumerableAggregate(group=[{0, 1}]) + EnumerableHashJoin(condition=[=($0, $1)], joinType=[inner]) + EnumerableCalc(expr#0..7=[{inputs}], DEPTNO=[$t7]) + EnumerableTableScan(table=[[scott, EMP]]) + EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0]) + EnumerableTableScan(table=[[scott, DEPT]]) +!plan + +# Tests the rule can pull the group by from the left side of the join +# even when there is a filtering on the right side. + +select g.deptno, g.total_sal, d.dname +from (select deptno, sum(sal) as total_sal + from emp group by deptno) g +join dept d on g.deptno = d.deptno +where d.dname = 'RESEARCH'; ++--------+-----------+----------+ +| DEPTNO | TOTAL_SAL | DNAME | ++--------+-----------+----------+ +| 20 | 10875.00 | RESEARCH | ++--------+-----------+----------+ +(1 row) + +!ok + +EnumerableCalc(expr#0..3=[{inputs}], DEPTNO=[$t0], TOTAL_SAL=[$t3], DNAME=[$t2]) + EnumerableAggregate(group=[{0, 2, 3}], TOTAL_SAL=[SUM($1)]) + EnumerableHashJoin(condition=[=($0, $2)], joinType=[inner]) + EnumerableCalc(expr#0..7=[{inputs}], DEPTNO=[$t7], SAL=[$t5]) + EnumerableTableScan(table=[[scott, EMP]]) + EnumerableCalc(expr#0..2=[{inputs}], expr#3=['RESEARCH':VARCHAR(14)], expr#4=[=($t1, $t3)], proj#0..1=[{exprs}], $condition=[$t4]) + EnumerableTableScan(table=[[scott, DEPT]]) +!plan + +# End join-agg-transpose.iq