From 84394e8ad21d4db0faf8188464a5e954aa09f1a5 Mon Sep 17 00:00:00 2001 From: JkSelf Date: Fri, 7 Sep 2018 17:02:18 +0800 Subject: [PATCH] auto calculate the initial partition number with ae (#61) --- .../apache/spark/sql/internal/SQLConf.scala | 11 +++++ .../sql/execution/DataSourceScanExec.scala | 8 +++- .../exchange/EnsureRequirements.scala | 48 ++++++++++++++----- .../spark/sql/execution/PlannerSuite.scala | 40 ++++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 11 ++++- .../hive/execution/HiveTableScanExec.scala | 16 ++++++- 6 files changed, 116 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 31c7a8f4d90f2..45aeb170d199e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -304,6 +304,14 @@ object SQLConf { "must be a positive integer.") .createWithDefault(500) + val ADAPTIVE_EXECUTION_AUTO_CALCULATE_INITIAL_PARTITION_NUM = + buildConf("spark.sql.adaptive.autoCalculateInitialPartitionNum") + .doc("When true and adaptive execution is enabled," + + " spark will calculate the initial partition number" + + " based on the statistics of the needed column.") + .booleanConf + .createWithDefault(false) + val SUBEXPRESSION_ELIMINATION_ENABLED = buildConf("spark.sql.subexpressionElimination.enabled") .internal() @@ -1382,6 +1390,9 @@ class SQLConf extends Serializable with Logging { def maxNumPostShufflePartitions: Int = getConf(SHUFFLE_MAX_NUM_POSTSHUFFLE_PARTITIONS) + def adaptiveAutoCalculateInitialPartitionNum: Boolean = + getConf(ADAPTIVE_EXECUTION_AUTO_CALCULATE_INITIAL_PARTITION_NUM) + def minBatchesToRetain: Int = getConf(MIN_BATCHES_TO_RETAIN) def parquetFilterPushDown: Boolean = getConf(PARQUET_FILTER_PUSHDOWN_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index df41b2df11268..0715dbc0d6304 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -74,7 +74,13 @@ trait DataSourceScanExec extends LeafExecNode with CodegenSupport { } override def computeStats(): Statistics = { - Statistics(sizeInBytes = relation.sizeInBytes) + // There should be some overhead in Row object, the size should not be zero when there is + // no columns, this help to prevent divide-by-zero error. + val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 + val dataSchema = sqlContext.sparkSession.sessionState.catalog.getTableMetadata( + tableIdentifier.get).dataSchema + val totalRowSize = dataSchema.map(_.dataType.defaultSize).sum + 8 + Statistics(sizeInBytes = ((relation.sizeInBytes * outputRowSize) / totalRowSize)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 34e4b18ca2b7b..c5dfcc5bb45dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -36,14 +36,31 @@ import org.apache.spark.sql.internal.SQLConf * the input partition ordering requirements are met. */ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { - private def defaultNumPreShufflePartitions: Int = + private def defaultNumPreShufflePartitions(plan: SparkPlan): Int = if (conf.adaptiveExecutionEnabled) { - conf.maxNumPostShufflePartitions + if (conf.adaptiveAutoCalculateInitialPartitionNum) { + autoCalculateInitialPartitionNum(plan) + } else { + conf.maxNumPostShufflePartitions + } } else { conf.numShufflePartitions } - private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = { + private def autoCalculateInitialPartitionNum(plan: SparkPlan): Int = { + val totalInputFileSize = plan.collectLeaves().map(_.statsPlan.sizeInBytes).sum + val autoInitialPartitionsNum = Math.ceil( + totalInputFileSize.toLong * 1.0 / conf.targetPostShuffleInputSize).toInt + if (autoInitialPartitionsNum < conf.minNumPostShufflePartitions) { + conf.minNumPostShufflePartitions + } else if (autoInitialPartitionsNum > conf.maxNumPostShufflePartitions) { + conf.maxNumPostShufflePartitions + } else { + autoInitialPartitionsNum + } + } + + private def ensureDistributionAndOrdering(operator: SparkPlan, rootNode: SparkPlan): SparkPlan = { val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering var children: Seq[SparkPlan] = operator.children @@ -58,7 +75,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { BroadcastExchangeExec(mode, child) case (child, distribution) => val numPartitions = distribution.requiredNumPartitions - .getOrElse(defaultNumPreShufflePartitions) + .getOrElse(defaultNumPreShufflePartitions(rootNode)) ShuffleExchangeExec(distribution.createPartitioning(numPartitions), child) } @@ -196,14 +213,19 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } } - def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - // TODO: remove this after we create a physical operator for `RepartitionByExpression`. - case operator @ ShuffleExchangeExec(upper: HashPartitioning, child) => - child.outputPartitioning match { - case lower: HashPartitioning if upper.semanticEquals(lower) => child - case _ => operator - } - case operator: SparkPlan => - ensureDistributionAndOrdering(reorderJoinPredicates(operator)) + def apply(plan: SparkPlan): SparkPlan = { + // Record the rootNode is order to collect all the leaves node of the rootNode + // when calculate the initial partition num + val rootNode = plan; + plan.transformUp { + // TODO: remove this after we create a physical operator for `RepartitionByExpression`. + case operator @ ShuffleExchangeExec(upper: HashPartitioning, child) => + child.outputPartitioning match { + case lower: HashPartitioning if upper.semanticEquals(lower) => child + case _ => operator + } + case operator: SparkPlan => + ensureDistributionAndOrdering(reorderJoinPredicates(operator), rootNode) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 906c3616bf3f5..7e96a6c541760 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -336,6 +336,46 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements with the initial partition number that" + + " is based on the statistics of leaf node") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 1) + + val inputPlan = DummySparkPlan( + children = Seq( + DummySparkPlan(outputPartitioning = childPartitioning), + DummySparkPlan(outputPartitioning = childPartitioning) + ), + requiredChildDistribution = Seq(distribution, distribution), + requiredChildOrdering = Seq(Seq.empty, Seq.empty) + ) + withSQLConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "1") { + + val totalInputFileSize = inputPlan.collectLeaves().map(_.stats.sizeInBytes).sum + val expectedNum = Math.ceil( + totalInputFileSize.toLong * 1.0 / conf.targetPostShuffleInputSize).toInt + + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_AUTO_CALCULATE_INITIAL_PARTITION_NUM.key -> "true") { + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + outputPlan.collect{ + case plan : ShuffleExchangeExec => + val realNum = plan.outputPartitioning.numPartitions + assert(realNum == expectedNum) + } + } + + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + outputPlan.collect{ + case plan : ShuffleExchangeExec => + val realNum = plan.outputPartitioning.numPartitions + assert(realNum != expectedNum) + } + } + } + } + test("EnsureRequirements with compatible child partitionings that satisfy distribution") { // In this case, all requirements are satisfied and no exchange should be added. val distribution = ClusteredDistribution(Literal(1) :: Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 8adfda07d29d5..de5c28167f67f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -21,7 +21,7 @@ import scala.util.control.NonFatal import com.google.common.util.concurrent.Striped import org.apache.hadoop.fs.Path - +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession @@ -154,7 +154,14 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log Some(partitionSchema)) val logicalRelation = cached.getOrElse { - val sizeInBytes = relation.stats.sizeInBytes.toLong + // for the partition table, the relation.stats.sizeInBytes is Long.max + // not the real size in hdfs + val sizeInBytes = if (relation.isPartitioned && relation.stats.sizeInBytes.toLong == Long.MaxValue) { + sparkSession.sharedState.externalCatalog.listPartitions(tableIdentifier.database, + tableIdentifier.name).map(_.parameters.get(StatsSetupConst.TOTAL_SIZE).get.toLong).sum + } else { + relation.stats.sizeInBytes.toLong + } val fileIndex = { val index = new CatalogFileIndex(sparkSession, relation.tableMeta, sizeInBytes) if (lazyPruningEnabled) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 233b13f1e790f..9515730f27393 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde.serdeConstants @@ -215,7 +216,18 @@ case class HiveTableScanExec( override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) override def computeStats(): Statistics = { - val stats = relation.computeStats() - Statistics(stats.sizeInBytes) + // There should be some overhead in Row object, the size should not be zero when there is + // no columns, this help to prevent divide-by-zero error. + val outputRowSize = output.map(_.dataType.defaultSize).sum + 8 + val totalRowSize = relation.tableMeta.dataSchema.map(_.dataType.defaultSize).sum + 8 + // For the partition table, we only get the selected partition statistics + val sizeInBytes = if (relation.isPartitioned && rawPartitions.nonEmpty) { + BigInt(rawPartitions.map(_.getParameters.get(StatsSetupConst.TOTAL_SIZE).toLong).sum) + } else { + relation.computeStats().sizeInBytes + } + // the sizeInBytes is the compressed size and we need multiply the compressionFactor + val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor.toLong + Statistics((sizeInBytes * compressionFactor * outputRowSize) / totalRowSize) } }