Skip to content

Commit

Permalink
Improve Python UDTF arrow serializer performance
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Feb 27, 2025
1 parent 53fc763 commit ae78d98
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 173 deletions.
152 changes: 5 additions & 147 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,11 @@ def load_stream(self, stream):

batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
for batch in batches:
struct = batch.column(0)
yield [pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))]
if len(batch.columns) > 0:
struct = batch.column(0)
yield [pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))]
else:
yield [pa.RecordBatch.from_pylist([])]

def dump_stream(self, iterator, stream):
"""
Expand Down Expand Up @@ -566,151 +569,6 @@ def __repr__(self):
return "ArrowStreamPandasUDFSerializer"


class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer):
"""
Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs.
"""

def __init__(self, timezone, safecheck):
super(ArrowStreamPandasUDTFSerializer, self).__init__(
timezone=timezone,
safecheck=safecheck,
# The output pandas DataFrame's columns are unnamed.
assign_cols_by_name=False,
# Set to 'False' to avoid converting struct type inputs into a pandas DataFrame.
df_for_struct=False,
# Defines how struct type inputs are converted. If set to "row", struct type inputs
# are converted into Rows. Without this setting, a struct type input would be treated
# as a dictionary. For example, for named_struct('name', 'Alice', 'age', 1),
# if struct_in_pandas="dict", it becomes {"name": "Alice", "age": 1}
# if struct_in_pandas="row", it becomes Row(name="Alice", age=1)
struct_in_pandas="row",
# When dealing with array type inputs, Arrow converts them into numpy.ndarrays.
# To ensure consistency across regular and arrow-optimized UDTFs, we further
# convert these numpy.ndarrays into Python lists.
ndarray_as_list=True,
# Enables explicit casting for mismatched return types of Arrow Python UDTFs.
arrow_cast=True,
)
self._converter_map = dict()

def _create_batch(self, series):
"""
Create an Arrow record batch from the given pandas.Series pandas.DataFrame
or list of Series or DataFrame, with optional type.
Parameters
----------
series : pandas.Series or pandas.DataFrame or list
A single series or dataframe, list of series or dataframe,
or list of (series or dataframe, arrow_type)
Returns
-------
pyarrow.RecordBatch
Arrow RecordBatch
"""
import pandas as pd
import pyarrow as pa

# Make input conform to [(series1, type1), (series2, type2), ...]
if not isinstance(series, (list, tuple)) or (
len(series) == 2 and isinstance(series[1], pa.DataType)
):
series = [series]
series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series)

arrs = []
for s, t in series:
if not isinstance(s, pd.DataFrame):
raise PySparkValueError(
"Output of an arrow-optimized Python UDTFs expects "
f"a pandas.DataFrame but got: {type(s)}"
)

arrs.append(self._create_struct_array(s, t))

return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))])

def _get_or_create_converter_from_pandas(self, dt):
if dt not in self._converter_map:
conv = _create_converter_from_pandas(
dt,
timezone=self._timezone,
error_on_duplicated_field_names=False,
ignore_unexpected_complex_type_values=True,
)
self._converter_map[dt] = conv
return self._converter_map[dt]

def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False):
"""
Override the `_create_array` method in the superclass to create an Arrow Array
from a given pandas.Series and an arrow type. The difference here is that we always
use arrow cast when creating the arrow array. Also, the error messages are specific
to arrow-optimized Python UDTFs.
Parameters
----------
series : pandas.Series
A single series
arrow_type : pyarrow.DataType, optional
If None, pyarrow's inferred type will be used
spark_type : DataType, optional
If None, spark type converted from arrow_type will be used
arrow_cast: bool, optional
Whether to apply Arrow casting when the user-specified return type mismatches the
actual return values.
Returns
-------
pyarrow.Array
"""
import pyarrow as pa
import pandas as pd

if isinstance(series.dtype, pd.CategoricalDtype):
series = series.astype(series.dtypes.categories.dtype)

if arrow_type is not None:
dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True)
conv = self._get_or_create_converter_from_pandas(dt)
series = conv(series)

if hasattr(series.array, "__arrow_array__"):
mask = None
else:
mask = series.isnull()

try:
try:
return pa.Array.from_pandas(
series, mask=mask, type=arrow_type, safe=self._safecheck
)
except pa.lib.ArrowException:
if arrow_cast:
return pa.Array.from_pandas(series, mask=mask).cast(
target_type=arrow_type, safe=self._safecheck
)
else:
raise
except pa.lib.ArrowException:
# Display the most user-friendly error messages instead of showing
# arrow's error message. This also works better with Spark Connect
# where the exception messages are by default truncated.
raise PySparkRuntimeError(
errorClass="UDTF_ARROW_TYPE_CAST_ERROR",
messageParameters={
"col_name": series.name,
"col_type": str(series.dtype),
"arrow_type": arrow_type,
},
) from None

def __repr__(self):
return "ArrowStreamPandasUDTFSerializer"


class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer):
"""
Serializes pyarrow.RecordBatch data with Arrow streaming format.
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3011,7 +3011,7 @@ def eval(self):
("x: smallint", err),
("x: int", err),
("x: bigint", err),
("x: string", err),
("x: string", [Row(x="[0, 1.1, 2]")]),
("x: date", err),
("x: timestamp", err),
("x: byte", err),
Expand All @@ -3020,7 +3020,7 @@ def eval(self):
("x: double", err),
("x: decimal(10, 0)", err),
("x: array<string>", [Row(x=["0", "1.1", "2"])]),
("x: array<boolean>", [Row(x=[False, True, True])]),
("x: array<boolean>", err),
("x: array<int>", [Row(x=[0, 1, 2])]),
("x: array<float>", [Row(x=[0, 1.1, 2])]),
("x: array<array<int>>", err),
Expand Down
69 changes: 45 additions & 24 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@
CPickleSerializer,
BatchedSerializer,
)
from pyspark.sql.conversion import LocalDataToArrowConversion, ArrowTableToRowsConversion
from pyspark.sql.functions import SkipRestOfInputTableException
from pyspark.sql.pandas.serializers import (
ArrowStreamPandasUDFSerializer,
ArrowStreamPandasUDTFSerializer,
CogroupArrowUDFSerializer,
CogroupPandasUDFSerializer,
ArrowStreamUDFSerializer,
Expand Down Expand Up @@ -976,6 +976,8 @@ def use_large_var_types(runner_conf):
# ensure the UDTF is valid. This function also prepares a mapper function for applying
# the UDTF logic to input rows.
def read_udtf(pickleSer, infile, eval_type):
prefers_large_var_types = False

if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:
runner_conf = {}
# Load conf used for arrow evaluation.
Expand All @@ -984,14 +986,9 @@ def read_udtf(pickleSer, infile, eval_type):
k = utf8_deserializer.loads(infile)
v = utf8_deserializer.loads(infile)
runner_conf[k] = v
prefers_large_var_types = use_large_var_types(runner_conf)

# NOTE: if timezone is set here, that implies respectSessionTimeZone is True
timezone = runner_conf.get("spark.sql.session.timeZone", None)
safecheck = (
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower()
== "true"
)
ser = ArrowStreamPandasUDTFSerializer(timezone, safecheck)
ser = ArrowStreamUDFSerializer()
else:
# Each row is a group so do not batch but send one by one.
ser = BatchedSerializer(CPickleSerializer(), 1)
Expand Down Expand Up @@ -1301,15 +1298,15 @@ def check(row):
if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF:

def wrap_arrow_udtf(f, return_type):
import pandas as pd
import pyarrow as pa

arrow_return_type = to_arrow_type(
return_type, prefers_large_types=use_large_var_types(runner_conf)
)
return_type_size = len(return_type)

def verify_result(result):
if not isinstance(result, pd.DataFrame):
if not isinstance(result, pa.RecordBatch):
raise PySparkTypeError(
errorClass="INVALID_ARROW_UDTF_RETURN_TYPE",
messageParameters={
Expand All @@ -1335,8 +1332,12 @@ def verify_result(result):
)

# Verify the type and the schema of the result.
verify_pandas_result(
result, return_type, assign_cols_by_name=False, truncate_return_schema=False
verify_arrow_result(
pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))),
assign_cols_by_name=False,
expected_cols_and_types=[
(col.name, to_arrow_type(col.dataType)) for col in return_type.fields
],
)
return result

