Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory #10024

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ protected MemoryConsumer(TaskMemoryManager taskMemoryManager) {
/**
* Returns the size of used memory in bytes.
*/
long getUsed() {
protected long getUsed() {
return used;
}

Expand Down Expand Up @@ -130,4 +130,22 @@ protected void freePage(MemoryBlock page) {
used -= page.size();
taskMemoryManager.freePage(page, this);
}

/**
* Allocates a heap memory of `size`.
*/
public long allocateHeapExecutionMemory(long size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not actually create any object, I'd like to call it acquireOnHeapMemory

long granted =
taskMemoryManager.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this);
used += granted;
return granted;
}

/**
* Release N bytes of heap memory.
*/
public void freeHeapExecutionMemory(long size) {
taskMemoryManager.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this);
used -= size;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -408,4 +408,11 @@ public long cleanUpAllAllocatedMemory() {
public long getMemoryConsumptionForThisTask() {
return memoryManager.getExecutionMemoryUsageForTask(taskAttemptId);
}

/**
* Returns Tungsten memory mode
*/
public MemoryMode getTungstenMemoryMode(){
return tungstenMemoryMode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ class ExternalAppendOnlyMap[K, V, C](
blockManager: BlockManager = SparkEnv.get.blockManager,
context: TaskContext = TaskContext.get(),
serializerManager: SerializerManager = SparkEnv.get.serializerManager)
extends Iterable[(K, C)]
extends Spillable[SizeTracker](context.taskMemoryManager())
with Serializable
with Logging
with Spillable[SizeTracker] {
with Iterable[(K, C)] {

if (context == null) {
throw new IllegalStateException(
Expand All @@ -81,8 +81,6 @@ class ExternalAppendOnlyMap[K, V, C](
this(createCombiner, mergeValue, mergeCombiners, serializer, blockManager, TaskContext.get())
}

override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()

private var currentMap = new SizeTrackingAppendOnlyMap[K, C]
private val spilledMaps = new ArrayBuffer[DiskMapIterator]
private val sparkConf = SparkEnv.get.conf
Expand Down Expand Up @@ -117,6 +115,8 @@ class ExternalAppendOnlyMap[K, V, C](
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()

private var inMemoryOrDiskIterator: Iterator[(K, C)] = null

/**
* Number of files this map has spilled so far.
* Exposed for testing.
Expand Down Expand Up @@ -182,6 +182,31 @@ class ExternalAppendOnlyMap[K, V, C](
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
override protected[this] def spill(collection: SizeTracker): Unit = {
val inMemoryIterator = currentMap.destructiveSortedIterator(keyComparator)
val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator)
spilledMaps.append(diskMapIterator)
}

/**
* Force to spilling the current in-memory collection to disk to release memory,
* It will be called by TaskMemoryManager when there is not enough memory for the task.
*/
override protected[this] def forceSpill(): Boolean = {
assert(inMemoryOrDiskIterator != null)
val inMemoryIterator = inMemoryOrDiskIterator
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator)
inMemoryOrDiskIterator = diskMapIterator
currentMap = null
true
}

/**
* Spill the in-memory Iterator to a temporary file on disk.
*/
private[this] def spillMemoryIteratorToDisk(inMemoryIterator: Iterator[(K, C)])
: DiskMapIterator = {
val (blockId, file) = diskBlockManager.createTempLocalBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
Expand All @@ -202,9 +227,8 @@ class ExternalAppendOnlyMap[K, V, C](

var success = false
try {
val it = currentMap.destructiveSortedIterator(keyComparator)
while (it.hasNext) {
val kv = it.next()
while (inMemoryIterator.hasNext) {
val kv = inMemoryIterator.next()
writer.write(kv._1, kv._2)
objectsWritten += 1

Expand Down Expand Up @@ -237,9 +261,23 @@ class ExternalAppendOnlyMap[K, V, C](
}
}

spilledMaps.append(new DiskMapIterator(file, blockId, batchSizes))
new DiskMapIterator(file, blockId, batchSizes)
}

/**
* Returns a destructive iterator for iterating over the entries of this map.
* If this iterator is forced spill to disk to release memory when there is not enough memory,
* it returns pairs from an on-disk map.
*/
def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = {
inMemoryOrDiskIterator = inMemoryIterator
new Iterator[(K, C)] {

override def hasNext = inMemoryOrDiskIterator.hasNext

override def next() = inMemoryOrDiskIterator.next()
}
}
/**
* Return a destructive iterator that merges the in-memory map with the spilled maps.
* If no spill has occurred, simply return the in-memory map's iterator.
Expand All @@ -250,15 +288,18 @@ class ExternalAppendOnlyMap[K, V, C](
"ExternalAppendOnlyMap.iterator is destructive and should only be called once.")
}
if (spilledMaps.isEmpty) {
CompletionIterator[(K, C), Iterator[(K, C)]](currentMap.iterator, freeCurrentMap())
CompletionIterator[(K, C), Iterator[(K, C)]](
destructiveIterator(currentMap.iterator), freeCurrentMap())
} else {
new ExternalIterator()
}
}

private def freeCurrentMap(): Unit = {
currentMap = null // So that the memory can be garbage-collected
releaseMemory()
if (currentMap != null) {
currentMap = null // So that the memory can be garbage-collected
releaseMemory()
}
}

/**
Expand All @@ -272,8 +313,8 @@ class ExternalAppendOnlyMap[K, V, C](

// Input streams are derived both from the in-memory map and spilled maps on disk
// The in-memory map is sorted in place, while the spilled maps are already in sorted order
private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](
currentMap.destructiveSortedIterator(keyComparator), freeCurrentMap())
private val sortedMap = CompletionIterator[(K, C), Iterator[(K, C)]](destructiveIterator(
currentMap.destructiveSortedIterator(keyComparator)), freeCurrentMap())
private val inputStreams = (Seq(sortedMap) ++ spilledMaps).map(it => it.buffered)

inputStreams.foreach { it =>
Expand Down Expand Up @@ -534,6 +575,10 @@ class ExternalAppendOnlyMap[K, V, C](

/** Convenience function to hash the given (K, C) pair by the key. */
private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1)

override def toString(): String = {
this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode())
}
}

private[spark] object ExternalAppendOnlyMap {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,8 @@ private[spark] class ExternalSorter[K, V, C](
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
serializer: Serializer = SparkEnv.get.serializer)
extends Logging
with Spillable[WritablePartitionedPairCollection[K, C]] {

override protected[this] def taskMemoryManager: TaskMemoryManager = context.taskMemoryManager()
extends Spillable[WritablePartitionedPairCollection[K, C]](context.taskMemoryManager())
with Logging {

private val conf = SparkEnv.get.conf

Expand Down Expand Up @@ -137,6 +135,10 @@ private[spark] class ExternalSorter[K, V, C](
private var _peakMemoryUsedBytes: Long = 0L
def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes

private var isShuffleSort: Boolean = true
var forceSpillFile: Option[SpilledFile] = None
private var inMemoryOrDiskIterator: Iterator[((Int, K), C)] = null

// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
// user. (A partial ordering means that equal keys have comparator.compare(k, k) = 0, but some
Expand All @@ -161,7 +163,7 @@ private[spark] class ExternalSorter[K, V, C](
// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
private[this] case class SpilledFile(
private[collection] case class SpilledFile(
file: File,
blockId: BlockId,
serializerBatchSizes: Array[Long],
Expand Down Expand Up @@ -235,6 +237,52 @@ private[spark] class ExternalSorter[K, V, C](
* @param collection whichever collection we're using (map or buffer)
*/
override protected[this] def spill(collection: WritablePartitionedPairCollection[K, C]): Unit = {
val inMemoryIterator = collection.destructiveSortedWritablePartitionedIterator(comparator)
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
spills.append(spillFile)
}

/**
* Force to spilling the current in-memory collection to disk to release memory,
* It will be called by TaskMemoryManager when there is not enough memory for the task.
*/
override protected[this] def forceSpill(): Boolean = {
if (isShuffleSort) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be triggered by a different thread, so it should be thread safe.

false
} else {
assert(inMemoryOrDiskIterator != null)
val it = inMemoryOrDiskIterator
val inMemoryIterator = new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}

def hasNext(): Boolean = cur != null

def nextPartition(): Int = cur._1._1
}
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
forceSpillFile = Some(spillMemoryIteratorToDisk(inMemoryIterator))
val spillReader = new SpillReader(forceSpillFile.get)
inMemoryOrDiskIterator = (0 until numPartitions).iterator.flatMap { p =>
val iterator = spillReader.readNextPartition()
iterator.map(cur => ((p, cur._1), cur._2))
}
map = null
buffer = null
true
}
}

/**
* Spill contents of in-memory iterator to a temporary file on disk.
*/
private[this] def spillMemoryIteratorToDisk(inMemoryIterator: WritablePartitionedIterator)
: SpilledFile = {
// Because these files may be read during shuffle, their compression must be controlled by
// spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
// createTempShuffleBlock here; see SPARK-3426 for more context.
Expand Down Expand Up @@ -271,12 +319,11 @@ private[spark] class ExternalSorter[K, V, C](

var success = false
try {
val it = collection.destructiveSortedWritablePartitionedIterator(comparator)
while (it.hasNext) {
val partitionId = it.nextPartition()
while (inMemoryIterator.hasNext) {
val partitionId = inMemoryIterator.nextPartition()
require(partitionId >= 0 && partitionId < numPartitions,
s"partition Id: ${partitionId} should be in the range [0, ${numPartitions})")
it.writeNext(writer)
inMemoryIterator.writeNext(writer)
elementsPerPartition(partitionId) += 1
objectsWritten += 1

Expand Down Expand Up @@ -308,7 +355,7 @@ private[spark] class ExternalSorter[K, V, C](
}
}

spills.append(SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition))
SpilledFile(file, blockId, batchSizes.toArray, elementsPerPartition)
}

/**
Expand Down Expand Up @@ -599,6 +646,25 @@ private[spark] class ExternalSorter[K, V, C](
}
}

/**
* Returns a destructive iterator for iterating over the entries of this map.
* If this iterator is forced spill to disk to release memory when there is not enough memory,
* it returns pairs from an on-disk map.
*/
def destructiveIterator(memoryIterator: Iterator[((Int, K), C)]): Iterator[((Int, K), C)] = {
if (isShuffleSort) {
memoryIterator
} else {
inMemoryOrDiskIterator = memoryIterator
new Iterator[((Int, K), C)] {

override def hasNext = inMemoryOrDiskIterator.hasNext

override def next() = inMemoryOrDiskIterator.next()
}
}
}

/**
* Return an iterator over all the data written to this object, grouped by partition and
* aggregated by the requested aggregator. For each partition we then have an iterator over its
Expand All @@ -618,21 +684,26 @@ private[spark] class ExternalSorter[K, V, C](
// we don't even need to sort by anything other than partition ID
if (!ordering.isDefined) {
// The user hasn't requested sorted keys, so only sort by partition ID, not key
groupByPartition(collection.partitionedDestructiveSortedIterator(None))
groupByPartition(destructiveIterator(collection.partitionedDestructiveSortedIterator(None)))
} else {
// We do need to sort by both partition ID and key
groupByPartition(collection.partitionedDestructiveSortedIterator(Some(keyComparator)))
groupByPartition(destructiveIterator(
collection.partitionedDestructiveSortedIterator(Some(keyComparator))))
}
} else {
// Merge spilled and in-memory data
merge(spills, collection.partitionedDestructiveSortedIterator(comparator))
merge(spills, destructiveIterator(
collection.partitionedDestructiveSortedIterator(comparator)))
}
}

/**
* Return an iterator over all the data written to this object, aggregated by our aggregator.
*/
def iterator: Iterator[Product2[K, C]] = partitionedIterator.flatMap(pair => pair._2)
def iterator: Iterator[Product2[K, C]] = {
isShuffleSort = false
partitionedIterator.flatMap(pair => pair._2)
}

/**
* Write all the data added into this ExternalSorter into a file in the disk store. This is
Expand Down Expand Up @@ -689,11 +760,18 @@ private[spark] class ExternalSorter[K, V, C](
}

def stop(): Unit = {
map = null // So that the memory can be garbage-collected
buffer = null // So that the memory can be garbage-collected
spills.foreach(s => s.file.delete())
spills.clear()
releaseMemory()
forceSpillFile.foreach(_.file.delete())
if (map != null || buffer != null) {
map = null // So that the memory can be garbage-collected
buffer = null // So that the memory can be garbage-collected
releaseMemory()
}
}

override def toString(): String = {
this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this?

}

/**
Expand Down
Loading