diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 0189bd73c56bf..4d9a9925fe3ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -153,7 +153,7 @@ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution {
* 1. number of partitions.
* 2. if it can satisfy a given distribution.
*/
-sealed trait Partitioning {
+trait Partitioning {
/** Returns the number of partitions that the data is split across */
val numPartitions: Int
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java
new file mode 100644
index 0000000000000..7346500de45b6
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * A concrete implementation of {@link Distribution}. Represents a distribution where records that
+ * share the same values for the {@link #clusteredColumns} will be produced by the same
+ * {@link ReadTask}.
+ */
+@InterfaceStability.Evolving
+public class ClusteredDistribution implements Distribution {
+
+ /**
+ * The names of the clustered columns. Note that they are order insensitive.
+ */
+ public final String[] clusteredColumns;
+
+ public ClusteredDistribution(String[] clusteredColumns) {
+ this.clusteredColumns = clusteredColumns;
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java
new file mode 100644
index 0000000000000..a6201a222f541
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * An interface to represent data distribution requirement, which specifies how the records should
+ * be distributed among the {@link ReadTask}s that are returned by
+ * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has nothing to do with
+ * the data ordering inside one partition(the output records of a single {@link ReadTask}).
+ *
+ * The instance of this interface is created and provided by Spark, then consumed by
+ * {@link Partitioning#satisfy(Distribution)}. This means data source developers don't need to
+ * implement this interface, but need to catch as more concrete implementations of this interface
+ * as possible in {@link Partitioning#satisfy(Distribution)}.
+ *
+ * Concrete implementations until now:
+ *
+ * - {@link ClusteredDistribution}
+ *
+ */
+@InterfaceStability.Evolving
+public interface Distribution {}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java
new file mode 100644
index 0000000000000..199e45d4a02ab
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * An interface to represent the output data partitioning for a data source, which is returned by
+ * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this should work like a
+ * snapshot. Once created, it should be deterministic and always report the same number of
+ * partitions and the same "satisfy" result for a certain distribution.
+ */
+@InterfaceStability.Evolving
+public interface Partitioning {
+
+ /**
+ * Returns the number of partitions(i.e., {@link ReadTask}s) the data source outputs.
+ */
+ int numPartitions();
+
+ /**
+ * Returns true if this partitioning can satisfy the given distribution, which means Spark does
+ * not need to shuffle the output data of this data source for some certain operations.
+ *
+ * Note that, Spark may add new concrete implementations of {@link Distribution} in new releases.
+ * This method should be aware of it and always return false for unrecognized distributions. It's
+ * recommended to check every Spark new release and support new distributions if possible, to
+ * avoid shuffle at Spark side for more cases.
+ */
+ boolean satisfy(Distribution distribution);
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
new file mode 100644
index 0000000000000..f786472ccf345
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * A mix in interface for {@link DataSourceV2Reader}. Data source readers can implement this
+ * interface to report data partitioning and try to avoid shuffle at Spark side.
+ */
+@InterfaceStability.Evolving
+public interface SupportsReportPartitioning {
+
+ /**
+ * Returns the output data partitioning that this reader guarantees.
+ */
+ Partitioning outputPartitioning();
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala
new file mode 100644
index 0000000000000..943d0100aca56
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression}
+import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, Partitioning}
+
+/**
+ * An adapter from public data source partitioning to catalyst internal `Partitioning`.
+ */
+class DataSourcePartitioning(
+ partitioning: Partitioning,
+ colNames: AttributeMap[String]) extends physical.Partitioning {
+
+ override val numPartitions: Int = partitioning.numPartitions()
+
+ override def satisfies(required: physical.Distribution): Boolean = {
+ super.satisfies(required) || {
+ required match {
+ case d: physical.ClusteredDistribution if isCandidate(d.clustering) =>
+ val attrs = d.clustering.map(_.asInstanceOf[Attribute])
+ partitioning.satisfy(
+ new ClusteredDistribution(attrs.map { a =>
+ val name = colNames.get(a)
+ assert(name.isDefined, s"Attribute ${a.name} is not found in the data source output")
+ name.get
+ }.toArray))
+
+ case _ => false
+ }
+ }
+ }
+
+ private def isCandidate(clustering: Seq[Expression]): Boolean = {
+ clustering.forall {
+ case a: Attribute => colNames.contains(a)
+ case _ => false
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index beb66738732be..69d871df3e1dd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical
import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec}
import org.apache.spark.sql.execution.streaming.continuous._
import org.apache.spark.sql.sources.v2.reader._
@@ -42,6 +43,14 @@ case class DataSourceV2ScanExec(
override def producedAttributes: AttributeSet = AttributeSet(fullOutput)
+ override def outputPartitioning: physical.Partitioning = reader match {
+ case s: SupportsReportPartitioning =>
+ new DataSourcePartitioning(
+ s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
+
+ case _ => super.outputPartitioning
+ }
+
private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader match {
case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks()
case _ =>
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
new file mode 100644
index 0000000000000..806d0bcd93f18
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources.v2;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.DataSourceV2Options;
+import org.apache.spark.sql.sources.v2.ReadSupport;
+import org.apache.spark.sql.sources.v2.reader.*;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport {
+
+ class Reader implements DataSourceV2Reader, SupportsReportPartitioning {
+ private final StructType schema = new StructType().add("a", "int").add("b", "int");
+
+ @Override
+ public StructType readSchema() {
+ return schema;
+ }
+
+ @Override
+ public List> createReadTasks() {
+ return java.util.Arrays.asList(
+ new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}),
+ new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2}));
+ }
+
+ @Override
+ public Partitioning outputPartitioning() {
+ return new MyPartitioning();
+ }
+ }
+
+ static class MyPartitioning implements Partitioning {
+
+ @Override
+ public int numPartitions() {
+ return 2;
+ }
+
+ @Override
+ public boolean satisfy(Distribution distribution) {
+ if (distribution instanceof ClusteredDistribution) {
+ String[] clusteredCols = ((ClusteredDistribution) distribution).clusteredColumns;
+ return Arrays.asList(clusteredCols).contains("a");
+ }
+
+ return false;
+ }
+ }
+
+ static class SpecificReadTask implements ReadTask, DataReader {
+ private int[] i;
+ private int[] j;
+ private int current = -1;
+
+ SpecificReadTask(int[] i, int[] j) {
+ assert i.length == j.length;
+ this.i = i;
+ this.j = j;
+ }
+
+ @Override
+ public boolean next() throws IOException {
+ current += 1;
+ return current < i.length;
+ }
+
+ @Override
+ public Row get() {
+ return new GenericRow(new Object[] {i[current], j[current]});
+ }
+
+ @Override
+ public void close() throws IOException {
+
+ }
+
+ @Override
+ public DataReader createDataReader() {
+ return this;
+ }
+ }
+
+ @Override
+ public DataSourceV2Reader createReader(DataSourceV2Options options) {
+ return new Reader();
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index 0ca29524c6d05..0620693b35d16 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -24,6 +24,7 @@ import test.org.apache.spark.sql.sources.v2._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
import org.apache.spark.sql.sources.{Filter, GreaterThan}
import org.apache.spark.sql.sources.v2.reader._
@@ -95,6 +96,40 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext {
}
}
+ test("partitioning reporting") {
+ import org.apache.spark.sql.functions.{count, sum}
+ Seq(classOf[PartitionAwareDataSource], classOf[JavaPartitionAwareDataSource]).foreach { cls =>
+ withClue(cls.getName) {
+ val df = spark.read.format(cls.getName).load()
+ checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 2), Row(4, 2)))
+
+ val groupByColA = df.groupBy('a).agg(sum('b))
+ checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 4)))
+ assert(groupByColA.queryExecution.executedPlan.collectFirst {
+ case e: ShuffleExchangeExec => e
+ }.isEmpty)
+
+ val groupByColAB = df.groupBy('a, 'b).agg(count("*"))
+ checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 1), Row(4, 2, 2)))
+ assert(groupByColAB.queryExecution.executedPlan.collectFirst {
+ case e: ShuffleExchangeExec => e
+ }.isEmpty)
+
+ val groupByColB = df.groupBy('b).agg(sum('a))
+ checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
+ assert(groupByColB.queryExecution.executedPlan.collectFirst {
+ case e: ShuffleExchangeExec => e
+ }.isDefined)
+
+ val groupByAPlusB = df.groupBy('a + 'b).agg(count("*"))
+ checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 1)))
+ assert(groupByAPlusB.queryExecution.executedPlan.collectFirst {
+ case e: ShuffleExchangeExec => e
+ }.isDefined)
+ }
+ }
+ }
+
test("simple writable data source") {
// TODO: java implementation.
Seq(classOf[SimpleWritableDataSource]).foreach { cls =>
@@ -365,3 +400,47 @@ class BatchReadTask(start: Int, end: Int)
override def close(): Unit = batch.close()
}
+
+class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
+
+ class Reader extends DataSourceV2Reader with SupportsReportPartitioning {
+ override def readSchema(): StructType = new StructType().add("a", "int").add("b", "int")
+
+ override def createReadTasks(): JList[ReadTask[Row]] = {
+ // Note that we don't have same value of column `a` across partitions.
+ java.util.Arrays.asList(
+ new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)),
+ new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2)))
+ }
+
+ override def outputPartitioning(): Partitioning = new MyPartitioning
+ }
+
+ class MyPartitioning extends Partitioning {
+ override def numPartitions(): Int = 2
+
+ override def satisfy(distribution: Distribution): Boolean = distribution match {
+ case c: ClusteredDistribution => c.clusteredColumns.contains("a")
+ case _ => false
+ }
+ }
+
+ override def createReader(options: DataSourceV2Options): DataSourceV2Reader = new Reader
+}
+
+class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] with DataReader[Row] {
+ assert(i.length == j.length)
+
+ private var current = -1
+
+ override def createDataReader(): DataReader[Row] = this
+
+ override def next(): Boolean = {
+ current += 1
+ current < i.length
+ }
+
+ override def get(): Row = Row(i(current), j(current))
+
+ override def close(): Unit = {}
+}