Skip to content

Commit

Permalink
fast path for single UDF
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Mar 30, 2016
1 parent f6b7373 commit 8e6e5bc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
1 change: 1 addition & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def test_chained_udf(self):
self.assertEqual(row[0], 6)

def test_multiple_udfs(self):
self.sqlCtx.registerFunction("double", lambda x: x * 2, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1), double(2)").collect()
self.assertEqual(tuple(row), (2, 4))
[row] = self.sqlCtx.sql("SELECT double(double(1)), double(double(2) + 2)").collect()
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ def read_udfs(pickleSer, infile):
if num_udfs == 1:
udf = udfs[0][2]

# fast path for single UDF
def mapper(args):
return (udf(*args),)
return udf(*args)
else:
def mapper(args):
return tuple(udf(*args[start:end]) for start, end, udf in udfs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c

val (pyFuncs, children) = udfs.map(collectFunctions).unzip
val numArgs = children.map(_.length)
val resultType = StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))

val pickle = new Pickler
// flatten all the arguments
Expand All @@ -97,15 +96,26 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val row = new GenericMutableRow(1)
val mutableRow = new GenericMutableRow(1)
val joined = new JoinedRow
val resultType = if (udfs.length == 1) {
udfs.head.dataType
} else {
StructType(udfs.map(u => StructField("", u.dataType, u.nullable)))
}
val resultProj = UnsafeProjection.create(output, output)

outputIterator.flatMap { pickedResult =>
val unpickledBatch = unpickle.loads(pickedResult)
unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala
}.map { result =>
val row = EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
val row = if (udfs.length == 1) {
// fast path for single UDF
mutableRow(0) = EvaluatePython.fromJava(result, resultType)
mutableRow
} else {
EvaluatePython.fromJava(result, resultType).asInstanceOf[InternalRow]
}
resultProj(joined(queue.poll(), row))
}
}
Expand Down

0 comments on commit 8e6e5bc

Please sign in to comment.