From 1e3fb8a278d3c1f1836acb01a63f0194713e764e Mon Sep 17 00:00:00 2001 From: jerryshao Date: Wed, 31 May 2017 15:42:35 +0800 Subject: [PATCH] Further address the comments Change-Id: I5eba16903914932392e05ba56c27808c36b033b3 --- .../apache/spark/deploy/SparkHadoopUtil.scala | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index f30b788ff0fa5..4ac706054a9d1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -24,6 +24,7 @@ import java.util.{Arrays, Comparator, Date, Locale} import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.control.NonFatal import com.google.common.primitives.Longs @@ -148,13 +149,25 @@ class SparkHadoopUtil extends Logging { private[spark] def getFSBytesReadOnThreadCallback(): () => Long = { val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum val baseline = (Thread.currentThread().getId, f()) - val bytesReadMap = new ConcurrentHashMap[Long, Long]() - () => { - bytesReadMap.put(Thread.currentThread().getId, f()) - bytesReadMap.asScala.map { case (k, v) => - v - (if (k == baseline._1) baseline._2 else 0) - }.sum + new Function0[Long] { + private val bytesReadMap = new mutable.HashMap[Long, Long]() + + /** + * Returns a function that can be called to calculate Hadoop FileSystem bytes read. + * This function may be called in both spawned child threads and parent task thread (in + * PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics. + * So we need a map to track the bytes read from the child threads and parent thread, + * summing them together to get the bytes read of this task. + */ + override def apply(): Long = { + bytesReadMap.synchronized { + bytesReadMap.put(Thread.currentThread().getId, f()) + bytesReadMap.map { case (k, v) => + v - (if (k == baseline._1) baseline._2 else 0) + }.sum + } + } } }