diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/PreValidateReWriter.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/PreValidateReWriter.scala
index 8fc91d4a849a5..4f75fad4d7c9d 100644
--- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/PreValidateReWriter.scala
+++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/calcite/PreValidateReWriter.scala
@@ -21,8 +21,9 @@ package org.apache.flink.table.planner.calcite
import org.apache.flink.sql.parser.SqlProperty
import org.apache.flink.sql.parser.dml.RichSqlInsert
import org.apache.flink.table.api.ValidationException
-import org.apache.flink.table.planner.calcite.PreValidateReWriter.appendPartitionAndNullsProjects
+import org.apache.flink.table.planner.calcite.PreValidateReWriter.{appendPartitionAndNullsProjects, notSupported}
import org.apache.flink.table.planner.plan.schema.{CatalogSourceTable, FlinkPreparingTableBase, LegacyCatalogSourceTable}
+import org.apache.flink.util.Preconditions.checkArgument
import org.apache.calcite.plan.RelOptTable
import org.apache.calcite.prepare.CalciteCatalogReader
@@ -33,7 +34,7 @@ import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.sql.parser.SqlParserPos
import org.apache.calcite.sql.util.SqlBasicVisitor
import org.apache.calcite.sql.validate.{SqlValidatorException, SqlValidatorTable, SqlValidatorUtil}
-import org.apache.calcite.sql.{SqlCall, SqlIdentifier, SqlKind, SqlLiteral, SqlNode, SqlNodeList, SqlSelect, SqlUtil}
+import org.apache.calcite.sql.{SqlCall, SqlIdentifier, SqlKind, SqlLiteral, SqlNode, SqlNodeList, SqlOrderBy, SqlSelect, SqlUtil}
import org.apache.calcite.util.Static.RESOURCE
import java.util
@@ -50,16 +51,11 @@ class PreValidateReWriter(
call match {
case r: RichSqlInsert
if r.getStaticPartitions.nonEmpty || r.getTargetColumnList != null => r.getSource match {
- case select: SqlSelect =>
- appendPartitionAndNullsProjects(r, validator, typeFactory, select, r.getStaticPartitions)
- case values: SqlCall if values.getKind == SqlKind.VALUES =>
- val newSource = appendPartitionAndNullsProjects(r, validator, typeFactory, values,
- r.getStaticPartitions)
+ case call: SqlCall =>
+ val newSource = appendPartitionAndNullsProjects(
+ r, validator, typeFactory, call, r.getStaticPartitions)
r.setOperand(2, newSource)
- case source =>
- throw new ValidationException(
- s"INSERT INTO
PARTITION [(COLUMN LIST)] statement only support "
- + s"SELECT and VALUES clause for now, '$source' is not supported yet.")
+ case source => throw new ValidationException(notSupported(source))
}
case _ =>
}
@@ -67,7 +63,14 @@ class PreValidateReWriter(
}
object PreValidateReWriter {
+
//~ Tools ------------------------------------------------------------------
+
+ private def notSupported(source: SqlNode): String = {
+ s"INSERT INTO PARTITION [(COLUMN LIST)] statement only support " +
+ s"SELECT, VALUES, SET_QUERY AND ORDER BY clause for now, '$source' is not supported yet."
+ }
+
/**
* Append the static partitions and unspecified columns to the data source projection list.
* The columns are appended to the corresponding positions.
@@ -108,7 +111,6 @@ object PreValidateReWriter {
typeFactory: RelDataTypeFactory,
source: SqlCall,
partitions: SqlNodeList): SqlCall = {
- assert(source.getKind == SqlKind.SELECT || source.getKind == SqlKind.VALUES)
val calciteCatalogReader = validator.getCatalogReader.unwrap(classOf[CalciteCatalogReader])
val names = sqlInsert.getTargetTable.asInstanceOf[SqlIdentifier].names
val table = calciteCatalogReader.getTable(names)
@@ -185,11 +187,49 @@ object PreValidateReWriter {
}
}
- source match {
- case select: SqlSelect =>
- rewriteSelect(validator, select, targetRowType, assignedFields, targetPosition)
- case values: SqlCall if values.getKind == SqlKind.VALUES =>
- rewriteValues(values, targetRowType, assignedFields, targetPosition)
+ rewriteSqlCall(validator, source, targetRowType, assignedFields, targetPosition)
+ }
+
+ private def rewriteSqlCall(
+ validator: FlinkCalciteSqlValidator,
+ call: SqlCall,
+ targetRowType: RelDataType,
+ assignedFields: util.LinkedHashMap[Integer, SqlNode],
+ targetPosition: util.List[Int]): SqlCall = {
+
+ def rewrite(node: SqlNode): SqlCall = {
+ checkArgument(node.isInstanceOf[SqlCall], node)
+ rewriteSqlCall(
+ validator,
+ node.asInstanceOf[SqlCall],
+ targetRowType,
+ assignedFields,
+ targetPosition)
+ }
+
+ call.getKind match {
+ case SqlKind.SELECT =>
+ rewriteSelect(
+ validator, call.asInstanceOf[SqlSelect], targetRowType, assignedFields, targetPosition)
+ case SqlKind.VALUES =>
+ rewriteValues(call, targetRowType, assignedFields, targetPosition)
+ case kind if SqlKind.SET_QUERY.contains(kind) =>
+ call.getOperandList.zipWithIndex.foreach {
+ case (operand, index) => call.setOperand(index, rewrite(operand))
+ }
+ call
+ case SqlKind.ORDER_BY =>
+ val operands = call.getOperandList
+ new SqlOrderBy(
+ call.getParserPosition,
+ rewrite(operands.get(0)),
+ operands.get(1).asInstanceOf[SqlNodeList],
+ operands.get(2),
+ operands.get(3))
+ // Not support:
+ // case SqlKind.WITH =>
+ // case SqlKind.EXPLICIT_TABLE =>
+ case _ => throw new ValidationException(notSupported(call))
}
}
diff --git a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/common/PartialInsertTest.xml b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/common/PartialInsertTest.xml
index b900cb65d118a..aa9aa501e26c0 100644
--- a/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/common/PartialInsertTest.xml
+++ b/flink-table/flink-table-planner-blink/src/test/resources/org/apache/flink/table/planner/plan/common/PartialInsertTest.xml
@@ -35,6 +35,51 @@ Sink(table=[default_catalog.default_database.sink], fields=[a, b, c, d, e, f, g]
+- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+- Exchange(distribution=[hash[a, b, c, d, e]])
+- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+]]>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
@@ -62,48 +107,432 @@ Sink(table=[default_catalog.default_database.partitioned_sink], fields=[a, c, d,
]]>
-
+
-
+
(sum_vcol_marker, 0)])
+ +- GroupAggregate(groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, SUM_RETRACT(vcol_marker) AS sum_vcol_marker])
+ +- Exchange(distribution=[hash[a, c, d, e, f, g]])
+ +- Union(all=[true], union=[a, c, d, e, f, g, vcol_marker])
+ :- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, 1:BIGINT AS vcol_marker])
+ : +- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ : +- Exchange(distribution=[hash[a, b, c, d, e]])
+ : +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+ +- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, -1:BIGINT AS vcol_marker])
+ +- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ +- Exchange(distribution=[hash[a, b, c, d, e]])
+ +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
]]>
-
+
-
+
(sum_vcol_marker, 0)])
+ +- HashAggregate(isMerge=[true], groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, Final_SUM(sum$0) AS sum_vcol_marker])
+ +- Exchange(distribution=[hash[a, c, d, e, f, g]])
+ +- LocalHashAggregate(groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, Partial_SUM(vcol_marker) AS sum$0])
+ +- Union(all=[true], union=[a, c, d, e, f, g, vcol_marker])
+ :- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, 1:BIGINT AS vcol_marker])
+ : +- HashAggregate(isMerge=[true], groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ : +- Exchange(distribution=[hash[a, b, c, d, e]])
+ : +- LocalHashAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ : +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+ +- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, -1:BIGINT AS vcol_marker])
+ +- HashAggregate(isMerge=[true], groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ +- Exchange(distribution=[hash[a, b, c, d, e]])
+ +- LocalHashAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+]]>
+
+
+
+
+
+
+
+
+
+
+ (vcol_left_cnt, vcol_right_cnt), vcol_right_cnt, vcol_left_cnt) AS $f0, a, c, d, e, f, g], where=[AND(>=(vcol_left_cnt, 1), >=(vcol_right_cnt, 1))])
+ +- GroupAggregate(groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, COUNT_RETRACT(vcol_left_marker) AS vcol_left_cnt, COUNT_RETRACT(vcol_right_marker) AS vcol_right_cnt])
+ +- Exchange(distribution=[hash[a, c, d, e, f, g]])
+ +- Union(all=[true], union=[a, c, d, e, f, g, vcol_left_marker, vcol_right_marker])
+ :- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, true AS vcol_left_marker, null:BOOLEAN AS vcol_right_marker])
+ : +- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ : +- Exchange(distribution=[hash[a, b, c, d, e]])
+ : +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+ +- Calc(select=[a, c, d, e, CAST(456:BIGINT) AS f, CAST(789) AS g, null:BOOLEAN AS vcol_left_marker, true AS vcol_right_marker])
+ +- GroupAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ +- Exchange(distribution=[hash[a, b, c, d, e]])
+ +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+]]>
+
+
+
+
+
+
+
+
+
+
+ (vcol_left_cnt, vcol_right_cnt), vcol_right_cnt, vcol_left_cnt) AS $f0, a, c, d, e, f, g], where=[AND(>=(vcol_left_cnt, 1), >=(vcol_right_cnt, 1))])
+ +- HashAggregate(isMerge=[true], groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, Final_COUNT(count$0) AS vcol_left_cnt, Final_COUNT(count$1) AS vcol_right_cnt])
+ +- Exchange(distribution=[hash[a, c, d, e, f, g]])
+ +- LocalHashAggregate(groupBy=[a, c, d, e, f, g], select=[a, c, d, e, f, g, Partial_COUNT(vcol_left_marker) AS count$0, Partial_COUNT(vcol_right_marker) AS count$1])
+ +- Union(all=[true], union=[a, c, d, e, f, g, vcol_left_marker, vcol_right_marker])
+ :- Calc(select=[a, c, d, e, CAST(123:BIGINT) AS f, CAST(456) AS g, true AS vcol_left_marker, null:BOOLEAN AS vcol_right_marker])
+ : +- HashAggregate(isMerge=[true], groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ : +- Exchange(distribution=[hash[a, b, c, d, e]])
+ : +- LocalHashAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ : +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+ +- Calc(select=[a, c, d, e, CAST(456:BIGINT) AS f, CAST(789) AS g, null:BOOLEAN AS vcol_left_marker, true AS vcol_right_marker])
+ +- HashAggregate(isMerge=[true], groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ +- Exchange(distribution=[hash[a, b, c, d, e]])
+ +- LocalHashAggregate(groupBy=[a, b, c, d, e], select=[a, b, c, d, e])
+ +- LegacyTableSourceScan(table=[[default_catalog, default_database, MyTable, source: [TestTableSource(a, b, c, d, e)]]], fields=[a, b, c, d, e])
+]]>
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/common/PartialInsertTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/common/PartialInsertTest.scala
index fb6eac13e2cd0..b16ce64919cc0 100644
--- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/common/PartialInsertTest.scala
+++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/common/PartialInsertTest.scala
@@ -73,6 +73,49 @@ class PartialInsertTest(isBatch: Boolean) extends TableTestBase {
util.verifyRelPlanInsert("INSERT INTO partitioned_sink (e,a,g,f,c,d) " +
"SELECT e,a,456,123,c,d FROM MyTable GROUP BY a,b,c,d,e")
}
+
+ @Test
+ def testPartialInsertWithUnion(): Unit = {
+ testPartialInsertWithSetOperator("UNION")
+ }
+
+ @Test
+ def testPartialInsertWithUnionAll(): Unit = {
+ testPartialInsertWithSetOperator("UNION ALL")
+ }
+
+ @Test
+ def testPartialInsertWithIntersectAll(): Unit = {
+ testPartialInsertWithSetOperator("INTERSECT ALL")
+ }
+
+ @Test
+ def testPartialInsertWithExceptAll(): Unit = {
+ testPartialInsertWithSetOperator("EXCEPT ALL")
+ }
+
+ private def testPartialInsertWithSetOperator(operator: String): Unit = {
+ util.verifyRelPlanInsert("INSERT INTO partitioned_sink (e,a,g,f,c,d) " +
+ "SELECT e,a,456,123,c,d FROM MyTable GROUP BY a,b,c,d,e " +
+ operator + " " +
+ "SELECT e,a,789,456,c,d FROM MyTable GROUP BY a,b,c,d,e ")
+ }
+
+ @Test
+ def testPartialInsertWithUnionAllNested(): Unit = {
+ util.verifyRelPlanInsert("INSERT INTO partitioned_sink (e,a,g,f,c,d) " +
+ "SELECT e,a,456,123,c,d FROM MyTable GROUP BY a,b,c,d,e " +
+ "UNION ALL " +
+ "SELECT e,a,789,456,c,d FROM MyTable GROUP BY a,b,c,d,e " +
+ "UNION ALL " +
+ "SELECT e,a,123,456,c,d FROM MyTable GROUP BY a,b,c,d,e ")
+ }
+
+ @Test
+ def testPartialInsertWithOrderBy(): Unit = {
+ util.verifyRelPlanInsert("INSERT INTO partitioned_sink (e,a,g,f,c,d) " +
+ "SELECT e,a,456,123,c,d FROM MyTable ORDER BY a,e,c,d")
+ }
}
object PartialInsertTest {