diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
index 934b7633a9f..c56f61ba138 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateJoinTransposeRule.java
@@ -129,7 +129,7 @@ public AggregateJoinTransposeRule(Class extends Aggregate> aggregateClass,
allowFunctions);
}
- private static boolean isAggregateSupported(Aggregate aggregate,
+ static boolean isAggregateSupported(Aggregate aggregate,
boolean allowFunctions) {
if (!allowFunctions && !aggregate.getAggCallList().isEmpty()) {
return false;
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
index 873d04016f6..bc21c41dd36 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/CoreRules.java
@@ -654,6 +654,11 @@ private CoreRules() {}
public static final JoinExtractFilterRule JOIN_EXTRACT_FILTER =
JoinExtractFilterRule.Config.DEFAULT.toRule();
+ /** Rule that pulls an {@link Aggregate} from the left input of a
+ * {@link Join} to above the join (group-by pull up). */
+ public static final JoinAggregateTransposeRule JOIN_AGGREGATE_TRANSPOSE =
+ JoinAggregateTransposeRule.Config.DEFAULT.toRule();
+
/** Rule that matches a {@link LogicalJoin} whose inputs are
* {@link LogicalProject}s, and pulls the project expressions up. */
public static final JoinProjectTransposeRule JOIN_PROJECT_BOTH_TRANSPOSE =
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/JoinAggregateTransposeRule.java b/core/src/main/java/org/apache/calcite/rel/rules/JoinAggregateTransposeRule.java
new file mode 100644
index 00000000000..99cf46141d7
--- /dev/null
+++ b/core/src/main/java/org/apache/calcite/rel/rules/JoinAggregateTransposeRule.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.rules;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.Join;
+import org.apache.calcite.rel.core.JoinInfo;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.ImmutableBitSet;
+import org.apache.calcite.util.mapping.MappingType;
+import org.apache.calcite.util.mapping.Mappings;
+
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.apache.calcite.rel.rules.AggregateJoinTransposeRule.isAggregateSupported;
+
+/**
+ * Planner rule that pulls an
+ * {@link org.apache.calcite.rel.core.Aggregate}
+ * from below a {@link org.apache.calcite.rel.core.Join} to above it.
+ *
+ *
Before
+ *
+ * SELECT s.sales
+ * FROM (SELECT ss_sold_date_sk, SUM(ss_sales_price) AS sales
+ * FROM store_sales
+ * GROUP BY ss_sold_date_sk) s
+ * JOIN date_dim d
+ * ON s.ss_sold_date_sk = d.d_date_sk
+ * WHERE d.d_year = 2000
+ *
+ *
+ * After
+ *
+ * SELECT SUM(ss_sales_price) AS sales
+ * FROM store_sales s
+ * JOIN date_dim d
+ * ON s.ss_sold_date_sk = d.d_date_sk
+ * WHERE d.d_year = 2000
+ * GROUP BY s.ss_sold_date_sk
+ *
+ *
+ * This rule implements the simplest form of group-by pull up transformation
+ * described in the following papers:
+ *
+ *
+ * - Weipeng P. Yan, and Per-Ake Larson. "Interchanging the order of grouping and join". Technical
+ * Report CS 95-09, Dept. of Computer Science, University of Waterloo, Canada, 1995.
+ * - Weipeng P. Yan, and Per-Ake Larson. "Eager Aggregation and Lazy Aggregation." Proceedings
+ * of the 21th International Conference on Very Large Data Bases. 1995.
+ *
+ *
+ * The papers contain additional variants ("lazy" aggregation) not currently
+ * implemented.
+ *
+ * @see CoreRules#JOIN_AGGREGATE_TRANSPOSE
+ */
+@Value.Enclosing
+public class JoinAggregateTransposeRule
+ extends RelRule
+ implements TransformationRule {
+
+ protected JoinAggregateTransposeRule(Config config) {
+ super(config);
+ }
+
+ @Override public final boolean matches(RelOptRuleCall call) {
+ final Join join = call.rel(0);
+ final Aggregate left = call.rel(1);
+ final RelNode right = call.rel(2);
+ final JoinInfo info = join.analyzeCondition();
+ final RelMetadataQuery mq = call.getMetadataQuery();
+
+ // Only handle INNER equijoins with simple aggregates for now.
+ // Join keys on the agg side must reference only group-by columns
+ // (ensures row elimination removes whole groups, not partial)
+ ImmutableBitSet groupOutput = ImmutableBitSet.range(left.getGroupCount());
+ return join.getJoinType() == JoinRelType.INNER
+ && info.isEqui()
+ // We could potentially relax the check for the supported functions
+ // in this rule. I opted to keep things more constrained for now
+ // in case we decide to extend this rule for lazy aggregation.
+ && isAggregateSupported(left, true)
+ && groupOutput.contains(info.leftSet())
+ // The right side must be unique on its join keys (no row duplication)
+ && Boolean.TRUE.equals(mq.areColumnsUnique(right, info.rightSet()));
+ }
+
+ @Override public void onMatch(RelOptRuleCall call) {
+ final Join join = call.rel(0);
+ final Aggregate left = call.rel(1);
+ final RelNode aggInput = left.getInput();
+ final RelNode right = join.getRight();
+
+ // Build the transformation
+ final int rawFieldCount = aggInput.getRowType().getFieldCount();
+ final int leftFields = left.getRowType().getFieldCount();
+ final int rightFields = right.getRowType().getFieldCount();
+ final List groupList = left.getGroupSet().toList();
+
+ // Remap join condition: replace references to left output columns
+ // with references to raw aggInput columns in the new join layout.
+ // Old join: [agg output (leftFields) | other (rightFields)]
+ // New join: [aggInput (rawFieldCount) | other (rightFields)]
+ final int oldJoinWidth = join.getRowType().getFieldCount();
+ final int newJoinWidth = rawFieldCount + rightFields;
+
+ final Mappings.TargetMapping condMapping =
+ Mappings.create(MappingType.FUNCTION, oldJoinWidth, newJoinWidth);
+ // Agg output positions 0..groupCount-1 -> raw aggInput column positions
+ for (int i = 0; i < groupList.size(); i++) {
+ condMapping.set(i, groupList.get(i));
+ }
+ // Other-side columns shift: from leftFields+j to rawFieldCount+j
+ for (int j = 0; j < rightFields; j++) {
+ condMapping.set(leftFields + j, rawFieldCount + j);
+ }
+ final RexNode newCondition = RexUtil.apply(condMapping, join.getCondition());
+
+ // Build new join
+ final RelBuilder relBuilder = call.builder();
+ relBuilder.push(aggInput).push(right);
+ relBuilder.join(JoinRelType.INNER, newCondition);
+
+ // Build new left above the join.
+ // New group-by set: original group columns (at their raw positions in
+ // aggInput) plus all other-side columns (to preserve them).
+ final ImmutableBitSet.Builder newGroupSetBuilder = ImmutableBitSet.builder();
+ for (int col : groupList) {
+ newGroupSetBuilder.set(col);
+ }
+ for (int j = 0; j < rightFields; j++) {
+ newGroupSetBuilder.set(rawFieldCount + j);
+ }
+ final ImmutableBitSet newGroupSet = newGroupSetBuilder.build();
+
+ relBuilder.aggregate(relBuilder.groupKey(newGroupSet), left.getAggCallList());
+
+ // Add project to restore original join output column order.
+ // Original output: [group(left_cols), agg_calls, right_cols]
+ // New output: [group(left_cols, right_cols), agg_calls]
+
+ // Create a mapping between the input (source) and the output (target)
+ // columns of the new aggregate. For example:
+ //
+ // Aggregate: Aggregate(group=[{7, 9, 10}])
+ // Mapping: { 7 -> 0, 9 -> 1, 10 -> 2 }
+ final Mappings.TargetMapping newGroupMap = Mappings.target(newGroupSet.toList(), newJoinWidth);
+
+ final List projects = new ArrayList<>();
+ // Group-by columns of original left
+ for (int col : groupList) {
+ int pos = newGroupMap.getTarget(col);
+ projects.add(relBuilder.field(pos));
+ }
+ // Aggregate call results
+ int aggCallBase = newGroupSet.cardinality();
+ for (int k = 0; k < left.getAggCallList().size(); k++) {
+ projects.add(relBuilder.field(aggCallBase + k));
+ }
+ // Right-side columns
+ for (int j = 0; j < rightFields; j++) {
+ int pos = newGroupMap.getTarget(rawFieldCount + j);
+ projects.add(relBuilder.field(pos));
+ }
+
+ relBuilder.project(projects, join.getRowType().getFieldNames());
+
+ call.transformTo(relBuilder.build());
+ }
+
+ /** Rule configuration. */
+ @Value.Immutable
+ public interface Config extends RelRule.Config {
+ Config DEFAULT = ImmutableJoinAggregateTransposeRule.Config.of()
+ .withOperandSupplier(b0 ->
+ b0.operand(Join.class).inputs(
+ b1 -> b1.operand(Aggregate.class).anyInputs(),
+ b2 -> b2.operand(RelNode.class).anyInputs()))
+ .withDescription("JoinAggregateTransposeRule");
+
+ @Override default JoinAggregateTransposeRule toRule() {
+ return new JoinAggregateTransposeRule(this);
+ }
+ }
+}
diff --git a/core/src/test/java/org/apache/calcite/test/JoinAggregateTransposeRuleTest.java b/core/src/test/java/org/apache/calcite/test/JoinAggregateTransposeRuleTest.java
new file mode 100644
index 00000000000..97f784522a6
--- /dev/null
+++ b/core/src/test/java/org/apache/calcite/test/JoinAggregateTransposeRuleTest.java
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.test;
+
+import org.apache.calcite.config.CalciteConnectionConfig;
+import org.apache.calcite.config.CalciteConnectionProperty;
+import org.apache.calcite.jdbc.CalciteSchema;
+import org.apache.calcite.plan.hep.HepProgram;
+import org.apache.calcite.prepare.CalciteCatalogReader;
+import org.apache.calcite.rel.rules.CoreRules;
+import org.apache.calcite.schema.SchemaPlus;
+import org.apache.calcite.tools.Frameworks;
+
+import com.google.common.collect.ImmutableList;
+
+import org.junit.jupiter.api.AfterAll;
+import org.junit.jupiter.api.Test;
+
+/**
+ * Unit tests for {@link org.apache.calcite.rel.rules.JoinAggregateTransposeRule}.
+ *
+ * Relevant tickets:
+ *
+ */
+class JoinAggregateTransposeRuleTest {
+
+ private static RelOptFixture fixture() {
+ // Use SCOTT schema to keep unit and end-to-end (Quidem) tests aligned.
+ SchemaPlus rootSchema = Frameworks.createRootSchema(true);
+ CalciteAssert.addSchema(rootSchema, CalciteAssert.SchemaSpec.SCOTT);
+ CalciteConnectionConfig config =
+ CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false");
+ return RelOptFixture.DEFAULT
+ .withCatalogReaderFactory(
+ (typeFactory, caseSensitive) ->
+ new CalciteCatalogReader(
+ CalciteSchema.from(rootSchema),
+ ImmutableList.of("SCOTT"),
+ typeFactory,
+ config))
+ .withDiffRepos(DiffRepository.lookup(JoinAggregateTransposeRuleTest.class));
+ }
+
+ private static RelOptFixture sql(String sql) {
+ return fixture().sql(sql);
+ }
+
+ /**
+ * Tests that the rule can pull the group by from the left side of the join
+ * in the trivial case where there are no aggregate functions.
+ */
+ @Test void testPullGroupByWithoutAggregateFunctions() {
+ final String sql = "select g.deptno\n"
+ + "from (select deptno from emp group by deptno) g\n"
+ + "join dept d on g.deptno = d.deptno";
+ sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).check();
+ }
+
+ /**
+ * Tests that the rule can pull the group by from the left side of the join
+ * when that is a simple aggregation.
+ */
+ @Test void testPullGroupByFromLeftWithSimpleAggregation() {
+ final String sql = "select g.deptno, g.total_sal, d.dname\n"
+ + "from (select deptno, sum(sal) as total_sal\n"
+ + " from emp group by deptno) g\n"
+ + "join dept d on g.deptno = d.deptno";
+ sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).check();
+ }
+
+ /**
+ * Tests that the rule can pull the group by from the left side of the join
+ * when that is a simple aggregation with multiple aggregate functions.
+ */
+ @Test void testPullGroupByFromLeftWithSimpleAggregationMultipleFunctions() {
+ final String sql = "select g.deptno, g.total_sal, g.low_sal, g.high_sal, d.dname\n"
+ + "from (select deptno, sum(sal) as total_sal, min(sal) as low_sal, max(sal) as high_sal\n"
+ + " from emp group by deptno) g\n"
+ + "join dept d on g.deptno = d.deptno";
+ sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).check();
+ }
+
+ /**
+ * Tests that the rule can pull the group by from the right side of the join
+ * by exploiting the join commutativity. Demonstrates that there is no
+ * need for implementing separate rule/logic for pulling the group by from the
+ * right side of the join.
+ */
+ @Test void testPullGroupByFromRightWithSimpleAggregation() {
+ final String sql = "select g.deptno, g.total_sal, d.dname\n"
+ + "from dept d\n"
+ + "join (select deptno, sum(sal) as total_sal\n"
+ + " from emp group by deptno) g\n"
+ + " on g.deptno = d.deptno";
+ HepProgram program = HepProgram.builder()
+ // Without a limit here the commute rule would keep flipping the
+ // join indefinitely causing stack overflow.
+ .addMatchLimit(1)
+ .addRuleInstance(CoreRules.JOIN_COMMUTE)
+ .addMatchLimit(HepProgram.MATCH_UNTIL_FIXPOINT)
+ .addRuleInstance(CoreRules.JOIN_AGGREGATE_TRANSPOSE)
+ .build();
+ sql(sql)
+ .withProgram(program)
+ .check();
+ }
+
+ /**
+ * Tests that the rule can pull the group by from the left side of the join
+ * even when there is a filtering on the right side.
+ */
+ @Test void testPullGroupByFromLeftWithFilterOnRight() {
+ final String sql = "select g.deptno, g.total_sal, d.dname\n"
+ + "from (select deptno, sum(sal) as total_sal\n"
+ + " from emp group by deptno) g\n"
+ + "join dept d on g.deptno = d.deptno\n"
+ + "where d.dname = 'RESEARCH'";
+ sql(sql)
+ .withPreRule(CoreRules.FILTER_INTO_JOIN)
+ .withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE)
+ .check();
+ }
+
+ /**
+ * Tests that the rules not apply when the join is not an equijoin.
+ */
+ @Test void testNoPullGroupByAboveNonEquiJoin() {
+ final String sql = "select g.deptno\n"
+ + "from (select deptno from emp group by deptno) g\n"
+ + "join dept d on g.deptno > d.deptno";
+ sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).checkUnchanged();
+ }
+
+ /**
+ * Tests that the rule does not apply when the right join keys are not unique.
+ * The job column is not unique thus the rule bails out.
+ */
+ @Test void testNoPullGroupByWhenRightJoinKeysNotUnique() {
+ final String sql = "select g.job, g.cnt\n"
+ + "from (select job, count(*) as cnt\n"
+ + " from emp group by job) g\n"
+ + "join emp e on g.job = e.job";
+ sql(sql).withRule(CoreRules.JOIN_AGGREGATE_TRANSPOSE).checkUnchanged();
+ }
+
+
+ @AfterAll static void checkActualAndReferenceFiles() {
+ fixture().diffRepos.checkActualAndReferenceFiles();
+ }
+}
diff --git a/core/src/test/resources/org/apache/calcite/test/JoinAggregateTransposeRuleTest.xml b/core/src/test/resources/org/apache/calcite/test/JoinAggregateTransposeRuleTest.xml
new file mode 100644
index 00000000000..8c7c7c59640
--- /dev/null
+++ b/core/src/test/resources/org/apache/calcite/test/JoinAggregateTransposeRuleTest.xml
@@ -0,0 +1,202 @@
+
+
+
+
+
+ d.deptno]]>
+
+
+ ($0, $1)], joinType=[inner])
+ LogicalAggregate(group=[{0}])
+ LogicalProject(DEPTNO=[$7])
+ LogicalTableScan(table=[[scott, EMP]])
+ LogicalTableScan(table=[[scott, DEPT]])
+]]>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/core/src/test/resources/sql/join-agg-transpose.iq b/core/src/test/resources/sql/join-agg-transpose.iq
new file mode 100644
index 00000000000..3cccb32da2a
--- /dev/null
+++ b/core/src/test/resources/sql/join-agg-transpose.iq
@@ -0,0 +1,78 @@
+# join-agg-transpose.iq - [CALCITE-7604] Add rule to pull up GROUP BY above join
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to you under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+!use post
+!set outputformat mysql
+!use scott
+
+!set hep-rules "
++FILTER_INTO_JOIN,
++JOIN_AGGREGATE_TRANSPOSE"
+
+# Tests the rule can pull the group by from the left side of the join
+# in the trivial case where there are no aggregate functions
+
+select g.deptno
+from (select deptno from emp group by deptno) g
+join dept d on g.deptno = d.deptno;
++--------+
+| DEPTNO |
++--------+
+| 10 |
+| 20 |
+| 30 |
++--------+
+(3 rows)
+
+!ok
+
+EnumerableCalc(expr#0..1=[{inputs}], DEPTNO=[$t0])
+ EnumerableAggregate(group=[{0, 1}])
+ EnumerableHashJoin(condition=[=($0, $1)], joinType=[inner])
+ EnumerableCalc(expr#0..7=[{inputs}], DEPTNO=[$t7])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableCalc(expr#0..2=[{inputs}], DEPTNO=[$t0])
+ EnumerableTableScan(table=[[scott, DEPT]])
+!plan
+
+# Tests the rule can pull the group by from the left side of the join
+# even when there is a filtering on the right side.
+
+select g.deptno, g.total_sal, d.dname
+from (select deptno, sum(sal) as total_sal
+ from emp group by deptno) g
+join dept d on g.deptno = d.deptno
+where d.dname = 'RESEARCH';
++--------+-----------+----------+
+| DEPTNO | TOTAL_SAL | DNAME |
++--------+-----------+----------+
+| 20 | 10875.00 | RESEARCH |
++--------+-----------+----------+
+(1 row)
+
+!ok
+
+EnumerableCalc(expr#0..3=[{inputs}], DEPTNO=[$t0], TOTAL_SAL=[$t3], DNAME=[$t2])
+ EnumerableAggregate(group=[{0, 2, 3}], TOTAL_SAL=[SUM($1)])
+ EnumerableHashJoin(condition=[=($0, $2)], joinType=[inner])
+ EnumerableCalc(expr#0..7=[{inputs}], DEPTNO=[$t7], SAL=[$t5])
+ EnumerableTableScan(table=[[scott, EMP]])
+ EnumerableCalc(expr#0..2=[{inputs}], expr#3=['RESEARCH':VARCHAR(14)], expr#4=[=($t1, $t3)], proj#0..1=[{exprs}], $condition=[$t4])
+ EnumerableTableScan(table=[[scott, DEPT]])
+!plan
+
+# End join-agg-transpose.iq