Skip to content

Commit

Permalink
[SPARK-23203][SQL] DataSourceV2: Use immutable logical plans.
Browse files Browse the repository at this point in the history
SPARK-23203: DataSourceV2 should use immutable catalyst trees instead of wrapping a mutable DataSourceV2Reader. This commit updates DataSourceV2Relation and consolidates much of the DataSourceV2 API requirements for the read path in it. Instead of wrapping a reader that changes, the relation lazily produces a reader from its configuration.

This commit also updates the predicate and projection push-down. Instead of the implementation from SPARK-22197, this reuses the rule matching from the Hive and DataSource read paths (using `PhysicalOperation`) and copies most of the implementation of `SparkPlanner.pruneFilterProject`, with updates for DataSourceV2. By reusing the implementation from other read paths, this should have fewer regressions from other read paths and is less code to maintain.

The new push-down rules also supports the following edge cases:

* The output of DataSourceV2Relation should be what is returned by the reader, in case the reader can only partially satisfy the requested schema projection
* The requested projection passed to the DataSourceV2Reader should include filter columns
* The push-down rule may be run more than once if filters are not pushed through projections

Existing push-down and read tests.

Author: Ryan Blue <blue@apache.org>

Closes apache#20387 from rdblue/SPARK-22386-push-down-immutable-trees.

(cherry picked from commit aadf953)

Conflicts:
	external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSourceSuite.scala
  • Loading branch information
rdblue authored and jzhuge committed Aug 27, 2019
1 parent 2c8f258 commit 7e3f6a4
Show file tree
Hide file tree
Showing 10 changed files with 281 additions and 198 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,9 @@

package org.apache.spark.sql.kafka010

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import org.scalatest.time.SpanSugar._
import scala.collection.mutable
import scala.util.Random

import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Dataset, ForeachWriter, Row}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.{StreamTest, Trigger}
import org.apache.spark.sql.test.{SharedSQLContext, TestSparkSession}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.streaming.Trigger

