Skip to content

Commit

Permalink
Move dispatching to the event thread.
Browse files Browse the repository at this point in the history
This change makes the event queue implement SparkListenerBus and inherit
all the metrics and dispatching behavior, making the change easier on the
scheduler and also restoring per-listener metrics.
  • Loading branch information
Marcelo Vanzin committed Sep 13, 2017
1 parent 6bee214 commit 2915a5e
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 162 deletions.
158 changes: 22 additions & 136 deletions core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,10 @@

package org.apache.spark.scheduler

import java.util.{ArrayList, List => JList}
import java.util.concurrent._
import java.util.concurrent.LinkedBlockingQueue
import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}

import scala.util.control.NonFatal

import com.codahale.metrics.{Counter, Gauge, MetricRegistry}
import com.codahale.metrics.{Gauge, Timer}

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.internal.Logging
Expand All @@ -36,36 +33,16 @@ import org.apache.spark.util.Utils
*
* Delivery will only begin when the `start()` method is called. The `stop()` method should be
* called when no more events need to be delivered.
*
* Instances of `AsyncEventQueue` are listeners themselves, but they're not to be used like regular
* listeners; they are used internally by `LiveListenerBus`, and are tightly coupled to the
* lifecycle of that implementation.
*/
private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveListenerBusMetrics)
extends SparkListenerInterface
extends SparkListenerBus
with Logging {

import AsyncEventQueue._

private val _listeners = new CopyOnWriteArrayList[SparkListenerInterface]()

def addListener(l: SparkListenerInterface): Unit = {
_listeners.add(l)
}

/**
* @return Whether there are remainning listeners in the queue.
*/
def removeListener(l: SparkListenerInterface): Boolean = {
_listeners.remove(l)
!_listeners.isEmpty()
}

def listeners: JList[SparkListenerInterface] = new ArrayList(_listeners)

// Cap the capacity of the queue so we get an explicit error (rather than an OOM exception) if
// it's perpetually being added to more quickly than it's being drained.
private val taskQueue = new LinkedBlockingQueue[SparkListenerInterface => Unit](
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](
conf.get(LISTENER_BUS_EVENT_QUEUE_CAPACITY))

// Keep the event count separately, so that waitUntilEmpty() can be implemented properly;
Expand All @@ -87,12 +64,13 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
private val stopped = new AtomicBoolean(false)

private val droppedEvents = metrics.metricRegistry.counter(s"queue.$name.numDroppedEvents")
private val processingTime = metrics.metricRegistry.timer(s"queue.$name.listenerProcessingTime")

// Remove the queue size gauge first, in case it was created by a previous incarnation of
// this queue that was removed from the listener bus.
metrics.metricRegistry.remove(s"queue.$name.size")
metrics.metricRegistry.register(s"queue.$name.size", new Gauge[Int] {
override def getValue: Int = taskQueue.size()
override def getValue: Int = eventQueue.size()
})

private val dispatchThread = new Thread(s"spark-listener-group-$name") {
Expand All @@ -104,20 +82,16 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi

private def dispatch(): Unit = LiveListenerBus.withinListenerThread.withValue(true) {
try {
var task: SparkListenerInterface => Unit = taskQueue.take()
while (task != POISON_PILL) {
val it = _listeners.iterator()
while (it.hasNext()) {
val listener = it.next()
try {
task(listener)
} catch {
case NonFatal(e) =>
logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e)
}
var next: SparkListenerEvent = eventQueue.take()
while (next != POISON_PILL) {
val ctx = processingTime.time()
try {
super.postToAll(next)
} finally {
ctx.stop()
}
eventCount.decrementAndGet()
task = taskQueue.take()
next = eventQueue.take()
}
eventCount.decrementAndGet()
} catch {
Expand All @@ -126,6 +100,10 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
}
}

override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = {
metrics.getTimer(listener.getClass().getName())
}

/**
* Start an asynchronous thread to dispatch events to the underlying listeners.
*
Expand All @@ -150,19 +128,19 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
throw new IllegalStateException(s"Attempted to stop $name that has not yet started!")
}
if (stopped.compareAndSet(false, true)) {
taskQueue.put(POISON_PILL)
eventQueue.put(POISON_PILL)
eventCount.incrementAndGet()
}
dispatchThread.join()
}

private def post(event: SparkListenerEvent)(task: SparkListenerInterface => Unit): Unit = {
def post(event: SparkListenerEvent): Unit = {
if (stopped.get()) {
return
}

eventCount.incrementAndGet()
if (taskQueue.offer(task)) {
if (eventQueue.offer(event)) {
return
}

Expand Down Expand Up @@ -210,102 +188,10 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi
true
}

override def onStageCompleted(event: SparkListenerStageCompleted): Unit = {
post(event)(_.onStageCompleted(event))
}

override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = {
post(event)(_.onStageSubmitted(event))
}

override def onTaskStart(event: SparkListenerTaskStart): Unit = {
post(event)(_.onTaskStart(event))
}

override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = {
post(event)(_.onTaskGettingResult(event))
}

override def onTaskEnd(event: SparkListenerTaskEnd): Unit = {
post(event)(_.onTaskEnd(event))
}

override def onJobStart(event: SparkListenerJobStart): Unit = {
post(event)(_.onJobStart(event))
}

override def onJobEnd(event: SparkListenerJobEnd): Unit = {
post(event)(_.onJobEnd(event))
}

override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = {
post(event)(_.onEnvironmentUpdate(event))
}

override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = {
post(event)(_.onBlockManagerAdded(event))
}

override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = {
post(event)(_.onBlockManagerRemoved(event))
}

override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = {
post(event)(_.onUnpersistRDD(event))
}

override def onApplicationStart(event: SparkListenerApplicationStart): Unit = {
post(event)(_.onApplicationStart(event))
}

override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = {
post(event)(_.onApplicationEnd(event))
}

override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = {
post(event)(_.onExecutorMetricsUpdate(event))
}

override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = {
post(event)(_.onExecutorAdded(event))
}

override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = {
post(event)(_.onExecutorRemoved(event))
}

override def onExecutorBlacklisted(event: SparkListenerExecutorBlacklisted): Unit = {
post(event)(_.onExecutorBlacklisted(event))
}

override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = {
post(event)(_.onExecutorUnblacklisted(event))
}

override def onNodeBlacklisted(event: SparkListenerNodeBlacklisted): Unit = {
post(event)(_.onNodeBlacklisted(event))
}

override def onNodeUnblacklisted(event: SparkListenerNodeUnblacklisted): Unit = {
post(event)(_.onNodeUnblacklisted(event))
}

override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {
post(event)(_.onBlockUpdated(event))
}

override def onSpeculativeTaskSubmitted(event: SparkListenerSpeculativeTaskSubmitted): Unit = {
post(event)(_.onSpeculativeTaskSubmitted(event))
}

override def onOtherEvent(event: SparkListenerEvent): Unit = {
post(event)(_.onOtherEvent(event))
}

}

private object AsyncEventQueue {

val POISON_PILL: SparkListenerInterface => Unit = { _ => Unit }
val POISON_PILL = new SparkListenerEvent() { }

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.metrics.source.Source
* has started will events be actually propagated to all attached listeners. This listener bus
* is stopped when `stop()` is called, and it will drop further events after stopping.
*/
private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus {
private[spark] class LiveListenerBus(conf: SparkConf) {

import LiveListenerBus._

Expand All @@ -60,8 +60,10 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus {
/** When `droppedEventsCounter` was logged last time in milliseconds. */
@volatile private var lastReportTimestamp = 0L

private val queues = new CopyOnWriteArrayList[AsyncEventQueue]()

/** Add a listener to the default queue. */
override def addListener(listener: SparkListenerInterface): Unit = {
def addListener(listener: SparkListenerInterface): Unit = {
addToQueue(listener, "default")
}

Expand All @@ -85,37 +87,36 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus {
if (started.get() && !stopped.get()) {
newQueue.start(sparkContext)
}
super.addListener(newQueue)
queues.add(newQueue)
}
}

override def removeListener(listener: SparkListenerInterface): Unit = synchronized {
def removeListener(listener: SparkListenerInterface): Unit = synchronized {
// Remove listener from all queues it was added to, and stop queues that have become empty.
queues.asScala
.filter(!_.removeListener(listener))
.filter { queue =>
queue.removeListener(listener)
queue.listeners.isEmpty()
}
.foreach { toRemove =>
if (started.get() && !stopped.get()) {
toRemove.stop()
}
super.removeListener(toRemove)
queues.remove(toRemove)
}
}

/** An alias for postToAll(), to avoid changing all call sites. */
def post(event: SparkListenerEvent): Unit = postToAll(event)

override def postToAll(event: SparkListenerEvent): Unit = {
def post(event: SparkListenerEvent): Unit = {
if (!stopped.get()) {
metrics.numEventsPosted.inc()
super.postToAll(event)
val it = queues.iterator()
while (it.hasNext()) {
it.next().post(event)
}
}
}

override protected def getTimer(listener: SparkListenerInterface): Option[Timer] = {
val name = listener.asInstanceOf[AsyncEventQueue].name
metrics.getTimer(s"queue.$name")
}

/**
* Start sending events to attached listeners.
*
Expand Down Expand Up @@ -168,21 +169,21 @@ private[spark] class LiveListenerBus(conf: SparkConf) extends SparkListenerBus {
}
}

override private[spark] def findListenersByClass[T <: SparkListenerInterface : ClassTag]():
private[spark] def findListenersByClass[T <: SparkListenerInterface : ClassTag]():
Seq[T] = {
val c = implicitly[ClassTag[T]].runtimeClass
queues.asScala.flatMap { queue =>
queue.listeners.asScala.filter(_.getClass() == c).map(_.asInstanceOf[T])
}
}

override private[spark] def listeners: JList[SparkListenerInterface] = {
private[spark] def listeners: JList[SparkListenerInterface] = {
queues.asScala.flatMap(_.listeners.asScala).asJava
}

// Exposed for testing.
private[scheduler] def queues: JList[AsyncEventQueue] = {
super.listeners.asInstanceOf[JList[AsyncEventQueue]]
// For testing only.
private[scheduler] def activeQueues(): Seq[String] = {
queues.asScala.map(_.name).toSeq
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit
eventLogger.start()
listenerBus.start(Mockito.mock(classOf[SparkContext]), Mockito.mock(classOf[MetricsSystem]))
listenerBus.addListener(eventLogger)
listenerBus.postToAll(applicationStart)
listenerBus.postToAll(applicationEnd)
listenerBus.post(applicationStart)
listenerBus.post(applicationEnd)
listenerBus.stop()
eventLogger.stop()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,19 +461,19 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
bus.addListener(counter1)
bus.addToQueue(counter2, "other")
bus.addToQueue(counter3, "other")
assert(bus.queues.asScala.map(_.name) === Seq("default", "other"))
assert(bus.activeQueues() === Seq("default", "other"))
assert(bus.findListenersByClass[BasicJobCounter]().size === 3)

bus.removeListener(counter1)
assert(bus.queues.asScala.map(_.name) === Seq("other"))
assert(bus.activeQueues() === Seq("other"))
assert(bus.findListenersByClass[BasicJobCounter]().size === 2)

bus.removeListener(counter2)
assert(bus.queues.asScala.map(_.name) === Seq("other"))
assert(bus.activeQueues() === Seq("other"))
assert(bus.findListenersByClass[BasicJobCounter]().size === 1)

bus.removeListener(counter3)
assert(bus.queues.isEmpty)
assert(bus.activeQueues().isEmpty)
assert(bus.findListenersByClass[BasicJobCounter]().isEmpty)
}

Expand Down

0 comments on commit 2915a5e

Please sign in to comment.