Skip to content

Commit

Permalink
Serialize the final task closure as well as ShuffleDependency in task…
Browse files Browse the repository at this point in the history
…Binary.
  • Loading branch information
rxin committed Jul 30, 2014
1 parent 111007d commit 252238d
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 63 deletions.
19 changes: 0 additions & 19 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1218,25 +1218,6 @@ abstract class RDD[T: ClassTag](
// Other internal methods and fields
// =======================================================================

/**
* Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that we broadcast
* the serialized copy of the RDD and for each task we will deserialize it, which means each
* task gets a different copy of the RDD. This provides stronger isolation between tasks that
* might modify state of objects referenced in their closures. This is necessary in Hadoop
* where the JobConf/Configuration object is not thread-safe.
*/
@transient private[spark] def createBroadcastBinary(): Broadcast[Array[Byte]] = synchronized {
val ser = SparkEnv.get.closureSerializer.newInstance()
val bytes = ser.serialize(this).array()
val size = Utils.bytesToString(bytes.length)
if (bytes.length > (1L << 20)) {
logWarning(s"Broadcasting RDD $id ($size), which contains large objects")
} else {
logDebug(s"Broadcasting RDD $id ($size)")
}
sc.broadcast(bytes)
}

private var storageLevel: StorageLevel = StorageLevel.NONE

/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
Expand Down
25 changes: 21 additions & 4 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -710,9 +710,23 @@ class DAGScheduler(
// event.
listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))

var broadcastRddBinary: Broadcast[Array[Byte]] = null
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
// Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
// the serialized copy of the RDD and for each task we will deserialize it, which means each
// task gets a different copy of the RDD. This provides stronger isolation between tasks that
// might modify state of objects referenced in their closures. This is necessary in Hadoop
// where the JobConf/Configuration object is not thread-safe.
var taskBinary: Broadcast[Array[Byte]] = null
try {
broadcastRddBinary = stage.rdd.createBroadcastBinary()
// For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep).
// For ResultTask, serialize and broadcast (rdd, func).
val taskBinaryBytes: Array[Byte] =
if (stage.isShuffleMap) {
Utils.serializeTaskClosure((stage.rdd, stage.shuffleDep.get) : AnyRef)
} else {
Utils.serializeTaskClosure((stage.rdd, stage.resultOfJob.get.func) : AnyRef)
}
taskBinary = sc.broadcast(taskBinaryBytes)
} catch {
// In the case of a failure during serialization, abort the stage.
case e: NotSerializableException =>
Expand All @@ -729,7 +743,7 @@ class DAGScheduler(
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
val locs = getPreferredLocs(stage.rdd, p)
val part = stage.rdd.partitions(p)
tasks += new ShuffleMapTask(stage.id, broadcastRddBinary, stage.shuffleDep.get, part, locs)
tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs)
}
} else {
// This is a final stage; figure out its job's missing partitions
Expand All @@ -738,7 +752,7 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
tasks += new ResultTask(stage.id, broadcastRddBinary, job.func, part, locs, id)
tasks += new ResultTask(stage.id, taskBinary, part, locs, id)
}
}

