Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add in support for OOM retry #7822

Merged
merged 3 commits into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -28,6 +28,7 @@ import ai.rapids.cudf.{BaseDeviceMemoryBuffer, MemoryBuffer, NvtxColor, NvtxRang
import com.nvidia.spark.rapids.{Arm, GpuDeviceManager, RapidsConf}
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.ThreadFactoryBuilder
import com.nvidia.spark.rapids.jni.RmmSpark
import com.nvidia.spark.rapids.shuffle.{ClientConnection, MemoryRegistrationCallback, MessageType, MetadataTransportBuffer, TransportBuffer, TransportUtils}
import org.openucx.jucx._
import org.openucx.jucx.ucp._
Expand Down Expand Up @@ -105,7 +106,9 @@ class UCX(transport: UCXShuffleTransport, executor: BlockManagerId, rapidsConf:
new ThreadFactoryBuilder()
.setNameFormat("progress-thread-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

// The pending queues are used to enqueue [[PendingReceive]] or [[PendingSend]], from executor
// task threads and [[progressThread]] will hand them to the UcpWorker thread.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import ai.rapids.cudf.{BaseDeviceMemoryBuffer, CudaMemoryBuffer, DeviceMemoryBuffer, HostMemoryBuffer, MemoryBuffer}
import com.nvidia.spark.rapids.{GpuDeviceManager, HashedPriorityQueue, RapidsConf}
import com.nvidia.spark.rapids.ThreadFactoryBuilder
import com.nvidia.spark.rapids.jni.RmmSpark
import com.nvidia.spark.rapids.shuffle._
import com.nvidia.spark.rapids.shuffle.{BounceBufferManager, BufferReceiveState, ClientConnection, PendingTransferRequest, RapidsShuffleClient, RapidsShuffleRequestHandler, RapidsShuffleServer, RapidsShuffleTransport, RefCountedDirectByteBuffer}

Expand Down Expand Up @@ -248,7 +249,9 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
new ThreadFactoryBuilder()
.setNameFormat("shuffle-transport-client-exec-%d")
.setDaemon(true)
.build),
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()),
// if we can't hand off because we are too busy, block the caller (in UCX's case,
// the progress thread)
new CallerRunsAndLogs())
Expand All @@ -258,7 +261,9 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat("shuffle-client-copy-thread-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

override def makeClient(blockManagerId: BlockManagerId): RapidsShuffleClient = {
val peerExecutorId = blockManagerId.executorId.toLong
Expand All @@ -280,14 +285,18 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat(s"shuffle-server-conn-thread-${shuffleServerId.executorId}-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

// This executor handles any task that would block (e.g. wait for spill synchronously due to OOM)
private[this] val serverCopyExecutor = Executors.newSingleThreadExecutor(
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat(s"shuffle-server-copy-thread-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

// This is used to queue up on the server all the [[BufferSendState]] as the server waits for
// bounce buffers to become available (it is the equivalent of the transport's throttle, minus
Expand All @@ -296,7 +305,9 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat(s"shuffle-server-bss-thread-%d")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

/**
* Construct a server instance
Expand Down Expand Up @@ -356,7 +367,9 @@ class UCXShuffleTransport(shuffleServerId: BlockManagerId, rapidsConf: RapidsCon
new ThreadFactoryBuilder()
.setNameFormat(s"shuffle-transport-throttle-monitor")
.setDaemon(true)
.build))
.build,
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

// helper class to hold transfer requests that have a bounce buffer
// and should be ready to be handled by a `BufferReceiveState`
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION.
* Copyright (c) 2020-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -350,15 +350,26 @@ object GpuDeviceManager extends Logging {
}

/** Wrap a thread factory with one that will set the GPU device on each thread created. */
def wrapThreadFactory(factory: ThreadFactory): ThreadFactory = new ThreadFactory() {
def wrapThreadFactory(factory: ThreadFactory,
before: () => Unit = null,
after: () => Unit = null): ThreadFactory = new ThreadFactory() {
private[this] val devId = getDeviceId.getOrElse {
throw new IllegalStateException("Device ID is not set")
}

override def newThread(runnable: Runnable): Thread = {
factory.newThread(() => {
Cuda.setDevice(devId)
runnable.run()
try {
if (before != null) {
before()
}
Comment on lines +370 to +372
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add logging here. An exception from before()/after() might be difficult to contextualize since it in a different thread.

Comment on lines +370 to +372
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider

Suggested change
if (before != null) {
before()
}
Option(before).foreach(_.apply())

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more functional. I get that. This is not performance critical code, but it is replacing a check and a branch, probably 3 or 4 instructions with calling a static method to create an object that then calls a method on that object with a function that is probably a separate class that had to be created, possibly as a singleton.

I personally prefer the null check, but if for consistency with other code styles we want the functional one liner I am fine with it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am ok with your preference. Performance considerations are irrelevant here. Thanks for considering the suggestion.

I just realized that we probably need neither version of the null check if you make the default parameter value a nop () => () instead of null

runnable.run()
} finally {
if (after != null) {
after()
}
}
})
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.{ConcurrentHashMap, Semaphore}
import scala.collection.mutable

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.jni.RmmSpark
import org.apache.commons.lang3.mutable.MutableInt

import org.apache.spark.TaskContext
Expand Down Expand Up @@ -132,6 +133,7 @@ private final class GpuSemaphore() extends Logging with Arm {
}
logDebug(s"Task $taskAttemptId acquiring GPU with $permits permits")
semaphore.acquire(permits)
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
if (refs != null) {
refs.count.increment()
} else {
Expand All @@ -142,13 +144,17 @@ private final class GpuSemaphore() extends Logging with Arm {
context.addTaskCompletionListener[Unit](completeTask)
}
GpuDeviceManager.initializeFromTask()
} else {
// Already had the semaphore, but we don't know if the thread is new or not
RmmSpark.associateCurrentThreadWithTask(taskAttemptId)
}
}
}

def releaseIfNecessary(context: TaskContext): Unit = {
val nvtxRange = new NvtxRange("Release GPU", NvtxColor.RED)
try {
RmmSpark.removeCurrentThreadAssociation()
val taskAttemptId = context.taskAttemptId()
val refs = activeTasks.get(taskAttemptId)
if (refs != null && refs.count.getValue > 0) {
Expand All @@ -164,6 +170,7 @@ private final class GpuSemaphore() extends Logging with Arm {

def completeTask(context: TaskContext): Unit = {
val taskAttemptId = context.taskAttemptId()
RmmSpark.taskDone(taskAttemptId)
val refs = activeTasks.remove(taskAttemptId)
if (refs == null) {
throw new IllegalStateException(s"Completion of unknown task $taskAttemptId")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.nvidia.spark.rapids.RapidsBufferCatalog.getExistingRapidsBufferAndAcq
import com.nvidia.spark.rapids.RapidsPluginImplicits._
import com.nvidia.spark.rapids.StorageTier.StorageTier
import com.nvidia.spark.rapids.format.TableMeta
import com.nvidia.spark.rapids.jni.RmmSpark

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -743,7 +744,19 @@ object RapidsBufferCatalog extends Logging with Arm {
rapidsConf.gpuOomDumpDir,
rapidsConf.isGdsSpillEnabled,
rapidsConf.gpuOomMaxRetries)
Rmm.setEventHandler(memoryEventHandler)

if (rapidsConf.sparkRmmStateEnable) {
val debugLoc = if (rapidsConf.sparkRmmDebugLocation.isEmpty) {
null
} else {
rapidsConf.sparkRmmDebugLocation
}

RmmSpark.setEventHandler(memoryEventHandler, debugLoc)
} else {
logWarning("SparkRMM retry has been disabled")
Rmm.setEventHandler(memoryEventHandler)
}

_shouldUnspill = rapidsConf.isUnspillEnabled
}
Expand Down
23 changes: 23 additions & 0 deletions sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsConf.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,25 @@ object RapidsConf {
.stringConf
.createWithDefault("NONE")

val SPARK_RMM_STATE_DEBUG = conf("spark.rapids.memory.gpu.state.debug")
.doc("To better recover from out of memory errors, RMM will track several states for " +
"the threads that interact with the GPU. This provides a log of those state " +
"transitions to aid in debugging it. STDOUT or STDERR will have the logging go there " +
"empty string will disable logging and anything else will be treated as a file to " +
"write the logs to.")
.startupOnly()
.stringConf
.createWithDefault("")

val SPARK_RMM_STATE_ENABLE = conf("spark.rapids.memory.gpu.state.enable")
.doc("Enabled or disable using the SparkRMM state tracking to improve " +
"OOM response. This includes possibly retrying parts of the processing in " +
"the case of an OOM")
.startupOnly()
.internal()
.booleanConf
.createWithDefault(true)

val GPU_OOM_DUMP_DIR = conf("spark.rapids.memory.gpu.oomDumpDir")
.doc("The path to a local directory where a heap dump will be created if the GPU " +
"encounters an unrecoverable out-of-memory (OOM) error. The filename will be of the " +
Expand Down Expand Up @@ -1959,6 +1978,10 @@ class RapidsConf(conf: Map[String, String]) extends Logging {

lazy val rmmDebugLocation: String = get(RMM_DEBUG)

lazy val sparkRmmDebugLocation: String = get(SPARK_RMM_STATE_DEBUG)

lazy val sparkRmmStateEnable: Boolean = get(SPARK_RMM_STATE_ENABLE)

lazy val gpuOomDumpDir: Option[String] = get(GPU_OOM_DUMP_DIR)

lazy val gpuOomMaxRetries: Int = get(GPU_OOM_MAX_RETRIES)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
* Copyright (c) 2021-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,7 @@ import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}

import scala.collection.mutable.ArrayBuffer

import com.nvidia.spark.rapids.jni.RmmSpark
import org.apache.commons.lang3.mutable.MutableLong

import org.apache.spark.SparkEnv
Expand Down Expand Up @@ -194,7 +195,9 @@ class RapidsShuffleHeartbeatEndpoint(pluginContext: PluginContext, conf: RapidsC
GpuDeviceManager.wrapThreadFactory(new ThreadFactoryBuilder()
.setNameFormat("rapids-shuffle-hb")
.setDaemon(true)
.build()))
.build(),
() => RmmSpark.associateCurrentThreadWithShuffle(),
() => RmmSpark.removeCurrentThreadAssociation()))

private class InitializeShuffleManager(ctx: PluginContext,
shuffleManager: RapidsShuffleInternalManagerBase) extends Runnable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import scala.collection.mutable

import ai.rapids.cudf.{NvtxColor, NvtxRange}
import com.nvidia.spark.rapids.{Arm, GpuSemaphore, NoopMetric, RapidsBuffer, RapidsBufferHandle, RapidsConf, ShuffleReceivedBufferCatalog}
import com.nvidia.spark.rapids.jni.RmmSpark

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -345,9 +346,14 @@ class RapidsShuffleIterator(

val blockedStart = System.currentTimeMillis()
var result: Option[ShuffleClientResult] = None

result = pollForResult(timeoutSeconds)
RmmSpark.threadCouldBlockOnShuffle()
try {
result = pollForResult(timeoutSeconds)
} finally {
RmmSpark.threadDoneWithShuffle()
}
val blockedTime = System.currentTimeMillis() - blockedStart

result match {
case Some(BufferReceived(handle)) =>
val nvtxRangeAfterGettingBatch = new NvtxRange("RapidsShuffleIterator.gotBatch",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

package com.nvidia.spark.rapids

import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer}
import ai.rapids.cudf.{Cuda, CudfException, DeviceMemoryBuffer}
import org.scalatest.{BeforeAndAfter, FunSuite}

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -54,8 +54,15 @@ class GpuDeviceManagerSuite extends FunSuite with Arm with BeforeAndAfter {
// initial allocation should fit within pool size
withResource(DeviceMemoryBuffer.allocate(allocSize)) { _ =>
assertThrows[OutOfMemoryError] {
// this should exceed the specified pool size
DeviceMemoryBuffer.allocate(allocSize).close()
try {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug code?

// this should exceed the specified pool size
DeviceMemoryBuffer.allocate(allocSize).close()
} catch {
case e: CudfException =>
System.err.println(e)
e.printStackTrace(System.err)
throw e
}
}
}
}
Expand Down
Loading