Skip to content

Commit

Permalink
spark-25401 reorder join predicates to match child outputOrdering
Browse files Browse the repository at this point in the history
  • Loading branch information
David Vrba committed Dec 9, 2018
1 parent 877f82c commit 6022e77
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,14 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
children = withExchangeCoordinator(children, requiredChildDistributions)

// Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
ensureOrdering(
reorderJoinPredicatesForOrdering(operator.withNewChildren(children))
)
}

private def ensureOrdering(operator: SparkPlan): SparkPlan = {
var children: Seq[SparkPlan] = operator.children
children = children.zip(operator.requiredChildOrdering).map { case (child, requiredOrdering) =>
// If child.outputOrdering already satisfies the requiredOrdering, we do not need to sort.
if (SortOrder.orderingSatisfies(child.outputOrdering, requiredOrdering)) {
child
Expand Down Expand Up @@ -243,24 +250,38 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
(leftKeysBuffer, rightKeysBuffer)
}

private def reorderJoinKeys(
private def reorderJoinKeys[A](
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {
leftChildDist: A,
rightChildDist: A): (Seq[Expression], Seq[Expression]) = {
if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
leftChildDist match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, leftExpressions, leftKeys)

case _ => rightPartitioning match {
case leftOrders: Seq[_]
if leftOrders.forall(_.isInstanceOf[Expression]) &&
leftOrders.length == leftKeys.length &&
leftKeys.forall { x =>
(leftOrders.map(_.asInstanceOf[Expression])).exists(_.semanticEquals(x))} =>
reorder(leftKeys, rightKeys, leftOrders.map(_.asInstanceOf[Expression]), leftKeys)

case _ => rightChildDist match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(leftKeys, rightKeys, rightExpressions, rightKeys)

case rightOrders: Seq[_]
if rightOrders.forall(_.isInstanceOf[Expression]) &&
rightOrders.length == leftKeys.length &&
leftKeys.forall { x =>
(rightOrders.map(_.asInstanceOf[Expression])).exists(_.semanticEquals(x))} =>
reorder(leftKeys, rightKeys, rightOrders.map(_.asInstanceOf[Expression]), leftKeys)

case _ => (leftKeys, rightKeys)
}
}
Expand All @@ -276,7 +297,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
private def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
private def reorderJoinPredicatesForPartitioning(plan: SparkPlan): SparkPlan = {
plan match {
case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
Expand All @@ -293,6 +314,21 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
}
}

private def reorderJoinPredicatesForOrdering(plan: SparkPlan): SparkPlan = {
plan match {
case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(
leftKeys,
rightKeys,
left.outputOrdering.map(_.child),
right.outputOrdering.map(_.child))
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)

case other => other
}
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
// TODO: remove this after we create a physical operator for `RepartitionByExpression`.
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, _) =>
Expand All @@ -301,6 +337,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
case _ => operator
}
case operator: SparkPlan =>
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
ensureDistributionAndOrdering(reorderJoinPredicatesForPartitioning(operator))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,26 @@ class PlannerSuite extends SharedSQLContext {
classOf[PartitioningCollection])
}
}

test("SPARK-25401: Reorder the join predicates to match child output ordering") {
val plan1 = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB),
outputPartitioning = HashPartitioning(exprB :: exprA :: Nil, 5))
val plan2 = DummySparkPlan(outputOrdering = Seq(orderingA, orderingB),
outputPartitioning = HashPartitioning(exprB :: exprA :: Nil, 5))
val smjExec = SortMergeJoinExec(
exprB :: exprA :: Nil, exprB :: exprA :: Nil, Inner, None, plan1, plan2)

val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec)
outputPlan match {
case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) =>
assert(leftKeys == Seq(exprA, exprB))
assert(rightKeys == Seq(exprA, exprB))
case _ => fail()
}
if (outputPlan.collect { case s: SortExec => true }.nonEmpty) {
fail(s"No sorts should have been added:\n$outputPlan")
}
}
}

// Used for unit-testing EnsureRequirements
Expand Down

0 comments on commit 6022e77

Please sign in to comment.