diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e55cdfedd3234..021fb26bf751f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -25,6 +25,8 @@ import scala.util.Random import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog, TableChange} +import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform} +import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util.loadTable import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -34,12 +36,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, InsertIntoStatement} import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.sources.v2.Table import org.apache.spark.sql.types._ /** @@ -167,6 +171,7 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveTableValuedFunctions :: ResolveAlterTable :: + ResolveInsertInto :: ResolveTables :: ResolveRelations :: ResolveReferences :: @@ -757,6 +762,136 @@ class Analyzer( } } + object ResolveInsertInto extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case i @ InsertIntoStatement( + UnresolvedRelation(CatalogObjectIdentifier(Some(tableCatalog), ident)), _, _, _, _) + if i.query.resolved => + loadTable(tableCatalog, ident) + .map(DataSourceV2Relation.create) + .map(relation => { + // ifPartitionNotExists is append with validation, but validation is not supported + if (i.ifPartitionNotExists) { + throw new AnalysisException( + s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}") + } + + val partCols = partitionColumnNames(relation.table) + validatePartitionSpec(partCols, i.partitionSpec) + + val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) + val query = addStaticPartitionColumns(relation, i.query, staticPartitions) + val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && + conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + + if (!i.overwrite) { + AppendData.byPosition(relation, query) + } else if (dynamicPartitionOverwrite) { + OverwritePartitionsDynamic.byPosition(relation, query) + } else { + OverwriteByExpression.byPosition( + relation, query, staticDeleteExpression(relation, staticPartitions)) + } + }) + .getOrElse(i) + + case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _) + if i.query.resolved => + InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists) + } + + private def partitionColumnNames(table: Table): Seq[String] = { + // get partition column names. in v2, partition columns are columns that are stored using an + // identity partition transform because the partition values and the column values are + // identical. otherwise, partition values are produced by transforming one or more source + // columns and cannot be set directly in a query's PARTITION clause. + table.partitioning.flatMap { + case IdentityTransform(FieldReference(Seq(name))) => Some(name) + case _ => None + } + } + + private def validatePartitionSpec( + partitionColumnNames: Seq[String], + partitionSpec: Map[String, Option[String]]): Unit = { + // check that each partition name is a partition column. otherwise, it is not valid + partitionSpec.keySet.foreach { partitionName => + partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { + case Some(_) => + case None => + throw new AnalysisException( + s"PARTITION clause cannot contain a non-partition column name: $partitionName") + } + } + } + + private def addStaticPartitionColumns( + relation: DataSourceV2Relation, + query: LogicalPlan, + staticPartitions: Map[String, String]): LogicalPlan = { + + if (staticPartitions.isEmpty) { + query + + } else { + // add any static value as a literal column + val withStaticPartitionValues = { + // for each static name, find the column name it will replace and check for unknowns. + val outputNameToStaticName = staticPartitions.keySet.map(staticName => + relation.output.find(col => conf.resolver(col.name, staticName)) match { + case Some(attr) => + attr.name -> staticName + case _ => + throw new AnalysisException( + s"Cannot add static value for unknown column: $staticName") + }).toMap + + val queryColumns = query.output.iterator + + // for each output column, add the static value as a literal, or use the next input + // column. this does not fail if input columns are exhausted and adds remaining columns + // at the end. both cases will be caught by ResolveOutputRelation and will fail the + // query with a helpful error message. + relation.output.flatMap { col => + outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match { + case Some(staticValue) => + Some(Alias(Cast(Literal(staticValue), col.dataType), col.name)()) + case _ if queryColumns.hasNext => + Some(queryColumns.next) + case _ => + None + } + } ++ queryColumns + } + + Project(withStaticPartitionValues, query) + } + } + + private def staticDeleteExpression( + relation: DataSourceV2Relation, + staticPartitions: Map[String, String]): Expression = { + if (staticPartitions.isEmpty) { + Literal(true) + } else { + staticPartitions.map { case (name, value) => + relation.output.find(col => conf.resolver(col.name, name)) match { + case Some(attr) => + // the delete expression must reference the table's column names, but these attributes + // are not available when CheckAnalysis runs because the relation is not a child of + // the logical operation. instead, expressions are resolved after + // ResolveOutputRelation runs, using the query's column names that will match the + // table names at that point. because resolution happens after a future rule, create + // an UnresolvedAttribute. + EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType)) + case None => + throw new AnalysisException(s"Unknown static partition column: $name") + } + }.reduce(And) + } + } + } + /** * Resolve ALTER TABLE statements that use a DSv2 catalog. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 7293d5af26646..7c99d24dbdfd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -166,39 +166,6 @@ case class DataSourceResolution( case DataSourceV2Relation(CatalogTableAsV2(catalogTable), _, _) => UnresolvedCatalogRelation(catalogTable) - case i @ InsertIntoStatement(UnresolvedRelation(CatalogObjectIdentifier(Some(catalog), ident)), - _, _, _, _) if i.query.resolved => - loadTable(catalog, ident) - .map(DataSourceV2Relation.create) - .map(relation => { - // ifPartitionNotExists is append with validation, but validation is not supported - if (i.ifPartitionNotExists) { - throw new AnalysisException( - s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}") - } - - val partCols = partitionColumnNames(relation.table) - validatePartitionSpec(partCols, i.partitionSpec) - - val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get) - val query = addStaticPartitionColumns(relation, i.query, staticPartitions) - val dynamicPartitionOverwrite = partCols.size > staticPartitions.size && - conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC - - if (!i.overwrite) { - AppendData.byPosition(relation, query) - } else if (dynamicPartitionOverwrite) { - OverwritePartitionsDynamic.byPosition(relation, query) - } else { - OverwriteByExpression.byPosition( - relation, query, staticDeleteExpression(relation, staticPartitions)) - } - }) - .getOrElse(i) - - case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _) - if i.query.resolved => - InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists) } object V1WriteProvider { @@ -387,94 +354,4 @@ case class DataSourceResolution( nullable = true, builder.build()) } - - private def partitionColumnNames(table: Table): Seq[String] = { - // get partition column names. in v2, partition columns are columns that are stored using an - // identity partition transform because the partition values and the column values are - // identical. otherwise, partition values are produced by transforming one or more source - // columns and cannot be set directly in a query's PARTITION clause. - table.partitioning.flatMap { - case IdentityTransform(FieldReference(Seq(name))) => Some(name) - case _ => None - } - } - - private def validatePartitionSpec( - partitionColumnNames: Seq[String], - partitionSpec: Map[String, Option[String]]): Unit = { - // check that each partition name is a partition column. otherwise, it is not valid - partitionSpec.keySet.foreach { partitionName => - partitionColumnNames.find(name => conf.resolver(name, partitionName)) match { - case Some(_) => - case None => - throw new AnalysisException( - s"PARTITION clause cannot contain a non-partition column name: $partitionName") - } - } - } - - private def addStaticPartitionColumns( - relation: DataSourceV2Relation, - query: LogicalPlan, - staticPartitions: Map[String, String]): LogicalPlan = { - - if (staticPartitions.isEmpty) { - query - - } else { - // add any static value as a literal column - val withStaticPartitionValues = { - // for each static name, find the column name it will replace and check for unknowns. - val outputNameToStaticName = staticPartitions.keySet.map(staticName => - relation.output.find(col => conf.resolver(col.name, staticName)) match { - case Some(attr) => - attr.name -> staticName - case _ => - throw new AnalysisException( - s"Cannot add static value for unknown column: $staticName") - }).toMap - - val queryColumns = query.output.iterator - - // for each output column, add the static value as a literal, or use the next input - // column. this does not fail if input columns are exhausted and adds remaining columns - // at the end. both cases will be caught by ResolveOutputRelation and will fail the - // query with a helpful error message. - relation.output.flatMap { col => - outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match { - case Some(staticValue) => - Some(Alias(Cast(Literal(staticValue), col.dataType), col.name)()) - case _ if queryColumns.hasNext => - Some(queryColumns.next) - case _ => - None - } - } ++ queryColumns - } - - Project(withStaticPartitionValues, query) - } - } - - private def staticDeleteExpression( - relation: DataSourceV2Relation, - staticPartitions: Map[String, String]): Expression = { - if (staticPartitions.isEmpty) { - Literal(true) - } else { - staticPartitions.map { case (name, value) => - relation.output.find(col => conf.resolver(col.name, name)) match { - case Some(attr) => - // the delete expression must reference the table's column names, but these attributes - // are not available when CheckAnalysis runs because the relation is not a child of the - // logical operation. instead, expressions are resolved after ResolveOutputRelation - // runs, using the query's column names that will match the table names at that point. - // because resolution happens after a future rule, create an UnresolvedAttribute. - EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType)) - case None => - throw new AnalysisException(s"Unknown static partition column: $name") - } - }.reduce(And) - } - } }