From cae0af33b535a7772fd2861851dca056e0c2186c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 18 Jul 2014 23:52:47 -0700 Subject: [PATCH 01/12] [SPARK-2521] Broadcast RDD object (instead of sending it along with every task). Currently (as of Spark 1.0.1), Spark sends RDD object (which contains closures) using Akka along with the task itself to the executors. This is inefficient because all tasks in the same stage use the same RDD object, but we have to send RDD object multiple times to the executors. This is especially bad when a closure references some variable that is very large. The current design led to users having to explicitly broadcast large variables. The patch uses broadcast to send RDD objects and the closures to executors, and use Akka to only send a reference to the broadcast RDD/closure along with the partition specific information for the task. For those of you who know more about the internals, Spark already relies on broadcast to send the Hadoop JobConf every time it uses the Hadoop input, because the JobConf is large. The user-facing impact of the change include: 1. Users won't need to decide what to broadcast anymore, unless they would want to use a large object multiple times in different operations 2. Task size will get smaller, resulting in faster scheduling and higher task dispatch throughput. In addition, the change will simplify some internals of Spark, eliminating the need to maintain task caches and the complex logic to broadcast JobConf (which also led to a deadlock recently). A simple way to test this: ```scala val a = new Array[Byte](1000*1000); scala.util.Random.nextBytes(a); sc.parallelize(1 to 1000, 1000).map { x => a; x }.groupBy { x => a; x }.count ``` Numbers on 3 r3.8xlarge instances on EC2 ``` master branch: 5.648436068 s, 4.715361895 s, 5.360161877 s with this change: 3.416348793 s, 1.477846558 s, 1.553432156 s ``` Author: Reynold Xin Closes #1452 from rxin/broadcast-task and squashes the following commits: 762e0be [Reynold Xin] Warn large broadcasts. ade6eac [Reynold Xin] Log broadcast size. c3b6f11 [Reynold Xin] Added a unit test for clean up. 754085f [Reynold Xin] Explain why broadcasting serialized copy of the task. 04b17f0 [Reynold Xin] [SPARK-2521] Broadcast RDD object once per TaskSet (instead of sending it for every task). (cherry picked from commit 7b8cd175254d42c8e82f0aa8eb4b7f3508d8fde2) Signed-off-by: Reynold Xin --- .../scala/org/apache/spark/Dependency.scala | 28 ++-- .../scala/org/apache/spark/SparkContext.scala | 2 - .../main/scala/org/apache/spark/rdd/RDD.scala | 30 +++- .../apache/spark/rdd/RDDCheckpointData.scala | 9 +- .../apache/spark/scheduler/DAGScheduler.scala | 4 - .../apache/spark/scheduler/ResultTask.scala | 128 +++++------------- .../spark/scheduler/ShuffleMapTask.scala | 125 ++++------------- .../apache/spark/ContextCleanerSuite.scala | 62 +++++---- 8 files changed, 137 insertions(+), 251 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 09a60571238ea..3935c8772252e 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle * Base class for dependencies. */ @DeveloperApi -abstract class Dependency[T](val rdd: RDD[T]) extends Serializable +abstract class Dependency[T] extends Serializable { + def rdd: RDD[T] +} /** @@ -36,20 +38,24 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable * partition of the child RDD. Narrow dependencies allow for pipelined execution. */ @DeveloperApi -abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { +abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { /** * Get the parent partitions for a child partition. * @param partitionId a partition of the child RDD * @return the partitions of the parent RDD that the child partition depends upon */ def getParents(partitionId: Int): Seq[Int] + + override def rdd: RDD[T] = _rdd } /** * :: DeveloperApi :: - * Represents a dependency on the output of a shuffle stage. - * @param rdd the parent RDD + * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle, + * the RDD is transient since we don't need it on the executor side. + * + * @param _rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will @@ -57,20 +63,22 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { */ @DeveloperApi class ShuffleDependency[K, V, C]( - @transient rdd: RDD[_ <: Product2[K, V]], + @transient _rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false) - extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { + extends Dependency[Product2[K, V]] { + + override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] - val shuffleId: Int = rdd.context.newShuffleId() + val shuffleId: Int = _rdd.context.newShuffleId() - val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( - shuffleId, rdd.partitions.size, this) + val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( + shuffleId, _rdd.partitions.size, this) - rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3e6addeaf04a8..fb4c86716bb8d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging { // TODO: Cache.stop()? env.stop() SparkEnv.set(null) - ShuffleMapTask.clearCache() - ResultTask.clearCache() listenerBus.stop() eventLogger.foreach(_.stop()) logInfo("Successfully stopped SparkContext") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a6abc49c5359e..995920e7f8008 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1206,21 +1207,36 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed: Boolean = { - checkpointData.map(_.isCheckpointed).getOrElse(false) - } + def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) /** * Gets the name of the file to which this RDD was checkpointed */ - def getCheckpointFile: Option[String] = { - checkpointData.flatMap(_.getCheckpointFile) - } + def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile) // ======================================================================= // Other internal methods and fields // ======================================================================= + /** + * Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that we broadcast + * the serialized copy of the RDD and for each task we will deserialize it, which means each + * task gets a different copy of the RDD. This provides stronger isolation between tasks that + * might modify state of objects referenced in their closures. This is necessary in Hadoop + * where the JobConf/Configuration object is not thread-safe. + */ + @transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = { + val ser = SparkEnv.get.closureSerializer.newInstance() + val bytes = ser.serialize(this).array() + val size = Utils.bytesToString(bytes.length) + if (bytes.length > (1L << 20)) { + logWarning(s"Broadcasting RDD $id ($size), which contains large objects") + } else { + logDebug(s"Broadcasting RDD $id ($size)") + } + sc.broadcast(bytes) + } + private var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index c3b2a33fb54d0..f67e5f1857979 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) cpRDD = Some(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed - RDDCheckpointData.clearTaskCaches() } logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } @@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } -private[spark] object RDDCheckpointData { - def clearTaskCaches() { - ShuffleMapTask.clearCache() - ResultTask.clearCache() - } -} +// Used for synchronization +private[spark] object RDDCheckpointData diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dc6142ab79d03..492553b65d256 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -361,9 +361,6 @@ class DAGScheduler( // data structures based on StageId stageIdToStage -= stageId - ShuffleMapTask.removeStage(stageId) - ResultTask.removeStage(stageId) - logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } @@ -691,7 +688,6 @@ class DAGScheduler( } } - /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index bbf9f7388b074..62beb0d02a9c3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,134 +17,68 @@ package org.apache.spark.scheduler -import scala.language.existentials +import java.nio.ByteBuffer import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} - -private[spark] object ResultTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = - { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(func) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = - { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] - (rdd, func) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD /** * A task that sends back the output to the driver application. * - * See [[org.apache.spark.scheduler.Task]] for more information. + * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param rdd input to func + * @param rddBinary broadcast version of of the serialized RDD * @param func a function to apply on a partition of the RDD - * @param _partitionId index of the number in the RDD + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). */ private[spark] class ResultTask[T, U]( stageId: Int, - var rdd: RDD[T], - var func: (TaskContext, Iterator[T]) => U, - _partitionId: Int, + val rddBinary: Broadcast[Array[Byte]], + val func: (TaskContext, Iterator[T]) => U, + val partition: Partition, @transient locs: Seq[TaskLocation], - var outputId: Int) - extends Task[U](stageId, _partitionId) with Externalizable { - - def this() = this(0, null, null, 0, null, 0) - - var split = if (rdd == null) null else rdd.partitions(partitionId) + val outputId: Int) + extends Task[U](stageId, partition.index) with Serializable { + + // TODO: Should we also broadcast func? For that we would need a place to + // keep a reference to it (perhaps in DAGScheduler's job object). + + def this( + stageId: Int, + rdd: RDD[T], + func: (TaskContext, Iterator[T]) => U, + partitionId: Int, + locs: Seq[TaskLocation], + outputId: Int) = { + this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId) + } - @transient private val preferredLocs: Seq[TaskLocation] = { + @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } override def runTask(context: TaskContext): U = { + // Deserialize the RDD using the broadcast variable. + val ser = SparkEnv.get.closureSerializer.newInstance() + val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value), + Thread.currentThread.getContextClassLoader) metrics = Some(context.taskMetrics) try { - func(context, rdd.iterator(split, context)) + func(context, rdd.iterator(partition, context)) } finally { context.executeOnCompleteCallbacks() } } + // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ResultTask.serializeInfo( - stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeInt(outputId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) - rdd = rdd_.asInstanceOf[RDD[T]] - func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] - partitionId = in.readInt() - outputId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index fdaf1de83f051..033c6e52861e0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,71 +17,13 @@ package org.apache.spark.scheduler -import scala.language.existentials - -import java.io._ -import java.util.zip.{GZIPInputStream, GZIPOutputStream} - -import scala.collection.mutable.HashMap +import java.nio.ByteBuffer import org.apache.spark._ -import org.apache.spark.rdd.{RDD, RDDCheckpointData} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter -private[spark] object ShuffleMapTask { - - // A simple map between the stage id to the serialized byte array of a task. - // Served as a cache for task serialization because serialization can be - // expensive on the master node if it needs to launch thousands of tasks. - private val serializedInfoCache = new HashMap[Int, Array[Byte]] - - def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { - synchronized { - val old = serializedInfoCache.get(stageId).orNull - if (old != null) { - return old - } else { - val out = new ByteArrayOutputStream - val ser = SparkEnv.get.closureSerializer.newInstance() - val objOut = ser.serializeStream(new GZIPOutputStream(out)) - objOut.writeObject(rdd) - objOut.writeObject(dep) - objOut.close() - val bytes = out.toByteArray - serializedInfoCache.put(stageId, bytes) - bytes - } - } - } - - def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val ser = SparkEnv.get.closureSerializer.newInstance() - val objIn = ser.deserializeStream(in) - val rdd = objIn.readObject().asInstanceOf[RDD[_]] - val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] - (rdd, dep) - } - - // Since both the JarSet and FileSet have the same format this is used for both. - def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = { - val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) - val objIn = new ObjectInputStream(in) - val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap - HashMap(set.toSeq: _*) - } - - def removeStage(stageId: Int) { - serializedInfoCache.remove(stageId) - } - - def clearCache() { - synchronized { - serializedInfoCache.clear() - } - } -} - /** * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner * specified in the ShuffleDependency). @@ -89,62 +31,47 @@ private[spark] object ShuffleMapTask { * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param rdd the final RDD in this stage + * @param rddBinary broadcast version of of the serialized RDD * @param dep the ShuffleDependency - * @param _partitionId index of the number in the RDD + * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, - var rdd: RDD[_], + var rddBinary: Broadcast[Array[Byte]], var dep: ShuffleDependency[_, _, _], - _partitionId: Int, + partition: Partition, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, _partitionId) - with Externalizable - with Logging { - - protected def this() = this(0, null, null, 0, null) + extends Task[MapStatus](stageId, partition.index) with Logging { + + // TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to + // keep a reference to it (perhaps in Stage). + + def this( + stageId: Int, + rdd: RDD[_], + dep: ShuffleDependency[_, _, _], + partitionId: Int, + locs: Seq[TaskLocation]) = { + this(stageId, rdd.broadcasted, dep, rdd.partitions(partitionId), locs) + } @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - var split = if (rdd == null) null else rdd.partitions(partitionId) - - override def writeExternal(out: ObjectOutput) { - RDDCheckpointData.synchronized { - split = rdd.partitions(partitionId) - out.writeInt(stageId) - val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) - out.writeInt(bytes.length) - out.write(bytes) - out.writeInt(partitionId) - out.writeLong(epoch) - out.writeObject(split) - } - } - - override def readExternal(in: ObjectInput) { - val stageId = in.readInt() - val numBytes = in.readInt() - val bytes = new Array[Byte](numBytes) - in.readFully(bytes) - val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) - rdd = rdd_ - dep = dep_ - partitionId = in.readInt() - epoch = in.readLong() - split = in.readObject().asInstanceOf[Partition] - } - override def runTask(context: TaskContext): MapStatus = { + // Deserialize the RDD using the broadcast variable. + val ser = SparkEnv.get.closureSerializer.newInstance() + val rdd = ser.deserialize[RDD[_]](ByteBuffer.wrap(rddBinary.value), + Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) - writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) + writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) return writer.stop(success = true).get } catch { case e: Exception => diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 13b415cccb647..871f831531bee 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -52,9 +52,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } } - test("cleanup RDD") { - val rdd = newRDD.persist() + val rdd = newRDD().persist() val collected = rdd.collect().toList val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) @@ -67,7 +66,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup shuffle") { - val (rdd, shuffleDeps) = newRDDWithShuffleDependencies + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies() val collected = rdd.collect().toList val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) @@ -80,7 +79,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup broadcast") { - val broadcast = newBroadcast + val broadcast = newBroadcast() val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) // Explicit cleanup @@ -89,7 +88,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup RDD") { - var rdd = newRDD.persist() + var rdd = newRDD().persist() rdd.count() // Test that GC does not cause RDD cleanup due to a strong reference @@ -107,7 +106,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup shuffle") { - var rdd = newShuffleRDD + var rdd = newShuffleRDD() rdd.count() // Test that GC does not cause shuffle cleanup due to a strong reference @@ -125,7 +124,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup broadcast") { - var broadcast = newBroadcast + var broadcast = newBroadcast() // Test that GC does not cause broadcast cleanup due to a strong reference val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) @@ -141,11 +140,23 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo postGCTester.assertCleanup() } + test("automatically cleanup broadcast data for task dispatching") { + var rdd = newRDDWithShuffleDependencies()._1 + rdd.count() // This triggers an action that broadcasts the RDDs. + + // Test that GC causes broadcast task data cleanup after dereferencing the RDD. + val postGCTester = new CleanerTester(sc, + broadcastIds = Seq(rdd.broadcasted.id, rdd.firstParent.broadcasted.id)) + rdd = null + runGC() + postGCTester.assertCleanup() + } + test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId val broadcastIds = 0L until numBroadcasts @@ -175,8 +186,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val numRdds = 10 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId val broadcastIds = 0L until numBroadcasts @@ -197,17 +208,18 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo //------ Helper functions ------ - def newRDD = sc.makeRDD(1 to 10) - def newPairRDD = newRDD.map(_ -> 1) - def newShuffleRDD = newPairRDD.reduceByKey(_ + _) - def newBroadcast = sc.broadcast(1 to 100) - def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { + private def newRDD() = sc.makeRDD(1 to 10) + private def newPairRDD() = newRDD().map(_ -> 1) + private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _) + private def newBroadcast() = sc.broadcast(1 to 100) + + private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => getAllDependencies(dep.rdd) } } - val rdd = newShuffleRDD + val rdd = newShuffleRDD() // Get all the shuffle dependencies val shuffleDeps = getAllDependencies(rdd) @@ -216,34 +228,34 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo (rdd, shuffleDeps) } - def randomRdd = { + private def randomRdd() = { val rdd: RDD[_] = Random.nextInt(3) match { - case 0 => newRDD - case 1 => newShuffleRDD - case 2 => newPairRDD.join(newPairRDD) + case 0 => newRDD() + case 1 => newShuffleRDD() + case 2 => newPairRDD.join(newPairRDD()) } if (Random.nextBoolean()) rdd.persist() rdd.count() rdd } - def randomBroadcast = { + private def randomBroadcast() = { sc.broadcast(Random.nextInt(Int.MaxValue)) } /** Run GC and make sure it actually has run */ - def runGC() { + private def runGC() { val weakRef = new WeakReference(new Object()) val startTime = System.currentTimeMillis System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. // Wait until a weak reference object has been GCed - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { System.gc() Thread.sleep(200) } } - def cleaner = sc.cleaner.get + private def cleaner = sc.cleaner.get } From d256b456b8450706ecacd98033c3f4d40b37814c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 20 Jul 2014 00:00:12 -0700 Subject: [PATCH 02/12] Fixed unit test failures. One more to go. --- .../org/apache/spark/scheduler/ShuffleMapTask.scala | 5 +++++ .../test/scala/org/apache/spark/rdd/RDDSuite.scala | 8 +------- .../spark/ui/jobs/JobProgressListenerSuite.scala | 11 ++++++----- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 033c6e52861e0..c7f7be4ee26d5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -56,6 +56,11 @@ private[spark] class ShuffleMapTask( this(stageId, rdd.broadcasted, dep, rdd.partitions(partitionId), locs) } + /** A constructor used only in test suites. This does not require passing in an RDD. */ + def this(partitionId: Int) { + this(0, null, null, new Partition { override def index = 0 }, null) + } + @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6654ec2d7c656..c52080378c970 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -155,19 +155,13 @@ class RDDSuite extends FunSuite with SharedSparkContext { override def getPartitions: Array[Partition] = Array(onlySplit) override val getDependencies = List[Dependency[_]]() override def compute(split: Partition, context: TaskContext): Iterator[Int] = { - if (shouldFail) { - throw new Exception("injected failure") - } else { - Array(1, 2, 3, 4).iterator - } + throw new Exception("injected failure") } }.cache() val thrown = intercept[Exception]{ rdd.collect() } assert(thrown.getMessage.contains("injected failure")) - shouldFail = false - assert(rdd.collect().toList === List(1, 2, 3, 4)) } test("empty RDD") { diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index b52f81877d557..86a271eb67000 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.util.Utils class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers { + test("test LRU eviction of stages") { val conf = new SparkConf() conf.set("spark.ui.retainedStages", 5.toString) @@ -66,7 +67,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics) var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - var task = new ShuffleMapTask(0, null, null, 0, null) + var task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) @@ -76,14 +77,14 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc taskInfo = new TaskInfo(1234L, 0, 1, 1000L, "exe-unknown", "host1", TaskLocality.NODE_LOCAL, true) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.size === 1) // finish this task, should get updated duration taskInfo = new TaskInfo(1235L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-1", fail()) .shuffleRead === 2000) @@ -91,7 +92,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc // finish this task, should get updated duration taskInfo = new TaskInfo(1236L, 0, 2, 0L, "exe-2", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - task = new ShuffleMapTask(0, null, null, 0, null) + task = new ShuffleMapTask(0) listener.onTaskEnd(SparkListenerTaskEnd(task.stageId, taskType, Success, taskInfo, taskMetrics)) assert(listener.stageIdToData.getOrElse(0, fail()).executorSummary.getOrElse("exe-2", fail()) .shuffleRead === 1000) @@ -103,7 +104,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc val metrics = new TaskMetrics() val taskInfo = new TaskInfo(1234L, 0, 3, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false) taskInfo.finishTime = 1 - val task = new ShuffleMapTask(0, null, null, 0, null) + val task = new ShuffleMapTask(0) val taskType = Utils.getFormattedClassName(task) // Go through all the failure cases to make sure we are counting them as failures. From cc152fcd14bb13104f391da0fb703a1d2203e3a6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 20 Jul 2014 18:48:18 -0700 Subject: [PATCH 03/12] Don't cache the RDD broadcast variable. --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../org/apache/spark/scheduler/DAGScheduler.scala | 11 +++++++---- .../scala/org/apache/spark/scheduler/ResultTask.scala | 10 ---------- .../org/apache/spark/scheduler/ShuffleMapTask.scala | 9 --------- .../scala/org/apache/spark/ContextCleanerSuite.scala | 2 +- .../org/apache/spark/scheduler/TaskContextSuite.scala | 7 ++++--- 6 files changed, 13 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 995920e7f8008..cc7f2316278eb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1225,7 +1225,7 @@ abstract class RDD[T: ClassTag]( * might modify state of objects referenced in their closures. This is necessary in Hadoop * where the JobConf/Configuration object is not thread-safe. */ - @transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = { + @transient private[spark] def createBroadcastBinary(): Broadcast[Array[Byte]] = synchronized { val ser = SparkEnv.get.closureSerializer.newInstance() val bytes = ser.serialize(this).array() val size = Utils.bytesToString(bytes.length) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 492553b65d256..2b2bc0553a7a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -694,18 +694,21 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() var tasks = ArrayBuffer[Task[_]]() + val broadcastRddBinary = stage.rdd.createBroadcastBinary() if (stage.isShuffleMap) { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { val locs = getPreferredLocs(stage.rdd, p) - tasks += new ShuffleMapTask(stage.id, stage.rdd, stage.shuffleDep.get, p, locs) + val part = stage.rdd.partitions(p) + tasks += new ShuffleMapTask(stage.id, broadcastRddBinary, stage.shuffleDep.get, part, locs) } } else { // This is a final stage; figure out its job's missing partitions val job = stage.resultOfJob.get for (id <- 0 until job.numPartitions if !job.finished(id)) { - val partition = job.partitions(id) - val locs = getPreferredLocs(stage.rdd, partition) - tasks += new ResultTask(stage.id, stage.rdd, job.func, partition, locs, id) + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + tasks += new ResultTask(stage.id, broadcastRddBinary, job.func, part, locs, id) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 62beb0d02a9c3..03c4f407df1e2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -50,16 +50,6 @@ private[spark] class ResultTask[T, U]( // TODO: Should we also broadcast func? For that we would need a place to // keep a reference to it (perhaps in DAGScheduler's job object). - def this( - stageId: Int, - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitionId: Int, - locs: Seq[TaskLocation], - outputId: Int) = { - this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId) - } - @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index c7f7be4ee26d5..fd04b035ccb22 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -47,15 +47,6 @@ private[spark] class ShuffleMapTask( // TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to // keep a reference to it (perhaps in Stage). - def this( - stageId: Int, - rdd: RDD[_], - dep: ShuffleDependency[_, _, _], - partitionId: Int, - locs: Seq[TaskLocation]) = { - this(stageId, rdd.broadcasted, dep, rdd.partitions(partitionId), locs) - } - /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { this(0, null, null, new Partition { override def index = 0 }, null) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 871f831531bee..3278f28a04c85 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -146,7 +146,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Test that GC causes broadcast task data cleanup after dereferencing the RDD. val postGCTester = new CleanerTester(sc, - broadcastIds = Seq(rdd.broadcasted.id, rdd.firstParent.broadcasted.id)) + broadcastIds = Seq(rdd.createBroadcastBinary.id, rdd.firstParent.createBroadcastBinary.id)) rdd = null runGC() postGCTester.assertCleanup() diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 8bb5317cd2875..72b4fe443cf43 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -38,13 +38,14 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte sys.error("failed") } } - val func = (c: TaskContext, i: Iterator[String]) => i.next - val task = new ResultTask[String, String](0, rdd, func, 0, Seq(), 0) + val func = (c: TaskContext, i: Iterator[String]) => i.next() + val task = new ResultTask[String, String]( + 0, rdd.createBroadcastBinary(), func, rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { task.run(0) } assert(completed === true) } - case class StubPartition(val index: Int) extends Partition + case class StubPartition(index: Int) extends Partition } From de779f8704a7f586190dc0e25642836e06136cbb Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 21 Jul 2014 00:21:13 -0700 Subject: [PATCH 04/12] Fix TaskContextSuite. --- .../apache/spark/scheduler/TaskContextSuite.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 72b4fe443cf43..98931bd24d7ea 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -29,12 +29,12 @@ import org.apache.spark.rdd.RDD class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { test("Calls executeOnCompleteCallbacks after failure") { - var completed = false + TaskContextSuite.completed = false sc = new SparkContext("local", "test") val rdd = new RDD[String](sc, List()) { override def getPartitions = Array[Partition](StubPartition(0)) override def compute(split: Partition, context: TaskContext) = { - context.addOnCompleteCallback(() => completed = true) + context.addOnCompleteCallback(() => TaskContextSuite.completed = true) sys.error("failed") } } @@ -44,8 +44,12 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte intercept[RuntimeException] { task.run(0) } - assert(completed === true) + assert(TaskContextSuite.completed === true) } +} - case class StubPartition(index: Int) extends Partition +private object TaskContextSuite { + @volatile var completed = false } + +private case class StubPartition(index: Int) extends Partition From 991c002fc4238308108c07fb40b3400a3d448e2f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 22 Jul 2014 22:41:53 -0700 Subject: [PATCH 05/12] Use HttpBroadcast. --- .../scala/org/apache/spark/broadcast/BroadcastManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index 8f8a0b11f9f2e..c88be6aba6901 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -39,7 +39,7 @@ private[spark] class BroadcastManager( synchronized { if (!initialized) { val broadcastFactoryClass = - conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") broadcastFactory = Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] From cf384501c1874284e8412439466eb6e22d5fe6d6 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Fri, 25 Jul 2014 00:10:18 -0700 Subject: [PATCH 06/12] Use TorrentBroadcastFactory. --- .../scala/org/apache/spark/broadcast/BroadcastManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index c88be6aba6901..8f8a0b11f9f2e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -39,7 +39,7 @@ private[spark] class BroadcastManager( synchronized { if (!initialized) { val broadcastFactoryClass = - conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") broadcastFactory = Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] From bab1d8b601d88b946e3770c611c33d2040472492 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 28 Jul 2014 15:38:57 -0700 Subject: [PATCH 07/12] Check for NotSerializableException in submitMissingTasks. --- .../apache/spark/scheduler/DAGScheduler.scala | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2b2bc0553a7a2..4ee3d9c7218eb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{NotSerializableException, PrintWriter, StringWriter} +import java.io.{NotSerializableException} import java.util.Properties import java.util.concurrent.atomic.AtomicInteger @@ -35,6 +35,7 @@ import akka.pattern.ask import akka.util.Timeout import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.TaskMetrics import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.RDD @@ -694,7 +695,21 @@ class DAGScheduler( // Get our pending tasks and remember them in our pendingTasks entry stage.pendingTasks.clear() var tasks = ArrayBuffer[Task[_]]() - val broadcastRddBinary = stage.rdd.createBroadcastBinary() + + var broadcastRddBinary: Broadcast[Array[Byte]] = null + try { + broadcastRddBinary = stage.rdd.createBroadcastBinary() + } catch { + case e: NotSerializableException => + abortStage(stage, "Task not serializable: " + e.toString) + runningStages -= stage + return + case NonFatal(e) => + abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return + } + if (stage.isShuffleMap) { for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { val locs = getPreferredLocs(stage.rdd, p) From 797c247ba12dd3eeaa5dda621b9db5a8419732f0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 28 Jul 2014 18:20:13 -0700 Subject: [PATCH 08/12] Properly send SparkListenerStageSubmitted and SparkListenerStageCompleted. --- .../apache/spark/scheduler/DAGScheduler.scala | 34 +++++++++++-------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4ee3d9c7218eb..c8e5f139cf9f2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{NotSerializableException} +import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger @@ -696,10 +696,25 @@ class DAGScheduler( stage.pendingTasks.clear() var tasks = ArrayBuffer[Task[_]]() + val properties = if (jobIdToActiveJob.contains(jobId)) { + jobIdToActiveJob(stage.jobId).properties + } else { + // this stage will be assigned to "default" pool + null + } + + runningStages += stage + // SparkListenerStageSubmitted should be posted before testing whether tasks are + // serializable. If tasks are not serializable, a SparkListenerStageCompleted event + // will be posted, which should always come after a corresponding SparkListenerStageSubmitted + // event. + listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) + var broadcastRddBinary: Broadcast[Array[Byte]] = null try { broadcastRddBinary = stage.rdd.createBroadcastBinary() } catch { + // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) runningStages -= stage @@ -727,21 +742,7 @@ class DAGScheduler( } } - val properties = if (jobIdToActiveJob.contains(jobId)) { - jobIdToActiveJob(stage.jobId).properties - } else { - // this stage will be assigned to "default" pool - null - } - if (tasks.size > 0) { - runningStages += stage - // SparkListenerStageSubmitted should be posted before testing whether tasks are - // serializable. If tasks are not serializable, a SparkListenerStageCompleted event - // will be posted, which should always come after a corresponding SparkListenerStageSubmitted - // event. - listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) - // Preemptively serialize a task to make sure it can be serialized. We are catching this // exception here because it would be fairly hard to catch the non-serializable exception // down the road, where we have several different implementations for local scheduler and @@ -766,6 +767,9 @@ class DAGScheduler( new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties)) stage.info.submissionTime = Some(clock.getTime()) } else { + // Because we posted SparkListenerStageSubmitted earlier, we should post + // SparkListenerStageCompleted here in case there are no tasks to run. + listenerBus.post(SparkListenerStageCompleted(stage.info)) logDebug("Stage " + stage + " is actually done; %b %d %d".format( stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions)) runningStages -= stage From 111007d719e9c23e5baf4fc3dc374d01115c0e1f Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 28 Jul 2014 18:29:33 -0700 Subject: [PATCH 09/12] Fix broadcast tests. --- .../test/scala/org/apache/spark/ContextCleanerSuite.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 3278f28a04c85..8678fc32ad7d3 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -19,6 +19,9 @@ package org.apache.spark import java.lang.ref.WeakReference +import org.apache.spark.broadcast.Broadcast + +import scala.collection.mutable import scala.collection.mutable.{HashSet, SynchronizedSet} import scala.language.existentials import scala.language.postfixOps @@ -159,7 +162,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId - val broadcastIds = 0L until numBroadcasts + val broadcastIds = broadcastBuffer.map(_.id) val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() @@ -190,7 +193,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId - val broadcastIds = 0L until numBroadcasts + val broadcastIds = broadcastBuffer.map(_.id) val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() From 252238da16fe3a3dfd4a067ca4a9ac47d4fac025 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 28 Jul 2014 21:31:20 -0700 Subject: [PATCH 10/12] Serialize the final task closure as well as ShuffleDependency in taskBinary. --- .../main/scala/org/apache/spark/rdd/RDD.scala | 19 ------------- .../apache/spark/scheduler/DAGScheduler.scala | 25 ++++++++++++++--- .../apache/spark/scheduler/ResultTask.scala | 19 ++++++------- .../spark/scheduler/ShuffleMapTask.scala | 27 +++++++++---------- .../scala/org/apache/spark/util/Utils.scala | 8 +++++- .../apache/spark/ContextCleanerSuite.scala | 12 --------- .../spark/scheduler/TaskContextSuite.scala | 3 ++- 7 files changed, 50 insertions(+), 63 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index cc7f2316278eb..726b3f2bbeea7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1218,25 +1218,6 @@ abstract class RDD[T: ClassTag]( // Other internal methods and fields // ======================================================================= - /** - * Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that we broadcast - * the serialized copy of the RDD and for each task we will deserialize it, which means each - * task gets a different copy of the RDD. This provides stronger isolation between tasks that - * might modify state of objects referenced in their closures. This is necessary in Hadoop - * where the JobConf/Configuration object is not thread-safe. - */ - @transient private[spark] def createBroadcastBinary(): Broadcast[Array[Byte]] = synchronized { - val ser = SparkEnv.get.closureSerializer.newInstance() - val bytes = ser.serialize(this).array() - val size = Utils.bytesToString(bytes.length) - if (bytes.length > (1L << 20)) { - logWarning(s"Broadcasting RDD $id ($size), which contains large objects") - } else { - logDebug(s"Broadcasting RDD $id ($size)") - } - sc.broadcast(bytes) - } - private var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c8e5f139cf9f2..8d3aa23cb0bcf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -710,9 +710,23 @@ class DAGScheduler( // event. listenerBus.post(SparkListenerStageSubmitted(stage.info, properties)) - var broadcastRddBinary: Broadcast[Array[Byte]] = null + // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times. + // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast + // the serialized copy of the RDD and for each task we will deserialize it, which means each + // task gets a different copy of the RDD. This provides stronger isolation between tasks that + // might modify state of objects referenced in their closures. This is necessary in Hadoop + // where the JobConf/Configuration object is not thread-safe. + var taskBinary: Broadcast[Array[Byte]] = null try { - broadcastRddBinary = stage.rdd.createBroadcastBinary() + // For ShuffleMapTask, serialize and broadcast (rdd, shuffleDep). + // For ResultTask, serialize and broadcast (rdd, func). + val taskBinaryBytes: Array[Byte] = + if (stage.isShuffleMap) { + Utils.serializeTaskClosure((stage.rdd, stage.shuffleDep.get) : AnyRef) + } else { + Utils.serializeTaskClosure((stage.rdd, stage.resultOfJob.get.func) : AnyRef) + } + taskBinary = sc.broadcast(taskBinaryBytes) } catch { // In the case of a failure during serialization, abort the stage. case e: NotSerializableException => @@ -729,7 +743,7 @@ class DAGScheduler( for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) { val locs = getPreferredLocs(stage.rdd, p) val part = stage.rdd.partitions(p) - tasks += new ShuffleMapTask(stage.id, broadcastRddBinary, stage.shuffleDep.get, part, locs) + tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs) } } else { // This is a final stage; figure out its job's missing partitions @@ -738,7 +752,7 @@ class DAGScheduler( val p: Int = job.partitions(id) val part = stage.rdd.partitions(p) val locs = getPreferredLocs(stage.rdd, p) - tasks += new ResultTask(stage.id, broadcastRddBinary, job.func, part, locs, id) + tasks += new ResultTask(stage.id, taskBinary, part, locs, id) } } @@ -747,6 +761,9 @@ class DAGScheduler( // exception here because it would be fairly hard to catch the non-serializable exception // down the road, where we have several different implementations for local scheduler and // cluster schedulers. + // + // We've already serialized RDDs and closures in taskBinary, but here we check for all other + // objects such as Partition. try { SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 03c4f407df1e2..e133dde392104 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -31,8 +31,8 @@ import org.apache.spark.rdd.RDD * See [[Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param rddBinary broadcast version of of the serialized RDD - * @param func a function to apply on a partition of the RDD + * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each + * partition of the given RDD. * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the @@ -40,25 +40,22 @@ import org.apache.spark.rdd.RDD */ private[spark] class ResultTask[T, U]( stageId: Int, - val rddBinary: Broadcast[Array[Byte]], - val func: (TaskContext, Iterator[T]) => U, - val partition: Partition, + taskBinary: Broadcast[Array[Byte]], // (RDD[T], (TaskContext, Iterator[T]) => U) + partition: Partition, @transient locs: Seq[TaskLocation], val outputId: Int) extends Task[U](stageId, partition.index) with Serializable { - // TODO: Should we also broadcast func? For that we would need a place to - // keep a reference to it (perhaps in DAGScheduler's job object). - @transient private[this] val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } override def runTask(context: TaskContext): U = { - // Deserialize the RDD using the broadcast variable. + // Deserialize the RDD and the func using the broadcast variables. val ser = SparkEnv.get.closureSerializer.newInstance() - val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value), - Thread.currentThread.getContextClassLoader) + val (rdd, func) = ser.deserialize[(RDD[T], (TaskContext, Iterator[T]) => U)]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) + metrics = Some(context.taskMetrics) try { func(context, rdd.iterator(partition, context)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index fd04b035ccb22..e95a97cdb40d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -19,37 +19,34 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import scala.language.existentials + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter /** - * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner - * specified in the ShuffleDependency). - * - * See [[org.apache.spark.scheduler.Task]] for more information. - * +* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner +* specified in the ShuffleDependency). +* +* See [[org.apache.spark.scheduler.Task]] for more information. +* * @param stageId id of the stage this task belongs to - * @param rddBinary broadcast version of of the serialized RDD - * @param dep the ShuffleDependency + * @param taskBinary broadcast version of of the RDD and the ShuffleDependency * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, - var rddBinary: Broadcast[Array[Byte]], - var dep: ShuffleDependency[_, _, _], + taskBinary: Broadcast[Array[Byte]], // (RDD[_], ShuffleDependency[_, _, _]) partition: Partition, @transient private var locs: Seq[TaskLocation]) extends Task[MapStatus](stageId, partition.index) with Logging { - // TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to - // keep a reference to it (perhaps in Stage). - /** A constructor used only in test suites. This does not require passing in an RDD. */ def this(partitionId: Int) { - this(0, null, null, new Partition { override def index = 0 }, null) + this(0, null, new Partition { override def index = 0 }, null) } @transient private val preferredLocs: Seq[TaskLocation] = { @@ -59,8 +56,8 @@ private[spark] class ShuffleMapTask( override def runTask(context: TaskContext): MapStatus = { // Deserialize the RDD using the broadcast variable. val ser = SparkEnv.get.closureSerializer.newInstance() - val rdd = ser.deserialize[RDD[_]](ByteBuffer.wrap(rddBinary.value), - Thread.currentThread.getContextClassLoader) + val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( + ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 8cbb9050f393b..538843507e1bf 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -38,7 +38,7 @@ import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.json4s._ import tachyon.client.{TachyonFile,TachyonFS} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.ExecutorUncaughtExceptionHandler import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} @@ -57,6 +57,12 @@ private[spark] object Utils extends Logging { new File(sparkHome + File.separator + "bin", which + suffix) } + /** Serialize an object using the closure serializer. **/ + def serializeTaskClosure[T: ClassTag](o: T): Array[Byte] = { + val ser = SparkEnv.get.closureSerializer.newInstance() + ser.serialize(o).array() + } + /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 8678fc32ad7d3..989273628e52a 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -143,18 +143,6 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo postGCTester.assertCleanup() } - test("automatically cleanup broadcast data for task dispatching") { - var rdd = newRDDWithShuffleDependencies()._1 - rdd.count() // This triggers an action that broadcasts the RDDs. - - // Test that GC causes broadcast task data cleanup after dereferencing the RDD. - val postGCTester = new CleanerTester(sc, - broadcastIds = Seq(rdd.createBroadcastBinary.id, rdd.firstParent.createBroadcastBinary.id)) - rdd = null - runGC() - postGCTester.assertCleanup() - } - test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 98931bd24d7ea..772fb33d735ab 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.Partition import org.apache.spark.SparkContext import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { @@ -40,7 +41,7 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte } val func = (c: TaskContext, i: Iterator[String]) => i.next() val task = new ResultTask[String, String]( - 0, rdd.createBroadcastBinary(), func, rdd.partitions(0), Seq(), 0) + 0, sc.broadcast(Utils.serializeTaskClosure((rdd, func))), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { task.run(0) } From f8535dc959b6d3b733fd46adbfa07708557a1d05 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 28 Jul 2014 21:35:23 -0700 Subject: [PATCH 11/12] Fixed the style violation. --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 538843507e1bf..b1ce6442a6a00 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -57,7 +57,7 @@ private[spark] object Utils extends Logging { new File(sparkHome + File.separator + "bin", which + suffix) } - /** Serialize an object using the closure serializer. **/ + /** Serialize an object using the closure serializer. */ def serializeTaskClosure[T: ClassTag](o: T): Array[Byte] = { val ser = SparkEnv.get.closureSerializer.newInstance() ser.serialize(o).array() From f7364db1f08555eb14cec757ee9e890de76db7c3 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 29 Jul 2014 23:46:13 -0700 Subject: [PATCH 12/12] Code review feedback. --- .../org/apache/spark/scheduler/DAGScheduler.scala | 10 +++++++--- .../org/apache/spark/scheduler/ResultTask.scala | 5 +++-- .../apache/spark/scheduler/ShuffleMapTask.scala | 5 +++-- .../main/scala/org/apache/spark/util/Utils.scala | 6 ------ .../org/apache/spark/ContextCleanerSuite.scala | 14 ++++++++++++++ .../apache/spark/scheduler/TaskContextSuite.scala | 8 +++----- 6 files changed, 30 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8d3aa23cb0bcf..50186d097a632 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -115,6 +115,10 @@ class DAGScheduler( private val dagSchedulerActorSupervisor = env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this))) + // A closure serializer that we reuse. + // This is only safe because DAGScheduler runs in a single thread. + private val closureSerializer = SparkEnv.get.closureSerializer.newInstance() + private[scheduler] var eventProcessActor: ActorRef = _ private def initializeEventProcessActor() { @@ -722,9 +726,9 @@ class DAGScheduler( // For ResultTask, serialize and broadcast (rdd, func). val taskBinaryBytes: Array[Byte] = if (stage.isShuffleMap) { - Utils.serializeTaskClosure((stage.rdd, stage.shuffleDep.get) : AnyRef) + closureSerializer.serialize((stage.rdd, stage.shuffleDep.get) : AnyRef).array() } else { - Utils.serializeTaskClosure((stage.rdd, stage.resultOfJob.get.func) : AnyRef) + closureSerializer.serialize((stage.rdd, stage.resultOfJob.get.func) : AnyRef).array() } taskBinary = sc.broadcast(taskBinaryBytes) } catch { @@ -765,7 +769,7 @@ class DAGScheduler( // We've already serialized RDDs and closures in taskBinary, but here we check for all other // objects such as Partition. try { - SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head) + closureSerializer.serialize(tasks.head) } catch { case e: NotSerializableException => abortStage(stage, "Task not serializable: " + e.toString) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index e133dde392104..d09fd7aa57642 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -32,7 +32,8 @@ import org.apache.spark.rdd.RDD * * @param stageId id of the stage this task belongs to * @param taskBinary broadcasted version of the serialized RDD and the function to apply on each - * partition of the given RDD. + * partition of the given RDD. Once deserialized, the type should be + * (RDD[T], (TaskContext, Iterator[T]) => U). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the @@ -40,7 +41,7 @@ import org.apache.spark.rdd.RDD */ private[spark] class ResultTask[T, U]( stageId: Int, - taskBinary: Broadcast[Array[Byte]], // (RDD[T], (TaskContext, Iterator[T]) => U) + taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient locs: Seq[TaskLocation], val outputId: Int) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index e95a97cdb40d1..11255c07469d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -33,13 +33,14 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param taskBinary broadcast version of of the RDD and the ShuffleDependency + * @param taskBinary broadcast version of of the RDD and the ShuffleDependency. Once deserialized, + * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, - taskBinary: Broadcast[Array[Byte]], // (RDD[_], ShuffleDependency[_, _, _]) + taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation]) extends Task[MapStatus](stageId, partition.index) with Logging { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b1ce6442a6a00..912fcb1b9cf62 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -57,12 +57,6 @@ private[spark] object Utils extends Logging { new File(sparkHome + File.separator + "bin", which + suffix) } - /** Serialize an object using the closure serializer. */ - def serializeTaskClosure[T: ClassTag](o: T): Array[Byte] = { - val ser = SparkEnv.get.closureSerializer.newInstance() - ser.serialize(o).array() - } - /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 989273628e52a..ad20f9b937ac1 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -164,6 +164,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo rddBuffer.clear() runGC() postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) } test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { @@ -195,6 +202,13 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo rddBuffer.clear() runGC() postGCTester.assertCleanup() + + // Make sure the broadcasted task closure no longer exists after GC. + val taskClosureBroadcastId = broadcastIds.max + 1 + assert(sc.env.blockManager.master.getMatchingBlockIds({ + case BroadcastBlockId(`taskClosureBroadcastId`, _) => true + case _ => false + }, askSlaves = true).isEmpty) } //------ Helper functions ------ diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 772fb33d735ab..270f7e661045a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -20,10 +20,7 @@ package org.apache.spark.scheduler import org.scalatest.FunSuite import org.scalatest.BeforeAndAfter -import org.apache.spark.LocalSparkContext -import org.apache.spark.Partition -import org.apache.spark.SparkContext -import org.apache.spark.TaskContext +import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -39,9 +36,10 @@ class TaskContextSuite extends FunSuite with BeforeAndAfter with LocalSparkConte sys.error("failed") } } + val closureSerializer = SparkEnv.get.closureSerializer.newInstance() val func = (c: TaskContext, i: Iterator[String]) => i.next() val task = new ResultTask[String, String]( - 0, sc.broadcast(Utils.serializeTaskClosure((rdd, func))), rdd.partitions(0), Seq(), 0) + 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { task.run(0) }