Skip to content

Commit

Permalink
fix thread safety & add ut
Browse files Browse the repository at this point in the history
  • Loading branch information
lianhuiwang committed Apr 19, 2016
1 parent 7c36ef0 commit 70bcffa
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ protected void freePage(MemoryBlock page) {
/**
* Allocates a heap memory of `size`.
*/
public long allocateHeapExecutionMemory(long size) {
public long acquireOnHeapMemory(long size) {
long granted =
taskMemoryManager.acquireExecutionMemory(size, MemoryMode.ON_HEAP, this);
used += granted;
Expand All @@ -144,7 +144,7 @@ public long allocateHeapExecutionMemory(long size) {
/**
* Release N bytes of heap memory.
*/
public void freeHeapExecutionMemory(long size) {
public void freeOnHeapMemory(long size) {
taskMemoryManager.releaseExecutionMemory(size, MemoryMode.ON_HEAP, this);
used -= size;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class ExternalAppendOnlyMap[K, V, C](
private val keyComparator = new HashComparator[K]
private val ser = serializer.newInstance()

private var inMemoryOrDiskIterator: Iterator[(K, C)] = null
private var readingIterator: SpillableIterator = null

/**
* Number of files this map has spilled so far.
Expand Down Expand Up @@ -192,14 +192,12 @@ class ExternalAppendOnlyMap[K, V, C](
* It will be called by TaskMemoryManager when there is not enough memory for the task.
*/
override protected[this] def forceSpill(): Boolean = {
assert(inMemoryOrDiskIterator != null)
val inMemoryIterator = inMemoryOrDiskIterator
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
val diskMapIterator = spillMemoryIteratorToDisk(inMemoryIterator)
inMemoryOrDiskIterator = diskMapIterator
currentMap = null
true
assert(readingIterator != null)
val isSpilled = readingIterator.spill()
if (isSpilled) {
currentMap = null
}
isSpilled
}

/**
Expand Down Expand Up @@ -270,14 +268,10 @@ class ExternalAppendOnlyMap[K, V, C](
* it returns pairs from an on-disk map.
*/
def destructiveIterator(inMemoryIterator: Iterator[(K, C)]): Iterator[(K, C)] = {
inMemoryOrDiskIterator = inMemoryIterator
new Iterator[(K, C)] {

override def hasNext = inMemoryOrDiskIterator.hasNext

override def next() = inMemoryOrDiskIterator.next()
}
readingIterator = new SpillableIterator(inMemoryIterator)
readingIterator
}

/**
* Return a destructive iterator that merges the in-memory map with the spilled maps.
* If no spill has occurred, simply return the in-memory map's iterator.
Expand Down Expand Up @@ -573,6 +567,39 @@ class ExternalAppendOnlyMap[K, V, C](
context.addTaskCompletionListener(context => cleanup())
}

private[this] class SpillableIterator(var upstream: Iterator[(K, C)])
extends Iterator[(K, C)] {

private var nextUpstream: Iterator[(K, C)] = null

private var cur: (K, C) = null

def spill(): Boolean = synchronized {
if (upstream == null || nextUpstream != null) {
false
} else {
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
s"it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
nextUpstream = spillMemoryIteratorToDisk(upstream)
true
}
}

override def hasNext: Boolean = synchronized {
if (nextUpstream != null) {
upstream = nextUpstream
nextUpstream = null
}
val r = upstream.hasNext
if (r) {
cur = upstream.next()
}
r
}

override def next(): (K, C) = cur
}

/** Convenience function to hash the given (K, C) pair by the key. */
private def hashKey(kc: (K, C)): Int = ExternalAppendOnlyMap.hash(kc._1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ private[spark] class ExternalSorter[K, V, C](
def peakMemoryUsedBytes: Long = _peakMemoryUsedBytes

private var isShuffleSort: Boolean = true
var forceSpillFile: Option[SpilledFile] = None
private var inMemoryOrDiskIterator: Iterator[((Int, K), C)] = null
private val forceSpillFiles = new ArrayBuffer[SpilledFile]
private var readingIterator: SpillableIterator = null

// A comparator for keys K that orders them within a partition to allow aggregation or sorting.
// Can be a partial ordering by hash code if a total ordering is not provided through by the
Expand All @@ -163,7 +163,7 @@ private[spark] class ExternalSorter[K, V, C](
// Information about a spilled file. Includes sizes in bytes of "batches" written by the
// serializer as we periodically reset its stream, as well as number of elements in each
// partition, used to efficiently keep track of partitions when merging.
private[collection] case class SpilledFile(
private[this] case class SpilledFile(
file: File,
blockId: BlockId,
serializerBatchSizes: Array[Long],
Expand Down Expand Up @@ -250,31 +250,13 @@ private[spark] class ExternalSorter[K, V, C](
if (isShuffleSort) {
false
} else {
assert(inMemoryOrDiskIterator != null)
val it = inMemoryOrDiskIterator
val inMemoryIterator = new WritablePartitionedIterator {
private[this] var cur = if (it.hasNext) it.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (it.hasNext) it.next() else null
}

def hasNext(): Boolean = cur != null

def nextPartition(): Int = cur._1._1
}
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
forceSpillFile = Some(spillMemoryIteratorToDisk(inMemoryIterator))
val spillReader = new SpillReader(forceSpillFile.get)
inMemoryOrDiskIterator = (0 until numPartitions).iterator.flatMap { p =>
val iterator = spillReader.readNextPartition()
iterator.map(cur => ((p, cur._1), cur._2))
assert(readingIterator != null)
val isSpilled = readingIterator.spill()
if (isSpilled) {
map = null
buffer = null
}
map = null
buffer = null
true
isSpilled
}
}

Expand Down Expand Up @@ -655,13 +637,8 @@ private[spark] class ExternalSorter[K, V, C](
if (isShuffleSort) {
memoryIterator
} else {
inMemoryOrDiskIterator = memoryIterator
new Iterator[((Int, K), C)] {

override def hasNext = inMemoryOrDiskIterator.hasNext

override def next() = inMemoryOrDiskIterator.next()
}
readingIterator = new SpillableIterator(memoryIterator)
readingIterator
}
}

Expand Down Expand Up @@ -762,18 +739,15 @@ private[spark] class ExternalSorter[K, V, C](
def stop(): Unit = {
spills.foreach(s => s.file.delete())
spills.clear()
forceSpillFile.foreach(_.file.delete())
forceSpillFiles.foreach(s => s.file.delete())
forceSpillFiles.clear()
if (map != null || buffer != null) {
map = null // So that the memory can be garbage-collected
buffer = null // So that the memory can be garbage-collected
releaseMemory()
}
}

override def toString(): String = {
this.getClass.getName + "@" + java.lang.Integer.toHexString(this.hashCode())
}

/**
* Given a stream of ((partition, key), combiner) pairs *assumed to be sorted by partition ID*,
* group together the pairs for each partition into a sub-iterator.
Expand Down Expand Up @@ -805,4 +779,55 @@ private[spark] class ExternalSorter[K, V, C](
(elem._1._2, elem._2)
}
}

private[this] class SpillableIterator(var upstream: Iterator[((Int, K), C)])
extends Iterator[((Int, K), C)] {

private var nextUpstream: Iterator[((Int, K), C)] = null

private var cur: ((Int, K), C) = null

def spill(): Boolean = synchronized {
if (upstream == null || nextUpstream != null) {
false
} else {
val inMemoryIterator = new WritablePartitionedIterator {
private[this] var cur = if (upstream.hasNext) upstream.next() else null

def writeNext(writer: DiskBlockObjectWriter): Unit = {
writer.write(cur._1._2, cur._2)
cur = if (upstream.hasNext) upstream.next() else null
}

def hasNext(): Boolean = cur != null

def nextPartition(): Int = cur._1._1
}
logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map to disk and " +
s" it will release ${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
forceSpillFiles.append(spillFile)
val spillReader = new SpillReader(spillFile)
nextUpstream = (0 until numPartitions).iterator.flatMap { p =>
val iterator = spillReader.readNextPartition()
iterator.map(cur => ((p, cur._1), cur._2))
}
true
}
}

override def hasNext: Boolean = synchronized {
if (nextUpstream != null) {
upstream = nextUpstream
nextUpstream = null
}
val r = upstream.hasNext
if (r) {
cur = upstream.next()
}
r
}

override def next(): ((Int, K), C) = cur
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) {
// Claim up to double our current memory from the shuffle memory pool
val amountToRequest = 2 * currentMemory - myMemoryThreshold
val granted = allocateHeapExecutionMemory(amountToRequest)
val granted = acquireOnHeapMemory(amountToRequest)
myMemoryThreshold += granted
// If we were granted too little memory to grow further (either tryToAcquire returned 0,
// or we already had more memory than myMemoryThreshold), spill the current collection
Expand Down Expand Up @@ -126,7 +126,7 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager)
* Release our memory back to the execution pool so that other tasks can grab it.
*/
def releaseMemory(): Unit = {
freeHeapExecutionMemory(myMemoryThreshold)
freeOnHeapMemory(myMemoryThreshold)
myMemoryThreshold = 0L
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,4 +418,18 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext {
}
}

test("force to spill for external aggregation") {
val conf = createSparkConf(loadDefaults = false)
.set("spark.shuffle.memoryFraction", "0.01")
.set("spark.memory.useLegacyMode", "true")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
sc = new SparkContext("local", "test", conf)
val N = 2e6.toInt
sc.parallelize(1 to N, 10)
.map { i => (i, i) }
.groupByKey()
.reduceByKey(_ ++ _)
.count()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -608,4 +608,20 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
}
}
}

test("force to spill for sorting") {
val conf = createSparkConf(loadDefaults = false, kryo = false)
.set("spark.shuffle.memoryFraction", "0.01")
.set("spark.memory.useLegacyMode", "true")
.set("spark.shuffle.sort.bypassMergeThreshold", "0")
sc = new SparkContext("local", "test", conf)
val N = 2e6.toInt
val p = new org.apache.spark.HashPartitioner(10)
val p2 = new org.apache.spark.HashPartitioner(5)
sc.parallelize(1 to N, 10)
.map { x => (x % 10000) -> x.toLong }
.repartitionAndSortWithinPartitions(p2)
.repartitionAndSortWithinPartitions(p)
.count()
}
}

0 comments on commit 70bcffa

Please sign in to comment.