Skip to content

Commit

Permalink
[SPARK-21190][PYSPARK] Python Vectorized UDFs
Browse files Browse the repository at this point in the history
This PR adds vectorized UDFs to the Python API

**Proposed API**
Introduce a flag to turn on vectorization for a defined UDF, for example:

```
pandas_udf(DoubleType())
def plus(a, b)
    return a + b
```
or

```
plus = pandas_udf(lambda a, b: a + b, DoubleType())
```
Usage is the same as normal UDFs

0-parameter UDFs
pandas_udf functions can declare an optional `**kwargs` and when evaluated, will contain a key "size" that will give the required length of the output.  For example:

```
pandas_udf(LongType())
def f0(**kwargs):
    return pd.Series(1).repeat(kwargs["size"])

df.select(f0())
```

Added new unit tests in pyspark.sql that are enabled if pyarrow and Pandas are available.

- [x] Fix support for promoted types with null values
- [ ] Discuss 0-param UDF API (use of kwargs)
- [x] Add tests for chained UDFs
- [ ] Discuss behavior when pyarrow not installed / enabled
- [ ] Cleanup pydoc and add user docs

Author: Bryan Cutler <cutlerb@gmail.com>
Author: Takuya UESHIN <ueshin@databricks.com>

Closes #18659 from BryanCutler/arrow-vectorized-udfs-SPARK-21404.
  • Loading branch information
BryanCutler authored and cloud-fan committed Sep 22, 2017
1 parent 8f130ad commit 27fc536
Show file tree
Hide file tree
Showing 13 changed files with 666 additions and 173 deletions.
22 changes: 17 additions & 5 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,23 @@ private[spark] case class PythonFunction(
*/
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])

/**
* Enumerate the type of command that will be sent to the Python worker
*/
private[spark] object PythonEvalType {
val NON_UDF = 0
val SQL_BATCHED_UDF = 1
val SQL_PANDAS_UDF = 2
}

private[spark] object PythonRunner {
def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = {
new PythonRunner(
Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuse_worker, false, Array(Array(0)))
Seq(ChainedPythonFunctions(Seq(func))),
bufferSize,
reuse_worker,
PythonEvalType.NON_UDF,
Array(Array(0)))
}
}

