diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 3055225fdafed..536ef552519e1 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -315,6 +315,7 @@ def test_chained_udf(self): self.assertEqual(row[0], 6) def test_multiple_udfs(self): + self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType()) [row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect() self.assertEqual(tuple(row), (2, 4)) [row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 9ce754aa893cc..b7bb1edc40890 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -91,8 +91,9 @@ def read_udfs(pickleSer, infile): if num_udfs == 1: udf = udfs[0][2] + # fast path for single UDF def mapper(args): - return (udf(*args),) + return udf(*args) else: def mapper(args): return tuple(udf(*args[start:end]) for start, end, udf in udfs) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index 8c642e026fb7c..a9c8bd4f6752b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -71,7 +71,6 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c val (pyFuncs, children) = udfs.map(collectFunctions).unzip val numArgs = children.map(_.length) - val resultType = StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) val pickle = new Pickler // flatten all the arguments @@ -97,15 +96,26 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c .compute(inputIterator, context.partitionId(), context) val unpickle = new Unpickler - val row = new GenericMutableRow(1) + val mutableRow = new GenericMutableRow(1) val joined = new JoinedRow + val resultType = if (udfs.length == 1) { + udfs.head.dataType + } else { + StructType(udfs.map(u => StructField("", u.dataType, u.nullable))) + } val resultProj = UnsafeProjection.create(output, output) outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => - val row = EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + val row = if (udfs.length == 1) { + // fast path for single UDF + mutableRow(0) = EvaluatePython.fromJava(result, resultType) + mutableRow + } else { + EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow] + } resultProj(joined(queue.poll(), row)) } }