Skip to content

Commit

Permalink
Refactor EvalPythonExec.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Sep 19, 2017
1 parent d49a3db commit 69112a5
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 229 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,110 +17,45 @@

package org.apache.spark.sql.execution.python

import java.io.File

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils

import org.apache.spark.sql.types.StructType

/**
* A physical plan that evaluates a [[PythonUDF]],
*/
case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {

def children: Seq[SparkPlan] = child :: Nil

override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))

private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}

protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)

inputRDD.mapPartitions { iter =>
val context = TaskContext.get()

// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
val queue = HybridRowQueue(context.taskMemoryManager(),
new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
context.addTaskCompletionListener { _ =>
queue.close()
}

val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip

// flatten all the arguments
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
if (allInputs.exists(_.semanticEquals(e))) {
allInputs.indexWhere(_.semanticEquals(e))
} else {
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
}
}.toArray
}.toArray
val projection = newMutableProjection(allInputs, child.output)
val schemaIn = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
StructField(s"_$i", dt)
})

// Iterator to construct Arrow payloads. Add rows to queue to join later with the result.
val projectedRowIter = iter.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
projection(inputRow)
}

val inputIterator = ArrowConverters.toPayloadIterator(
projectedRowIter, schemaIn, conf.arrowMaxRecordsPerBatch, context)
.map(_.asPythonSerializable)

// Output iterator for results from Python.
val outputIterator = new PythonRunner(
pyFuncs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets)
.compute(inputIterator, context.partitionId(), context)

val outputRowIterator = ArrowConverters.fromPayloadIterator(
outputIterator.map(new ArrowPayload(_)), context)

// Verify that the output schema is correct
val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
.map { case (attr, i) => attr.withName(s"_$i") })
assert(schemaOut.equals(outputRowIterator.schema),
s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")

val joined = new JoinedRow
val resultProj = UnsafeProjection.create(output, output)

outputRowIterator.map { outputRow =>
resultProj(joined(queue.remove(), outputRow))
}
}
extends EvalPythonExec(udfs, output, child) {

protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuseWorker: Boolean,
argOffsets: Array[Array[Int]],
iter: Iterator[InternalRow],
schema: StructType,
context: TaskContext): Iterator[InternalRow] = {
val inputIterator = ArrowConverters.toPayloadIterator(
iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable)

// Output iterator for results from Python.
val outputIterator = new PythonRunner(
funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets)
.compute(inputIterator, context.partitionId(), context)

val outputRowIterator = ArrowConverters.fromPayloadIterator(
outputIterator.map(new ArrowPayload(_)), context)

// Verify that the output schema is correct
val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex
.map { case (attr, i) => attr.withName(s"_$i") })
assert(schemaOut.equals(outputRowIterator.schema),
s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}")

outputRowIterator
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,154 +17,78 @@

package org.apache.spark.sql.execution.python

import java.io.File

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.TaskContext
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils

import org.apache.spark.sql.types.{StructField, StructType}

/**
* A physical plan that evaluates a [[PythonUDF]], one partition of tuples at a time.
*
* Python evaluation works by sending the necessary (projected) input data via a socket to an
* external Python process, and combine the result from the Python process with the original row.
*
* For each row we send to Python, we also put it in a queue first. For each output row from Python,
* we drain the queue to find the original input row. Note that if the Python process is way too
* slow, this could lead to the queue growing unbounded and spill into disk when run out of memory.
*
* Here is a diagram to show how this works:
*
* Downstream (for parent)
* / \
* / socket (output of UDF)
* / \
* RowQueue Python
* \ /
* \ socket (input of UDF)
* \ /
* upstream (from child)
*
* The rows sent to and received from Python are packed into batches (100 rows) and serialized,
* there should be always some rows buffered in the socket or Python process, so the pulling from
* RowQueue ALWAYS happened after pushing into it.
* A physical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], child: SparkPlan)
extends SparkPlan {

def children: Seq[SparkPlan] = child :: Nil

override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))

private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (chained, children) = collectFunctions(u)
(ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(ChainedPythonFunctions(Seq(udf.func)), udf.children)
}
}

protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true)

inputRDD.mapPartitions { iter =>
EvaluatePython.registerPicklers() // register pickler for Row

// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
val queue = HybridRowQueue(TaskContext.get().taskMemoryManager(),
new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
TaskContext.get().addTaskCompletionListener({ ctx =>
queue.close()
})

val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip

// flatten all the arguments
val allInputs = new ArrayBuffer[Expression]
val dataTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
if (allInputs.exists(_.semanticEquals(e))) {
allInputs.indexWhere(_.semanticEquals(e))
} else {
allInputs += e
dataTypes += e.dataType
allInputs.length - 1
}
}.toArray
}.toArray
val projection = newMutableProjection(allInputs, child.output)
val schema = StructType(dataTypes.map(dt => StructField("", dt)))
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)

// enable memo iff we serialize the row with schema (schema and class should be memorized)
val pickle = new Pickler(needConversion)
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.map { inputRow =>
queue.add(inputRow.asInstanceOf[UnsafeRow])
val row = projection(inputRow)
if (needConversion) {
EvaluatePython.toJava(row, schema)
} else {
// fast path for these types that does not need conversion in Python
val fields = new Array[Any](row.numFields)
var i = 0
while (i < row.numFields) {
val dt = dataTypes(i)
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
i += 1
}
fields
}
}.grouped(100).map(x => pickle.dumps(x.toArray))

val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(
pyFuncs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val mutableRow = new GenericInternalRow(1)
val joined = new JoinedRow
val resultType = if (udfs.length == 1) {
udfs.head.dataType
extends EvalPythonExec(udfs, output, child) {

protected override def evaluate(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuseWorker: Boolean,
argOffsets: Array[Array[Int]],
iter: Iterator[InternalRow],
schema: StructType,
context: TaskContext): Iterator[InternalRow] = {
EvaluatePython.registerPicklers() // register pickler for Row

val dataTypes = schema.map(_.dataType)
val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython)

// enable memo iff we serialize the row with schema (schema and class should be memorized)
val pickle = new Pickler(needConversion)
// Input iterator to Python: input rows are grouped so we send them in batches to Python.
// For each row, add it to the queue.
val inputIterator = iter.map { row =>
if (needConversion) {
EvaluatePython.toJava(row, schema)
} else {
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
val resultProj = UnsafeProjection.create(output, output)
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
val row = if (udfs.length == 1) {
// fast path for single UDF
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
mutableRow
} else {
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
// fast path for these types that does not need conversion in Python
val fields = new Array[Any](row.numFields)
var i = 0
while (i < row.numFields) {
val dt = dataTypes(i)
fields(i) = EvaluatePython.toJava(row.get(i, dt), dt)
i += 1
}
resultProj(joined(queue.remove(), row))
fields
}
}.grouped(100).map(x => pickle.dumps(x.toArray))

// Output iterator for results from Python.
val outputIterator = new PythonRunner(
funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val mutableRow = new GenericInternalRow(1)
val resultType = if (udfs.length == 1) {
udfs.head.dataType
} else {
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
if (udfs.length == 1) {
// fast path for single UDF
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
mutableRow
} else {
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
}
}
}
Expand Down
Loading

0 comments on commit 69112a5

Please sign in to comment.