Expand All @@ -747,6 +761,9 @@ class DAGScheduler(
// exception here because it would be fairly hard to catch the non-serializable exception
// down the road, where we have several different implementations for local scheduler and
// cluster schedulers.
//
// We've already serialized RDDs and closures in taskBinary, but here we check for all other
// objects such as Partition.
try {
SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
} catch {
Expand Down
19 changes: 8 additions & 11 deletions core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,31 @@ import org.apache.spark.rdd.RDD
* See [[Task]] for more information.
*
* @param stageId id of the stage this task belongs to
* @param rddBinary broadcast version of of the serialized RDD
* @param func a function to apply on a partition of the RDD
* @param taskBinary broadcasted version of the serialized RDD and the function to apply on each
* partition of the given RDD.
* @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
*/
private[spark] class ResultTask[T, U](
stageId: Int,
val rddBinary: Broadcast[Array[Byte]],
val func: (TaskContext, Iterator[T]) => U,
val partition: Partition,
taskBinary: Broadcast[Array[Byte]], // (RDD[T], (TaskContext, Iterator[T]) => U)
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
extends Task[U](stageId, partition.index) with Serializable {

// TODO: Should we also broadcast func? For that we would need a place to
// keep a reference to it (perhaps in DAGScheduler's job object).

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}

override def runTask(context: TaskContext): U = {
// Deserialize the RDD using the broadcast variable.
// Deserialize the RDD and the func using the broadcast variables.
val ser = SparkEnv.get.closureSerializer.newInstance()
val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value),
Thread.currentThread.getContextClassLoader)
val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(partition, context))
Expand Down
27 changes: 12 additions & 15 deletions core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,34 @@ package org.apache.spark.scheduler

import java.nio.ByteBuffer

import scala.language.existentials

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.shuffle.ShuffleWriter

/**
* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
* specified in the ShuffleDependency).
*
* See [[org.apache.spark.scheduler.Task]] for more information.
*
* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
* specified in the ShuffleDependency).
*
* See [[org.apache.spark.scheduler.Task]] for more information.
*
* @param stageId id of the stage this task belongs to
* @param rddBinary broadcast version of of the serialized RDD
* @param dep the ShuffleDependency
* @param taskBinary broadcast version of of the RDD and the ShuffleDependency
* @param partition partition of the RDD this task is associated with
* @param locs preferred task execution locations for locality scheduling
*/
private[spark] class ShuffleMapTask(
stageId: Int,
var rddBinary: Broadcast[Array[Byte]],
var dep: ShuffleDependency[_, _, _],
taskBinary: Broadcast[Array[Byte]], // (RDD[_], ShuffleDependency[_, _, _])
partition: Partition,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, partition.index) with Logging {

// TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to
// keep a reference to it (perhaps in Stage).

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, null, null, new Partition { override def index = 0 }, null)
this(0, null, new Partition { override def index = 0 }, null)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
Expand All @@ -59,8 +56,8 @@ private[spark] class ShuffleMapTask(
override def runTask(context: TaskContext): MapStatus = {
// Deserialize the RDD using the broadcast variable.
val ser = SparkEnv.get.closureSerializer.newInstance()
val rdd = ser.deserialize[RDD[_]](ByteBuffer.wrap(rddBinary.value),
Thread.currentThread.getContextClassLoader)
val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)

metrics = Some(context.taskMetrics)
var writer: ShuffleWriter[Any, Any] = null
Expand Down
8 changes: 7 additions & 1 deletion core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.json4s._
import tachyon.client.{TachyonFile,TachyonFS}

import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.executor.ExecutorUncaughtExceptionHandler
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
Expand All @@ -57,6 +57,12 @@ private[spark] object Utils extends Logging {
new File(sparkHome + File.separator + "bin", which + suffix)
}

/** Serialize an object using the closure serializer. **/
def serializeTaskClosure[T: ClassTag](o: T): Array[Byte] = {
val ser = SparkEnv.get.closureSerializer.newInstance()
ser.serialize(o).array()
}

/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
Expand Down
12 changes: 0 additions & 12 deletions core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,6 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
postGCTester.assertCleanup()
}

test("automatically cleanup broadcast data for task dispatching") {
var rdd = newRDDWithShuffleDependencies()._1
rdd.count() // This triggers an action that broadcasts the RDDs.

// Test that GC causes broadcast task data cleanup after dereferencing the RDD.
val postGCTester = new CleanerTester(sc,
broadcastIds = Seq(rdd.createBroadcastBinary.id, rdd.firstParent.createBroadcastBinary.id))
rdd = null
runGC()
postGCTester.assertCleanup()
}

test("automatically cleanup RDD + shuffle + broadcast") {
val numRdds = 100
val numBroadcasts = 4 // Broadcasts are more costly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.spark.Partition
import org.apache.spark.SparkContext
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils

class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {

Expand All @@ -40,7 +41,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte
}
val func = (c: TaskContext, i: Iterator[String]) => i.next()
val task = new ResultTask[String, String](
0, rdd.createBroadcastBinary(), func, rdd.partitions(0), Seq(), 0)
0, sc.broadcast(Utils.serializeTaskClosure((rdd, func))), rdd.partitions(0), Seq(), 0)
intercept[RuntimeException] {
task.run(0)
}
Expand Down

0 comments on commit 252238d

Please sign in to comment.