Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlKind;
Expand Down Expand Up @@ -99,6 +100,10 @@ protected AggregateMinMaxToLimitRule(Config config) {
isDesc ? builder.desc(r) : r)
.build());

final RelDataType aggCallType = aggCall.getType();
if (!subQuery.getType().equals(aggCallType)) {
subQuery = builder.getRexBuilder().makeCast(aggCallType, subQuery, false, false);
}
newProjects.add(subQuery);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,11 @@ private static RexNode rewriteScalarQuery(RexSubQuery e, Set<CorrelationId> vari
builder.field(0)));
}
builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet);
return field(builder, inputCount, offset);
final RexNode ref = field(builder, inputCount, offset);
if (ref.getType().equals(e.getType())) {
return ref;
}
return builder.getRexBuilder().makeCast(e.getType(), ref, false, false);
}

/**
Expand Down
10 changes: 7 additions & 3 deletions core/src/main/java/org/apache/calcite/rex/RexSubQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,19 @@ public static RexSubQuery unique(RelNode rel) {
ImmutableList.of(), rel);
}

/** Creates a scalar sub-query. */
/** Creates a scalar sub-query.
*
* <p>The expression's type is the single output column's type, as returned by
* {@link RelDataTypeFactory#copyType(RelDataType)}
* (see <a href="https://issues.apache.org/jira/browse/CALCITE-2901">[CALCITE-2901]</a>).
*/
public static RexSubQuery scalar(RelNode rel) {
final List<RelDataTypeField> fieldList = rel.getRowType().getFieldList();
if (fieldList.size() != 1) {
throw new IllegalArgumentException();
}
final RelDataTypeFactory typeFactory = rel.getCluster().getTypeFactory();
final RelDataType type =
typeFactory.createTypeWithNullability(fieldList.get(0).getType(), true);
final RelDataType type = typeFactory.copyType(fieldList.get(0).getType());
return new RexSubQuery(type, SqlStdOperatorTable.SCALAR_QUERY,
ImmutableList.of(), rel);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1522,7 +1522,7 @@ private static RelDataType multivalentStringWithSepSumPrecision(
}
final RelDataType firstColType = fieldType.getType();
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
return typeFactory.createTypeWithNullability(firstColType, true);
return typeFactory.copyType(firstColType);
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5628,8 +5628,8 @@ private void handleScalarSubQuery(SqlSelect parentSelect,
assert type instanceof RelRecordType;
RelRecordType rec = (RelRecordType) type;

RelDataType nodeType = rec.getFieldList().get(0).getType();
nodeType = typeFactory.createTypeWithNullability(nodeType, true);
RelDataType nodeType =
typeFactory.copyType(rec.getFieldList().get(0).getType());
fieldList.add(alias, nodeType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5920,6 +5920,16 @@ && isConvertedSubq(rex)) {
SqlStdOperatorTable.IS_NOT_NULL,
fieldAccess);
}
if (kind == SqlKind.SCALAR_QUERY) {
fieldAccess =
StandardConvertletTable.castToValidatedType(
expr.getParserPosition(),
expr,
fieldAccess,
validator(),
rexBuilder,
false);
}
return fieldAccess;

case OVER:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6617,31 +6617,31 @@ private void checkLiteral2(String expression, String expected) {
+ "UNION ALL\n"
+ "SELECT NULL) END AS `$f0`\n"
+ "FROM `foodmart`.`product`) AS `t0` ON TRUE\n"
+ "WHERE `product`.`net_weight` > CAST(`t0`.`$f0` AS DOUBLE)";
+ "WHERE `product`.`net_weight` > CAST(CAST(`t0`.`$f0` AS SIGNED) AS DOUBLE)";

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case necessary?
If not, can you avoid inserting it?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking the time to review my changes! Agreed that we should not rely on later cast cleanup. I’ll go through the changed casts and drop any that aren’t needed for correctness.

For the CAST(CAST($f0 AS …) AS DOUBLE) case in RelToSqlConverterTest: I’ll trace where each cast is introduced and try to simplify to a single cast (or align types earlier) if semantics stay the same; if two steps are still required I’ll note why in a short comment or reply here.

final String expectedPostgresql = "SELECT \"product\".\"product_class_id\" AS \"C\"\n"
+ "FROM \"foodmart\".\"product\"\n"
+ "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(\"product_class_id\") ELSE (SELECT CAST(NULL AS INTEGER)\n"
+ "UNION ALL\n"
+ "SELECT CAST(NULL AS INTEGER)) END AS \"$f0\"\n"
+ "FROM \"foodmart\".\"product\") AS \"t0\" ON TRUE\n"
+ "WHERE \"product\".\"net_weight\" > CAST(\"t0\".\"$f0\" AS DOUBLE PRECISION)";
+ "WHERE \"product\".\"net_weight\" > CAST(CAST(\"t0\".\"$f0\" AS INTEGER) AS DOUBLE PRECISION)";
final String expectedHsqldb = "SELECT product.product_class_id AS C\n"
+ "FROM foodmart.product\n"
+ "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(product_class_id) ELSE ((VALUES 0E0)\n"
+ "UNION ALL\n"
+ "(VALUES 0E0)) END AS $f0\n"
+ "FROM foodmart.product) AS t0 ON TRUE\n"
+ "WHERE product.net_weight > CAST(t0.$f0 AS DOUBLE)";
+ "WHERE product.net_weight > CAST(CAST(t0.$f0 AS INTEGER) AS DOUBLE)";
final String expectedSpark = "SELECT `product`.`product_class_id` `C`\n"
+ "FROM `foodmart`.`product`\n"
+ "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(`product_class_id`) ELSE RAISE_ERROR('more than one value in agg SINGLE_VALUE') END `$f0`\n"
+ "FROM `foodmart`.`product`) `t0` ON TRUE\n"
+ "WHERE `product`.`net_weight` > CAST(`t0`.`$f0` AS DOUBLE)";
+ "WHERE `product`.`net_weight` > CAST(CAST(`t0`.`$f0` AS INTEGER) AS DOUBLE)";
final String expectedHive = "SELECT `product`.`product_class_id` `C`\n"
+ "FROM `foodmart`.`product`\n"
+ "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(`product_class_id`) ELSE ASSERT_TRUE(FALSE) END `$f0`\n"
+ "FROM `foodmart`.`product`) `t0` ON TRUE\n"
+ "WHERE `product`.`net_weight` > CAST(`t0`.`$f0` AS DOUBLE)";
+ "WHERE `product`.`net_weight` > CAST(CAST(`t0`.`$f0` AS INT) AS DOUBLE)";
sql(query)
.withConfig(c -> c.withExpand(true))
.withMysql().ok(expectedMysql)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ public static Frameworks.ConfigBuilder config() {
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
Collections.emptyList(), Collections.emptyList());
final String planBefore = ""
+ "LogicalProject(DEPTNO=[$0], A=[$3])\n"
+ "LogicalProject(DEPTNO=[$0], A=[CAST($3):CHAR(7) NOT NULL])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalProject(EXPR=[CASE(>($0, 10.00), 'VIP ', 'Regular')])\n"
Expand All @@ -356,7 +356,7 @@ public static Frameworks.ConfigBuilder config() {

// Verify plan
final String planAfter = ""
+ "LogicalProject(DEPTNO=[$0], A=[$3])\n"
+ "LogicalProject(DEPTNO=[$0], A=[CAST($3):CHAR(7) NOT NULL])\n"
+ " LogicalJoin(condition=[=($0, $4)], joinType=[left])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalProject(EXPR=[CASE(>($2, 10.00), 'VIP ', 'Regular')], DEPTNO=[$0])\n"
Expand Down Expand Up @@ -500,7 +500,7 @@ public static Frameworks.ConfigBuilder config() {
Collections.emptyList(), Collections.emptyList());
final String planBefore = ""
+ "LogicalSort(sort0=[$0], dir0=[ASC])\n"
+ " LogicalProject(DNAME=[$1], C=[$3])\n"
+ " LogicalProject(DNAME=[$1], C=[CAST($3):BIGINT NOT NULL])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalAggregate(group=[{}], EXPR$0=[COUNT()])\n"
Expand Down Expand Up @@ -556,7 +556,7 @@ public static Frameworks.ConfigBuilder config() {
// LogicalTableScan(table=[[scott, EMP]])
final String planAfter = ""
+ "LogicalSort(sort0=[$0], dir0=[ASC])\n"
+ " LogicalProject(DNAME=[$1], C=[$7])\n"
+ " LogicalProject(DNAME=[$1], C=[CAST($7):BIGINT NOT NULL])\n"
+ " LogicalJoin(condition=[AND(=($0, $5), =($4, $6))], joinType=[left])\n"
+ " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], DEPTNO0=[$0], $f4=[*($0, 100)])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
Expand Down Expand Up @@ -627,14 +627,15 @@ public static Frameworks.ConfigBuilder config() {
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
Collections.emptyList(), Collections.emptyList());
final String planBefore = ""
+ "LogicalCorrelate(correlation=[$cor1], joinType=[left], requiredColumns=[{0}])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n"
+ " LogicalProject(EXPR$0=[$1])\n"
+ " LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])\n"
+ " LogicalProject($f0=[$cor1.DEPTNO])\n"
+ " LogicalFilter(condition=[=($7, $cor1.DEPTNO)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
+ "LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], NUM_DEPT_GROUPS=[CAST($3):BIGINT NOT NULL])\n"
+ " LogicalCorrelate(correlation=[$cor1], joinType=[left], requiredColumns=[{0}])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n"
+ " LogicalProject(EXPR$0=[$1])\n"
+ " LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])\n"
+ " LogicalProject($f0=[$cor1.DEPTNO])\n"
+ " LogicalFilter(condition=[=($7, $cor1.DEPTNO)])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n";
assertThat(before, hasTree(planBefore));

// Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator
Expand All @@ -643,7 +644,7 @@ public static Frameworks.ConfigBuilder config() {
RuleSets.ofList(Collections.emptyList()));
// Verify plan
final String planAfter = ""
+ "LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], $f1=[$4])\n"
+ "LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], NUM_DEPT_GROUPS=[CAST($4):BIGINT NOT NULL])\n"
+ " LogicalJoin(condition=[=($0, $3)], joinType=[left])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n"
Expand Down Expand Up @@ -1180,7 +1181,7 @@ public static Frameworks.ConfigBuilder config() {
program.run(cluster.getPlanner(), parsedRel, cluster.traitSet(),
Collections.emptyList(), Collections.emptyList());
final String planOriginal = ""
+ "LogicalProject(EXPR$0=[ROW($8, $1)])\n"
+ "LogicalProject(EXPR$0=[ROW(CAST($8):TINYINT NOT NULL, $1)])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{7}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n"
Expand All @@ -1192,7 +1193,7 @@ public static Frameworks.ConfigBuilder config() {
// Default decorrelate
final RelNode decorrelatedDefault = RelDecorrelator.decorrelateQuery(original, builder);
final String planDecorrelatedDefault = ""
+ "LogicalProject(EXPR$0=[ROW($8, $1)])\n"
+ "LogicalProject(EXPR$0=[ROW(CAST($8):TINYINT NOT NULL, $1)])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], DEPTNO8=[$8])\n"
+ " LogicalJoin(condition=[=($8, $7)], joinType=[left])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
Expand Down Expand Up @@ -1221,7 +1222,7 @@ public static Frameworks.ConfigBuilder config() {
final RelNode decorrelatedNoRules =
RelDecorrelator.decorrelateQuery(original, builder, noRules);
final String planDecorrelatedNoRules = ""
+ "LogicalProject(EXPR$0=[ROW($9, $1)])\n"
+ "LogicalProject(EXPR$0=[ROW(CAST($9):TINYINT NOT NULL, $1)])\n"
+ " LogicalJoin(condition=[=($7, $8)], joinType=[left])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n"
Expand Down Expand Up @@ -1538,7 +1539,7 @@ public static Frameworks.ConfigBuilder config() {
final String planBefore = ""
+ "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
+ " LogicalFilter(condition=[=($0, CAST($8):SMALLINT)])\n"
+ " LogicalFilter(condition=[=($0, CAST($8):SMALLINT NOT NULL)])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{1}])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n"
Expand All @@ -1557,7 +1558,7 @@ public static Frameworks.ConfigBuilder config() {
final String planAfter = ""
+ "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n"
+ " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], ENAME0=[$8], $f1=[CAST($9):TINYINT])\n"
+ " LogicalJoin(condition=[AND(=($1, $8), =($0, CAST($9):SMALLINT))], joinType=[inner])\n"
+ " LogicalJoin(condition=[AND(=($1, $8), =($0, CAST($9):SMALLINT NOT NULL))], joinType=[inner])\n"
+ " LogicalTableScan(table=[[scott, EMP]])\n"
+ " LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n"
+ " LogicalProject(ENAME=[$3], DEPTNO=[$0])\n"
Expand Down Expand Up @@ -1700,7 +1701,7 @@ public static Frameworks.ConfigBuilder config() {
program.run(cluster.getPlanner(), originalRel, cluster.traitSet(),
Collections.emptyList(), Collections.emptyList());
final String planBefore = ""
+ "LogicalProject(DNAME=[$1], MATCHED_SUBORDINATE_COUNT=[$3])\n"
+ "LogicalProject(DNAME=[$1], MATCHED_SUBORDINATE_COUNT=[CAST($3):BIGINT NOT NULL])\n"
+ " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)])\n"
Expand All @@ -1717,7 +1718,7 @@ public static Frameworks.ConfigBuilder config() {
RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()),
RuleSets.ofList(Collections.emptyList()));
final String planAfter = ""
+ "LogicalProject(DNAME=[$1], MATCHED_SUBORDINATE_COUNT=[$4])\n"
+ "LogicalProject(DNAME=[$1], MATCHED_SUBORDINATE_COUNT=[CAST($4):BIGINT NOT NULL])\n"
+ " LogicalJoin(condition=[=($0, $3)], joinType=[left])\n"
+ " LogicalTableScan(table=[[scott, DEPT]])\n"
+ " LogicalProject(_cor_$cor0_0=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)])\n"
Expand Down Expand Up @@ -2332,7 +2333,7 @@ public static Frameworks.ConfigBuilder config() {

// Verify decorrelation produced a valid plan (no Correlate nodes)
final String planAfter = ""
+ "LogicalProject(DEPTNO=[$0], EXPR$1=[$4])\n"
+ "LogicalProject(DEPTNO=[$0], EXPR$1=[CAST($4):BIGINT NOT NULL])\n"
+ " LogicalJoin(condition=[AND(IS NOT DISTINCT FROM($0, $2), IS NOT DISTINCT FROM($1, $3))], joinType=[left])\n"
+ " LogicalAggregate(group=[{0}], S=[SUM($1)])\n"
+ " LogicalProject(DEPTNO=[$7], SAL=[$5])\n"
Expand Down
Loading
Loading