Expand Down Expand Up @@ -1372,19 +1373,39 @@ def check_return_value(res):
else:
yield from res

def evaluate(*args: pd.Series):
def convert_to_arrow(data: Iterable):
data = list(check_return_value(data))
if len(data) == 0:
return [
pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type)))
]
try:
return LocalDataToArrowConversion.convert(
data, return_type, prefers_large_var_types
).to_batches()
except Exception as e:
raise PySparkRuntimeError(
errorClass="UDTF_ARROW_TYPE_CAST_ERROR",
messageParameters={
"col_name": return_type.names,
"col_type": return_type.simpleString(),
"arrow_type": arrow_return_type,
},
) from e

def evaluate(*args: pa.ChunkedArray):
if len(args) == 0:
res = func()
yield verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type
for batch in convert_to_arrow(func()):
yield verify_result(batch), arrow_return_type

else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
row_tuples = zip(*args)
for row in row_tuples:
res = func(*row)
yield verify_result(
pd.DataFrame(check_return_value(res))
), arrow_return_type
rows = ArrowTableToRowsConversion.convert(
pa.Table.from_arrays(list(args), names=["_0"]),
schema=return_type,
)
for row in rows:
for batch in convert_to_arrow(func(*row)):
yield verify_result(batch), arrow_return_type

return evaluate

Expand All @@ -1404,7 +1425,7 @@ def mapper(_, it):
try:
for a in it:
# The eval function yields an iterator. Each element produced by this
# iterator is a tuple in the form of (pandas.DataFrame, arrow_return_type).
# iterator is a tuple in the form of (pyarrow.RecordBatch, arrow_return_type).
yield from eval(*[a[o] for o in args_kwargs_offsets])
if terminate is not None:
yield from terminate()
Expand Down

0 comments on commit ae78d98

Please sign in to comment.