Skip to content

Commit

Permalink
auto calculate the initial partition number with ae (apache#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
JkSelf authored and luzhonghao committed Dec 11, 2018
1 parent b74f4f5 commit 84394e8
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,14 @@ object SQLConf {
"must be a positive integer.")
.createWithDefault(500)

val ADAPTIVE_EXECUTION_AUTO_CALCULATE_INITIAL_PARTITION_NUM =
buildConf("spark.sql.adaptive.autoCalculateInitialPartitionNum")
.doc("When true and adaptive execution is enabled," +
" spark will calculate the initial partition number" +
" based on the statistics of the needed column.")
.booleanConf
.createWithDefault(false)

val SUBEXPRESSION_ELIMINATION_ENABLED =
buildConf("spark.sql.subexpressionElimination.enabled")
.internal()
Expand Down Expand Up @@ -1382,6 +1390,9 @@ class SQLConf extends Serializable with Logging {

def maxNumPostShufflePartitions: Int = getConf(SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS)

def adaptiveAutoCalculateInitialPartitionNum: Boolean =
getConf(ADAPTIVE_EXECUTION_AUTO_CALCULATE_INITIAL_PARTITION_NUM)

def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN)

def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,13 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport {
}

override def computeStats(): Statistics = {
Statistics(sizeInBytes = relation.sizeInBytes)
// There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error.
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
val dataSchema = sqlContext.sparkSession.sessionState.catalog.getTableMetadata(
tableIdentifier.get).dataSchema
val totalRowSize = dataSchema.map(_.dataType.defaultSize).sum + 8
Statistics(sizeInBytes = ((relation.sizeInBytes * outputRowSize) / totalRowSize))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,31 @@ import org.apache.spark.sql.internal.SQLConf
* the input partition ordering requirements are met.
*/
case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
private def defaultNumPreShufflePartitions: Int =
private def defaultNumPreShufflePartitions(plan: SparkPlan): Int =
if (conf.adaptiveExecutionEnabled) {
conf.maxNumPostShufflePartitions
if (conf.adaptiveAutoCalculateInitialPartitionNum) {
autoCalculateInitialPartitionNum(plan)
} else {
conf.maxNumPostShufflePartitions
}
} else {
conf.numShufflePartitions
}

private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
private def autoCalculateInitialPartitionNum(plan: SparkPlan): Int = {
val totalInputFileSize = plan.collectLeaves().map(_.statsPlan.sizeInBytes).sum
val autoInitialPartitionsNum = Math.ceil(
totalInputFileSize.toLong * 1.0 / conf.targetPostShuffleInputSize).toInt
if (autoInitialPartitionsNum < conf.minNumPostShufflePartitions) {
conf.minNumPostShufflePartitions
} else if (autoInitialPartitionsNum > conf.maxNumPostShufflePartitions) {
conf.maxNumPostShufflePartitions
} else {
autoInitialPartitionsNum
}
}

private def ensureDistributionAndOrdering(operator: SparkPlan, rootNode: SparkPlan): SparkPlan = {
val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
var children: Seq[SparkPlan] = operator.children
Expand All @@ -58,7 +75,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
BroadcastExchangeExec(mode, child)
case (child, distribution) =>
val numPartitions = distribution.requiredNumPartitions
.getOrElse(defaultNumPreShufflePartitions)
.getOrElse(defaultNumPreShufflePartitions(rootNode))
ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child)
}

Expand Down Expand Up @@ -196,14 +213,19 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child) =>
child.outputPartitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => child
case _ => operator
}
case operator: SparkPlan =>
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
def apply(plan: SparkPlan): SparkPlan = {
// Record the rootNode is order to collect all the leaves node of the rootNode
// when calculate the initial partition num
val rootNode = plan;
plan.transformUp {
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child) =>
child.outputPartitioning match {
case lower: HashPartitioning if upper.semanticEquals(lower) => child
case _ => operator
}
case operator: SparkPlan =>
ensureDistributionAndOrdering(reorderJoinPredicates(operator), rootNode)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,46 @@ class PlannerSuite extends SharedSQLContext {
}
}