// Run tests in KafkaSourceSuiteBase in continuous execution mode.
class KafkaContinuousSourceSuite extends KafkaSourceSuiteBase with KafkaContinuousTest
Expand Down Expand Up @@ -71,7 +60,7 @@ class KafkaContinuousSourceTopicDeletionSuite extends KafkaContinuousTest {
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r
}.exists { r =>
// Ensure the new topic is present and the old topic is gone.
r.knownPartitions.exists(_.topic == topic2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger

import org.apache.spark.SparkContext
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd, SparkListenerTaskStart}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming.StreamExecution
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.streaming.Trigger
Expand All @@ -47,7 +47,7 @@ trait KafkaContinuousTest extends KafkaSourceTest {
eventually(timeout(streamingTimeout)) {
assert(
query.lastExecution.logical.collectFirst {
case DataSourceV2Relation(_, r: KafkaContinuousReader) => r
case StreamingDataSourceV2Relation(_, r: KafkaContinuousReader) => r
}.exists(_.knownPartitions.size == newCount),
s"query never reconfigured to $newCount partitions")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.scalatest.time.SpanSugar._

import org.apache.spark.SparkContext
import org.apache.spark.sql.{Dataset, ForeachWriter}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
import org.apache.spark.sql.functions.{count, window}
Expand Down Expand Up @@ -117,9 +117,10 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext {
} ++ (query.get.lastExecution match {
case null => Seq()
case e => e.logical.collect {
case DataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
case StreamingDataSourceV2Relation(_, reader: KafkaContinuousReader) => reader
}
})

if (sources.isEmpty) {
throw new Exception(
"Could not find Kafka source in the StreamExecution logical plan to add data to")
Expand Down
41 changes: 9 additions & 32 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.datasources.jdbc._
import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils
import org.apache.spark.sql.sources.v2._
import org.apache.spark.sql.sources.v2.{DataSourceV2, ReadSupport, ReadSupportWithSchema}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -189,39 +189,16 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {

val cls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf)
if (classOf[DataSourceV2].isAssignableFrom(cls)) {
val ds = cls.newInstance()
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
ds = ds.asInstanceOf[DataSourceV2],
conf = sparkSession.sessionState.conf)
val options = new DataSourceOptions((sessionOptions ++ extraOptions).asJava)

// Streaming also uses the data source V2 API. So it may be that the data source implements
// v2, but has no v2 implementation for batch reads. In that case, we fall back to loading
// the dataframe as a v1 source.
val reader = (ds, userSpecifiedSchema) match {
case (ds: ReadSupportWithSchema, Some(schema)) =>
ds.createReader(schema, options)

case (ds: ReadSupport, None) =>
ds.createReader(options)

case (ds: ReadSupportWithSchema, None) =>
throw new AnalysisException(s"A schema needs to be specified when using $ds.")

case (ds: ReadSupport, Some(schema)) =>
val reader = ds.createReader(options)
if (reader.readSchema() != schema) {
throw new AnalysisException(s"$ds does not allow user-specified schemas.")
}
reader

case _ => null // fall back to v1
}
val ds = cls.newInstance().asInstanceOf[DataSourceV2]
if (ds.isInstanceOf[ReadSupport] || ds.isInstanceOf[ReadSupportWithSchema]) {
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
ds = ds, conf = sparkSession.sessionState.conf)
Dataset.ofRows(sparkSession, DataSourceV2Relation.create(
ds, sessionOptions ++ extraOptions.toMap,
userSpecifiedSchema = userSpecifiedSchema))

if (reader == null) {
loadV1Source(paths: _*)
} else {
Dataset.ofRows(sparkSession, DataSourceV2Relation(reader))
loadV1Source(paths: _*)
}
} else {
loadV1Source(paths: _*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,80 @@

package org.apache.spark.sql.execution.datasources.v2

import scala.collection.JavaConverters._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.sources.v2.reader._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources.{DataSourceRegister, Filter}
import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, ReadSupport, ReadSupportWithSchema}
import org.apache.spark.sql.sources.v2.reader.{DataSourceReader, SupportsPushDownCatalystFilters, SupportsPushDownFilters, SupportsPushDownRequiredColumns, SupportsReportStatistics}
import org.apache.spark.sql.types.StructType

case class DataSourceV2Relation(
output: Seq[AttributeReference],
reader: DataSourceReader)
extends LeafNode with MultiInstanceRelation with DataSourceReaderHolder {
source: DataSourceV2,
options: Map[String, String],
projection: Seq[AttributeReference],
filters: Option[Seq[Expression]] = None,
userSpecifiedSchema: Option[StructType] = None) extends LeafNode with MultiInstanceRelation {

import DataSourceV2Relation._

override def simpleString: String = {
s"DataSourceV2Relation(source=${source.name}, " +
s"schema=[${output.map(a => s"$a ${a.dataType.simpleString}").mkString(", ")}], " +
s"filters=[${pushedFilters.mkString(", ")}], options=$options)"
}

override lazy val schema: StructType = reader.readSchema()

override lazy val output: Seq[AttributeReference] = {
// use the projection attributes to avoid assigning new ids. fields that are not projected
// will be assigned new ids, which is okay because they are not projected.
val attrMap = projection.map(a => a.name -> a).toMap
schema.map(f => attrMap.getOrElse(f.name,
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()))
}

private lazy val v2Options: DataSourceOptions = makeV2Options(options)

lazy val (
reader: DataSourceReader,
unsupportedFilters: Seq[Expression],
pushedFilters: Seq[Expression]) = {
val newReader = userSpecifiedSchema match {
case Some(s) =>
source.asReadSupportWithSchema.createReader(s, v2Options)
case _ =>
source.asReadSupport.createReader(v2Options)
}

DataSourceV2Relation.pushRequiredColumns(newReader, projection.toStructType)

override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation]
val (remainingFilters, pushedFilters) = filters match {
case Some(filterSeq) =>
DataSourceV2Relation.pushFilters(newReader, filterSeq)
case _ =>
(Nil, Nil)
}

(newReader, remainingFilters, pushedFilters)
}

override def doCanonicalize(): LogicalPlan = {
val c = super.doCanonicalize().asInstanceOf[DataSourceV2Relation]

// override output with canonicalized output to avoid attempting to configure a reader
val canonicalOutput: Seq[AttributeReference] = this.output
.map(a => QueryPlan.normalizeExprId(a, projection))

new DataSourceV2Relation(c.source, c.options, c.projection) {
override lazy val output: Seq[AttributeReference] = canonicalOutput
}
}

override def computeStats(): Statistics = reader match {
case r: SupportsReportStatistics =>
Expand All @@ -37,22 +100,147 @@ case class DataSourceV2Relation(
}

override def newInstance(): DataSourceV2Relation = {
copy(output = output.map(_.newInstance()))
// projection is used to maintain id assignment.
// if projection is not set, use output so the copy is not equal to the original
copy(projection = projection.map(_.newInstance()))
}
}

/**
* A specialization of DataSourceV2Relation with the streaming bit set to true. Otherwise identical
* to the non-streaming relation.
*/
class StreamingDataSourceV2Relation(
case class StreamingDataSourceV2Relation(
output: Seq[AttributeReference],
reader: DataSourceReader) extends DataSourceV2Relation(output, reader) {
reader: DataSourceReader)
extends LeafNode with DataSourceReaderHolder with MultiInstanceRelation {
override def isStreaming: Boolean = true

override def canEqual(other: Any): Boolean = other.isInstanceOf[StreamingDataSourceV2Relation]

override def newInstance(): LogicalPlan = copy(output = output.map(_.newInstance()))

override def computeStats(): Statistics = reader match {
case r: SupportsReportStatistics =>
Statistics(sizeInBytes = r.getStatistics.sizeInBytes().orElse(conf.defaultSizeInBytes))
case _ =>
Statistics(sizeInBytes = conf.defaultSizeInBytes)
}
}

object DataSourceV2Relation {
def apply(reader: DataSourceReader): DataSourceV2Relation = {
new DataSourceV2Relation(reader.readSchema().toAttributes, reader)
private implicit class SourceHelpers(source: DataSourceV2) {
def asReadSupport: ReadSupport = {
source match {
case support: ReadSupport =>
support
case _: ReadSupportWithSchema =>
// this method is only called if there is no user-supplied schema. if there is no
// user-supplied schema and ReadSupport was not implemented, throw a helpful exception.
throw new AnalysisException(s"Data source requires a user-supplied schema: $name")
case _ =>
throw new AnalysisException(s"Data source is not readable: $name")
}
}

def asReadSupportWithSchema: ReadSupportWithSchema = {
source match {
case support: ReadSupportWithSchema =>
support
case _: ReadSupport =>
throw new AnalysisException(
s"Data source does not support user-supplied schema: $name")
case _ =>
throw new AnalysisException(s"Data source is not readable: $name")
}
}

def name: String = {
source match {
case registered: DataSourceRegister =>
registered.shortName()
case _ =>
source.getClass.getSimpleName
}
}
}

private def makeV2Options(options: Map[String, String]): DataSourceOptions = {
new DataSourceOptions(options.asJava)
}

private def schema(
source: DataSourceV2,
v2Options: DataSourceOptions,
userSchema: Option[StructType]): StructType = {
val reader = userSchema match {
// TODO: remove this case because it is confusing for users
case Some(s) if !source.isInstanceOf[ReadSupportWithSchema] =>
val reader = source.asReadSupport.createReader(v2Options)
if (reader.readSchema() != s) {
throw new AnalysisException(s"${source.name} does not allow user-specified schemas.")
}
reader
case Some(s) =>
source.asReadSupportWithSchema.createReader(s, v2Options)
case _ =>
source.asReadSupport.createReader(v2Options)
}
reader.readSchema()
}

def create(
source: DataSourceV2,
options: Map[String, String],
filters: Option[Seq[Expression]] = None,
userSpecifiedSchema: Option[StructType] = None): DataSourceV2Relation = {
val projection = schema(source, makeV2Options(options), userSpecifiedSchema).toAttributes
DataSourceV2Relation(source, options, projection, filters,
// if the source does not implement ReadSupportWithSchema, then the userSpecifiedSchema must
// be equal to the reader's schema. the schema method enforces this. because the user schema
// and the reader's schema are identical, drop the user schema.
if (source.isInstanceOf[ReadSupportWithSchema]) userSpecifiedSchema else None)
}

private def pushRequiredColumns(reader: DataSourceReader, struct: StructType): Unit = {
reader match {
case projectionSupport: SupportsPushDownRequiredColumns =>
projectionSupport.pruneColumns(struct)
case _ =>
}
}

private def pushFilters(
reader: DataSourceReader,
filters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
reader match {
case catalystFilterSupport: SupportsPushDownCatalystFilters =>
(
catalystFilterSupport.pushCatalystFilters(filters.toArray),
catalystFilterSupport.pushedCatalystFilters()
)

case filterSupport: SupportsPushDownFilters =>
// A map from original Catalyst expressions to corresponding translated data source
// filters. If a predicate is not in this map, it means it cannot be pushed down.
val translatedMap: Map[Expression, Filter] = filters.flatMap { p =>
DataSourceStrategy.translateFilter(p).map(f => p -> f)
}.toMap

// Catalyst predicate expressions that cannot be converted to data source filters.
val nonConvertiblePredicates = filters.filterNot(translatedMap.contains)

// Data source filters that cannot be pushed down. An unhandled filter means
// the data source cannot guarantee the rows returned can pass the filter.
// As a result we must return it so Spark can plan an extra filter operator.
val unhandledFilters = filterSupport.pushFilters(translatedMap.values.toArray).toSet
val (unhandledPredicates, pushedPredicates) = translatedMap.partition { case (_, f) =>
unhandledFilters.contains(f)
}

(nonConvertiblePredicates ++ unhandledPredicates.keys, pushedPredicates.keys.toSeq)

case _ => (filters, Nil)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@ import org.apache.spark.sql.execution.SparkPlan

object DataSourceV2Strategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case DataSourceV2Relation(output, reader) =>
DataSourceV2ScanExec(output, reader) :: Nil
case relation: DataSourceV2Relation =>
DataSourceV2ScanExec(relation.output, relation.reader) :: Nil

case relation: StreamingDataSourceV2Relation =>
DataSourceV2ScanExec(relation.output, relation.reader) :: Nil

case WriteToDataSourceV2(writer, query) =>
WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil
Expand Down
Loading

0 comments on commit 7e3f6a4

Please sign in to comment.