Skip to content

Commit

Permalink
Fix a corner case in the SparkResourceAdaptor state machine (#2976)
Browse files Browse the repository at this point in the history
This fixes NVIDIA/spark-rapids#12158

But it is not there are no real test changes as a part of this because
it we are not set up for testing spill in spark-rapids-jni.

---------

Signed-off-by: Robert (Bobby) Evans <bobby@apache.org>
  • Loading branch information
revans2 authored Feb 25, 2025
1 parent 2aaf1b3 commit 458b7da
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 9 deletions.
40 changes: 33 additions & 7 deletions src/main/cpp/src/SparkResourceAdaptorJni.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
* Copyright (c) 2023-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -285,6 +285,15 @@ class full_thread_state {
bool is_cpu_alloc = false;
// Is the thread transitively blocked on a pool or not.
bool pool_blocked = false;
// We keep track of when memory is freed, which lets us wake up
// blocked threads to make progress. But we do not keep track of
// when buffers are made spillable. This can result in us
// throwing a split and retry exception even if memory was made
// spillable. So, instead of tracking when any buffer is made
// spillable, we retry the allocation before we going to the
// BUFN_THROW state. This variable holds if we are in
// the middle of this retry or not.
bool is_retry_alloc_before_bufn = false;

oom_state_type retry_oom;
oom_state_type split_and_retry_oom;
Expand Down Expand Up @@ -1361,6 +1370,8 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
// pre allocate checks
auto const thread = threads.find(thread_id);
if (!was_recursive && thread != threads.end()) {
// The allocation succeeded so we are no longer doing a retry
thread->second.is_retry_alloc_before_bufn = false;
switch (thread->second.state) {
case thread_state::THREAD_ALLOC:
// fall through
Expand Down Expand Up @@ -1627,10 +1638,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
if (need_to_break_deadlock) {
// Find the task thread with the lowest priority that is not already BUFN
thread_priority to_bufn(-1, -1);
bool is_to_bufn_set = false;
bool is_to_bufn_set = false;
int blocked_thread_count = 0;
for (auto const& [thread_id, t_state] : threads) {
switch (t_state.state) {
case thread_state::THREAD_BLOCKED: {
blocked_thread_count++;
thread_priority const& current = t_state.priority();
if (!is_to_bufn_set || current < to_bufn) {
to_bufn = current;
Expand All @@ -1644,11 +1657,19 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
long const thread_id_to_bufn = to_bufn.get_thread_id();
auto const thread = threads.find(thread_id_to_bufn);
if (thread != threads.end()) {
transition(thread->second, thread_state::THREAD_BUFN_THROW);
if (blocked_thread_count == 1) {
// This is the very last thread that is going to
// transition to BUFN. When that happens the
// thread would throw a split and retry exception.
// But we are not tracking when data is made spillable
// so if data was made spillable we will retry the
// allocation, instead of going to BUFN.
thread->second.is_retry_alloc_before_bufn = true;
transition(thread->second, thread_state::THREAD_RUNNING);
} else {
transition(thread->second, thread_state::THREAD_BUFN_THROW);
}
thread->second.wake_condition->notify_all();
// We are explicitly not going to update the state around BUFN
// here, because we really want to wait for the retry to run
// it's course instead of doing a split right away.
}
}
// We now need a way to detect if we need to split the input and retry.
Expand Down Expand Up @@ -1727,7 +1748,12 @@ class spark_resource_adaptor final : public rmm::mr::device_memory_resource {
transition(thread->second, thread_state::THREAD_RUNNING);
break;
case thread_state::THREAD_ALLOC:
if (is_oom && blocking) {
if (is_oom && thread->second.is_retry_alloc_before_bufn) {
thread->second.is_retry_alloc_before_bufn = false;
transition(thread->second, thread_state::THREAD_BUFN_THROW);
thread->second.wake_condition->notify_all();
} else if (is_oom && blocking) {
thread->second.is_retry_alloc_before_bufn = false;
transition(thread->second, thread_state::THREAD_BLOCKED);
} else {
// don't block unless it is OOM on a blocking allocation
Expand Down
6 changes: 4 additions & 2 deletions src/test/java/com/nvidia/spark/rapids/jni/RmmSparkTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2024, NVIDIA CORPORATION.
* Copyright (c) 2020-2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -1262,7 +1262,9 @@ public void testAllocationDuringSpill() {
RmmSpark.removeDedicatedThreadAssociation(threadId, taskId);
}
});
assertEquals(11, rmmEventHandler.getAllocationCount());
// We retry the failed allocation for the last thread before going into
// the BUFN state. So we have 22 allocations instead of the expected 11
assertEquals(22, rmmEventHandler.getAllocationCount());
}

@Test
Expand Down

0 comments on commit 458b7da

Please sign in to comment.