From f9c913263219f5e8a375542994142645dd0f6c6a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 6 Feb 2018 12:43:45 -0800 Subject: [PATCH] [SPARK-23315][SQL] failed to get output from canonicalized data source v2 related plans ## What changes were proposed in this pull request? `DataSourceV2Relation` keeps a `fullOutput` and resolves the real output on demand by column name lookup. i.e. ``` lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => fullOutput.find(_.name == name).get } ``` This will be broken after we canonicalize the plan, because all attribute names become "None", see https://github.com/apache/spark/blob/v2.3.0-rc1/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Canonicalize.scala#L42 To fix this, `DataSourceV2Relation` should just keep `output`, and update the `output` when doing column pruning. ## How was this patch tested? a new test case Author: Wenchen Fan Closes #20485 from cloud-fan/canonicalize. (cherry picked from commit b96a083b1c6ff0d2c588be9499b456e1adce97dc) Signed-off-by: gatorsmile --- .../v2/DataSourceReaderHolder.scala | 12 +++----- .../datasources/v2/DataSourceV2Relation.scala | 8 ++--- .../datasources/v2/DataSourceV2ScanExec.scala | 4 +-- .../v2/PushDownOperatorsToDataSource.scala | 29 +++++++++++++------ .../sql/sources/v2/DataSourceV2Suite.scala | 20 ++++++++++++- 5 files changed, 48 insertions(+), 25 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala index 6460c97abe344..81219e9771bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.Objects -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.sources.v2.reader._ /** @@ -28,9 +28,9 @@ import org.apache.spark.sql.sources.v2.reader._ trait DataSourceReaderHolder { /** - * The full output of the data source reader, without column pruning. + * The output of the data source reader, w.r.t. column pruning. */ - def fullOutput: Seq[AttributeReference] + def output: Seq[Attribute] /** * The held data source reader. @@ -46,7 +46,7 @@ trait DataSourceReaderHolder { case s: SupportsPushDownFilters => s.pushedFilters().toSet case _ => Nil } - Seq(fullOutput, reader.getClass, reader.readSchema(), filters) + Seq(output, reader.getClass, filters) } def canEqual(other: Any): Boolean @@ -61,8 +61,4 @@ trait DataSourceReaderHolder { override def hashCode(): Int = { metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) } - - lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => - fullOutput.find(_.name == name).get - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index eebfa29f91b99..38f6b15224788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( - fullOutput: Seq[AttributeReference], + output: Seq[AttributeReference], reader: DataSourceReader) extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder { @@ -37,7 +37,7 @@ case class DataSourceV2Relation( } override def newInstance(): DataSourceV2Relation = { - copy(fullOutput = fullOutput.map(_.newInstance())) + copy(output = output.map(_.newInstance())) } } @@ -46,8 +46,8 @@ case class DataSourceV2Relation( * to the non-streaming relation. */ class StreamingDataSourceV2Relation( - fullOutput: Seq[AttributeReference], - reader: DataSourceReader) extends DataSourceV2Relation(fullOutput, reader) { + output: Seq[AttributeReference], + reader: DataSourceReader) extends DataSourceV2Relation(output, reader) { override def isStreaming: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index df469af2c262a..7d9581be4db89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -35,14 +35,12 @@ import org.apache.spark.sql.types.StructType * Physical plan node for scanning data from a data source. */ case class DataSourceV2ScanExec( - fullOutput: Seq[AttributeReference], + output: Seq[AttributeReference], @transient reader: DataSourceReader) extends LeafExecNode with DataSourceReaderHolder with ColumnarBatchScan { override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] - override def producedAttributes: AttributeSet = AttributeSet(fullOutput) - override def outputPartitioning: physical.Partitioning = reader match { case s: SupportsReportPartitioning => new DataSourcePartitioning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala index 566a48394f02e..1ca6cbf061b4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -81,33 +81,44 @@ object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHel // TODO: add more push down rules. - pushDownRequiredColumns(filterPushed, filterPushed.outputSet) + val columnPruned = pushDownRequiredColumns(filterPushed, filterPushed.outputSet) // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. - RemoveRedundantProject(filterPushed) + RemoveRedundantProject(columnPruned) } // TODO: nested fields pruning - private def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: AttributeSet): Unit = { + private def pushDownRequiredColumns( + plan: LogicalPlan, requiredByParent: AttributeSet): LogicalPlan = { plan match { - case Project(projectList, child) => + case p @ Project(projectList, child) => val required = projectList.flatMap(_.references) - pushDownRequiredColumns(child, AttributeSet(required)) + p.copy(child = pushDownRequiredColumns(child, AttributeSet(required))) - case Filter(condition, child) => + case f @ Filter(condition, child) => val required = requiredByParent ++ condition.references - pushDownRequiredColumns(child, required) + f.copy(child = pushDownRequiredColumns(child, required)) case relation: DataSourceV2Relation => relation.reader match { case reader: SupportsPushDownRequiredColumns => + // TODO: Enable the below assert after we make `DataSourceV2Relation` immutable. Fow now + // it's possible that the mutable reader being updated by someone else, and we need to + // always call `reader.pruneColumns` here to correct it. + // assert(relation.output.toStructType == reader.readSchema(), + // "Schema of data source reader does not match the relation plan.") + val requiredColumns = relation.output.filter(requiredByParent.contains) reader.pruneColumns(requiredColumns.toStructType) - case _ => + val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap + val newOutput = reader.readSchema().map(_.name).map(nameToAttr) + relation.copy(output = newOutput) + + case _ => relation } // TODO: there may be more operators that can be used to calculate the required columns. We // can add more and more in the future. - case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.outputSet)) + case _ => plan.mapChildren(c => pushDownRequiredColumns(c, c.outputSet)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index e0e034d734f27..6ad0e5f79bc40 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,7 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanExec +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2ScanExec} import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ @@ -297,6 +297,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val reader4 = getReader(q4) assert(reader4.requiredSchema.fieldNames === Seq("i")) } + + test("SPARK-23315: get output from canonicalized data source v2 related plans") { + def checkCanonicalizedOutput(df: DataFrame, numOutput: Int): Unit = { + val logical = df.queryExecution.optimizedPlan.collect { + case d: DataSourceV2Relation => d + }.head + assert(logical.canonicalized.output.length == numOutput) + + val physical = df.queryExecution.executedPlan.collect { + case d: DataSourceV2ScanExec => d + }.head + assert(physical.canonicalized.output.length == numOutput) + } + + val df = spark.read.format(classOf[AdvancedDataSourceV2].getName).load() + checkCanonicalizedOutput(df, 2) + checkCanonicalizedOutput(df.select('i), 1) + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport {