diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 26fa0cb6d7bde..8a0f5a602de12 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -76,10 +76,6 @@ class ExternalAppendOnlyMap[K, V, C]( private val sparkConf = SparkEnv.get.conf private val diskBlockManager = blockManager.diskBlockManager - // Number of pairs inserted since last spill; note that we count them even if a value is merged - // with a previous key in case we're doing something like groupBy where the result grows - protected[this] var elementsRead = 0L - /** * Size of object batches when reading/writing from serializers. * @@ -132,7 +128,7 @@ class ExternalAppendOnlyMap[K, V, C]( currentMap = new SizeTrackingAppendOnlyMap[K, C] } currentMap.changeValue(curEntry._1, update) - elementsRead += 1 + addElementsRead() } } @@ -209,8 +205,6 @@ class ExternalAppendOnlyMap[K, V, C]( } spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes)) - - elementsRead = 0 } def diskBytesSpilled: Long = _diskBytesSpilled diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index c1ce13683b569..c617ff5c51d04 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -119,10 +119,6 @@ private[spark] class ExternalSorter[K, V, C]( private var map = new SizeTrackingAppendOnlyMap[(Int, K), C] private var buffer = new SizeTrackingPairBuffer[(Int, K), C] - // Number of pairs read from input since last spill; note that we count them even if a value is - // merged with a previous key in case we're doing something like groupBy where the result grows - protected[this] var elementsRead = 0L - // Total spilling statistics private var _diskBytesSpilled = 0L @@ -204,7 +200,7 @@ private[spark] class ExternalSorter[K, V, C]( if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2) } while (records.hasNext) { - elementsRead += 1 + addElementsRead() kv = records.next() map.changeValue((getPartition(kv._1), kv._1), update) maybeSpillCollection(usingMap = true) @@ -212,7 +208,7 @@ private[spark] class ExternalSorter[K, V, C]( } else { // Stick values into our buffer while (records.hasNext) { - elementsRead += 1 + addElementsRead() val kv = records.next() buffer.insert((getPartition(kv._1), kv._1), kv._2.asInstanceOf[C]) maybeSpillCollection(usingMap = false) diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 0e4c6d633a4a9..cb73b377fca98 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -36,7 +36,11 @@ private[spark] trait Spillable[C] { protected def spill(collection: C): Unit // Number of elements read from input since last spill - protected var elementsRead: Long + protected def elementsRead: Long = _elementsRead + + // Called by subclasses every time a record is read + // It's used for checking spilling frequency + protected def addElementsRead(): Unit = { _elementsRead += 1 } // Memory manager that can be used to acquire/release memory private[this] val shuffleMemoryManager = SparkEnv.get.shuffleMemoryManager @@ -44,6 +48,9 @@ private[spark] trait Spillable[C] { // What threshold of elementsRead we start estimating collection size at private[this] val trackMemoryThreshold = 1000 + // Number of elements read from input since last spill + private[this] var _elementsRead = 0L + // How much of the shared memory pool this collection has claimed private[this] var myMemoryThreshold = 0L @@ -76,6 +83,7 @@ private[spark] trait Spillable[C] { spill(collection) + _elementsRead = 0 // Keep track of spills, and release memory _memoryBytesSpilled += currentMemory releaseMemoryForThisThread()