From 7241556d8b550e22eed2341287812ea373dc1cb2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 22 Jan 2018 15:21:09 -0800 Subject: [PATCH] [SPARK-22389][SQL] data source v2 partitioning reporting interface ## What changes were proposed in this pull request? a new interface which allows data source to report partitioning and avoid shuffle at Spark side. The design is pretty like the internal distribution/partitioing framework. Spark defines a `Distribution` interfaces and several concrete implementations, and ask the data source to report a `Partitioning`, the `Partitioning` should tell Spark if it can satisfy a `Distribution` or not. ## How was this patch tested? new test Author: Wenchen Fan Closes #20201 from cloud-fan/partition-reporting. (cherry picked from commit 51eb750263dd710434ddb60311571fa3dcec66eb) Signed-off-by: gatorsmile --- .../plans/physical/partitioning.scala | 2 +- .../v2/reader/ClusteredDistribution.java | 38 ++++++ .../sql/sources/v2/reader/Distribution.java | 39 +++++++ .../sql/sources/v2/reader/Partitioning.java | 46 ++++++++ .../v2/reader/SupportsReportPartitioning.java | 33 ++++++ .../v2/DataSourcePartitioning.scala | 56 +++++++++ .../datasources/v2/DataSourceV2ScanExec.scala | 9 ++ .../v2/JavaPartitionAwareDataSource.java | 110 ++++++++++++++++++ .../sql/sources/v2/DataSourceV2Suite.scala | 79 +++++++++++++ 9 files changed, 411 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 0189bd73c56bf..4d9a9925fe3ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -153,7 +153,7 @@ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution { * 1. number of partitions. * 2. if it can satisfy a given distribution. */ -sealed trait Partitioning { +trait Partitioning { /** Returns the number of partitions that the data is split across */ val numPartitions: Int diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java new file mode 100644 index 0000000000000..7346500de45b6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A concrete implementation of {@link Distribution}. Represents a distribution where records that + * share the same values for the {@link #clusteredColumns} will be produced by the same + * {@link ReadTask}. + */ +@InterfaceStability.Evolving +public class ClusteredDistribution implements Distribution { + + /** + * The names of the clustered columns. Note that they are order insensitive. + */ + public final String[] clusteredColumns; + + public ClusteredDistribution(String[] clusteredColumns) { + this.clusteredColumns = clusteredColumns; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java new file mode 100644 index 0000000000000..a6201a222f541 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent data distribution requirement, which specifies how the records should + * be distributed among the {@link ReadTask}s that are returned by + * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with + * the data ordering inside one partition(the output records of a single {@link ReadTask}). + * + * The instance of this interface is created and provided by Spark, then consumed by + * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to + * implement this interface, but need to catch as more concrete implementations of this interface + * as possible in {@link Partitioning#satisfy(Distribution)}. + * + * Concrete implementations until now: + *
    + *
  • {@link ClusteredDistribution}
  • + *