test("EnsureRequirements with the initial partition number that" +
" is based on the statistics of leaf node") {
val distribution = ClusteredDistribution(Literal(1) :: Nil)
val childPartitioning = HashPartitioning(Literal(2) :: Nil, 1)

val inputPlan = DummySparkPlan(
children = Seq(
DummySparkPlan(outputPartitioning = childPartitioning),
DummySparkPlan(outputPartitioning = childPartitioning)
),
requiredChildDistribution = Seq(distribution, distribution),
requiredChildOrdering = Seq(Seq.empty, Seq.empty)
)
withSQLConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "1") {

val totalInputFileSize = inputPlan.collectLeaves().map(_.stats.sizeInBytes).sum
val expectedNum = Math.ceil(
totalInputFileSize.toLong * 1.0 / conf.targetPostShuffleInputSize).toInt

withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.ADAPTIVE_EXECUTION_AUTO_CALCULATE_INITIAL_PARTITION_NUM.key -> "true") {
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
outputPlan.collect{
case plan : ShuffleExchangeExec =>
val realNum = plan.outputPartitioning.numPartitions
assert(realNum == expectedNum)
}
}

withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan)
outputPlan.collect{
case plan : ShuffleExchangeExec =>
val realNum = plan.outputPartitioning.numPartitions
assert(realNum != expectedNum)
}
}
}
}

test("EnsureRequirements with compatible child partitionings that satisfy distribution") {
// In this case, all requirements are satisfied and no exchange should be added.
val distribution = ClusteredDistribution(Literal(1) :: Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.util.control.NonFatal

import com.google.common.util.concurrent.Striped
import org.apache.hadoop.fs.Path

import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -154,7 +154,14 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
Some(partitionSchema))

val logicalRelation = cached.getOrElse {
val sizeInBytes = relation.stats.sizeInBytes.toLong
// for the partition table, the relation.stats.sizeInBytes is Long.max
// not the real size in hdfs
val sizeInBytes = if (relation.isPartitioned && relation.stats.sizeInBytes.toLong == Long.MaxValue) {
sparkSession.sharedState.externalCatalog.listPartitions(tableIdentifier.database,
tableIdentifier.name).map(_.parameters.get(StatsSetupConst.TOTAL_SIZE).get.toLong).sum
} else {
relation.stats.sizeInBytes.toLong
}
val fileIndex = {
val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes)
if (lazyPruningEnabled) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution
import scala.collection.JavaConverters._

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hive.common.StatsSetupConst
import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition}
import org.apache.hadoop.hive.ql.plan.TableDesc
import org.apache.hadoop.hive.serde.serdeConstants
Expand Down Expand Up @@ -215,7 +216,18 @@ case class HiveTableScanExec(
override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession)

override def computeStats(): Statistics = {
val stats = relation.computeStats()
Statistics(stats.sizeInBytes)
// There should be some overhead in Row object, the size should not be zero when there is
// no columns, this help to prevent divide-by-zero error.
val outputRowSize = output.map(_.dataType.defaultSize).sum + 8
val totalRowSize = relation.tableMeta.dataSchema.map(_.dataType.defaultSize).sum + 8
// For the partition table, we only get the selected partition statistics
val sizeInBytes = if (relation.isPartitioned && rawPartitions.nonEmpty) {
BigInt(rawPartitions.map(_.getParameters.get(StatsSetupConst.TOTAL_SIZE).toLong).sum)
} else {
relation.computeStats().sizeInBytes
}
// the sizeInBytes is the compressed size and we need multiply the compressionFactor
val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor.toLong
Statistics((sizeInBytes * compressionFactor * outputRowSize) / totalRowSize)
}
}

0 comments on commit 84394e8

Please sign in to comment.