From 6022e77877aa2ea2e0e7c6847f31e42ff3f1f1c8 Mon Sep 17 00:00:00 2001 From: David Vrba Date: Sun, 9 Dec 2018 21:46:38 +0100 Subject: [PATCH] spark-25401 reorder join predicates to match child outputOrdering --- .../exchange/EnsureRequirements.scala | 52 ++++++++++++++++--- .../spark/sql/execution/PlannerSuite.scala | 20 +++++++ 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d2d5011bbcb97..b37e5cedaa59f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -208,7 +208,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { children = withExchangeCoordinator(children, requiredChildDistributions) // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings: - children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) => + ensureOrdering( + reorderJoinPredicatesForOrdering(operator.withNewChildren(children)) + ) + } + + private def ensureOrdering(operator: SparkPlan): SparkPlan = { + var children: Seq[SparkPlan] = operator.children + children = children.zip(operator.requiredChildOrdering).map { case (child, requiredOrdering) => // If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort. if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) { child @@ -243,24 +250,38 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { (leftKeysBuffer, rightKeysBuffer) } - private def reorderJoinKeys( + private def reorderJoinKeys[A]( leftKeys: Seq[Expression], rightKeys: Seq[Expression], - leftPartitioning: Partitioning, - rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = { + leftChildDist: A, + rightChildDist: A): (Seq[Expression], Seq[Expression]) = { if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) { - leftPartitioning match { + leftChildDist match { case HashPartitioning(leftExpressions, _) if leftExpressions.length == leftKeys.length && leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) => reorder(leftKeys, rightKeys, leftExpressions, leftKeys) - case _ => rightPartitioning match { + case leftOrders: Seq[_] + if leftOrders.forall(_.isInstanceOf[Expression]) && + leftOrders.length == leftKeys.length && + leftKeys.forall { x => + (leftOrders.map(_.asInstanceOf[Expression])).exists(_.semanticEquals(x))} => + reorder(leftKeys, rightKeys, leftOrders.map(_.asInstanceOf[Expression]), leftKeys) + + case _ => rightChildDist match { case HashPartitioning(rightExpressions, _) if rightExpressions.length == rightKeys.length && rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) => reorder(leftKeys, rightKeys, rightExpressions, rightKeys) + case rightOrders: Seq[_] + if rightOrders.forall(_.isInstanceOf[Expression]) && + rightOrders.length == leftKeys.length && + leftKeys.forall { x => + (rightOrders.map(_.asInstanceOf[Expression])).exists(_.semanticEquals(x))} => + reorder(leftKeys, rightKeys, rightOrders.map(_.asInstanceOf[Expression]), leftKeys) + case _ => (leftKeys, rightKeys) } } @@ -276,7 +297,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { * introduced). This rule will change the ordering of the join keys to match with the * partitioning of the join nodes' children. */ - private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = { + private def reorderJoinPredicatesForPartitioning(plan: SparkPlan): SparkPlan = { plan match { case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = @@ -293,6 +314,21 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } } + private def reorderJoinPredicatesForOrdering(plan: SparkPlan): SparkPlan = { + plan match { + case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + val (reorderedLeftKeys, reorderedRightKeys) = + reorderJoinKeys( + leftKeys, + rightKeys, + left.outputOrdering.map(_.child), + right.outputOrdering.map(_.child)) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + + case other => other + } + } + def apply(plan: SparkPlan): SparkPlan = plan.transformUp { // TODO: remove this after we create a physical operator for `RepartitionByExpression`. case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) => @@ -301,6 +337,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case _ => operator } case operator: SparkPlan => - ensureDistributionAndOrdering(reorderJoinPredicates(operator)) + ensureDistributionAndOrdering(reorderJoinPredicatesForPartitioning(operator)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 142ab6170a734..6469103352fd9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -780,6 +780,26 @@ class PlannerSuite extends SharedSQLContext { classOf[PartitioningCollection]) } } + + test("SPARK-25401: Reorder the join predicates to match child output ordering") { + val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB), + outputPartitioning = HashPartitioning(exprB :: exprA :: Nil, 5)) + val plan2 = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB), + outputPartitioning = HashPartitioning(exprB :: exprA :: Nil, 5)) + val smjExec = SortMergeJoinExec( + exprB :: exprA :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan2) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) + outputPlan match { + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + assert(leftKeys == Seq(exprA, exprB)) + assert(rightKeys == Seq(exprA, exprB)) + case _ => fail() + } + if (outputPlan.collect { case s: SortExec => true }.nonEmpty) { + fail(s"No sorts should have been added:\n$outputPlan") + } + } } // Used for unit-testing EnsureRequirements