diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateMinMaxToLimitRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateMinMaxToLimitRule.java index 51c29961769b..cd6d5b03f711 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateMinMaxToLimitRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateMinMaxToLimitRule.java @@ -21,6 +21,7 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.Aggregate; import org.apache.calcite.rel.core.AggregateCall; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexUtil; import org.apache.calcite.sql.SqlKind; @@ -99,6 +100,10 @@ protected AggregateMinMaxToLimitRule(Config config) { isDesc ? builder.desc(r) : r) .build()); + final RelDataType aggCallType = aggCall.getType(); + if (!subQuery.getType().equals(aggCallType)) { + subQuery = builder.getRexBuilder().makeCast(aggCallType, subQuery, false, false); + } newProjects.add(subQuery); } diff --git a/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java b/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java index f7a6ce552fe5..f49a778fb793 100644 --- a/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java +++ b/core/src/main/java/org/apache/calcite/rel/rules/SubQueryRemoveRule.java @@ -141,7 +141,11 @@ private static RexNode rewriteScalarQuery(RexSubQuery e, Set vari builder.field(0))); } builder.join(JoinRelType.LEFT, builder.literal(true), variablesSet); - return field(builder, inputCount, offset); + final RexNode ref = field(builder, inputCount, offset); + if (ref.getType().equals(e.getType())) { + return ref; + } + return builder.getRexBuilder().makeCast(e.getType(), ref, false, false); } /** diff --git a/core/src/main/java/org/apache/calcite/rex/RexSubQuery.java b/core/src/main/java/org/apache/calcite/rex/RexSubQuery.java index 5a9cad2a970b..c3363b9f81c6 100644 --- a/core/src/main/java/org/apache/calcite/rex/RexSubQuery.java +++ b/core/src/main/java/org/apache/calcite/rex/RexSubQuery.java @@ -108,15 +108,19 @@ public static RexSubQuery unique(RelNode rel) { ImmutableList.of(), rel); } - /** Creates a scalar sub-query. */ + /** Creates a scalar sub-query. + * + *

The expression's type is the single output column's type, as returned by + * {@link RelDataTypeFactory#copyType(RelDataType)} + * (see [CALCITE-2901]). + */ public static RexSubQuery scalar(RelNode rel) { final List fieldList = rel.getRowType().getFieldList(); if (fieldList.size() != 1) { throw new IllegalArgumentException(); } final RelDataTypeFactory typeFactory = rel.getCluster().getTypeFactory(); - final RelDataType type = - typeFactory.createTypeWithNullability(fieldList.get(0).getType(), true); + final RelDataType type = typeFactory.copyType(fieldList.get(0).getType()); return new RexSubQuery(type, SqlStdOperatorTable.SCALAR_QUERY, ImmutableList.of(), rel); } diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java index 9e32e0563d96..fcafdd0b4ccb 100644 --- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java +++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java @@ -1522,7 +1522,7 @@ private static RelDataType multivalentStringWithSepSumPrecision( } final RelDataType firstColType = fieldType.getType(); final RelDataTypeFactory typeFactory = opBinding.getTypeFactory(); - return typeFactory.createTypeWithNullability(firstColType, true); + return typeFactory.copyType(firstColType); }; /** diff --git a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java index 1e2b29880864..c2f47ca1977c 100644 --- a/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java +++ b/core/src/main/java/org/apache/calcite/sql/validate/SqlValidatorImpl.java @@ -5628,8 +5628,8 @@ private void handleScalarSubQuery(SqlSelect parentSelect, assert type instanceof RelRecordType; RelRecordType rec = (RelRecordType) type; - RelDataType nodeType = rec.getFieldList().get(0).getType(); - nodeType = typeFactory.createTypeWithNullability(nodeType, true); + RelDataType nodeType = + typeFactory.copyType(rec.getFieldList().get(0).getType()); fieldList.add(alias, nodeType); } diff --git a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java index 4d14e16937ef..53ca6432a92a 100644 --- a/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java +++ b/core/src/main/java/org/apache/calcite/sql2rel/SqlToRelConverter.java @@ -5920,6 +5920,16 @@ && isConvertedSubq(rex)) { SqlStdOperatorTable.IS_NOT_NULL, fieldAccess); } + if (kind == SqlKind.SCALAR_QUERY) { + fieldAccess = + StandardConvertletTable.castToValidatedType( + expr.getParserPosition(), + expr, + fieldAccess, + validator(), + rexBuilder, + false); + } return fieldAccess; case OVER: diff --git a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java index 4e2825bb3b62..3a343bd6b653 100644 --- a/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java +++ b/core/src/test/java/org/apache/calcite/rel/rel2sql/RelToSqlConverterTest.java @@ -6617,31 +6617,31 @@ private void checkLiteral2(String expression, String expected) { + "UNION ALL\n" + "SELECT NULL) END AS `$f0`\n" + "FROM `foodmart`.`product`) AS `t0` ON TRUE\n" - + "WHERE `product`.`net_weight` > CAST(`t0`.`$f0` AS DOUBLE)"; + + "WHERE `product`.`net_weight` > CAST(CAST(`t0`.`$f0` AS SIGNED) AS DOUBLE)"; final String expectedPostgresql = "SELECT \"product\".\"product_class_id\" AS \"C\"\n" + "FROM \"foodmart\".\"product\"\n" + "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(\"product_class_id\") ELSE (SELECT CAST(NULL AS INTEGER)\n" + "UNION ALL\n" + "SELECT CAST(NULL AS INTEGER)) END AS \"$f0\"\n" + "FROM \"foodmart\".\"product\") AS \"t0\" ON TRUE\n" - + "WHERE \"product\".\"net_weight\" > CAST(\"t0\".\"$f0\" AS DOUBLE PRECISION)"; + + "WHERE \"product\".\"net_weight\" > CAST(CAST(\"t0\".\"$f0\" AS INTEGER) AS DOUBLE PRECISION)"; final String expectedHsqldb = "SELECT product.product_class_id AS C\n" + "FROM foodmart.product\n" + "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(product_class_id) ELSE ((VALUES 0E0)\n" + "UNION ALL\n" + "(VALUES 0E0)) END AS $f0\n" + "FROM foodmart.product) AS t0 ON TRUE\n" - + "WHERE product.net_weight > CAST(t0.$f0 AS DOUBLE)"; + + "WHERE product.net_weight > CAST(CAST(t0.$f0 AS INTEGER) AS DOUBLE)"; final String expectedSpark = "SELECT `product`.`product_class_id` `C`\n" + "FROM `foodmart`.`product`\n" + "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(`product_class_id`) ELSE RAISE_ERROR('more than one value in agg SINGLE_VALUE') END `$f0`\n" + "FROM `foodmart`.`product`) `t0` ON TRUE\n" - + "WHERE `product`.`net_weight` > CAST(`t0`.`$f0` AS DOUBLE)"; + + "WHERE `product`.`net_weight` > CAST(CAST(`t0`.`$f0` AS INTEGER) AS DOUBLE)"; final String expectedHive = "SELECT `product`.`product_class_id` `C`\n" + "FROM `foodmart`.`product`\n" + "LEFT JOIN (SELECT CASE COUNT(*) WHEN 0 THEN NULL WHEN 1 THEN MIN(`product_class_id`) ELSE ASSERT_TRUE(FALSE) END `$f0`\n" + "FROM `foodmart`.`product`) `t0` ON TRUE\n" - + "WHERE `product`.`net_weight` > CAST(`t0`.`$f0` AS DOUBLE)"; + + "WHERE `product`.`net_weight` > CAST(CAST(`t0`.`$f0` AS INT) AS DOUBLE)"; sql(query) .withConfig(c -> c.withExpand(true)) .withMysql().ok(expectedMysql) diff --git a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java index 240f9a36ae02..690a2bd91473 100644 --- a/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java +++ b/core/src/test/java/org/apache/calcite/sql2rel/RelDecorrelatorTest.java @@ -339,7 +339,7 @@ public static Frameworks.ConfigBuilder config() { program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), Collections.emptyList(), Collections.emptyList()); final String planBefore = "" - + "LogicalProject(DEPTNO=[$0], A=[$3])\n" + + "LogicalProject(DEPTNO=[$0], A=[CAST($3):CHAR(7) NOT NULL])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" + " LogicalProject(EXPR=[CASE(>($0, 10.00), 'VIP ', 'Regular')])\n" @@ -356,7 +356,7 @@ public static Frameworks.ConfigBuilder config() { // Verify plan final String planAfter = "" - + "LogicalProject(DEPTNO=[$0], A=[$3])\n" + + "LogicalProject(DEPTNO=[$0], A=[CAST($3):CHAR(7) NOT NULL])\n" + " LogicalJoin(condition=[=($0, $4)], joinType=[left])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" + " LogicalProject(EXPR=[CASE(>($2, 10.00), 'VIP ', 'Regular')], DEPTNO=[$0])\n" @@ -500,7 +500,7 @@ public static Frameworks.ConfigBuilder config() { Collections.emptyList(), Collections.emptyList()); final String planBefore = "" + "LogicalSort(sort0=[$0], dir0=[ASC])\n" - + " LogicalProject(DNAME=[$1], C=[$3])\n" + + " LogicalProject(DNAME=[$1], C=[CAST($3):BIGINT NOT NULL])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" + " LogicalAggregate(group=[{}], EXPR$0=[COUNT()])\n" @@ -556,7 +556,7 @@ public static Frameworks.ConfigBuilder config() { // LogicalTableScan(table=[[scott, EMP]]) final String planAfter = "" + "LogicalSort(sort0=[$0], dir0=[ASC])\n" - + " LogicalProject(DNAME=[$1], C=[$7])\n" + + " LogicalProject(DNAME=[$1], C=[CAST($7):BIGINT NOT NULL])\n" + " LogicalJoin(condition=[AND(=($0, $5), =($4, $6))], joinType=[left])\n" + " LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], DEPTNO0=[$0], $f4=[*($0, 100)])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" @@ -627,14 +627,15 @@ public static Frameworks.ConfigBuilder config() { program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), Collections.emptyList(), Collections.emptyList()); final String planBefore = "" - + "LogicalCorrelate(correlation=[$cor1], joinType=[left], requiredColumns=[{0}])\n" - + " LogicalTableScan(table=[[scott, DEPT]])\n" - + " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n" - + " LogicalProject(EXPR$0=[$1])\n" - + " LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])\n" - + " LogicalProject($f0=[$cor1.DEPTNO])\n" - + " LogicalFilter(condition=[=($7, $cor1.DEPTNO)])\n" - + " LogicalTableScan(table=[[scott, EMP]])\n"; + + "LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], NUM_DEPT_GROUPS=[CAST($3):BIGINT NOT NULL])\n" + + " LogicalCorrelate(correlation=[$cor1], joinType=[left], requiredColumns=[{0}])\n" + + " LogicalTableScan(table=[[scott, DEPT]])\n" + + " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n" + + " LogicalProject(EXPR$0=[$1])\n" + + " LogicalAggregate(group=[{0}], EXPR$0=[COUNT()])\n" + + " LogicalProject($f0=[$cor1.DEPTNO])\n" + + " LogicalFilter(condition=[=($7, $cor1.DEPTNO)])\n" + + " LogicalTableScan(table=[[scott, EMP]])\n"; assertThat(before, hasTree(planBefore)); // Decorrelate without any rules, just "purely" decorrelation algorithm on RelDecorrelator @@ -643,7 +644,7 @@ public static Frameworks.ConfigBuilder config() { RuleSets.ofList(Collections.emptyList())); // Verify plan final String planAfter = "" - + "LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], $f1=[$4])\n" + + "LogicalProject(DEPTNO=[$0], DNAME=[$1], LOC=[$2], NUM_DEPT_GROUPS=[CAST($4):BIGINT NOT NULL])\n" + " LogicalJoin(condition=[=($0, $3)], joinType=[left])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" + " LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n" @@ -1180,7 +1181,7 @@ public static Frameworks.ConfigBuilder config() { program.run(cluster.getPlanner(), parsedRel, cluster.traitSet(), Collections.emptyList(), Collections.emptyList()); final String planOriginal = "" - + "LogicalProject(EXPR$0=[ROW($8, $1)])\n" + + "LogicalProject(EXPR$0=[ROW(CAST($8):TINYINT NOT NULL, $1)])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{7}])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" + " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n" @@ -1192,7 +1193,7 @@ public static Frameworks.ConfigBuilder config() { // Default decorrelate final RelNode decorrelatedDefault = RelDecorrelator.decorrelateQuery(original, builder); final String planDecorrelatedDefault = "" - + "LogicalProject(EXPR$0=[ROW($8, $1)])\n" + + "LogicalProject(EXPR$0=[ROW(CAST($8):TINYINT NOT NULL, $1)])\n" + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], DEPTNO8=[$8])\n" + " LogicalJoin(condition=[=($8, $7)], joinType=[left])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" @@ -1221,7 +1222,7 @@ public static Frameworks.ConfigBuilder config() { final RelNode decorrelatedNoRules = RelDecorrelator.decorrelateQuery(original, builder, noRules); final String planDecorrelatedNoRules = "" - + "LogicalProject(EXPR$0=[ROW($9, $1)])\n" + + "LogicalProject(EXPR$0=[ROW(CAST($9):TINYINT NOT NULL, $1)])\n" + " LogicalJoin(condition=[=($7, $8)], joinType=[left])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" + " LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n" @@ -1538,7 +1539,7 @@ public static Frameworks.ConfigBuilder config() { final String planBefore = "" + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" - + " LogicalFilter(condition=[=($0, CAST($8):SMALLINT)])\n" + + " LogicalFilter(condition=[=($0, CAST($8):SMALLINT NOT NULL)])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{1}])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" + " LogicalAggregate(group=[{}], agg#0=[SINGLE_VALUE($0)])\n" @@ -1557,7 +1558,7 @@ public static Frameworks.ConfigBuilder config() { final String planAfter = "" + "LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7])\n" + " LogicalProject(EMPNO=[$0], ENAME=[$1], JOB=[$2], MGR=[$3], HIREDATE=[$4], SAL=[$5], COMM=[$6], DEPTNO=[$7], ENAME0=[$8], $f1=[CAST($9):TINYINT])\n" - + " LogicalJoin(condition=[AND(=($1, $8), =($0, CAST($9):SMALLINT))], joinType=[inner])\n" + + " LogicalJoin(condition=[AND(=($1, $8), =($0, CAST($9):SMALLINT NOT NULL))], joinType=[inner])\n" + " LogicalTableScan(table=[[scott, EMP]])\n" + " LogicalAggregate(group=[{0}], agg#0=[SINGLE_VALUE($1)])\n" + " LogicalProject(ENAME=[$3], DEPTNO=[$0])\n" @@ -1700,7 +1701,7 @@ public static Frameworks.ConfigBuilder config() { program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), Collections.emptyList(), Collections.emptyList()); final String planBefore = "" - + "LogicalProject(DNAME=[$1], MATCHED_SUBORDINATE_COUNT=[$3])\n" + + "LogicalProject(DNAME=[$1], MATCHED_SUBORDINATE_COUNT=[CAST($3):BIGINT NOT NULL])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" + " LogicalAggregate(group=[{}], EXPR$0=[COUNT($0)])\n" @@ -1717,7 +1718,7 @@ public static Frameworks.ConfigBuilder config() { RelDecorrelator.decorrelateQuery(before, builder, RuleSets.ofList(Collections.emptyList()), RuleSets.ofList(Collections.emptyList())); final String planAfter = "" - + "LogicalProject(DNAME=[$1], MATCHED_SUBORDINATE_COUNT=[$4])\n" + + "LogicalProject(DNAME=[$1], MATCHED_SUBORDINATE_COUNT=[CAST($4):BIGINT NOT NULL])\n" + " LogicalJoin(condition=[=($0, $3)], joinType=[left])\n" + " LogicalTableScan(table=[[scott, DEPT]])\n" + " LogicalProject(_cor_$cor0_0=[$0], EXPR$0=[CASE(IS NOT NULL($2), $2, 0)])\n" @@ -2332,7 +2333,7 @@ public static Frameworks.ConfigBuilder config() { // Verify decorrelation produced a valid plan (no Correlate nodes) final String planAfter = "" - + "LogicalProject(DEPTNO=[$0], EXPR$1=[$4])\n" + + "LogicalProject(DEPTNO=[$0], EXPR$1=[CAST($4):BIGINT NOT NULL])\n" + " LogicalJoin(condition=[AND(IS NOT DISTINCT FROM($0, $2), IS NOT DISTINCT FROM($1, $3))], joinType=[left])\n" + " LogicalAggregate(group=[{0}], S=[SUM($1)])\n" + " LogicalProject(DEPTNO=[$7], SAL=[$5])\n" diff --git a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java index e53f1027f78d..3f4f5511251a 100644 --- a/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java +++ b/core/src/test/java/org/apache/calcite/test/RelBuilderTest.java @@ -53,6 +53,7 @@ import org.apache.calcite.rex.RexFieldCollation; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; +import org.apache.calcite.rex.RexSubQuery; import org.apache.calcite.rex.RexWindowBounds; import org.apache.calcite.runtime.CalciteException; import org.apache.calcite.schema.SchemaPlus; @@ -76,6 +77,7 @@ import org.apache.calcite.sql.type.SqlOperandMetadata; import org.apache.calcite.sql.type.SqlTypeFamily; import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; import org.apache.calcite.sql.validate.SqlUserDefinedTableFunction; import org.apache.calcite.test.schemata.hr.HrSchema; import org.apache.calcite.tools.Frameworks; @@ -968,6 +970,102 @@ private void project1(int value, SqlTypeName sqlTypeName, String message, String + " LogicalValues(tuples=[[]])\n"); } + /** Test case for + * [CALCITE-2901] + * RexSubQuery.scalar needs to allow specifying a different nullability value + * instead of the hard coded "true" value. */ + @Test void testScalarSubQueryPreservesColumnNullability() { + final RelDataTypeFactory tf = + RelBuilder.create(config().build()).getCluster().getTypeFactory(); + + // INTEGER NOT NULL literal row (historical bug: scalar type was forced nullable). + final RelBuilder b0 = RelBuilder.create(config().build()); + assertScalarSubQueryTypeMatchesColumn(b0.values(new String[]{"x"}, 42).build()); + + // Nullable INTEGER explicit row type. + final RelBuilder b1 = RelBuilder.create(config().build()); + final RelDataType nullableInt = + tf.createTypeWithNullability(tf.createSqlType(SqlTypeName.INTEGER), true); + final RelDataType nullableRow = tf.builder().add("y", nullableInt).build(); + assertScalarSubQueryTypeMatchesColumn(b1.values(nullableRow, 1).build()); + + // DECIMAL(10, 2) NOT NULL — precision and scale must survive copyType. + final RelBuilder b2 = RelBuilder.create(config().build()); + final RelDataType dec = + tf.createTypeWithNullability(tf.createSqlType(SqlTypeName.DECIMAL, 10, 2), false); + final RelDataType decRow = tf.builder().add("d", dec).build(); + assertScalarSubQueryTypeMatchesColumn( + b2.values(decRow, new BigDecimal("123.45")).build()); + + // CHAR(5) NOT NULL. + final RelBuilder b3 = RelBuilder.create(config().build()); + final RelDataType ch = + tf.createTypeWithNullability(tf.createSqlType(SqlTypeName.CHAR, 5), false); + final RelDataType chRow = tf.builder().add("c", ch).build(); + assertScalarSubQueryTypeMatchesColumn(b3.values(chRow, "abcde").build()); + + // VARCHAR(10) NULLABLE with a non-null cell — column type stays nullable. + final RelBuilder b4 = RelBuilder.create(config().build()); + final RelDataType vc = + tf.createTypeWithNullability(tf.createSqlType(SqlTypeName.VARCHAR, 10), true); + final RelDataType vcRow = tf.builder().add("s", vc).build(); + assertScalarSubQueryTypeMatchesColumn(b4.values(vcRow, "abc").build()); + + // BOOLEAN NOT NULL literal. + final RelBuilder b5 = RelBuilder.create(config().build()); + assertScalarSubQueryTypeMatchesColumn(b5.values(new String[]{"f"}, true).build()); + + // REAL NOT NULL. + final RelBuilder b6 = RelBuilder.create(config().build()); + final RelDataType realNn = + tf.createTypeWithNullability(tf.createSqlType(SqlTypeName.REAL), false); + final RelDataType realRow = tf.builder().add("r", realNn).build(); + assertScalarSubQueryTypeMatchesColumn(b6.values(realRow, 1.25F).build()); + + // Zero-row VALUES: only declared row type exists; nullability must still match. + final RelBuilder b7 = RelBuilder.create(config().build()); + final RelDataType nnInt = + tf.createTypeWithNullability(tf.createSqlType(SqlTypeName.INTEGER), false); + final RelDataType emptyRow = tf.builder().add("k", nnInt).build(); + assertScalarSubQueryTypeMatchesColumn(b7.values(emptyRow).build()); + + // Single column projected from a catalog table (EMPNO is NOT NULL in SCOTT). + final RelBuilder b8 = RelBuilder.create(config().build()); + assertScalarSubQueryTypeMatchesColumn( + b8.scan("EMP").project(b8.field("EMPNO")).build()); + } + + /** Companion to {@link #testScalarSubQueryPreservesColumnNullability}: arity validation. */ + @Test void testScalarSubQueryRequiresExactlyOneOutputColumn() { + final RelBuilder b = RelBuilder.create(config().build()); + final RelNode twoCols = + b.scan("EMP").project(b.field("EMPNO"), b.field("ENAME")).build(); + final IllegalArgumentException ex2 = + assertThrows(IllegalArgumentException.class, () -> RexSubQuery.scalar(twoCols)); + assertThat(ex2.getMessage(), nullValue()); + + final RelBuilder b2 = RelBuilder.create(config().build()); + final RelNode twoLiteralCols = b2.values(new String[]{"a", "b"}, 1, 2).build(); + final IllegalArgumentException exLit = + assertThrows(IllegalArgumentException.class, () -> RexSubQuery.scalar(twoLiteralCols)); + assertThat(exLit.getMessage(), nullValue()); + + final RelBuilder b3 = RelBuilder.create(config().build()); + final RelNode threeCols = + b3.scan("EMP").project(b3.field("EMPNO"), b3.field("ENAME"), b3.field("JOB")).build(); + assertThrows(IllegalArgumentException.class, () -> RexSubQuery.scalar(threeCols)); + } + + /** Asserts {@link RexSubQuery#scalar} matches the lone column type (CALCITE-2901). */ + private static void assertScalarSubQueryTypeMatchesColumn(RelNode singleColumnRel) { + final RelDataType columnType = + singleColumnRel.getRowType().getFieldList().get(0).getType(); + final RelDataType scalarType = RexSubQuery.scalar(singleColumnRel).getType(); + final RelDataTypeFactory tf = singleColumnRel.getCluster().getTypeFactory(); + assertThat(SqlTypeUtil.equalSansNullability(tf, columnType, scalarType), is(true)); + assertThat(scalarType.isNullable(), is(columnType.isNullable())); + } + @Test void testProjectBloat() { final Function f = b -> b.scan("EMP") diff --git a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java index 9bdd3177ff29..4c1a836f28df 100644 --- a/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java +++ b/core/src/test/java/org/apache/calcite/test/SqlValidatorTest.java @@ -555,7 +555,7 @@ static SqlOperatorTable operatorTableFor(SqlLibrary library) { @Test void testBinaryStringFails() { expr("select x'ffee'='abc' from (values(true))") - .columnType("BOOLEAN"); + .columnType("BOOLEAN NOT NULL"); sql("select ^x'ffee'='abc'^ from (values(true))") .withTypeCoercion(false) .fails("(?s).*Cannot apply '=' to arguments of type " @@ -2134,7 +2134,7 @@ void testLikeAndSimilarFails() { .columnType("RecordType(INTEGER NOT NULL EXPR$0, VARCHAR(20) NOT NULL EXPR$1) NOT NULL"); sql("select row((select deptno from dept where dept.deptno = emp.deptno), emp.ename)\n" + "from emp") - .columnType("RecordType(INTEGER EXPR$0, VARCHAR(20) NOT NULL EXPR$1) NOT NULL"); + .columnType("RecordType(INTEGER NOT NULL EXPR$0, VARCHAR(20) NOT NULL EXPR$1) NOT NULL"); sql("select ROW^(x'12') <> ROW(0.01)^") .fails("Cannot apply '<>' to arguments of type.*"); // Test cases for [CALCITE-6911] https://issues.apache.org/jira/browse/CALCITE-6911 @@ -9980,18 +9980,17 @@ void testGroupExpressionEquivalenceParams() { + "INTEGER HISAL\\)>\\)'\\. Supported form\\(s\\): " + "'\\$SCALAR_QUERY\\(\\)'"); - // Note that X is a field (not a record) and is nullable even though - // EMP.NAME is NOT NULL. + // Scalar subquery type copies the select column type (including nullability). sql("SELECT ename,(select name from dept where deptno=1) FROM emp") - .type("RecordType(VARCHAR(20) NOT NULL ENAME, VARCHAR(10) EXPR$1) NOT NULL"); + .type("RecordType(VARCHAR(20) NOT NULL ENAME, VARCHAR(10) NOT NULL EXPR$1) NOT NULL"); // scalar subqery inside AS operator sql("SELECT ename,(select name from dept where deptno=1) as X FROM emp") - .type("RecordType(VARCHAR(20) NOT NULL ENAME, VARCHAR(10) X) NOT NULL"); + .type("RecordType(VARCHAR(20) NOT NULL ENAME, VARCHAR(10) NOT NULL X) NOT NULL"); // scalar subqery inside + operator sql("SELECT ename, 1 + (select deptno from dept where deptno=1) as X FROM emp") - .type("RecordType(VARCHAR(20) NOT NULL ENAME, INTEGER X) NOT NULL"); + .type("RecordType(VARCHAR(20) NOT NULL ENAME, INTEGER NOT NULL X) NOT NULL"); // scalar sub-query inside WHERE sql("select * from emp where (select true from dept)").ok(); diff --git a/core/src/test/java/org/apache/calcite/test/TableFunctionTest.java b/core/src/test/java/org/apache/calcite/test/TableFunctionTest.java index e34845c1e650..2bc64c9f0820 100644 --- a/core/src/test/java/org/apache/calcite/test/TableFunctionTest.java +++ b/core/src/test/java/org/apache/calcite/test/TableFunctionTest.java @@ -123,7 +123,7 @@ private CalciteAssert.AssertThat with() { + "from (values (2), (4)) as t (x)"; ResultSet resultSet = connection.createStatement().executeQuery(sql); assertThat(CalciteAssert.toString(resultSet), - equalTo("X=2; EXPR$1=null\nX=4; EXPR$1=null\n")); + equalTo("X=2; EXPR$1=0\nX=4; EXPR$1=0\n")); } } diff --git a/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java b/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java index e4f521d0bb2d..d2847355f81e 100644 --- a/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java +++ b/core/src/test/java/org/apache/calcite/test/TypeCoercionTest.java @@ -219,12 +219,12 @@ private static ImmutableList combine( sql("select '1' from (values(true)) union values 2") .type("RecordType(VARCHAR NOT NULL EXPR$0) NOT NULL"); sql("select (select 1+2 from (values true)) tt from (values(true)) union values '2'") - .type("RecordType(VARCHAR TT) NOT NULL"); + .type("RecordType(VARCHAR NOT NULL TT) NOT NULL"); // union with star sql("select * from (values(1, '3')) union select * from (values('2', 4))") .type("RecordType(VARCHAR NOT NULL EXPR$0, VARCHAR NOT NULL EXPR$1) NOT NULL"); sql("select 1 from (values(true)) union values (select '1' from (values (true)) as tt)") - .type("RecordType(VARCHAR EXPR$0) NOT NULL"); + .type("RecordType(VARCHAR NOT NULL EXPR$0) NOT NULL"); // union with func sql("select LOCALTIME from (values(true)) union values '1'") .type("RecordType(VARCHAR NOT NULL LOCALTIME) NOT NULL"); @@ -702,16 +702,16 @@ private static ImmutableList combine( .fails("(?s).*Cannot apply.*"); // smallint int double expr("select t1_smallint||t1_int||t1_double from t1") - .columnType("VARCHAR"); + .columnType("VARCHAR NOT NULL"); // boolean float smallint expr("select t1_boolean||t1_real||t1_smallint from t1") - .columnType("VARCHAR"); + .columnType("VARCHAR NOT NULL"); // decimal expr("select t1_decimal||t1_varchar20 from t1") - .columnType("VARCHAR"); + .columnType("VARCHAR NOT NULL"); // date timestamp expr("select t1_timestamp||t1_date from t1") - .columnType("VARCHAR"); + .columnType("VARCHAR NOT NULL"); } /** Test case for {@link TypeCoercion#querySourceCoercion}. */ diff --git a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml index 29e539a0f7b5..b5b181dc361c 100644 --- a/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml +++ b/core/src/test/resources/org/apache/calcite/test/RelOptRulesTest.xml @@ -915,15 +915,15 @@ LogicalAggregate(group=[{}], EXPR$0=[MIN($0)], EXPR$1=[MAX($0)]) @@ -2197,7 +2197,7 @@ LogicalAggregate(group=[{}], EXPR$0=[MAX($0)]) (10, $2)]) + LogicalFilter(condition=[>(10, CAST($2):INTEGER NOT NULL)]) LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}]) LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) LogicalSort(sort0=[$0], dir0=[ASC], fetch=[1]) @@ -2649,7 +2649,7 @@ LogicalSort(sort0=[$0], dir0=[DESC-nulls-last], fetch=[1]) (10, $2)]) + LogicalFilter(condition=[>(10, CAST($2):INTEGER NOT NULL)]) LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}]) LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) LogicalSort(sort0=[$0], dir0=[DESC-nulls-last], fetch=[1]) @@ -2691,7 +2691,7 @@ LogicalSort(sort0=[$0], dir0=[DESC], fetch=[1]) (10, $2)]) + LogicalFilter(condition=[>(10, CAST($2):INTEGER NOT NULL)]) LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}]) LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) LogicalSort(sort0=[$0], dir0=[DESC], fetch=[1]) @@ -2734,7 +2734,7 @@ LogicalProject(SAL=[$0]) (10, $2)]) + LogicalFilter(condition=[>(10, CAST($2):INTEGER NOT NULL)]) LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}]) LogicalTableScan(table=[[CATALOG, SALES, DEPT]]) LogicalProject(SAL=[$0]) @@ -2775,7 +2775,7 @@ LogicalSort(sort0=[$0], dir0=[ASC], fetch=[1]) ($5, 1000)]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) @@ -7217,7 +7217,7 @@ QUALIFY rank_val = (SELECT COUNT(*) FROM emp)]]> ($7, $9)]) + LogicalFilter(condition=[>($7, CAST($9):INTEGER NOT NULL)]) LogicalJoin(condition=[true], joinType=[left]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) LogicalSort(sort0=[$0], dir0=[ASC], fetch=[1]) @@ -8550,7 +8550,7 @@ where deptno > (values 10)]]> ($7, $9)]) + LogicalFilter(condition=[>($7, CAST($9):INTEGER NOT NULL)]) LogicalJoin(condition=[true], joinType=[left]) LogicalTableScan(table=[[CATALOG, SALES, EMP]]) LogicalValues(tuples=[[{ 10 }]]) @@ -10352,7 +10352,7 @@ from emp]]> 10; !ok !if (use_old_decorr) { -EnumerableCalc(expr#0..3=[{inputs}], expr#4=[IS NULL($t3)], expr#5=[0:BIGINT], expr#6=[CASE($t4, $t5, $t3)], EMPNO=[$t0], $f1=[$t6]) +EnumerableCalc(expr#0..3=[{inputs}], expr#4=[IS NULL($t3)], expr#5=[0:BIGINT], expr#6=[CASE($t4, $t5, $t3)], expr#7=[CAST($t6):BIGINT NOT NULL], EMPNO=[$t0], $f1=[$t7]) EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($1, $2)], joinType=[left]) EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)], expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], EMPNO=[$t0], MGR=[$t3], $condition=[$t10]) EnumerableTableScan(table=[[scott, EMP]]) @@ -388,7 +388,7 @@ EnumerableCalc(expr#0..3=[{inputs}], expr#4=[IS NULL($t3)], expr#5=[0:BIGINT], e !} !if (use_new_decorr) { -EnumerableCalc(expr#0..3=[{inputs}], EMPNO=[$t0], $f1=[$t3]) +EnumerableCalc(expr#0..3=[{inputs}], expr#4=[CAST($t3):BIGINT NOT NULL], EMPNO=[$t0], $f1=[$t4]) EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($1, $2)], joinType=[left]) EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)], expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], EMPNO=[$t0], MGR=[$t3], $condition=[$t10]) EnumerableTableScan(table=[[scott, EMP]]) @@ -437,7 +437,7 @@ WHERE sal > 10; !ok !if (use_old_decorr) { -EnumerableCalc(expr#0..9=[{inputs}], expr#10=[IS NULL($t9)], expr#11=[0:BIGINT], expr#12=[CASE($t10, $t11, $t9)], EMPNO=[$t0], $f1=[$t12]) +EnumerableCalc(expr#0..9=[{inputs}], expr#10=[IS NULL($t9)], expr#11=[0:BIGINT], expr#12=[CASE($t10, $t11, $t9)], expr#13=[CAST($t12):BIGINT NOT NULL], EMPNO=[$t0], $f1=[$t13]) EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($3, $8)], joinType=[left]) EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)], expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], proj#0..7=[{exprs}], $condition=[$t10]) EnumerableTableScan(table=[[scott, EMP]]) @@ -456,7 +456,7 @@ EnumerableCalc(expr#0..9=[{inputs}], expr#10=[IS NULL($t9)], expr#11=[0:BIGINT], !} !if (use_new_decorr) { -EnumerableCalc(expr#0..10=[{inputs}], EMPNO=[$t0], $f1=[$t10]) +EnumerableCalc(expr#0..10=[{inputs}], expr#11=[CAST($t10):BIGINT NOT NULL], EMPNO=[$t0], $f1=[$t11]) EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($3, $8)], joinType=[left]) EnumerableCalc(expr#0..7=[{inputs}], expr#8=[CAST($t5):DECIMAL(12, 2)], expr#9=[10.00:DECIMAL(12, 2)], expr#10=[>($t8, $t9)], proj#0..7=[{exprs}], $condition=[$t10]) EnumerableTableScan(table=[[scott, EMP]]) diff --git a/core/src/test/resources/sql/sub-query.iq b/core/src/test/resources/sql/sub-query.iq index 4b699d14b85f..2f9a104c9bfd 100644 --- a/core/src/test/resources/sql/sub-query.iq +++ b/core/src/test/resources/sql/sub-query.iq @@ -4490,7 +4490,7 @@ select (select (1, 2)); !ok !if (use_old_decorr) { -EnumerableCalc(expr#0=[{inputs}], expr#1=[1], expr#2=[2], expr#3=[ROW($t1, $t2)], expr#4=[CAST($t3):RecordType(INTEGER EXPR$0, INTEGER EXPR$1)], EXPR$0=[$t4]) +EnumerableCalc(expr#0=[{inputs}], expr#1=[1], expr#2=[2], expr#3=[ROW($t1, $t2)], EXPR$0=[$t3]) EnumerableValues(tuples=[[{ 0 }]]) !plan !} @@ -5866,7 +5866,7 @@ select * from emp where deptno <> (select count(deptno) from dept where dept.dep !ok !if (use_old_decorr) { -EnumerableCalc(expr#0..4=[{inputs}], expr#5=[IS NULL($t4)], expr#6=[CAST($t1):BIGINT], expr#7=[0:BIGINT], expr#8=[<>($t6, $t7)], expr#9=[AND($t5, $t8)], expr#10=[<>($t6, $t4)], expr#11=[OR($t9, $t10)], proj#0..2=[{exprs}], $condition=[$t11]) +EnumerableCalc(expr#0..4=[{inputs}], expr#5=[CAST($t1):BIGINT], expr#6=[IS NULL($t4)], expr#7=[0:BIGINT], expr#8=[CASE($t6, $t7, $t4)], expr#9=[CAST($t8):BIGINT NOT NULL], expr#10=[<>($t5, $t9)], proj#0..2=[{exprs}], $condition=[$t10]) EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($1, $3)], joinType=[left]) EnumerableValues(tuples=[[{ 'Jane ', 10, 'F' }, { 'Bob ', 10, 'M' }, { 'Eric ', 20, 'M' }, { 'Susan', 30, 'F' }, { 'Alice', 30, 'F' }, { 'Adam ', 50, 'M' }, { 'Eve ', 50, 'F' }, { 'Grace', 60, 'F' }, { 'Wilma', null, 'F' }]]) EnumerableCalc(expr#0..2=[{inputs}], expr#3=[IS NOT NULL($t2)], expr#4=[0], expr#5=[CASE($t3, $t2, $t4)], DEPTNO=[$t0], EXPR$0=[$t5]) @@ -5894,7 +5894,7 @@ select * from emp where deptno <> (select count(deptno) + 10 from dept where de !ok !if (use_old_decorr) { -EnumerableCalc(expr#0..4=[{inputs}], expr#5=[CAST($t1):BIGINT], expr#6=[IS NULL($t4)], expr#7=[0:BIGINT], expr#8=[CASE($t6, $t7, $t4)], expr#9=[10], expr#10=[+($t8, $t9)], expr#11=[<>($t5, $t10)], proj#0..2=[{exprs}], $condition=[$t11]) +EnumerableCalc(expr#0..4=[{inputs}], expr#5=[CAST($t1):BIGINT], expr#6=[IS NULL($t4)], expr#7=[0:BIGINT], expr#8=[CASE($t6, $t7, $t4)], expr#9=[10], expr#10=[+($t8, $t9)], expr#11=[CAST($t10):BIGINT NOT NULL], expr#12=[<>($t5, $t11)], proj#0..2=[{exprs}], $condition=[$t12]) EnumerableHashJoin(condition=[IS NOT DISTINCT FROM($1, $3)], joinType=[left]) EnumerableValues(tuples=[[{ 'Jane ', 10, 'F' }, { 'Bob ', 10, 'M' }, { 'Eric ', 20, 'M' }, { 'Susan', 30, 'F' }, { 'Alice', 30, 'F' }, { 'Adam ', 50, 'M' }, { 'Eve ', 50, 'F' }, { 'Grace', 60, 'F' }, { 'Wilma', null, 'F' }]]) EnumerableCalc(expr#0=[{inputs}], expr#1=[0], expr#2=[CAST($t1):BIGINT NOT NULL], DEPTNO=[$t0], $f1=[$t2]) @@ -6217,7 +6217,7 @@ FROM dept; | 10 | ACCOUNTING | NEW YORK | 3 | | 20 | RESEARCH | DALLAS | 5 | | 30 | SALES | CHICAGO | 6 | -| 40 | OPERATIONS | BOSTON | | +| 40 | OPERATIONS | BOSTON | 0 | +--------+------------+----------+-----------------+ (4 rows) @@ -6242,7 +6242,7 @@ FROM dept; | 10 | ACCOUNTING | NEW YORK | 1 | | 20 | RESEARCH | DALLAS | 1 | | 30 | SALES | CHICAGO | 1 | -| 40 | OPERATIONS | BOSTON | | +| 40 | OPERATIONS | BOSTON | 0 | +--------+------------+----------+-----------------+ (4 rows) @@ -9660,7 +9660,7 @@ from emp as e; !ok !if (use_old_decorr) { -EnumerableCalc(expr#0..6=[{inputs}], expr#7=[IS NULL($t6)], expr#8=[0:BIGINT], expr#9=[CASE($t7, $t8, $t6)], ENAME=[$t1], C=[$t9]) +EnumerableCalc(expr#0..6=[{inputs}], expr#7=[IS NULL($t6)], expr#8=[0:BIGINT], expr#9=[CASE($t7, $t8, $t6)], expr#10=[CAST($t9):BIGINT NOT NULL], ENAME=[$t1], C=[$t10]) EnumerableHashJoin(condition=[AND(IS NOT DISTINCT FROM($2, $4), =($3, $5))], joinType=[left]) EnumerableCalc(expr#0..7=[{inputs}], expr#8=[IS NULL($t6)], proj#0..1=[{exprs}], COMM=[$t6], $f3=[$t8]) EnumerableTableScan(table=[[scott, EMP]]) diff --git a/plus/src/test/java/org/apache/calcite/adapter/tpch/TpchTest.java b/plus/src/test/java/org/apache/calcite/adapter/tpch/TpchTest.java index 39aa2f3f16e0..4b79a3b299f6 100644 --- a/plus/src/test/java/org/apache/calcite/adapter/tpch/TpchTest.java +++ b/plus/src/test/java/org/apache/calcite/adapter/tpch/TpchTest.java @@ -1054,7 +1054,7 @@ private CalciteAssert.AssertThat with() { program.run(cluster.getPlanner(), originalRel, cluster.traitSet(), Collections.emptyList(), Collections.emptyList()); final String planBefore = "" - + "LogicalProject(EXPR$0=[$5], EXPR$1=[$6])\n" + + "LogicalProject(EXPR$0=[CAST($5):BIGINT NOT NULL], EXPR$1=[CAST($6):BIGINT NOT NULL])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n" + " LogicalTableScan(table=[[TPCH, PARTSUPP]])\n" @@ -1072,7 +1072,7 @@ private CalciteAssert.AssertThat with() { RuleSets.ofList(Collections.emptyList()), RuleSets.ofList(Collections.emptyList())); final String planAfter = "" - + "LogicalProject(EXPR$0=[$5], EXPR$1=[$8])\n" + + "LogicalProject(EXPR$0=[CAST($5):BIGINT NOT NULL], EXPR$1=[CAST($8):BIGINT NOT NULL])\n" + " LogicalJoin(condition=[IS NOT DISTINCT FROM($6, $7)], joinType=[left])\n" + " LogicalProject(PS_PARTKEY=[$0], PS_SUPPKEY=[$1], PS_AVAILQTY=[$2], PS_SUPPLYCOST=[$3], PS_COMMENT=[$4], EXPR$0=[$6], $f6=[+($0, 1)])\n" + " LogicalJoin(condition=[IS NOT DISTINCT FROM($0, $5)], joinType=[left])\n" @@ -1169,7 +1169,7 @@ private CalciteAssert.AssertThat with() { final String planBefore = "" + "LogicalProject(C_CUSTKEY=[$0], C_NAME=[$1], C_ADDRESS=[$2], C_NATIONKEY=[$3], C_PHONE=[$4], C_ACCTBAL=[$5], C_MKTSEGMENT=[$6], C_COMMENT=[$7])\n" + " LogicalProject(C_CUSTKEY=[$0], C_NAME=[$1], C_ADDRESS=[$2], C_NATIONKEY=[$3], C_PHONE=[$4], C_ACCTBAL=[$5], C_MKTSEGMENT=[$6], C_COMMENT=[$7])\n" - + " LogicalFilter(condition=[AND(=(CAST($6):VARCHAR, 'AUTOMOBILE'), >($8, 5))])\n" + + " LogicalFilter(condition=[AND(=(CAST($6):VARCHAR, 'AUTOMOBILE'), >(CAST($8):BIGINT NOT NULL, 5))])\n" + " LogicalCorrelate(correlation=[$cor0], joinType=[left], requiredColumns=[{0}])\n" + " LogicalTableScan(table=[[TPCH, CUSTOMER]])\n" + " LogicalAggregate(group=[{}], EXPR$0=[COUNT()])\n" diff --git a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java index 080a704a8cc8..b5700d7991fd 100644 --- a/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java +++ b/testkit/src/main/java/org/apache/calcite/test/SqlOperatorTest.java @@ -3502,7 +3502,7 @@ static void checkOverlaps(OverlapChecker c) { final SqlOperatorFixture f0 = fixture(); // This query does not fail if checked arithmetic is not used f0.checkScalar("SELECT -CAST(-32768 AS SMALLINT)", - "-32768", "SMALLINT"); + "-32768", "SMALLINT NOT NULL"); // The last two queries should fail in any conformance level // because the value "32768" cannot be represented as a SMALLINT f0.checkFails("SELECT CAST(32768 AS SMALLINT)", @@ -16419,6 +16419,8 @@ void testTimestampDiff(boolean coercionEnabled) { * ANY/SOME, ALL operators should support collection expressions. */ @Test void testQuantifyCollectionOperators() { final SqlOperatorFixture f = fixture(); + final String someScalarQuantifyBooleanType = "BOOLEAN"; + final String allScalarQuantifyBooleanType = "BOOLEAN NOT NULL"; QUANTIFY_OPERATORS.forEach(operator -> f.setFor(operator, SqlOperatorFixture.VmName.EXPAND)); Function2 checkBoolean = (sql, result) -> { @@ -16470,7 +16472,7 @@ void testTimestampDiff(boolean coercionEnabled) { "BOOLEAN", isNullValue()); f.check("SELECT (SELECT * FROM UNNEST(ARRAY[3]) LIMIT 1) = " + "some(x.t) FROM (SELECT ARRAY[1,2,3,null] as t) as x", - "BOOLEAN", true); + someScalarQuantifyBooleanType, true); checkNull.apply("1 = all (COLLECTION[1,1,null])"); @@ -16505,7 +16507,7 @@ void testTimestampDiff(boolean coercionEnabled) { "BOOLEAN", isNullValue()); f.check("SELECT (SELECT * FROM UNNEST(ARRAY[3]) LIMIT 1) = " + "all(x.t) FROM (SELECT ARRAY[3,3] as t) as x", - "BOOLEAN", true); + allScalarQuantifyBooleanType, true); } @Test void testQuantifyOperatorsWithTypeException() {