Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-22125][PYSPARK][SQL] Enable Arrow Stream format for vectorized UDF. #19349

Closed
wants to merge 15 commits into from
Closed
325 changes: 4 additions & 321 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private[spark] class PythonRDD(
extends RDD[Array[Byte]](parent) {

val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true)

override def getPartitions: Array[Partition] = firstParent.partitions

Expand All @@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = PythonRunner(func, bufferSize, reuse_worker)
val runner = PythonRunner(func, bufferSize, reuseWorker)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
Expand All @@ -83,318 +83,9 @@ private[spark] case class PythonFunction(
*/
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])

/**
* Enumerate the type of command that will be sent to the Python worker
*/
private[spark] object PythonEvalType {
val NON_UDF = 0
val SQL_BATCHED_UDF = 1
val SQL_PANDAS_UDF = 2
}

private[spark] object PythonRunner {
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(
Seq(ChainedPythonFunctions(Seq(func))),
bufferSize,
reuse_worker,
PythonEvalType.NON_UDF,
Array(Array(0)))
}
}

/**
* A helper class to run Python mapPartition/UDFs in Spark.
*
* funcs is a list of independent Python functions, each one of them is a list of chained Python
* functions (from bottom to top).
*/
private[spark] class PythonRunner(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
evalType: Int,
argOffsets: Array[Array[Int]])
extends Logging {

require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs")

// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.funcs.head.envVars
private val pythonExec = funcs.head.funcs.head.pythonExec
private val pythonVer = funcs.head.funcs.head.pythonVer

// TODO: support accumulator in multiple UDF
private val accumulator = funcs.head.funcs.head.accumulator

def compute(
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",")
envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread
if (reuse_worker) {
envVars.put("SPARK_REUSE_WORKER", "1")
}
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
// Whether is the worker released into idle pool
@volatile var released = false

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context)

context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
if (!reuse_worker || !released) {
try {
worker.close()
} catch {
case e: Exception =>
logWarning("Failed to close worker socket", e)
}
}
}

writerThread.start()
new MonitorThread(env, worker, context).start()

// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
val stdoutIterator = new Iterator[Array[Byte]] {
override def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
_nextObj = read()
}
obj
}

private def read(): Array[Byte] = {
if (writerThread.exception.isDefined) {
throw writerThread.exception.get
}
try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case 0 => Array.empty[Byte]
case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj, StandardCharsets.UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
// read some accumulator updates:
val numAccumulatorUpdates = stream.readInt()
(1 to numAccumulatorUpdates).foreach { _ =>
val updateLen = stream.readInt()
val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator.add(update)
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
if (reuse_worker) {
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
released = true
}
}
null
}
} catch {

case e: Exception if context.isInterrupted =>
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason"))

case e: Exception if env.isStopped =>
logDebug("Exception thrown after context is stopped", e)
null // exit silently

case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
throw writerThread.exception.get

case eof: EOFException =>
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
}

var _nextObj = read()

override def hasNext: Boolean = _nextObj != null
}
new InterruptibleIterator(context, stdoutIterator)
}

/**
* The thread responsible for writing the data from the PythonRDD's parent iterator to the
* Python process.
*/
class WriterThread(
env: SparkEnv,
worker: Socket,
inputIterator: Iterator[_],
partitionIndex: Int,
context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {

@volatile private var _exception: Exception = null

private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))

setDaemon(true)

/** Contains the exception thrown while writing the parent iterator to the Python process. */
def exception: Option[Exception] = Option(_exception)

/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
assert(context.isCompleted)
this.interrupt()
}

override def run(): Unit = Utils.logUncaughtExceptions {
try {
TaskContext.setTaskContext(context)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
// Partition index
dataOut.writeInt(partitionIndex)
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// Write out the TaskContextInfo
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
dataOut.writeLong(context.taskAttemptId())
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
dataOut.writeInt(cnt)
for (bid <- toRemove) {
// remove the broadcast from worker
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
PythonRDD.writeUTF(broadcast.value.path, dataOut)
oldBids.add(broadcast.id)
}
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(evalType)
if (evalType != PythonEvalType.NON_UDF) {
dataOut.writeInt(funcs.length)
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
offsets.foreach { offset =>
dataOut.writeInt(offset)
}
dataOut.writeInt(chained.funcs.length)
chained.funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
}
} else {
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}

case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
if (!worker.isClosed) {
Utils.tryLog(worker.shutdownOutput())
}
}
}
}

/**
* It is necessary to have a monitor thread for python workers if the user cancels with
* interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
* threads can block indefinitely.
*/
class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
extends Thread(s"Worker Monitor for $pythonExec") {

setDaemon(true)

override def run() {
// Kill the worker if it is interrupted, checking until task completion.
// TODO: This has a race condition if interruption occurs, as completed may still become true.
while (!context.isInterrupted && !context.isCompleted) {
Thread.sleep(2000)
}
if (!context.isCompleted) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
}
}
}
}
}

/** Thrown for exceptions in user Python code. */
private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause)
private[spark] class PythonException(msg: String, cause: Exception)
extends RuntimeException(msg, cause)

/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
Expand All @@ -411,14 +102,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte]
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}

private object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
val END_OF_STREAM = -4
val NULL = -5
}

private[spark] object PythonRDD extends Logging {

// remember the broadcasts sent to each worker
Expand Down
Loading