From 2915a5ec1bd9d4bc7a40b0ad20ca5b0db8f5382e Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 13 Sep 2017 09:10:28 -0700 Subject: [PATCH] Move dispatching to the event thread. This change makes the event queue implement SparkListenerBus and inherit all the metrics and dispatching behavior, making the change easier on the scheduler and also restoring per-listener metrics. --- .../spark/scheduler/AsyncEventQueue.scala | 158 +++--------------- .../spark/scheduler/LiveListenerBus.scala | 41 ++--- .../scheduler/EventLoggingListenerSuite.scala | 4 +- .../spark/scheduler/SparkListenerSuite.scala | 8 +- 4 files changed, 49 insertions(+), 162 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index ee3a300f4e116..c8481f8dcf2d4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -17,13 +17,10 @@ package org.apache.spark.scheduler -import java.util.{ArrayList, List => JList} -import java.util.concurrent._ +import java.util.concurrent.LinkedBlockingQueue import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong} -import scala.util.control.NonFatal - -import com.codahale.metrics.{Counter, Gauge, MetricRegistry} +import com.codahale.metrics.{Gauge, Timer} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging @@ -36,36 +33,16 @@ import org.apache.spark.util.Utils * * Delivery will only begin when the `start()` method is called. The `stop()` method should be * called when no more events need to be delivered. - * - * Instances of `AsyncEventQueue` are listeners themselves, but they're not to be used like regular - * listeners; they are used internally by `LiveListenerBus`, and are tightly coupled to the - * lifecycle of that implementation. */ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics) - extends SparkListenerInterface + extends SparkListenerBus with Logging { import AsyncEventQueue._ - private val _listeners = new CopyOnWriteArrayList[SparkListenerInterface]() - - def addListener(l: SparkListenerInterface): Unit = { - _listeners.add(l) - } - - /** - * @return Whether there are remainning listeners in the queue. - */ - def removeListener(l: SparkListenerInterface): Boolean = { - _listeners.remove(l) - !_listeners.isEmpty() - } - - def listeners: JList[SparkListenerInterface] = new ArrayList(_listeners) - // Cap the capacity of the queue so we get an explicit error (rather than an OOM exception) if // it's perpetually being added to more quickly than it's being drained. - private val taskQueue = new LinkedBlockingQueue[SparkListenerInterface => Unit]( + private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent]( conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY)) // Keep the event count separately, so that waitUntilEmpty() can be implemented properly; @@ -87,12 +64,13 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi private val stopped = new AtomicBoolean(false) private val droppedEvents = metrics.metricRegistry.counter(s"queue.$name.numDroppedEvents") + private val processingTime = metrics.metricRegistry.timer(s"queue.$name.listenerProcessingTime") // Remove the queue size gauge first, in case it was created by a previous incarnation of // this queue that was removed from the listener bus. metrics.metricRegistry.remove(s"queue.$name.size") metrics.metricRegistry.register(s"queue.$name.size", new Gauge[Int] { - override def getValue: Int = taskQueue.size() + override def getValue: Int = eventQueue.size() }) private val dispatchThread = new Thread(s"spark-listener-group-$name") { @@ -104,20 +82,16 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) { try { - var task: SparkListenerInterface => Unit = taskQueue.take() - while (task != POISON_PILL) { - val it = _listeners.iterator() - while (it.hasNext()) { - val listener = it.next() - try { - task(listener) - } catch { - case NonFatal(e) => - logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) - } + var next: SparkListenerEvent = eventQueue.take() + while (next != POISON_PILL) { + val ctx = processingTime.time() + try { + super.postToAll(next) + } finally { + ctx.stop() } eventCount.decrementAndGet() - task = taskQueue.take() + next = eventQueue.take() } eventCount.decrementAndGet() } catch { @@ -126,6 +100,10 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi } } + override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { + metrics.getTimer(listener.getClass().getName()) + } + /** * Start an asynchronous thread to dispatch events to the underlying listeners. * @@ -150,19 +128,19 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") } if (stopped.compareAndSet(false, true)) { - taskQueue.put(POISON_PILL) + eventQueue.put(POISON_PILL) eventCount.incrementAndGet() } dispatchThread.join() } - private def post(event: SparkListenerEvent)(task: SparkListenerInterface => Unit): Unit = { + def post(event: SparkListenerEvent): Unit = { if (stopped.get()) { return } eventCount.incrementAndGet() - if (taskQueue.offer(task)) { + if (eventQueue.offer(event)) { return } @@ -210,102 +188,10 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi true } - override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { - post(event)(_.onStageCompleted(event)) - } - - override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { - post(event)(_.onStageSubmitted(event)) - } - - override def onTaskStart(event: SparkListenerTaskStart): Unit = { - post(event)(_.onTaskStart(event)) - } - - override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = { - post(event)(_.onTaskGettingResult(event)) - } - - override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { - post(event)(_.onTaskEnd(event)) - } - - override def onJobStart(event: SparkListenerJobStart): Unit = { - post(event)(_.onJobStart(event)) - } - - override def onJobEnd(event: SparkListenerJobEnd): Unit = { - post(event)(_.onJobEnd(event)) - } - - override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { - post(event)(_.onEnvironmentUpdate(event)) - } - - override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { - post(event)(_.onBlockManagerAdded(event)) - } - - override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { - post(event)(_.onBlockManagerRemoved(event)) - } - - override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { - post(event)(_.onUnpersistRDD(event)) - } - - override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { - post(event)(_.onApplicationStart(event)) - } - - override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { - post(event)(_.onApplicationEnd(event)) - } - - override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { - post(event)(_.onExecutorMetricsUpdate(event)) - } - - override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { - post(event)(_.onExecutorAdded(event)) - } - - override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { - post(event)(_.onExecutorRemoved(event)) - } - - override def onExecutorBlacklisted(event: SparkListenerExecutorBlacklisted): Unit = { - post(event)(_.onExecutorBlacklisted(event)) - } - - override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { - post(event)(_.onExecutorUnblacklisted(event)) - } - - override def onNodeBlacklisted(event: SparkListenerNodeBlacklisted): Unit = { - post(event)(_.onNodeBlacklisted(event)) - } - - override def onNodeUnblacklisted(event: SparkListenerNodeUnblacklisted): Unit = { - post(event)(_.onNodeUnblacklisted(event)) - } - - override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { - post(event)(_.onBlockUpdated(event)) - } - - override def onSpeculativeTaskSubmitted(event: SparkListenerSpeculativeTaskSubmitted): Unit = { - post(event)(_.onSpeculativeTaskSubmitted(event)) - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = { - post(event)(_.onOtherEvent(event)) - } - } private object AsyncEventQueue { - val POISON_PILL: SparkListenerInterface => Unit = { _ => Unit } + val POISON_PILL = new SparkListenerEvent() { } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index e8d196f411526..2cf81ce9d8bcd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -41,7 +41,7 @@ import org.apache.spark.metrics.source.Source * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when `stop()` is called, and it will drop further events after stopping. */ -private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { +private[spark] class LiveListenerBus(conf: SparkConf) { import LiveListenerBus._ @@ -60,8 +60,10 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { /** When `droppedEventsCounter` was logged last time in milliseconds. */ @volatile private var lastReportTimestamp = 0L + private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + /** Add a listener to the default queue. */ - override def addListener(listener: SparkListenerInterface): Unit = { + def addListener(listener: SparkListenerInterface): Unit = { addToQueue(listener, "default") } @@ -85,37 +87,36 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { if (started.get() && !stopped.get()) { newQueue.start(sparkContext) } - super.addListener(newQueue) + queues.add(newQueue) } } - override def removeListener(listener: SparkListenerInterface): Unit = synchronized { + def removeListener(listener: SparkListenerInterface): Unit = synchronized { // Remove listener from all queues it was added to, and stop queues that have become empty. queues.asScala - .filter(!_.removeListener(listener)) + .filter { queue => + queue.removeListener(listener) + queue.listeners.isEmpty() + } .foreach { toRemove => if (started.get() && !stopped.get()) { toRemove.stop() } - super.removeListener(toRemove) + queues.remove(toRemove) } } /** An alias for postToAll(), to avoid changing all call sites. */ - def post(event: SparkListenerEvent): Unit = postToAll(event) - - override def postToAll(event: SparkListenerEvent): Unit = { + def post(event: SparkListenerEvent): Unit = { if (!stopped.get()) { metrics.numEventsPosted.inc() - super.postToAll(event) + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } } } - override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = { - val name = listener.asInstanceOf[AsyncEventQueue].name - metrics.getTimer(s"queue.$name") - } - /** * Start sending events to attached listeners. * @@ -168,7 +169,7 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { } } - override private[spark] def findListenersByClass[T <: SparkListenerInterface : ClassTag](): + private[spark] def findListenersByClass[T <: SparkListenerInterface : ClassTag](): Seq[T] = { val c = implicitly[ClassTag[T]].runtimeClass queues.asScala.flatMap { queue => @@ -176,13 +177,13 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus { } } - override private[spark] def listeners: JList[SparkListenerInterface] = { + private[spark] def listeners: JList[SparkListenerInterface] = { queues.asScala.flatMap(_.listeners.asScala).asJava } - // Exposed for testing. - private[scheduler] def queues: JList[AsyncEventQueue] = { - super.listeners.asInstanceOf[JList[AsyncEventQueue]] + // For testing only. + private[scheduler] def activeQueues(): Seq[String] = { + queues.asScala.map(_.name).toSeq } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 0afd07b851cf9..b14461a2c2873 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -165,8 +165,8 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit eventLogger.start() listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem])) listenerBus.addListener(eventLogger) - listenerBus.postToAll(applicationStart) - listenerBus.postToAll(applicationEnd) + listenerBus.post(applicationStart) + listenerBus.post(applicationEnd) listenerBus.stop() eventLogger.stop() diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 27854464598e7..da5e5a7e86cac 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -461,19 +461,19 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.addListener(counter1) bus.addToQueue(counter2, "other") bus.addToQueue(counter3, "other") - assert(bus.queues.asScala.map(_.name) === Seq("default", "other")) + assert(bus.activeQueues() === Seq("default", "other")) assert(bus.findListenersByClass[BasicJobCounter]().size === 3) bus.removeListener(counter1) - assert(bus.queues.asScala.map(_.name) === Seq("other")) + assert(bus.activeQueues() === Seq("other")) assert(bus.findListenersByClass[BasicJobCounter]().size === 2) bus.removeListener(counter2) - assert(bus.queues.asScala.map(_.name) === Seq("other")) + assert(bus.activeQueues() === Seq("other")) assert(bus.findListenersByClass[BasicJobCounter]().size === 1) bus.removeListener(counter3) - assert(bus.queues.isEmpty) + assert(bus.activeQueues().isEmpty) assert(bus.findListenersByClass[BasicJobCounter]().isEmpty) }