Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-2565. Update ShuffleReadMetrics as blocks are fetched #1507

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ private[spark] class Executor(
for (taskRunner <- runningTasks.values()) {
if (!taskRunner.attemptedTask.isEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
metrics.updateShuffleReadMetrics
tasksMetrics += ((taskRunner.taskId, metrics))
}
}
Expand Down
57 changes: 43 additions & 14 deletions core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.executor

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.storage.{BlockId, BlockStatus}

Expand Down Expand Up @@ -81,11 +83,26 @@ class TaskMetrics extends Serializable {
var inputMetrics: Option[InputMetrics] = None

/**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
* This includes read metrics aggregated over all the task's shuffle dependencies.
*/
private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None

def shuffleReadMetrics = _shuffleReadMetrics
def shuffleReadMetrics() = _shuffleReadMetrics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: since this doesn't mutate internal state the original lack of parentheses is correct style.


/**
* This should only be used when recreating TaskMetrics, not when updating read metrics in
* executors.
*/
private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) {
_shuffleReadMetrics = shuffleReadMetrics
}

/**
* ShuffleReadMetrics per dependency for collecting independently while task is in progress.
*/
@transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] =
new ArrayBuffer[ShuffleReadMetrics]()

/**
* If this task writes to shuffle output, metrics on the written shuffle data will be collected
Expand All @@ -98,19 +115,31 @@ class TaskMetrics extends Serializable {
*/
var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None

/** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */
def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized {
_shuffleReadMetrics match {
case Some(existingMetrics) =>
existingMetrics.shuffleFinishTime = math.max(
existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime)
existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime
existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched
existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched
existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead
case None =>
_shuffleReadMetrics = Some(newMetrics)
/**
* A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization
* issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each
* dependency, and merge these metrics before reporting them to the driver. This method returns
* a ShuffleReadMetrics for a dependency and registers it for merging later.
*/
private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized {
val readMetrics = new ShuffleReadMetrics()
depsShuffleReadMetrics += readMetrics
readMetrics
}

/**
* Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
*/
private[spark] def updateShuffleReadMetrics() = synchronized {
val merged = new ShuffleReadMetrics()
for (depMetrics <- depsShuffleReadMetrics) {
merged.fetchWaitTime += depMetrics.fetchWaitTime
merged.localBlocksFetched += depMetrics.localBlocksFetched
merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
merged.remoteBytesRead += depMetrics.remoteBytesRead
merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime)
}
_shuffleReadMetrics = Some(merged)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
serializer: Serializer,
shuffleMetrics: ShuffleReadMetrics)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
Expand Down Expand Up @@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
}
}

val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics)
val itr = blockFetcherItr.flatMap(unpackBlock)

val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics)
context.taskMetrics.updateShuffleReadMetrics()
})

new InterruptibleIterator[T](context, completionIter)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val ser = Serializer.getSerializer(dep.serializer)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser,
readMetrics)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import scala.util.{Failure, Success}
import io.netty.buffer.ByteBuf

import org.apache.spark.{Logging, SparkException}
import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.network.BufferMessage
import org.apache.spark.network.ConnectionManagerId
import org.apache.spark.network.netty.ShuffleCopier
Expand All @@ -47,10 +48,6 @@ import org.apache.spark.util.Utils
private[storage]
trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
def initialize()
def numLocalBlocks: Int
def numRemoteBlocks: Int
def fetchWaitTime: Long
def remoteBytesRead: Long
}


