Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-29248][SQL] Pass in number of partitions to WriteBuilder #25945

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,16 @@ default WriteBuilder withInputDataSchema(StructType schema) {
return this;
}

/**
* Passes the number of partitions of the input data from Spark to data source.
*
* @return a new builder with the `schema`. By default it returns `this`, which means the given
* `numPartitions` is ignored. Please override this method to take the `numPartitions`.
*/
default WriteBuilder withNumPartitions(int numPartitions) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with the approach here, but just want to share a few thoughts about how to make the API better. The use case is: there are some additional information (input schema, numPartition, etc.) that Spark should always provide, and the implementation only need to write extra code if they need to access the additional information.

With the current API, we can:

  1. add more additional information in future versions without breaking backward compatibility.
  2. users only need to overwrite withNumPartitions and other methods if they need to access the additional information.

But there is one drawback: we need to take extra effort to make sure the additional information is provided by Spark. It's better to guarantee this at compile time.

I think we can improve this API a little bit. For Table#newWriteBuilder, we can define it as

WriteBuilder newWriteBuilder(CaseInsensitiveStringMap options, WriteInfo info);

While WriteInfo is an interface providing additional information:

interface WriteInfo {
  String queryId();
  StructType inputDataSchema();
  ...
}

The WriteInfo is implemented by Spark and called by data source implementations, so we can add more methods in future versions without breaking backward compatibility. The WriteInfo can also make sure Spark always provide additional information at compile time.

If you guys think it makes sense, we can do it in a followup.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with you that the WriteInfo approach has better compile time guarantees. I actually started implementing the change like that, but then felt it was maybe too much of a change and that I should focus on the numPartitions.

I'm happy to change it in a followup PR, if that works for everyone.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to take this approach, then let's do it now before a release. Otherwise we should use the original implementation to add an additional method because that is a compatible change.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@edrevo can you implement this approach here? I think adding a numPartitions is really a small change, we can set the main focus of this PR to improve this API.

return this;
}

/**
* Returns a {@link BatchWrite} to write data to batch source. By default this method throws
* exception, data sources must overwrite this method to provide an implementation, if the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ case class CreateTableAsSelectExec(
case table: SupportsWrite =>
val writeBuilder = table.newWriteBuilder(writeOptions)
.withInputDataSchema(schema)
.withNumPartitions(rdd.getNumPartitions)
.withQueryId(UUID.randomUUID().toString)

writeBuilder match {
Expand Down Expand Up @@ -181,6 +182,7 @@ case class ReplaceTableAsSelectExec(
case table: SupportsWrite =>
val writeBuilder = table.newWriteBuilder(writeOptions)
.withInputDataSchema(schema)
.withNumPartitions(rdd.getNumPartitions)
.withQueryId(UUID.randomUUID().toString)

writeBuilder match {
Expand Down Expand Up @@ -332,11 +334,13 @@ case class WriteToDataSourceV2Exec(
trait BatchWriteHelper {
def table: SupportsWrite
def query: SparkPlan
def rdd: RDD[InternalRow]
def writeOptions: CaseInsensitiveStringMap

def newWriteBuilder(): WriteBuilder = {
table.newWriteBuilder(writeOptions)
.withInputDataSchema(query.schema)
.withNumPartitions(rdd.getNumPartitions)
.withQueryId(UUID.randomUUID().toString)
}
}
Expand All @@ -347,34 +351,38 @@ trait BatchWriteHelper {
trait V2TableWriteExec extends UnaryExecNode {
def query: SparkPlan

lazy val rdd: RDD[InternalRow] = {
val tempRdd = query.execute()
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
// partition rdd to make sure we at least set up one write task to write the metadata.
if (tempRdd.partitions.length == 0) {
sparkContext.parallelize(Array.empty[InternalRow], 1)
} else {
tempRdd
}
}

var commitProgress: Option[StreamWriterCommitProgress] = None

override def child: SparkPlan = query
override def output: Seq[Attribute] = Nil

protected def writeWithV2(batchWrite: BatchWrite): RDD[InternalRow] = {
val writerFactory = batchWrite.createBatchWriterFactory()
val useCommitCoordinator = batchWrite.useCommitCoordinator
val rdd = query.execute()
// SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single
// partition rdd to make sure we at least set up one write task to write the metadata.
val rddWithNonEmptyPartitions = if (rdd.partitions.length == 0) {
sparkContext.parallelize(Array.empty[InternalRow], 1)
} else {
rdd
}
val messages = new Array[WriterCommitMessage](rddWithNonEmptyPartitions.partitions.length)
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
val totalNumRowsAccumulator = new LongAccumulator()

val writerFactory = batchWrite.createBatchWriterFactory()

logInfo(s"Start processing data source write support: $batchWrite. " +
s"The input RDD has ${messages.length} partitions.")

try {
sparkContext.runJob(
rddWithNonEmptyPartitions,
rdd,
(context: TaskContext, iter: Iterator[InternalRow]) =>
DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator),
rddWithNonEmptyPartitions.partitions.indices,
rdd.partitions.indices,
(index, result: DataWritingSparkTaskResult) => {
val commitMessage = result.writerCommitMessage
messages(index) = commitMessage
Expand Down Expand Up @@ -480,6 +488,7 @@ private[v2] trait AtomicTableWriteExec extends V2TableWriteExec with SupportsV1W
case table: SupportsWrite =>
val writeBuilder = table.newWriteBuilder(writeOptions)
.withInputDataSchema(query.schema)
.withNumPartitions(rdd.getNumPartitions)
.withQueryId(UUID.randomUUID().toString)

val writtenRows = writeBuilder match {
Expand Down