Expand All @@ -100,7 +113,7 @@ private[spark] class PythonRunner(
funcs: Seq[ChainedPythonFunctions],
bufferSize: Int,
reuse_worker: Boolean,
isUDF: Boolean,
evalType: Int,
argOffsets: Array[Array[Int]])
extends Logging {

Expand Down Expand Up @@ -309,8 +322,8 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
if (isUDF) {
dataOut.writeInt(1)
dataOut.writeInt(evalType)
if (evalType != PythonEvalType.NON_UDF) {
dataOut.writeInt(funcs.length)
funcs.zip(argOffsets).foreach { case (chained, offsets) =>
dataOut.writeInt(offsets.length)
Expand All @@ -324,7 +337,6 @@ private[spark] class PythonRunner(
}
}
} else {
dataOut.writeInt(0)
val command = funcs.head.funcs.head.command
dataOut.writeInt(command.length)
dataOut.write(command)
Expand Down
65 changes: 63 additions & 2 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ class SpecialLengths(object):
NULL = -5


class PythonEvalType(object):
NON_UDF = 0
SQL_BATCHED_UDF = 1
SQL_PANDAS_UDF = 2


class Serializer(object):

def dump_stream(self, iterator, stream):
Expand Down Expand Up @@ -187,8 +193,14 @@ class ArrowSerializer(FramedSerializer):
Serializes an Arrow stream.
"""

def dumps(self, obj):
raise NotImplementedError
def dumps(self, batch):
import pyarrow as pa
import io
sink = io.BytesIO()
writer = pa.RecordBatchFileWriter(sink, batch.schema)
writer.write_batch(batch)
writer.close()
return sink.getvalue()

def loads(self, obj):
import pyarrow as pa
Expand All @@ -199,6 +211,55 @@ def __repr__(self):
return "ArrowSerializer"


class ArrowPandasSerializer(ArrowSerializer):
"""
Serializes Pandas.Series as Arrow data.
"""

def __init__(self):
super(ArrowPandasSerializer, self).__init__()

def dumps(self, series):
"""
Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or
a list of series accompanied by an optional pyarrow type to coerce the data to.
"""
import pyarrow as pa
# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or \
(len(series) == 2 and isinstance(series[1], pa.DataType)):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

# If a nullable integer series has been promoted to floating point with NaNs, need to cast
# NOTE: this is not necessary with Arrow >= 0.7
def cast_series(s, t):
if t is None or s.dtype == t.to_pandas_dtype():
return s
else:
return s.fillna(0).astype(t.to_pandas_dtype(), copy=False)

arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series]
batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
return super(ArrowPandasSerializer, self).dumps(batch)

def loads(self, obj):
"""
Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series
followed by a dictionary containing length of the loaded batches.
"""
import pyarrow as pa
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
batches = [reader.get_batch(i) for i in xrange(reader.num_record_batches)]
# NOTE: a 0-parameter pandas_udf will produce an empty batch that can have num_rows set
num_rows = sum((batch.num_rows for batch in batches))
table = pa.Table.from_batches(batches)
return [c.to_pandas() for c in table.itercolumns()] + [{"length": num_rows}]

def __repr__(self):
return "ArrowPandasSerializer"


class BatchedSerializer(Serializer):

"""
Expand Down
49 changes: 37 additions & 12 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,7 +2044,7 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
def __init__(self, func, returnType, name=None):
def __init__(self, func, returnType, name=None, vectorized=False):
if not callable(func):
raise TypeError(
"Not a function or callable (__call__ is not defined): "
Expand All @@ -2058,6 +2058,7 @@ def __init__(self, func, returnType, name=None):
self._name = name or (
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self._vectorized = vectorized

@property
def returnType(self):
Expand Down Expand Up @@ -2089,7 +2090,7 @@ def _create_judf(self):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
self._name, wrapped_func, jdt)
self._name, wrapped_func, jdt, self._vectorized)
return judf

def __call__(self, *cols):
Expand Down Expand Up @@ -2123,6 +2124,22 @@ def wrapper(*args):
return wrapper


def _create_udf(f, returnType, vectorized):

def _udf(f, returnType=StringType(), vectorized=vectorized):
udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized)
return udf_obj._wrapped()

# decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf
if f is None or isinstance(f, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(_udf, returnType=return_type, vectorized=vectorized)
else:
return _udf(f=f, returnType=returnType, vectorized=vectorized)


@since(1.3)
def udf(f=None, returnType=StringType()):
"""Creates a :class:`Column` expression representing a user defined function (UDF).
Expand Down Expand Up @@ -2154,18 +2171,26 @@ def udf(f=None, returnType=StringType()):
| 8| JOHN DOE| 22|
+----------+--------------+------------+
"""
def _udf(f, returnType=StringType()):
udf_obj = UserDefinedFunction(f, returnType)
return udf_obj._wrapped()
return _create_udf(f, returnType=returnType, vectorized=False)

# decorator @udf, @udf() or @udf(dataType())
if f is None or isinstance(f, (str, DataType)):
# If DataType has been passed as a positional argument
# for decorator use it as a returnType
return_type = f or returnType
return functools.partial(_udf, returnType=return_type)

@since(2.3)
def pandas_udf(f=None, returnType=StringType()):
"""
Creates a :class:`Column` expression representing a user defined function (UDF) that accepts
`Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length.
:param f: python function if used as a standalone function
:param returnType: a :class:`pyspark.sql.types.DataType` object
# TODO: doctest
"""
import inspect
# If function "f" does not define the optional kwargs, then wrap with a kwargs placeholder
if inspect.getargspec(f).keywords is None:
return _create_udf(lambda *a, **kwargs: f(*a), returnType=returnType, vectorized=True)
else:
return _udf(f=f, returnType=returnType)
return _create_udf(f, returnType=returnType, vectorized=True)


blacklist = ['map', 'since', 'ignore_unicode_prefix']
Expand Down
Loading

0 comments on commit 27fc536

Please sign in to comment.