Skip to content

Commit

Permalink
Fix deadlock in BounceBufferPool (#12212)
Browse files Browse the repository at this point in the history
This fixes #12211

The issue can only show up if we set the GPU concurrent value to be > 4,
and we spill a lot. This should be rather rare, but it is possible to
hit.

This fix is to not let the blocking queue inside of the pool to actually
block. That way we have control over when threads are woken up and when
they are not.

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Feb 25, 2025
1 parent e96bf43 commit 8d4f21a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2302,7 +2302,7 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.
val CHUNKED_PACK_BOUNCE_BUFFER_COUNT = conf("spark.rapids.sql.chunkedPack.bounceBuffers")
.doc("Number of chunked pack bounce buffers, needed during spill from GPU to host memory. ")
.internal()
.longConf
.integerConf
.checkValue(v => v >= 1,
"The chunked pack bounce buffer count must be at least 1")
.createWithDefault(4)
Expand All @@ -2321,7 +2321,7 @@ val SHUFFLE_COMPRESSION_LZ4_CHUNK_SIZE = conf("spark.rapids.shuffle.compression.
conf("spark.rapids.memory.host.spillToDiskBounceBuffers")
.doc("Number of bounce buffers used for gpu to disk spill that bypasses the host store.")
.internal()
.longConf
.integerConf
.checkValue(v => v >= 1,
"The gpu to disk spill bounce buffer count must be positive")
.createWithDefault(4)
Expand Down Expand Up @@ -3273,11 +3273,11 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val chunkedPackBounceBufferSize: Long = get(CHUNKED_PACK_BOUNCE_BUFFER_SIZE)

lazy val chunkedPackBounceBufferCount: Long = get(CHUNKED_PACK_BOUNCE_BUFFER_COUNT)
lazy val chunkedPackBounceBufferCount: Int = get(CHUNKED_PACK_BOUNCE_BUFFER_COUNT)

lazy val spillToDiskBounceBufferSize: Long = get(SPILL_TO_DISK_BOUNCE_BUFFER_SIZE)

lazy val spillToDiskBounceBufferCount: Long = get(SPILL_TO_DISK_BOUNCE_BUFFER_COUNT)
lazy val spillToDiskBounceBufferCount: Int = get(SPILL_TO_DISK_BOUNCE_BUFFER_COUNT)

lazy val splitUntilSizeOverride: Option[Long] = get(SPLIT_UNTIL_SIZE_OVERRIDE)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.nio.channels.{Channels, FileChannel, WritableByteChannel}
import java.nio.file.StandardOpenOption
import java.util
import java.util.UUID
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue}
import java.util.concurrent.{ArrayBlockingQueue, ConcurrentHashMap}

import scala.collection.mutable

Expand Down Expand Up @@ -1805,21 +1805,28 @@ private[spill] class BounceBuffer[T <: AutoCloseable](
* Callers should synchronize before calling close on their `DeviceMemoryBuffer`s.
*/
class BounceBufferPool[T <: AutoCloseable](private val bufSize: Long,
private val bbCount: Long,
private val bbCount: Int,
private val allocator: Long => T)
extends AutoCloseable with Logging {

private val pool = new LinkedBlockingQueue[BounceBuffer[T]]
for (_ <- 1L to bbCount) {
private val pool = new ArrayBlockingQueue[BounceBuffer[T]](bbCount)
for (_ <- 1 to bbCount) {
pool.offer(new BounceBuffer[T](allocator(bufSize), this))
}

def bufferSize: Long = bufSize
def nextBuffer(): BounceBuffer[T] = synchronized {
if (closed) {
logError("tried to acquire a bounce buffer after the" +
throw new IllegalStateException("tried to acquire a bounce buffer after the" +
"pool has been closed!")
}
while (pool.size() <= 0) {
wait()
if (closed) {
throw new IllegalStateException("tried to acquire a bounce buffer after the" +
"pool has been closed!")
}
}
pool.take()
}

Expand All @@ -1828,6 +1835,8 @@ class BounceBufferPool[T <: AutoCloseable](private val bufSize: Long,
buffer.release()
} else {
pool.offer(buffer)
// Wake up one thread to take the next bounce buffer
notify()
}
}

Expand All @@ -1842,6 +1851,8 @@ class BounceBufferPool[T <: AutoCloseable](private val bufSize: Long,

pool.forEach(_.release())
pool.clear()
// Wake up any threads that might be waiting still...
notifyAll()
}
}
}
Expand Down

0 comments on commit 8d4f21a

Please sign in to comment.