diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 7fb614402b7b7..15a04e6558685 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -55,20 +55,20 @@ class DAGScheduler( mapOutputTracker: MapOutputTracker, blockManagerMaster: BlockManagerMaster, env: SparkEnv) - extends TaskSchedulerListener with Logging { + extends Logging { def this(taskSched: TaskScheduler) { this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get) } - taskSched.setListener(this) + taskSched.setDAGScheduler(this) // Called by TaskScheduler to report task's starting. - override def taskStarted(task: Task[_], taskInfo: TaskInfo) { + def taskStarted(task: Task[_], taskInfo: TaskInfo) { eventQueue.put(BeginEvent(task, taskInfo)) } // Called by TaskScheduler to report task completions or failures. - override def taskEnded( + def taskEnded( task: Task[_], reason: TaskEndReason, result: Any, @@ -79,18 +79,18 @@ class DAGScheduler( } // Called by TaskScheduler when an executor fails. - override def executorLost(execId: String) { + def executorLost(execId: String) { eventQueue.put(ExecutorLost(execId)) } // Called by TaskScheduler when a host is added - override def executorGained(execId: String, host: String) { + def executorGained(execId: String, host: String) { eventQueue.put(ExecutorGained(execId, host)) } // Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or // cancellation of the job itself. - override def taskSetFailed(taskSet: TaskSet, reason: String) { + def taskSetFailed(taskSet: TaskSet, reason: String) { eventQueue.put(TaskSetFailed(taskSet, reason)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 6a51efe8d6c05..10e047810827c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -24,8 +24,7 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode * Each TaskScheduler schedulers task for a single SparkContext. * These schedulers get sets of tasks submitted to them from the DAGScheduler for each stage, * and are responsible for sending the tasks to the cluster, running them, retrying if there - * are failures, and mitigating stragglers. They return events to the DAGScheduler through - * the TaskSchedulerListener interface. + * are failures, and mitigating stragglers. They return events to the DAGScheduler. */ private[spark] trait TaskScheduler { @@ -48,8 +47,8 @@ private[spark] trait TaskScheduler { // Cancel a stage. def cancelTasks(stageId: Int) - // Set a listener for upcalls. This is guaranteed to be set before submitTasks is called. - def setListener(listener: TaskSchedulerListener): Unit + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. + def setDAGScheduler(dagScheduler: DAGScheduler): Unit // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs. def defaultParallelism(): Int diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala deleted file mode 100644 index 593fa9fb93a55..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerListener.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import scala.collection.mutable.Map - -import org.apache.spark.TaskEndReason -import org.apache.spark.executor.TaskMetrics - -/** - * Interface for getting events back from the TaskScheduler. - */ -private[spark] trait TaskSchedulerListener { - // A task has started. - def taskStarted(task: Task[_], taskInfo: TaskInfo) - - // A task has finished or failed. - def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any], - taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit - - // A node was added to the cluster. - def executorGained(execId: String, host: String): Unit - - // A node was lost from the cluster. - def executorLost(execId: String): Unit - - // The TaskScheduler wants to abort an entire task set. - def taskSetFailed(taskSet: TaskSet, reason: String): Unit -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 7a72ff047442a..4ea8bf88534cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -79,7 +79,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) private val executorIdToHost = new HashMap[String, String] // Listener object to pass upcalls into - var listener: TaskSchedulerListener = null + var dagScheduler: DAGScheduler = null var backend: SchedulerBackend = null @@ -94,8 +94,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler } def initialize(context: SchedulerBackend) { @@ -297,7 +297,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } // Update the DAGScheduler without holding a lock on this, since that can deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get) backend.reviveOffers() } if (taskFailed) { @@ -397,9 +397,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) logError("Lost an executor " + executorId + " (already removed): " + reason) } } - // Call listener.executorLost without holding the lock on this to prevent deadlock + // Call dagScheduler.executorLost without holding the lock on this to prevent deadlock if (failedExecutor != None) { - listener.executorLost(failedExecutor.get) + dagScheduler.executorLost(failedExecutor.get) backend.reviveOffers() } } @@ -418,7 +418,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) } def executorGained(execId: String, host: String) { - listener.executorGained(execId, host) + dagScheduler.executorGained(execId, host) } def getExecutorsAliveOnHost(host: String): Option[Set[String]] = synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 7bd3499300e98..29093e3b4f511 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -415,11 +415,11 @@ private[spark] class ClusterTaskSetManager( } private def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) + sched.dagScheduler.taskStarted(task, info) } /** - * Marks the task as successful and notifies the listener that a task has ended. + * Marks the task as successful and notifies the DAGScheduler that a task has ended. */ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = { val info = taskInfos(tid) @@ -429,7 +429,7 @@ private[spark] class ClusterTaskSetManager( if (!successful(index)) { logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( tid, info.duration, info.host, tasksSuccessful, numTasks)) - sched.listener.taskEnded( + sched.dagScheduler.taskEnded( tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) // Mark successful and stop if all the tasks have succeeded. @@ -445,7 +445,8 @@ private[spark] class ClusterTaskSetManager( } /** - * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener. + * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the + * DAG Scheduler. */ def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { val info = taskInfos(tid) @@ -463,7 +464,7 @@ private[spark] class ClusterTaskSetManager( reason.foreach { case fetchFailed: FetchFailed => logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) successful(index) = true tasksSuccessful += 1 sched.taskSetFinished(this) @@ -472,11 +473,11 @@ private[spark] class ClusterTaskSetManager( case TaskKilled => logWarning("Task %d was killed.".format(tid)) - sched.listener.taskEnded(tasks(index), reason.get, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) return case ef: ExceptionFailure => - sched.listener.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) + sched.dagScheduler.taskEnded(tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) val key = ef.description val now = clock.getTime() val (printFull, dupCount) = { @@ -504,7 +505,7 @@ private[spark] class ClusterTaskSetManager( case TaskResultLost => logWarning("Lost result for TID %s on host %s".format(tid, info.host)) - sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) case _ => {} } @@ -533,7 +534,7 @@ private[spark] class ClusterTaskSetManager( failed = true causeOfFailure = message // TODO: Kill running tasks if we were not terminated due to a Mesos error - sched.listener.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message) removeAllRunningTasks() sched.taskSetFinished(this) } @@ -606,7 +607,7 @@ private[spark] class ClusterTaskSetManager( addPendingTask(index) // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. - sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null) + sched.dagScheduler.taskEnded(tasks(index), Resubmitted, null, null, info, null) } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index b445260d1b3f1..2699f0b33e8d3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -81,7 +81,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: val env = SparkEnv.get val attemptId = new AtomicInteger - var listener: TaskSchedulerListener = null + var dagScheduler: DAGScheduler = null // Application dependencies (added through SparkContext) that we've fetched so far on this node. // Each map holds the master's timestamp for the version of that file or JAR we got. @@ -114,8 +114,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test") } - override def setListener(listener: TaskSchedulerListener) { - this.listener = listener + override def setDAGScheduler(dagScheduler: DAGScheduler) { + this.dagScheduler = dagScheduler } override def submitTasks(taskSet: TaskSet) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala index f72e77d40fc12..55f8313e8719e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala @@ -133,7 +133,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } def taskStarted(task: Task[_], info: TaskInfo) { - sched.listener.taskStarted(task, info) + sched.dagScheduler.taskStarted(task, info) } def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) { @@ -148,7 +148,8 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } } result.metrics.resultSize = serializedData.limit() - sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics) + sched.dagScheduler.taskEnded(task, Success, result.value, result.accumUpdates, info, + result.metrics) numFinished += 1 decreaseRunningTasks(1) finished(index) = true @@ -165,7 +166,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas decreaseRunningTasks(1) val reason: ExceptionFailure = ser.deserialize[ExceptionFailure]( serializedData, getClass.getClassLoader) - sched.listener.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) + sched.dagScheduler.taskEnded(task, reason, null, null, info, reason.metrics.getOrElse(null)) if (!finished(index)) { copiesRunning(index) -= 1 numFailures(index) += 1 @@ -176,7 +177,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format( taskSet.id, index, 4, reason.description) decreaseRunningTasks(runningTasks) - sched.listener.taskSetFailed(taskSet, errorMessage) + sched.dagScheduler.taskSetFailed(taskSet, errorMessage) // need to delete failed Taskset from schedule queue sched.taskSetFinished(this) } @@ -184,7 +185,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas } override def error(message: String) { - sched.listener.taskSetFailed(taskSet, message) + sched.dagScheduler.taskSetFailed(taskSet, message) sched.taskSetFinished(this) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 838179c6b582e..2a2f828be6967 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -60,7 +60,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont taskSets += taskSet } override def cancelTasks(stageId: Int) {} - override def setListener(listener: TaskSchedulerListener) = {} + override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala index 80d0c5a5e929a..b97f2b19b581c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -28,6 +28,30 @@ import org.apache.spark.executor.TaskMetrics import java.nio.ByteBuffer import org.apache.spark.util.{Utils, FakeClock} +class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) { + override def taskStarted(task: Task[_], taskInfo: TaskInfo) { + taskScheduler.startedTasks += taskInfo.index + } + + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: mutable.Map[Long, Any], + taskInfo: TaskInfo, + taskMetrics: TaskMetrics) { + taskScheduler.endedTasks(taskInfo.index) = reason + } + + override def executorGained(execId: String, host: String) {} + + override def executorLost(execId: String) {} + + override def taskSetFailed(taskSet: TaskSet, reason: String) { + taskScheduler.taskSetsFailed += taskSet.id + } +} + /** * A mock ClusterScheduler implementation that just remembers information about tasks started and * feedback received from the TaskSetManagers. Note that it's important to initialize this with @@ -44,30 +68,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* val executors = new mutable.HashMap[String, String] ++ liveExecutors - listener = new TaskSchedulerListener { - def taskStarted(task: Task[_], taskInfo: TaskInfo) { - startedTasks += taskInfo.index - } - - def taskEnded( - task: Task[_], - reason: TaskEndReason, - result: Any, - accumUpdates: mutable.Map[Long, Any], - taskInfo: TaskInfo, - taskMetrics: TaskMetrics) - { - endedTasks(taskInfo.index) = reason - } - - def executorGained(execId: String, host: String) {} - - def executorLost(execId: String) {} - - def taskSetFailed(taskSet: TaskSet, reason: String) { - taskSetsFailed += taskSet.id - } - } + dagScheduler = new FakeDAGScheduler(this) def removeExecutor(execId: String): Unit = executors -= execId