diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 9cc321af4bde2..6afe58bff5229 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -23,6 +23,7 @@ import java.text.DateFormat import java.util.{Arrays, Comparator, Date, Locale} import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.util.control.NonFatal import com.google.common.primitives.Longs @@ -143,14 +144,29 @@ class SparkHadoopUtil extends Logging { * Returns a function that can be called to find Hadoop FileSystem bytes read. If * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will * return the bytes read on r since t. - * - * @return None if the required method can't be found. */ private[spark] def getFSBytesReadOnThreadCallback(): () => Long = { - val threadStats = FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics) - val f = () => threadStats.map(_.getBytesRead).sum - val baselineBytesRead = f() - () => f() - baselineBytesRead + val f = () => FileSystem.getAllStatistics.asScala.map(_.getThreadStatistics.getBytesRead).sum + val baseline = (Thread.currentThread().getId, f()) + + /** + * This function may be called in both spawned child threads and parent task thread (in + * PythonRDD), and Hadoop FileSystem uses thread local variables to track the statistics. + * So we need a map to track the bytes read from the child threads and parent thread, + * summing them together to get the bytes read of this task. + */ + new Function0[Long] { + private val bytesReadMap = new mutable.HashMap[Long, Long]() + + override def apply(): Long = { + bytesReadMap.synchronized { + bytesReadMap.put(Thread.currentThread().getId, f()) + bytesReadMap.map { case (k, v) => + v - (if (k == baseline._1) baseline._2 else 0) + }.sum + } + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 4bf8ecc383542..76ea8b86c53d2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -251,7 +251,13 @@ class HadoopRDD[K, V]( null } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener{ context => closeIfNeeded() } + context.addTaskCompletionListener { context => + // Update the bytes read before closing is to make sure lingering bytesRead statistics in + // this thread get correctly added. + updateBytesRead() + closeIfNeeded() + } + private val key: K = if (reader == null) null.asInstanceOf[K] else reader.createKey() private val value: V = if (reader == null) null.asInstanceOf[V] else reader.createValue() diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index ce3a9a2a1e2a8..482875e6c1ac5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -191,7 +191,13 @@ class NewHadoopRDD[K, V]( } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => close()) + context.addTaskCompletionListener { context => + // Update the bytesRead before closing is to make sure lingering bytesRead statistics in + // this thread get correctly added. + updateBytesRead() + close() + } + private var havePair = false private var recordsSinceMetricsUpdate = 0 diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 5d522189a0c29..6f4203da1d866 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -34,7 +34,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SharedSparkContext, SparkFunSuite} import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext with BeforeAndAfter { @@ -319,6 +319,35 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext } assert(bytesRead >= tmpFile.length()) } + + test("input metrics with old Hadoop API in different thread") { + val bytesRead = runAndReturnBytesRead { + sc.textFile(tmpFilePath, 4).mapPartitions { iter => + val buf = new ArrayBuffer[String]() + ThreadUtils.runInNewThread("testThread", false) { + iter.flatMap(_.split(" ")).foreach(buf.append(_)) + } + + buf.iterator + }.count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new Hadoop API in different thread") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], + classOf[Text]).mapPartitions { iter => + val buf = new ArrayBuffer[String]() + ThreadUtils.runInNewThread("testThread", false) { + iter.map(_._2.toString).flatMap(_.split(" ")).foreach(buf.append(_)) + } + + buf.iterator + }.count() + } + assert(bytesRead >= tmpFile.length()) + } } /**