Skip to content

Commit

Permalink
Move InsertInto rules into Analyzer to fix error messages.
Browse files Browse the repository at this point in the history
  • Loading branch information
rdblue committed Jul 24, 2019
1 parent 8449d66 commit 97dc04c
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.util.Random

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog, TableChange}
import org.apache.spark.sql.catalog.v2.expressions.{FieldReference, IdentityTransform}
import org.apache.spark.sql.catalog.v2.utils.CatalogV2Util.loadTable
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.encoders.OuterScopes
Expand All @@ -34,12 +36,14 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement}
import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, InsertIntoStatement}
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode
import org.apache.spark.sql.sources.v2.Table
import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -167,6 +171,7 @@ class Analyzer(
Batch("Resolution", fixedPoint,
ResolveTableValuedFunctions ::
ResolveAlterTable ::
ResolveInsertInto ::
ResolveTables ::
ResolveRelations ::
ResolveReferences ::
Expand Down Expand Up @@ -757,6 +762,136 @@ class Analyzer(
}
}

object ResolveInsertInto extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoStatement(
UnresolvedRelation(CatalogObjectIdentifier(Some(tableCatalog), ident)), _, _, _, _)
if i.query.resolved =>
loadTable(tableCatalog, ident)
.map(DataSourceV2Relation.create)
.map(relation => {
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw new AnalysisException(
s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}")
}

val partCols = partitionColumnNames(relation.table)
validatePartitionSpec(partCols, i.partitionSpec)

val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get)
val query = addStaticPartitionColumns(relation, i.query, staticPartitions)
val dynamicPartitionOverwrite = partCols.size > staticPartitions.size &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC

if (!i.overwrite) {
AppendData.byPosition(relation, query)
} else if (dynamicPartitionOverwrite) {
OverwritePartitionsDynamic.byPosition(relation, query)
} else {
OverwriteByExpression.byPosition(
relation, query, staticDeleteExpression(relation, staticPartitions))
}
})
.getOrElse(i)

case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _)
if i.query.resolved =>
InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists)
}

private def partitionColumnNames(table: Table): Seq[String] = {
// get partition column names. in v2, partition columns are columns that are stored using an
// identity partition transform because the partition values and the column values are
// identical. otherwise, partition values are produced by transforming one or more source
// columns and cannot be set directly in a query's PARTITION clause.
table.partitioning.flatMap {
case IdentityTransform(FieldReference(Seq(name))) => Some(name)
case _ => None
}
}

private def validatePartitionSpec(
partitionColumnNames: Seq[String],
partitionSpec: Map[String, Option[String]]): Unit = {
// check that each partition name is a partition column. otherwise, it is not valid
partitionSpec.keySet.foreach { partitionName =>
partitionColumnNames.find(name => conf.resolver(name, partitionName)) match {
case Some(_) =>
case None =>
throw new AnalysisException(
s"PARTITION clause cannot contain a non-partition column name: $partitionName")
}
}
}

private def addStaticPartitionColumns(
relation: DataSourceV2Relation,
query: LogicalPlan,
staticPartitions: Map[String, String]): LogicalPlan = {

if (staticPartitions.isEmpty) {
query

} else {
// add any static value as a literal column
val withStaticPartitionValues = {
// for each static name, find the column name it will replace and check for unknowns.
val outputNameToStaticName = staticPartitions.keySet.map(staticName =>
relation.output.find(col => conf.resolver(col.name, staticName)) match {
case Some(attr) =>
attr.name -> staticName
case _ =>
throw new AnalysisException(
s"Cannot add static value for unknown column: $staticName")
}).toMap

val queryColumns = query.output.iterator

// for each output column, add the static value as a literal, or use the next input
// column. this does not fail if input columns are exhausted and adds remaining columns
// at the end. both cases will be caught by ResolveOutputRelation and will fail the
// query with a helpful error message.
relation.output.flatMap { col =>
outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match {
case Some(staticValue) =>
Some(Alias(Cast(Literal(staticValue), col.dataType), col.name)())
case _ if queryColumns.hasNext =>
Some(queryColumns.next)
case _ =>
None
}
} ++ queryColumns
}

Project(withStaticPartitionValues, query)
}
}

private def staticDeleteExpression(
relation: DataSourceV2Relation,
staticPartitions: Map[String, String]): Expression = {
if (staticPartitions.isEmpty) {
Literal(true)
} else {
staticPartitions.map { case (name, value) =>
relation.output.find(col => conf.resolver(col.name, name)) match {
case Some(attr) =>
// the delete expression must reference the table's column names, but these attributes
// are not available when CheckAnalysis runs because the relation is not a child of
// the logical operation. instead, expressions are resolved after
// ResolveOutputRelation runs, using the query's column names that will match the
// table names at that point. because resolution happens after a future rule, create
// an UnresolvedAttribute.
EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType))
case None =>
throw new AnalysisException(s"Unknown static partition column: $name")
}
}.reduce(And)
}
}
}

