Skip to content

Commit

Permalink
[SPARK-2521] Broadcast RDD object once per TaskSet (instead of sendin…
Browse files Browse the repository at this point in the history
…g it for every task).
  • Loading branch information
rxin committed Jul 18, 2014
1 parent 6afca2d commit 04b17f0
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 225 deletions.
28 changes: 18 additions & 10 deletions core/src/main/scala/org/apache/spark/Dependency.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle
* Base class for dependencies.
*/
@DeveloperApi
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
abstract class Dependency[T] extends Serializable {
def rdd: RDD[T]
}


/**
Expand All @@ -36,41 +38,47 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
*/
@DeveloperApi
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
/**
* Get the parent partitions for a child partition.
* @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
def getParents(partitionId: Int): Seq[Int]

override def rdd: RDD[T] = _rdd
}


/**
* :: DeveloperApi ::
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
* the RDD is transient since we don't need it on the executor side.
*
* @param _rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
@transient _rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializer: Option[Serializer] = None,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
extends Dependency[Product2[K, V]] {

override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]

val shuffleId: Int = rdd.context.newShuffleId()
val shuffleId: Int = _rdd.context.newShuffleId()

val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
shuffleId, rdd.partitions.size, this)
val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
shuffleId, _rdd.partitions.size, this)

rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
_rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}


Expand Down
2 changes: 0 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging {
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
ResultTask.clearCache()
listenerBus.stop()
eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
Expand Down
17 changes: 11 additions & 6 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1195,21 +1195,26 @@ abstract class RDD[T: ClassTag](
/**
* Return whether this RDD has been checkpointed or not
*/
def isCheckpointed: Boolean = {
checkpointData.map(_.isCheckpointed).getOrElse(false)
}
def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)

/**
* Gets the name of the file to which this RDD was checkpointed
*/
def getCheckpointFile: Option[String] = {
checkpointData.flatMap(_.getCheckpointFile)
}
def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)

// =======================================================================
// Other internal methods and fields
// =======================================================================

/**
* Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that this is
* a lazy val so the broadcast is created only when tasks are scheduled on this RDD.
*/
@transient private[spark] lazy val broadcasted = {
val ser = SparkEnv.get.closureSerializer.newInstance()
sc.broadcast(ser.serialize(this).array())
}

private var storageLevel: StorageLevel = StorageLevel.NONE

/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
}
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
}
Expand All @@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}

private[spark] object RDDCheckpointData {
def clearTaskCaches() {
ShuffleMapTask.clearCache()
ResultTask.clearCache()
}
}
// Used for synchronization
private[spark] object RDDCheckpointData
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,6 @@ class DAGScheduler(
stageIdToStage -= stageId
stageIdToJobIds -= stageId

ShuffleMapTask.removeStage(stageId)
ResultTask.removeStage(stageId)

logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
Expand Down Expand Up @@ -723,7 +720,6 @@ class DAGScheduler(
}
}


/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
Expand Down
128 changes: 31 additions & 97 deletions core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,134 +17,68 @@

package org.apache.spark.scheduler

import scala.language.existentials
import java.nio.ByteBuffer

import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

import scala.collection.mutable.HashMap

import org.apache.spark._
import org.apache.spark.rdd.{RDD, RDDCheckpointData}

private[spark] object ResultTask {

// A simple map between the stage id to the serialized byte array of a task.
// Served as a cache for task serialization because serialization can be
// expensive on the master node if it needs to launch thousands of tasks.
private val serializedInfoCache = new HashMap[Int, Array[Byte]]

def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
{
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
old
} else {
val out = new ByteArrayOutputStream
val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(func)
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
bytes
}
}
}

def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
{
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
(rdd, func)
}

def removeStage(stageId: Int) {
serializedInfoCache.remove(stageId)
}

def clearCache() {
synchronized {
serializedInfoCache.clear()
}
}
}

import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD

/**
* A task that sends back the output to the driver application.
*
* See [[org.apache.spark.scheduler.Task]] for more information.
* See [[Task]] for more information.
*
* @param stageId id of the stage this task belongs to
* @param rdd input to func
* @param rddBinary broadcast version of of the serialized RDD
* @param func a function to apply on a partition of the RDD
* @param _partitionId index of the number in the RDD
* @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
*/
private[spark] class ResultTask[T, U](
stageId: Int,
var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U,
_partitionId: Int,
val rddBinary: Broadcast[Array[Byte]],
val func: (TaskContext, Iterator[T]) => U,
val partition: Partition,
@transient locs: Seq[TaskLocation],
var outputId: Int)
extends Task[U](stageId, _partitionId) with Externalizable {

def this() = this(0, null, null, 0, null, 0)

var split = if (rdd == null) null else rdd.partitions(partitionId)
val outputId: Int)
extends Task[U](stageId, partition.index) with Serializable {

// TODO: Should we also broadcast func? For that we would need a place to
// keep a reference to it (perhaps in DAGScheduler's job object).

def this(
stageId: Int,
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitionId: Int,
locs: Seq[TaskLocation],
outputId: Int) = {
this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}

override def runTask(context: TaskContext): U = {
// Deserialize the RDD using the broadcast variable.
val ser = SparkEnv.get.closureSerializer.newInstance()
val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value),
Thread.currentThread.getContextClassLoader)
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
func(context, rdd.iterator(partition, context))
} finally {
context.executeOnCompleteCallbacks()
}
}

// This is only callable on the driver side.
override def preferredLocations: Seq[TaskLocation] = preferredLocs

override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"

override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
out.writeInt(bytes.length)
out.write(bytes)
out.writeInt(partitionId)
out.writeInt(outputId)
out.writeLong(epoch)
out.writeObject(split)
}
}

override def readExternal(in: ObjectInput) {
val stageId = in.readInt()
val numBytes = in.readInt()
val bytes = new Array[Byte](numBytes)
in.readFully(bytes)
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
rdd = rdd_.asInstanceOf[RDD[T]]
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
partitionId = in.readInt()
outputId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]
}
}
Loading

0 comments on commit 04b17f0

Please sign in to comment.