Skip to content

Commit

Permalink
Fix issue where SparkRapidsAdaptor could be shut down but not through…
Browse files Browse the repository at this point in the history
… RmmSpark (#984)

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Mar 1, 2023
1 parent 76d7acc commit 7ff429b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/main/java/com/nvidia/spark/rapids/jni/RmmSpark.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public static long getCurrentThreadId() {
*/
public static void associateThreadWithTask(long threadId, long taskId) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
sra.associateThreadWithTask(threadId, taskId);
}
}
Expand All @@ -140,7 +140,7 @@ public static void associateCurrentThreadWithTask(long taskId) {
*/
public static void associateThreadWithShuffle(long threadId) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
sra.associateThreadWithShuffle(threadId);
}
}
Expand All @@ -160,7 +160,7 @@ public static void associateCurrentThreadWithShuffle() {
*/
public static void removeThreadAssociation(long threadId) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
sra.removeThreadAssociation(threadId);
}
}
Expand All @@ -180,7 +180,7 @@ public static void removeCurrentThreadAssociation() {
*/
public static void taskDone(long taskId) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
sra.taskDone(taskId);
}
}
Expand All @@ -192,7 +192,7 @@ public static void taskDone(long taskId) {
*/
public static void threadCouldBlockOnShuffle(long threadId) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
sra.threadCouldBlockOnShuffle(threadId);
}
}
Expand All @@ -211,7 +211,7 @@ public static void threadCouldBlockOnShuffle() {
*/
public static void threadDoneWithShuffle(long threadId) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
sra.threadDoneWithShuffle(threadId);
}
}
Expand Down Expand Up @@ -247,7 +247,7 @@ public static void blockThreadUntilReady() {
// Technically there is a race here, but because this can block we cannot hold the Rmm
// lock while doing this, or we can deadlock. So we are going to rely on Rmm shutting down
// or being reconfigured to be rare.
if (local != null) {
if (local != null && local.isOpen()) {
local.blockThreadUntilReady();
}
}
Expand All @@ -267,7 +267,7 @@ public static void forceRetryOOM(long threadId) {
*/
public static void forceRetryOOM(long threadId, int numOOMs) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
sra.forceRetryOOM(threadId, numOOMs);
} else {
throw new IllegalStateException("RMM has not been configured for OOM injection");
Expand All @@ -290,7 +290,7 @@ public static void forceSplitAndRetryOOM(long threadId) {
*/
public static void forceSplitAndRetryOOM(long threadId, int numOOMs) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
sra.forceSplitAndRetryOOM(threadId, numOOMs);
} else {
throw new IllegalStateException("RMM has not been configured for OOM injection");
Expand All @@ -315,7 +315,7 @@ public static void forceCudfException(long threadId) {
*/
public static void forceCudfException(long threadId, int numTimes) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
sra.forceCudfException(threadId, numTimes);
} else {
throw new IllegalStateException("RMM has not been configured for OOM injection");
Expand All @@ -325,7 +325,7 @@ public static void forceCudfException(long threadId, int numTimes) {

public static RmmSparkThreadState getStateOf(long threadId) {
synchronized (Rmm.class) {
if (sra != null) {
if (sra != null && sra.isOpen()) {
return sra.getStateOf(threadId);
} else {
// sra is not set so the thread is by definition unknown to it.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ public void close() {
super.close();
}


public boolean isOpen() {
return handle != 0;
}

/**
* Associate a thread with a given task id.
* @param threadId the thread ID to use (not java thread id)
Expand Down

0 comments on commit 7ff429b

Please sign in to comment.