Expand All @@ -72,14 +69,12 @@ object BlockFetcherIterator {
class BasicBlockFetcherIterator(
private val blockManager: BlockManager,
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
serializer: Serializer,
readMetrics: ShuffleReadMetrics)
extends BlockFetcherIterator {

import blockManager._

private var _remoteBytesRead = 0L
private var _fetchWaitTime = 0L

if (blocksByAddress == null) {
throw new IllegalArgumentException("BlocksByAddress is null")
}
Expand All @@ -89,13 +84,9 @@ object BlockFetcherIterator {

protected var startTime = System.currentTimeMillis

// This represents the number of local blocks, also counting zero-sized blocks
private var numLocal = 0
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
protected val localBlocksToFetch = new ArrayBuffer[BlockId]()

// This represents the number of remote blocks, also counting zero-sized blocks
private var numRemote = 0
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
protected val remoteBlocksToFetch = new HashSet[BlockId]()

Expand Down Expand Up @@ -132,7 +123,10 @@ object BlockFetcherIterator {
val networkSize = blockMessage.getData.limit()
results.put(new FetchResult(blockId, sizeMap(blockId),
() => dataDeserialize(blockId, blockMessage.getData, serializer)))
_remoteBytesRead += networkSize
// TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can
// be incrementing bytes read at the same time (SPARK-2625).
readMetrics.remoteBytesRead += networkSize
readMetrics.remoteBlocksFetched += 1
logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
}
Expand All @@ -155,14 +149,14 @@ object BlockFetcherIterator {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val remoteRequests = new ArrayBuffer[FetchRequest]
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
if (address == blockManagerId) {
numLocal = blockInfos.size
// Filter out zero-sized blocks
localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
_numBlocksToFetch += localBlocksToFetch.size
} else {
numRemote += blockInfos.size
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
Expand Down Expand Up @@ -192,7 +186,7 @@ object BlockFetcherIterator {
}
}
logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
(numLocal + numRemote) + " blocks")
totalBlocks + " blocks")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ever used other than for logging?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naw

remoteRequests
}

Expand All @@ -205,6 +199,7 @@ object BlockFetcherIterator {
// getLocalFromDisk never return None but throws BlockException
val iter = getLocalFromDisk(id, serializer).get
// Pass 0 as size since it's not in flight
readMetrics.localBlocksFetched += 1
results.put(new FetchResult(id, 0, () => iter))
logDebug("Got local block " + id)
} catch {
Expand Down Expand Up @@ -238,12 +233,6 @@ object BlockFetcherIterator {
logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
}

override def numLocalBlocks: Int = numLocal
override def numRemoteBlocks: Int = numRemote
override def fetchWaitTime: Long = _fetchWaitTime
override def remoteBytesRead: Long = _remoteBytesRead


// Implementing the Iterator methods with an iterator that reads fetched blocks off the queue
// as they arrive.
@volatile protected var resultsGotten = 0
Expand All @@ -255,7 +244,7 @@ object BlockFetcherIterator {
val startFetchWait = System.currentTimeMillis()
val result = results.take()
val stopFetchWait = System.currentTimeMillis()
_fetchWaitTime += (stopFetchWait - startFetchWait)
readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
if (! result.failed) bytesInFlight -= result.size
while (!fetchRequests.isEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
Expand All @@ -269,8 +258,9 @@ object BlockFetcherIterator {
class NettyBlockFetcherIterator(
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
serializer: Serializer,
readMetrics: ShuffleReadMetrics)
extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) {

import blockManager._

Expand Down
11 changes: 7 additions & 4 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
import sun.nio.ch.DirectBuffer

import org.apache.spark._
import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics}
import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.serializer.Serializer
Expand Down Expand Up @@ -539,12 +539,15 @@ private[spark] class BlockManager(
*/
def getMultiple(
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer): BlockFetcherIterator = {
serializer: Serializer,
readMetrics: ShuffleReadMetrics): BlockFetcherIterator = {
val iter =
if (conf.getBoolean("spark.shuffle.use.netty", false)) {
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer,
readMetrics)
} else {
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer,
readMetrics)
}
iter.initialize()
iter
Expand Down
5 changes: 2 additions & 3 deletions core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -560,9 +560,8 @@ private[spark] object JsonProtocol {
metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long]
metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long]
metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long]
Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics =>
metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics))
}
metrics.setShuffleReadMetrics(
Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson))
metrics.shuffleWriteMetrics =
Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
metrics.inputMetrics =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import org.mockito.invocation.InvocationOnMock

import org.apache.spark.storage.BlockFetcherIterator._
import org.apache.spark.network.{ConnectionManager, Message}
import org.apache.spark.executor.ShuffleReadMetrics

class BlockFetcherIteratorSuite extends FunSuite with Matchers {

Expand Down Expand Up @@ -70,8 +71,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
(bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
new ShuffleReadMetrics())

iterator.initialize()

Expand Down Expand Up @@ -121,8 +122,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
(bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
new ShuffleReadMetrics())

iterator.initialize()

Expand Down Expand Up @@ -165,7 +166,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
blocksByAddress, null, new ShuffleReadMetrics())

iterator.initialize()
iterator.foreach{
Expand Down Expand Up @@ -219,7 +220,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
)

val iterator = new BasicBlockFetcherIterator(blockManager,
blocksByAddress, null)
blocksByAddress, null, new ShuffleReadMetrics())
iterator.initialize()
iterator.foreach{
case (_, r) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc

// finish this task, should get updated shuffleRead
shuffleReadMetrics.remoteBytesRead = 1000
taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
taskInfo.finishTime = 1
var task = new ShuffleMapTask(0)
Expand Down Expand Up @@ -142,7 +142,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
val taskMetrics = new TaskMetrics()
val shuffleReadMetrics = new ShuffleReadMetrics()
val shuffleWriteMetrics = new ShuffleWriteMetrics()
taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
shuffleReadMetrics.remoteBytesRead = base + 1
shuffleReadMetrics.remoteBlocksFetched = base + 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ class JsonProtocolSuite extends FunSuite {
sr.localBlocksFetched = e
sr.fetchWaitTime = a + d
sr.remoteBlocksFetched = f
t.updateShuffleReadMetrics(sr)
t.setShuffleReadMetrics(Some(sr))
}
sw.shuffleBytesWritten = a + b + c
sw.shuffleWriteTime = b + c + d
Expand Down