+ */ +@InterfaceStability.Evolving +public interface Distribution {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java new file mode 100644 index 0000000000000..199e45d4a02ab --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * An interface to represent the output data partitioning for a data source, which is returned by + * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a + * snapshot. Once created, it should be deterministic and always report the same number of + * partitions and the same "satisfy" result for a certain distribution. + */ +@InterfaceStability.Evolving +public interface Partitioning { + + /** + * Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs. + */ + int numPartitions(); + + /** + * Returns true if this partitioning can satisfy the given distribution, which means Spark does + * not need to shuffle the output data of this data source for some certain operations. + * + * Note that, Spark may add new concrete implementations of {@link Distribution} in new releases. + * This method should be aware of it and always return false for unrecognized distributions. It's + * recommended to check every Spark new release and support new distributions if possible, to + * avoid shuffle at Spark side for more cases. + */ + boolean satisfy(Distribution distribution); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java new file mode 100644 index 0000000000000..f786472ccf345 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.reader; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this + * interface to report data partitioning and try to avoid shuffle at Spark side. + */ +@InterfaceStability.Evolving +public interface SupportsReportPartitioning { + + /** + * Returns the output data partitioning that this reader guarantees. + */ + Partitioning outputPartitioning(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala new file mode 100644 index 0000000000000..943d0100aca56 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} +import org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning} + +/** + * An adapter from public data source partitioning to catalyst internal `Partitioning`. + */ +class DataSourcePartitioning( + partitioning: Partitioning, + colNames: AttributeMap[String]) extends physical.Partitioning { + + override val numPartitions: Int = partitioning.numPartitions() + + override def satisfies(required: physical.Distribution): Boolean = { + super.satisfies(required) || { + required match { + case d: physical.ClusteredDistribution if isCandidate(d.clustering) => + val attrs = d.clustering.map(_.asInstanceOf[Attribute]) + partitioning.satisfy( + new ClusteredDistribution(attrs.map { a => + val name = colNames.get(a) + assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output") + name.get + }.toArray)) + + case _ => false + } + } + } + + private def isCandidate(clustering: Seq[Expression]): Boolean = { + clustering.forall { + case a: Attribute => colNames.contains(a) + case _ => false + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index beb66738732be..69d871df3e1dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader._ @@ -42,6 +43,14 @@ case class DataSourceV2ScanExec( override def producedAttributes: AttributeSet = AttributeSet(fullOutput) + override def outputPartitioning: physical.Partitioning = reader match { + case s: SupportsReportPartitioning => + new DataSourcePartitioning( + s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name))) + + case _ => super.outputPartitioning + } + private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match { case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks() case _ => diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java new file mode 100644 index 0000000000000..806d0bcd93f18 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.sources.v2.DataSourceV2; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.ReadSupport; +import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.types.StructType; + +public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport { + + class Reader implements DataSourceV2Reader, SupportsReportPartitioning { + private final StructType schema = new StructType().add("a", "int").add("b", "int"); + + @Override + public StructType readSchema() { + return schema; + } + + @Override + public List> createReadTasks() { + return java.util.Arrays.asList( + new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}), + new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2})); + } + + @Override + public Partitioning outputPartitioning() { + return new MyPartitioning(); + } + } + + static class MyPartitioning implements Partitioning { + + @Override + public int numPartitions() { + return 2; + } + + @Override + public boolean satisfy(Distribution distribution) { + if (distribution instanceof ClusteredDistribution) { + String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns; + return Arrays.asList(clusteredCols).contains("a"); + } + + return false; + } + } + + static class SpecificReadTask implements ReadTask, DataReader { + private int[] i; + private int[] j; + private int current = -1; + + SpecificReadTask(int[] i, int[] j) { + assert i.length == j.length; + this.i = i; + this.j = j; + } + + @Override + public boolean next() throws IOException { + current += 1; + return current < i.length; + } + + @Override + public Row get() { + return new GenericRow(new Object[] {i[current], j[current]}); + } + + @Override + public void close() throws IOException { + + } + + @Override + public DataReader createDataReader() { + return this; + } + } + + @Override + public DataSourceV2Reader createReader(DataSourceV2Options options) { + return new Reader(); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 0ca29524c6d05..0620693b35d16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -24,6 +24,7 @@ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.sources.{Filter, GreaterThan} import org.apache.spark.sql.sources.v2.reader._ @@ -95,6 +96,40 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test("partitioning reporting") { + import org.apache.spark.sql.functions.{count, sum} + Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2))) + + val groupByColA = df.groupBy('a).agg(sum('b)) + checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4))) + assert(groupByColA.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColAB = df.groupBy('a, 'b).agg(count("*")) + checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2))) + assert(groupByColAB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isEmpty) + + val groupByColB = df.groupBy('b).agg(sum('a)) + checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5))) + assert(groupByColB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + + val groupByAPlusB = df.groupBy('a + 'b).agg(count("*")) + checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1))) + assert(groupByAPlusB.queryExecution.executedPlan.collectFirst { + case e: ShuffleExchangeExec => e + }.isDefined) + } + } + } + test("simple writable data source") { // TODO: java implementation. Seq(classOf[SimpleWritableDataSource]).foreach { cls => @@ -365,3 +400,47 @@ class BatchReadTask(start: Int, end: Int) override def close(): Unit = batch.close() } + +class PartitionAwareDataSource extends DataSourceV2 with ReadSupport { + + class Reader extends DataSourceV2Reader with SupportsReportPartitioning { + override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int") + + override def createReadTasks(): JList[ReadTask[Row]] = { + // Note that we don't have same value of column `a` across partitions. + java.util.Arrays.asList( + new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)), + new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2))) + } + + override def outputPartitioning(): Partitioning = new MyPartitioning + } + + class MyPartitioning extends Partitioning { + override def numPartitions(): Int = 2 + + override def satisfy(distribution: Distribution): Boolean = distribution match { + case c: ClusteredDistribution => c.clusteredColumns.contains("a") + case _ => false + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader +} + +class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] { + assert(i.length == j.length) + + private var current = -1 + + override def createDataReader(): DataReader[Row] = this + + override def next(): Boolean = { + current += 1 + current < i.length + } + + override def get(): Row = Row(i(current), j(current)) + + override def close(): Unit = {} +}