Skip to content

Commit

Permalink
Rewrite join implementation to allow streaming of one relation.
Browse files Browse the repository at this point in the history
  • Loading branch information
marmbrus committed Mar 27, 2014
1 parent 1fa48d9 commit bc0cb84
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,16 @@ trait Row extends Seq[Any] with Serializable {
s"[${this.mkString(",")}]"

def copy(): Row

/** Returns true if there are any NULL values in this row. */
def anyNull: Boolean = {
var i = 0
while(i < length) {
if(isNullAt(i)) return true
i += 1
}
false
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
val strategies: Seq[Strategy] =
TopK ::
PartialAggregation ::
SparkEquiInnerJoin ::
HashJoin ::
ParquetOperations ::
BasicOperators ::
CartesianProduct ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.parquet._
abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>

object SparkEquiInnerJoin extends Strategy {
object HashJoin extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case FilteredOperation(predicates, logical.Join(left, right, Inner, condition)) =>
logger.debug(s"Considering join: ${predicates ++ condition}")
Expand All @@ -51,8 +51,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
val leftKeys = joinKeys.map(_._1)
val rightKeys = joinKeys.map(_._2)

val joinOp = execution.SparkEquiInnerJoin(
leftKeys, rightKeys, planLater(left), planLater(right))
val joinOp = execution.HashJoin(
leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))

// Make sure other conditions are met if present.
if (otherPredicates.nonEmpty) {
Expand Down
133 changes: 98 additions & 35 deletions sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,33 @@
* limitations under the License.
*/

package org.apache.spark.sql.execution
package org.apache.spark.sql
package execution

import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, BitSet}

import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext

import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
import catalyst.errors._
import catalyst.expressions._
import catalyst.plans._
import catalyst.plans.physical.{ClusteredDistribution, Partitioning}

import org.apache.spark.rdd.PartitionLocalRDDFunctions._
sealed abstract class BuildSide
case object BuildLeft extends BuildSide
case object BuildRight extends BuildSide

case class SparkEquiInnerJoin(
object InterpretCondition {
def apply(expression: Expression): (Row => Boolean) = {
(r: Row) => expression.apply(r).asInstanceOf[Boolean]
}
}

case class HashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
buildSide: BuildSide,
left: SparkPlan,
right: SparkPlan) extends BinaryNode {

Expand All @@ -40,33 +50,85 @@ case class SparkEquiInnerJoin(
override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil

val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
}

val (buildKeys, streamedKeys) = buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}

def output = left.output ++ right.output

def execute() = attachTree(this, "execute") {
val leftWithKeys = left.execute().mapPartitions { iter =>
val generateLeftKeys = new Projection(leftKeys, left.output)
iter.map(row => (generateLeftKeys(row), row.copy()))
}
@transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
@transient lazy val streamSideKeyGenerator =
() => new MutableProjection(streamedKeys, streamedPlan.output)

val rightWithKeys = right.execute().mapPartitions { iter =>
val generateRightKeys = new Projection(rightKeys, right.output)
iter.map(row => (generateRightKeys(row), row.copy()))
}
def execute() = {

// Do the join.
val joined = filterNulls(leftWithKeys).joinLocally(filterNulls(rightWithKeys))
// Drop join keys and merge input tuples.
joined.map { case (_, (leftTuple, rightTuple)) => buildRow(leftTuple ++ rightTuple) }
}
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
var currentRow: Row = null

// Create a mapping of buildKeys -> rows
while(buildIter.hasNext) {
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
val newMatchList = new ArrayBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
existingMatchList
}
matchList += currentRow.copy()
}
}

/**
* Filters any rows where the any of the join keys is null, ensuring three-valued
* logic for the equi-join conditions.
*/
protected def filterNulls(rdd: RDD[(Row, Row)]) =
rdd.filter {
case (key: Seq[_], _) => !key.exists(_ == null)
new Iterator[Row] {
private[this] var currentRow: Row = _
private[this] var currentMatches: ArrayBuffer[Row] = _
private[this] var currentPosition: Int = -1

// Mutable per row objects.
private[this] val joinRow = new JoinedRow

@transient private val joinKeys = streamSideKeyGenerator()

def hasNext: Boolean =
(currentPosition != -1 && currentPosition < currentMatches.size) ||
(streamIter.hasNext && fetchNext())

def next() = {
val ret = joinRow(currentRow, currentMatches(currentPosition))
currentPosition += 1
ret
}

private def fetchNext(): Boolean = {
currentMatches = null
currentPosition = -1

while (currentMatches == null && streamIter.hasNext) {
currentRow = streamIter.next()
if(!joinKeys(currentRow).anyNull)
currentMatches = hashTable.get(joinKeys.currentValue)
}

if (currentMatches == null) {
false
} else {
currentPosition = 0
true
}
}
}
}
}
}

case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode {
Expand Down Expand Up @@ -95,17 +157,18 @@ case class BroadcastNestedLoopJoin(
def right = broadcast

@transient lazy val boundCondition =
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true))
InterpretCondition(
condition
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))


def execute() = {
val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)

val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter =>
val matchedRows = new mutable.ArrayBuffer[Row]
val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size)
val matchedRows = new ArrayBuffer[Row]
val includedBroadcastTuples = new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
val joinedRow = new JoinedRow

streamedIter.foreach { streamedRow =>
Expand All @@ -115,7 +178,7 @@ case class BroadcastNestedLoopJoin(
while (i < broadcastedRelation.value.size) {
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) {
if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
matchedRows += buildRow(streamedRow ++ broadcastedRow)
matched = true
includedBroadcastTuples += i
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
DataSinks,
Scripts,
PartialAggregation,
SparkEquiInnerJoin,
HashJoin,
BasicOperators,
CartesianProduct,
BroadcastNestedLoopJoin
Expand Down

0 comments on commit bc0cb84

Please sign in to comment.