Skip to content

Commit

Permalink
Enable vectorized UDF via Arrow stream protocol.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Sep 26, 2017
1 parent 8016721 commit e62d619
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ private[spark] object PythonEvalType {
val NON_UDF = 0
val SQL_BATCHED_UDF = 1
val SQL_PANDAS_UDF = 2
val SQL_PANDAS_UDF_STREAM = 3
}

/**
Expand Down
1 change: 1 addition & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ class PythonEvalType(object):
NON_UDF = 0
SQL_BATCHED_UDF = 1
SQL_PANDAS_UDF = 2
SQL_PANDAS_UDF_STREAM = 3


class Serializer(object):
Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3376,6 +3376,16 @@ def test_vectorized_udf_empty_partition(self):
res = df.select(f(col('id')))
self.assertEquals(df.collect(), res.collect())


@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
class ArrowStreamVectorizedUDFTests(VectorizedUDFTests):

@classmethod
def setUpClass(cls):
VectorizedUDFTests.setUpClass()
cls.spark.conf.set("spark.sql.execution.arrow.stream.enable", "true")


if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
Expand Down
10 changes: 6 additions & 4 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \
BatchedSerializer, ArrowPandasSerializer
BatchedSerializer, ArrowPandasSerializer, ArrowStreamPandasSerializer
from pyspark.sql.types import toArrowType
from pyspark import shuffle

Expand Down Expand Up @@ -98,10 +98,10 @@ def read_single_udf(pickleSer, infile, eval_type):
else:
row_func = chain(row_func, f)
# the last returnType will be the return type of UDF
if eval_type == PythonEvalType.SQL_PANDAS_UDF:
return arg_offsets, wrap_pandas_udf(row_func, return_type)
else:
if eval_type == PythonEvalType.SQL_BATCHED_UDF:
return arg_offsets, wrap_udf(row_func, return_type)
else:
return arg_offsets, wrap_pandas_udf(row_func, return_type)


def read_udfs(pickleSer, infile, eval_type):
Expand All @@ -124,6 +124,8 @@ def read_udfs(pickleSer, infile, eval_type):

if eval_type == PythonEvalType.SQL_PANDAS_UDF:
ser = ArrowPandasSerializer()
elif eval_type == PythonEvalType.SQL_PANDAS_UDF_STREAM:
ser = ArrowStreamPandasSerializer()
else:
ser = BatchedSerializer(PickleSerializer(), 100)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,13 @@ object SQLConf {
.intConf
.createWithDefault(10000)

val ARROW_EXECUTION_STREAM_ENABLE =
buildConf("spark.sql.execution.arrow.stream.enable")
.internal()
.doc("When using Apache Arrow, use Arrow stream protocol if possible.")
.booleanConf
.createWithDefault(false)

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
Expand Down Expand Up @@ -1203,6 +1210,8 @@ class SQLConf extends Serializable with Logging {

def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)

def arrowStreamEnable: Boolean = getConf(ARROW_EXECUTION_STREAM_ENABLE)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ public int numValidRows() {
return numRows - numRowsFiltered;
}

/**
* Returns the schema that makes up this batch.
*/
public StructType schema() { return schema; }

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {

/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
python.ExtractPythonUDFs,
python.ExtractPythonUDFs(sparkSession.sessionState.conf),
PlanSubqueries(sparkSession),
new ReorderJoinPredicates,
EnsureRequirements(sparkSession.sessionState.conf),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ case class ArrowStreamEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute

val columnarBatchIter = new ArrowStreamPythonUDFRunner(
funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker,
PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema)
PythonEvalType.SQL_PANDAS_UDF_STREAM, argOffsets, schema)
.compute(iter, context.partitionId(), context)

new Iterator[InternalRow] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
import org.apache.spark.sql.internal.SQLConf


/**
Expand Down Expand Up @@ -90,7 +91,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
case class ExtractPythonUDFs(conf: SQLConf) extends Rule[SparkPlan] with PredicateHelper {

private def hasPythonUDF(e: Expression): Boolean = {
e.find(_.isInstanceOf[PythonUDF]).isDefined
Expand Down Expand Up @@ -141,7 +142,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {

val evaluation = validUdfs.partition(_.vectorized) match {
case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty =>
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
if (conf.arrowStreamEnable) {
ArrowStreamEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
} else {
ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child)
}
case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty =>
BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child)
case _ =>
Expand Down

0 comments on commit e62d619

Please sign in to comment.