/**
* Resolve ALTER TABLE statements that use a DSv2 catalog.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,39 +166,6 @@ case class DataSourceResolution(
case DataSourceV2Relation(CatalogTableAsV2(catalogTable), _, _) =>
UnresolvedCatalogRelation(catalogTable)

case i @ InsertIntoStatement(UnresolvedRelation(CatalogObjectIdentifier(Some(catalog), ident)),
_, _, _, _) if i.query.resolved =>
loadTable(catalog, ident)
.map(DataSourceV2Relation.create)
.map(relation => {
// ifPartitionNotExists is append with validation, but validation is not supported
if (i.ifPartitionNotExists) {
throw new AnalysisException(
s"Cannot write, IF NOT EXISTS is not supported for table: ${relation.table.name}")
}

val partCols = partitionColumnNames(relation.table)
validatePartitionSpec(partCols, i.partitionSpec)

val staticPartitions = i.partitionSpec.filter(_._2.isDefined).mapValues(_.get)
val query = addStaticPartitionColumns(relation, i.query, staticPartitions)
val dynamicPartitionOverwrite = partCols.size > staticPartitions.size &&
conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC

if (!i.overwrite) {
AppendData.byPosition(relation, query)
} else if (dynamicPartitionOverwrite) {
OverwritePartitionsDynamic.byPosition(relation, query)
} else {
OverwriteByExpression.byPosition(
relation, query, staticDeleteExpression(relation, staticPartitions))
}
})
.getOrElse(i)

case i @ InsertIntoStatement(UnresolvedRelation(AsTableIdentifier(_)), _, _, _, _)
if i.query.resolved =>
InsertIntoTable(i.table, i.partitionSpec, i.query, i.overwrite, i.ifPartitionNotExists)
}

object V1WriteProvider {
Expand Down Expand Up @@ -387,94 +354,4 @@ case class DataSourceResolution(
nullable = true,
builder.build())
}

private def partitionColumnNames(table: Table): Seq[String] = {
// get partition column names. in v2, partition columns are columns that are stored using an
// identity partition transform because the partition values and the column values are
// identical. otherwise, partition values are produced by transforming one or more source
// columns and cannot be set directly in a query's PARTITION clause.
table.partitioning.flatMap {
case IdentityTransform(FieldReference(Seq(name))) => Some(name)
case _ => None
}
}

private def validatePartitionSpec(
partitionColumnNames: Seq[String],
partitionSpec: Map[String, Option[String]]): Unit = {
// check that each partition name is a partition column. otherwise, it is not valid
partitionSpec.keySet.foreach { partitionName =>
partitionColumnNames.find(name => conf.resolver(name, partitionName)) match {
case Some(_) =>
case None =>
throw new AnalysisException(
s"PARTITION clause cannot contain a non-partition column name: $partitionName")
}
}
}

private def addStaticPartitionColumns(
relation: DataSourceV2Relation,
query: LogicalPlan,
staticPartitions: Map[String, String]): LogicalPlan = {

if (staticPartitions.isEmpty) {
query

} else {
// add any static value as a literal column
val withStaticPartitionValues = {
// for each static name, find the column name it will replace and check for unknowns.
val outputNameToStaticName = staticPartitions.keySet.map(staticName =>
relation.output.find(col => conf.resolver(col.name, staticName)) match {
case Some(attr) =>
attr.name -> staticName
case _ =>
throw new AnalysisException(
s"Cannot add static value for unknown column: $staticName")
}).toMap

val queryColumns = query.output.iterator

// for each output column, add the static value as a literal, or use the next input
// column. this does not fail if input columns are exhausted and adds remaining columns
// at the end. both cases will be caught by ResolveOutputRelation and will fail the
// query with a helpful error message.
relation.output.flatMap { col =>
outputNameToStaticName.get(col.name).flatMap(staticPartitions.get) match {
case Some(staticValue) =>
Some(Alias(Cast(Literal(staticValue), col.dataType), col.name)())
case _ if queryColumns.hasNext =>
Some(queryColumns.next)
case _ =>
None
}
} ++ queryColumns
}

Project(withStaticPartitionValues, query)
}
}

private def staticDeleteExpression(
relation: DataSourceV2Relation,
staticPartitions: Map[String, String]): Expression = {
if (staticPartitions.isEmpty) {
Literal(true)
} else {
staticPartitions.map { case (name, value) =>
relation.output.find(col => conf.resolver(col.name, name)) match {
case Some(attr) =>
// the delete expression must reference the table's column names, but these attributes
// are not available when CheckAnalysis runs because the relation is not a child of the
// logical operation. instead, expressions are resolved after ResolveOutputRelation
// runs, using the query's column names that will match the table names at that point.
// because resolution happens after a future rule, create an UnresolvedAttribute.
EqualTo(UnresolvedAttribute(attr.name), Cast(Literal(value), attr.dataType))
case None =>
throw new AnalysisException(s"Unknown static partition column: $name")
}
}.reduce(And)
}
}
}

0 comments on commit 97dc04c

Please sign in to comment.