Skip to content

Commit

Permalink
[SPARK-33277][PYSPARK][SQL] Writer thread must not access input after…
Browse files Browse the repository at this point in the history
… task completion listener returns

### What changes were proposed in this pull request?

Python UDFs in Spark SQL are run in a separate Python process. The Python process is fed input by a dedicated thread (`BasePythonRunner.WriterThread`). This writer thread drives the child plan by pulling rows from its output iterator and serializing them across a socket.

When the child exec node is the off-heap vectorized Parquet reader, these rows are backed by off-heap memory. The child node uses a task completion listener to free the off-heap memory at the end of the task, which invalidates the output iterator and any rows it has produced. Since task completion listeners are registered bottom-up and executed in reverse order of registration, this is safe as long as an exec node never accesses its input after its task completion listener has executed.[^1]

The BasePythonRunner task completion listener violates this assumption. It interrupts the writer thread, but does not wait for it to exit. This causes a race condition that can lead to an executor crash:
1. The Python writer thread is processing a row backed by off-heap memory.
2. The task finishes, for example because it has reached a row limit.
3. The BasePythonRunner task completion listener sets the interrupt status of the writer thread, but the writer thread does not check it immediately.
4. The child plan's task completion listener frees its off-heap memory, invalidating the row that the Python writer thread is processing.
5. The Python writer thread attempts to access the invalidated row. The use-after-free triggers a segfault that crashes the executor.

This PR fixes the bug by making the BasePythonRunner task completion listener wait for the writer thread to exit before returning. This prevents its input from being invalidated while the thread is running. The sequence of events is now as follows:
1. The Python writer thread is processing a row backed by off-heap memory.
2. The task finishes, for example because it has reached a row limit.
3. The BasePythonRunner task completion listener sets the interrupt status of the writer thread and waits for the writer thread to exit.
4. The child plan's task completion listener can safely free its off-heap memory without invalidating live rows.

TaskContextImpl previously held a lock while invoking the task completion listeners. This would now cause a deadlock because the writer thread's exception handler calls `TaskContextImpl#isCompleted()`, which needs to acquire the same lock. To avoid deadlock, this PR modifies TaskContextImpl to release the lock before invoking the listeners, while still maintaining sequential execution of listeners.

