diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 4fd58bc99caf4..a9c5f2669a9ec 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -36,6 +36,7 @@ private[spark] object PythonEvalType { val NON_UDF = 0 val SQL_BATCHED_UDF = 1 val SQL_PANDAS_UDF = 2 + val SQL_PANDAS_UDF_STREAM = 3 } /** diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 1755d30c60e83..d3651f4b6d01f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -85,6 +85,7 @@ class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 SQL_PANDAS_UDF = 2 + SQL_PANDAS_UDF_STREAM = 3 class Serializer(object): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1b3af42c47ad2..03636ac31005e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3376,6 +3376,16 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class ArrowStreamVectorizedUDFTests(VectorizedUDFTests): + + @classmethod + def setUpClass(cls): + VectorizedUDFTests.setUpClass() + cls.spark.conf.set("spark.sql.execution.arrow.stream.enable", "true") + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fd917c400c872..08cf47e1acea5 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowPandasSerializer + BatchedSerializer, ArrowPandasSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import toArrowType from pyspark import shuffle @@ -98,10 +98,10 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_PANDAS_UDF: - return arg_offsets, wrap_pandas_udf(row_func, return_type) - else: + if eval_type == PythonEvalType.SQL_BATCHED_UDF: return arg_offsets, wrap_udf(row_func, return_type) + else: + return arg_offsets, wrap_pandas_udf(row_func, return_type) def read_udfs(pickleSer, infile, eval_type): @@ -124,6 +124,8 @@ def read_udfs(pickleSer, infile, eval_type): if eval_type == PythonEvalType.SQL_PANDAS_UDF: ser = ArrowPandasSerializer() + elif eval_type == PythonEvalType.SQL_PANDAS_UDF_STREAM: + ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d00c672487532..dc8707b0b21e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -925,6 +925,13 @@ object SQLConf { .intConf .createWithDefault(10000) + val ARROW_EXECUTION_STREAM_ENABLE = + buildConf("spark.sql.execution.arrow.stream.enable") + .internal() + .doc("When using Apache Arrow, use Arrow stream protocol if possible.") + .booleanConf + .createWithDefault(false) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1203,6 +1210,8 @@ class SQLConf extends Serializable with Logging { def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) + def arrowStreamEnable: Boolean = getConf(ARROW_EXECUTION_STREAM_ENABLE) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index a469880cb0c14..bc546c7c425b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -462,6 +462,9 @@ public int numValidRows() { return numRows - numRowsFiltered; } + /** + * Returns the schema that makes up this batch. + */ public StructType schema() { return schema; } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 4accf54a18232..e3dc63f07b3df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -102,7 +102,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( - python.ExtractPythonUDFs, + python.ExtractPythonUDFs(sparkSession.sessionState.conf), PlanSubqueries(sparkSession), new ReorderJoinPredicates, EnsureRequirements(sparkSession.sessionState.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala index 418ea48ca2796..07678144e99b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowStreamEvalPythonExec.scala @@ -47,7 +47,7 @@ case class ArrowStreamEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute val columnarBatchIter = new ArrowStreamPythonUDFRunner( funcs, conf.arrowMaxRecordsPerBatch, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_UDF_STREAM, argOffsets, schema) .compute(iter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index fec456d86dbe2..3780fd37e6b66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Proj import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution import org.apache.spark.sql.execution.{FilterExec, SparkPlan} +import org.apache.spark.sql.internal.SQLConf /** @@ -90,7 +91,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { +case class ExtractPythonUDFs(conf: SQLConf) extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined @@ -141,7 +142,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val evaluation = validUdfs.partition(_.vectorized) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => - ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + if (conf.arrowStreamEnable) { + ArrowStreamEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + } else { + ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) + } case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => BatchEvalPythonExec(plainUdfs, child.output ++ resultAttrs, child) case _ =>