Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
tejasapatil committed Nov 27, 2017
1 parent d218fc3 commit d9620ef
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

package org.apache.spark.sql.execution.exchange

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.joins.ReorderJoinPredicates
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec,
SortMergeJoinExec}
import org.apache.spark.sql.internal.SQLConf

/**
Expand All @@ -32,8 +35,6 @@ import org.apache.spark.sql.internal.SQLConf
* the input partition ordering requirements are met.
*/
case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
private val reorderJoinPredicates = new ReorderJoinPredicates

private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions

private def targetPostShuffleInputSize: Long = conf.targetPostShuffleInputSize
Expand Down Expand Up @@ -251,6 +252,75 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
operator.withNewChildren(children)
}

/**
* When the physical operators are created for JOIN, the ordering of join keys is based on order
* in which the join keys appear in the user query. That might not match with the output
* partitioning of the join node's children (thus leading to extra sort / shuffle being
* introduced). This rule will change the ordering of the join keys to match with the
* partitioning of the join nodes' children.
*/
def reorderJoinPredicates(plan: SparkPlan): SparkPlan = {
def reorderJoinKeys(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
leftPartitioning: Partitioning,
rightPartitioning: Partitioning): (Seq[Expression], Seq[Expression]) = {

def reorder(expectedOrderOfKeys: Seq[Expression],
currentOrderOfKeys: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
val leftKeysBuffer = ArrayBuffer[Expression]()
val rightKeysBuffer = ArrayBuffer[Expression]()

expectedOrderOfKeys.foreach(expression => {
val index = currentOrderOfKeys.indexWhere(e => e.semanticEquals(expression))
leftKeysBuffer.append(leftKeys(index))
rightKeysBuffer.append(rightKeys(index))
})
(leftKeysBuffer, rightKeysBuffer)
}

if (leftKeys.forall(_.deterministic) && rightKeys.forall(_.deterministic)) {
leftPartitioning match {
case HashPartitioning(leftExpressions, _)
if leftExpressions.length == leftKeys.length &&
leftKeys.forall(x => leftExpressions.exists(_.semanticEquals(x))) =>
reorder(leftExpressions, leftKeys)

case _ => rightPartitioning match {
case HashPartitioning(rightExpressions, _)
if rightExpressions.length == rightKeys.length &&
rightKeys.forall(x => rightExpressions.exists(_.semanticEquals(x))) =>
reorder(rightExpressions, rightKeys)

case _ => (leftKeys, rightKeys)
}
}
} else {
(leftKeys, rightKeys)
}
}

plan.transformUp {
case BroadcastHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left,
right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
BroadcastHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)

case ShuffledHashJoinExec(leftKeys, rightKeys, joinType, buildSide, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition,
left, right)

case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) =>
val (reorderedLeftKeys, reorderedRightKeys) =
reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning)
SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right)
}
}

def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(partitioning, child, _) =>
child.children match {
Expand All @@ -259,6 +329,6 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] {
case _ => operator
}
case operator: SparkPlan =>
ensureDistributionAndOrdering(reorderJoinPredicates.apply(operator))
ensureDistributionAndOrdering(reorderJoinPredicates(operator))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -609,17 +609,26 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils {
df.write.format("parquet").bucketBy(8, "j", "k").saveAsTable("bucketed_table")

withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "0") {
sql("""
|SELECT *
|FROM (
| SELECT a.i, a.j, a.k
| FROM bucketed_table a
| JOIN table1 b
| ON a.i = b.i
|) c
|JOIN table2
|ON c.i = table2.i
|""".stripMargin).explain()
checkAnswer(
sql("""
|SELECT ab.i, ab.j, ab.k, c.i, c.j, c.k
|FROM (
| SELECT a.i, a.j, a.k
| FROM bucketed_table a
| JOIN table1 b
| ON a.i = b.i
|) ab
|JOIN table2 c
|ON ab.i = c.i
|""".stripMargin),
sql("""
|SELECT a.i, a.j, a.k, c.i, c.j, c.k
|FROM bucketed_table a
|JOIN table1 b
|ON a.i = b.i
|JOIN table2 c
|ON a.i = c.i
|""".stripMargin))
}
}
}
Expand Down

0 comments on commit d9620ef

Please sign in to comment.