[^1]: This guarantee was not historically recognized, leading to similar bugs as far back as 2014 ([SPARK-1019](https://issues.apache.org/jira/browse/SPARK-1019?focusedCommentId=13953661&page=com.atlassian.jira.plugin.system.issuetabpanels%3Acomment-tabpanel#comment-13953661)). The root cause was the lack of a reliably-ordered mechanism for operators to free resources at the end of a task. Such a mechanism (task completion listeners) was added and gradually refined, and we can now make this guarantee explicit. (An alternative approach is to use closeable iterators everywhere, but this would be a major change.)

### Why are the changes needed?

Without this PR, attempting to use Python UDFs while the off-heap vectorized Parquet reader is enabled (`spark.sql.columnVector.offheap.enabled true`) can cause executors to segfault.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

A [previous PR](apache#30177) reduced the likelihood of encountering this race condition, but did not eliminate it. The accompanying tests were therefore flaky and had to be disabled. This PR eliminates the race condition, allowing us to re-enable these tests. One of the tests, `test_pandas_udf_scalar`, previously failed 30/1000 times and now always succeeds.

An internal workload previously failed with a segfault about 40% of the time when run with `spark.sql.columnVector.offheap.enabled true`, and now succeeds 100% of the time.

Closes apache#34245 from ankurdave/SPARK-33277-thread-join.

Authored-by: Ankur Dave <ankurdave@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit dfca1d1)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
ankurdave authored and catalinii committed Mar 4, 2022
1 parent 2f94758 commit f54aafa
Show file tree
Hide file tree
Showing 7 changed files with 299 additions and 43 deletions.
5 changes: 5 additions & 0 deletions core/src/main/scala/org/apache/spark/TaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ abstract class TaskContext extends Serializable {
* This will be called in all situations - success, failure, or cancellation. Adding a listener
* to an already completed task will result in that listener being called immediately.
*
* Two listeners registered in the same thread will be invoked in reverse order of registration if
* the task completes after both are registered. There are no ordering guarantees for listeners
* registered in different threads, or for listeners registered after the task completes.
* Listeners are guaranteed to execute sequentially.
*
* An example use is for HadoopRDD to register a callback to close the input stream.
*
* Exceptions thrown by the listener will result in failure of the task.
Expand Down
144 changes: 102 additions & 42 deletions core/src/main/scala/org/apache/spark/TaskContextImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark

import java.util.Properties
import java.util.{Properties, Stack}
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
Expand All @@ -39,9 +39,9 @@ import org.apache.spark.util._
* A small note on thread safety. The interrupted & fetchFailed fields are volatile, this makes
* sure that updates are always visible across threads. The complete & failed flags and their
* callbacks are protected by locking on the context instance. For instance, this ensures
* that you cannot add a completion listener in one thread while we are completing (and calling
* the completion listeners) in another thread. Other state is immutable, however the exposed
* `TaskMetrics` & `MetricsSystem` objects are not thread safe.
* that you cannot add a completion listener in one thread while we are completing in another
* thread. Other state is immutable, however the exposed `TaskMetrics` & `MetricsSystem` objects are
* not thread safe.
*/
private[spark] class TaskContextImpl(
override val stageId: Int,
Expand All @@ -58,81 +58,141 @@ private[spark] class TaskContextImpl(
extends TaskContext
with Logging {

/** List of callback functions to execute when the task completes. */
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
/**
* List of callback functions to execute when the task completes.
*
* Using a stack causes us to process listeners in reverse order of registration. As listeners are
* invoked, they are popped from the stack.
*/
@transient private val onCompleteCallbacks = new Stack[TaskCompletionListener]

/** List of callback functions to execute when the task fails. */
@transient private val onFailureCallbacks = new ArrayBuffer[TaskFailureListener]
@transient private val onFailureCallbacks = new Stack[TaskFailureListener]

/**
* The thread currently executing task completion or failure listeners, if any.
*
* `invokeListeners()` uses this to ensure listeners are called sequentially.
*/
@transient private var listenerInvocationThread: Option[Thread] = None

// If defined, the corresponding task has been killed and this option contains the reason.
@volatile private var reasonIfKilled: Option[String] = None

// Whether the task has completed.
private var completed: Boolean = false

// Whether the task has failed.
private var failed: Boolean = false

// Throwable that caused the task to fail
private var failure: Throwable = _
// If defined, the task has failed and this option contains the Throwable that caused the task to
// fail.
private var failureCauseOpt: Option[Throwable] = None

// If there was a fetch failure in the task, we store it here, to make sure user-code doesn't
// hide the exception. See SPARK-19276
@volatile private var _fetchFailedException: Option[FetchFailedException] = None

@GuardedBy("this")
override def addTaskCompletionListener(listener: TaskCompletionListener)
: this.type = synchronized {
if (completed) {
listener.onTaskCompletion(this)
} else {
onCompleteCallbacks += listener
override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = {
val needToCallListener = synchronized {
// If there is already a thread invoking listeners, adding the new listener to
// `onCompleteCallbacks` will cause that thread to execute the new listener, and the call to
// `invokeTaskCompletionListeners()` below will be a no-op.
//
// If there is no such thread, the call to `invokeTaskCompletionListeners()` below will
// execute all listeners, including the new listener.
onCompleteCallbacks.push(listener)
completed
}
if (needToCallListener) {
invokeTaskCompletionListeners(None)
}
this
}

@GuardedBy("this")
override def addTaskFailureListener(listener: TaskFailureListener)
: this.type = synchronized {
if (failed) {
listener.onTaskFailure(this, failure)
} else {
onFailureCallbacks += listener
}
override def addTaskFailureListener(listener: TaskFailureListener): this.type = {
synchronized {
onFailureCallbacks.push(listener)
failureCauseOpt
}.foreach(invokeTaskFailureListeners)
this
}

override def resourcesJMap(): java.util.Map[String, ResourceInformation] = {
resources.asJava
}

@GuardedBy("this")
private[spark] override def markTaskFailed(error: Throwable): Unit = synchronized {
if (failed) return
failed = true
failure = error
invokeListeners(onFailureCallbacks.toSeq, "TaskFailureListener", Option(error)) {
_.onTaskFailure(this, error)
private[spark] override def markTaskFailed(error: Throwable): Unit = {
synchronized {
if (failureCauseOpt.isDefined) return
failureCauseOpt = Some(error)
}
invokeTaskFailureListeners(error)
}

@GuardedBy("this")
private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = synchronized {
if (completed) return
completed = true
invokeListeners(onCompleteCallbacks.toSeq, "TaskCompletionListener", error) {
private[spark] override def markTaskCompleted(error: Option[Throwable]): Unit = {
synchronized {
if (completed) return
completed = true
}
invokeTaskCompletionListeners(error)
}

private def invokeTaskCompletionListeners(error: Option[Throwable]): Unit = {
// It is safe to access the reference to `onCompleteCallbacks` without holding the TaskContext
// lock. `invokeListeners()` acquires the lock before accessing the contents.
invokeListeners(onCompleteCallbacks, "TaskCompletionListener", error) {
_.onTaskCompletion(this)
}
}

private def invokeTaskFailureListeners(error: Throwable): Unit = {
// It is safe to access the reference to `onFailureCallbacks` without holding the TaskContext
// lock. `invokeListeners()` acquires the lock before accessing the contents.
invokeListeners(onFailureCallbacks, "TaskFailureListener", Option(error)) {
_.onTaskFailure(this, error)
}
}

private def invokeListeners[T](
listeners: Seq[T],
listeners: Stack[T],
name: String,
error: Option[Throwable])(
callback: T => Unit): Unit = {
// This method is subject to two constraints:
//
// 1. Listeners must be run sequentially to uphold the guarantee provided by the TaskContext
// API.
//
// 2. Listeners may spawn threads that call methods on this TaskContext. To avoid deadlock, we
// cannot call listeners while holding the TaskContext lock.
//
// We meet these constraints by ensuring there is at most one thread invoking listeners at any
// point in time.
synchronized {
if (listenerInvocationThread.nonEmpty) {
// If another thread is already invoking listeners, do nothing.
return
} else {
// If no other thread is invoking listeners, register this thread as the listener invocation
// thread. This prevents other threads from invoking listeners until this thread is
// deregistered.
listenerInvocationThread = Some(Thread.currentThread())
}
}

def getNextListenerOrDeregisterThread(): Option[T] = synchronized {
if (listeners.empty()) {
// We have executed all listeners that have been added so far. Deregister this thread as the
// callback invocation thread.
listenerInvocationThread = None
None
} else {
Some(listeners.pop())
}
}

val errorMsgs = new ArrayBuffer[String](2)
// Process callbacks in the reverse order of registration
listeners.reverse.foreach { listener =>
var listenerOption: Option[T] = None
while ({listenerOption = getNextListenerOrDeregisterThread(); listenerOption.nonEmpty}) {
val listener = listenerOption.get
try {
callback(listener)
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,20 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
/** Contains the throwable thrown while writing the parent iterator to the Python process. */
def exception: Option[Throwable] = Option(_exception)

/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
/**
* Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur
* due to cleanup.
*/
def shutdownOnTaskCompletion(): Unit = {
assert(context.isCompleted)
this.interrupt()
// Task completion listeners that run after this method returns may invalidate
// `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized
// reader, a task completion listener will free the underlying off-heap buffers. If the writer
// thread is still running when `inputIterator` is invalidated, it can cause a use-after-free
// bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer
// thread to exit before returning.
this.join()
}

/**
Expand Down
121 changes: 121 additions & 0 deletions core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
package org.apache.spark.scheduler

import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable.ArrayBuffer

import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito._
Expand Down Expand Up @@ -334,6 +337,124 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
assert(e.getMessage.contains("exception in task"))
}

test("listener registers another listener (reentrancy)") {
val context = TaskContext.empty()
var invocations = 0
val simpleListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
invocations += 1
}
}

// Create a listener that registers another listener.
val reentrantListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
context.addTaskCompletionListener(simpleListener)
invocations += 1
}
}
context.addTaskCompletionListener(reentrantListener)

// Ensure the listener can execute without encountering deadlock.
assert(invocations == 0)
context.markTaskCompleted(None)
assert(invocations == 2)
}

test("listener registers another listener using a second thread") {
val context = TaskContext.empty()
val invocations = new AtomicInteger(0)
val simpleListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
invocations.getAndIncrement()
}
}

// Create a listener that registers another listener using a second thread.
val multithreadedListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
val thread = new Thread(new Runnable {
override def run(): Unit = {
context.addTaskCompletionListener(simpleListener)
}
})
thread.start()
invocations.getAndIncrement()
thread.join()
}
}
context.addTaskCompletionListener(multithreadedListener)

// Ensure the listener can execute without encountering deadlock.
assert(invocations.get() == 0)
context.markTaskCompleted(None)
assert(invocations.get() == 2)
}

test("listeners registered from different threads are called sequentially") {
val context = TaskContext.empty()
val invocations = new AtomicInteger(0)
val numRunningListeners = new AtomicInteger(0)

// Create a listener that will throw if more than one instance is running at the same time.
val registerExclusiveListener = new Runnable {
override def run(): Unit = {
context.addTaskCompletionListener(new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
if (numRunningListeners.getAndIncrement() != 0) throw new Exception()
Thread.sleep(100)
if (numRunningListeners.decrementAndGet() != 0) throw new Exception()
invocations.getAndIncrement()
}
})
}
}

// Register it multiple times from different threads before and after the task completes.
assert(invocations.get() == 0)
assert(numRunningListeners.get() == 0)
val thread1 = new Thread(registerExclusiveListener)
val thread2 = new Thread(registerExclusiveListener)
thread1.start()
thread2.start()
thread1.join()
thread2.join()
assert(invocations.get() == 0)
context.markTaskCompleted(None)
assert(invocations.get() == 2)
val thread3 = new Thread(registerExclusiveListener)
val thread4 = new Thread(registerExclusiveListener)
thread3.start()
thread4.start()
thread3.join()
thread4.join()
assert(invocations.get() == 4)
assert(numRunningListeners.get() == 0)
}

test("listeners registered from same thread are called in reverse order") {
val context = TaskContext.empty()
val invocationOrder = ArrayBuffer.empty[String]

// Create listeners that log an id to `invocationOrder` when they are invoked.
def makeLoggingListener(id: String): TaskCompletionListener = new TaskCompletionListener {
override def onTaskCompletion(context: TaskContext): Unit = {
invocationOrder += id
}
}
context.addTaskCompletionListener(makeLoggingListener("A"))
context.addTaskCompletionListener(makeLoggingListener("B"))
context.addTaskCompletionListener(makeLoggingListener("C"))

// Ensure the listeners are called in reverse order of registration, except when they are called
// after the task is complete.
assert(invocationOrder === Seq.empty)
context.markTaskCompleted(None)
assert(invocationOrder === Seq("C", "B", "A"))
context.addTaskCompletionListener(makeLoggingListener("D"))
assert(invocationOrder === Seq("C", "B", "A", "D"))
}

}

private object TaskContextSuite {
Expand Down
Loading

0 comments on commit f54aafa

Please sign in to comment.