diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 74d9a2ce65608..55388900f3f46 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -566,151 +566,6 @@ def __repr__(self): return "ArrowStreamPandasUDFSerializer" -class ArrowStreamPandasUDTFSerializer(ArrowStreamPandasUDFSerializer): - """ - Serializer used by Python worker to evaluate Arrow-optimized Python UDTFs. - """ - - def __init__(self, timezone, safecheck): - super(ArrowStreamPandasUDTFSerializer, self).__init__( - timezone=timezone, - safecheck=safecheck, - # The output pandas DataFrame's columns are unnamed. - assign_cols_by_name=False, - # Set to 'False' to avoid converting struct type inputs into a pandas DataFrame. - df_for_struct=False, - # Defines how struct type inputs are converted. If set to "row", struct type inputs - # are converted into Rows. Without this setting, a struct type input would be treated - # as a dictionary. For example, for named_struct('name', 'Alice', 'age', 1), - # if struct_in_pandas="dict", it becomes {"name": "Alice", "age": 1} - # if struct_in_pandas="row", it becomes Row(name="Alice", age=1) - struct_in_pandas="row", - # When dealing with array type inputs, Arrow converts them into numpy.ndarrays. - # To ensure consistency across regular and arrow-optimized UDTFs, we further - # convert these numpy.ndarrays into Python lists. - ndarray_as_list=True, - # Enables explicit casting for mismatched return types of Arrow Python UDTFs. - arrow_cast=True, - ) - self._converter_map = dict() - - def _create_batch(self, series): - """ - Create an Arrow record batch from the given pandas.Series pandas.DataFrame - or list of Series or DataFrame, with optional type. - - Parameters - ---------- - series : pandas.Series or pandas.DataFrame or list - A single series or dataframe, list of series or dataframe, - or list of (series or dataframe, arrow_type) - - Returns - ------- - pyarrow.RecordBatch - Arrow RecordBatch - """ - import pandas as pd - import pyarrow as pa - - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or ( - len(series) == 2 and isinstance(series[1], pa.DataType) - ): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - arrs = [] - for s, t in series: - if not isinstance(s, pd.DataFrame): - raise PySparkValueError( - "Output of an arrow-optimized Python UDTFs expects " - f"a pandas.DataFrame but got: {type(s)}" - ) - - arrs.append(self._create_struct_array(s, t)) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - def _get_or_create_converter_from_pandas(self, dt): - if dt not in self._converter_map: - conv = _create_converter_from_pandas( - dt, - timezone=self._timezone, - error_on_duplicated_field_names=False, - ignore_unexpected_complex_type_values=True, - ) - self._converter_map[dt] = conv - return self._converter_map[dt] - - def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): - """ - Override the `_create_array` method in the superclass to create an Arrow Array - from a given pandas.Series and an arrow type. The difference here is that we always - use arrow cast when creating the arrow array. Also, the error messages are specific - to arrow-optimized Python UDTFs. - - Parameters - ---------- - series : pandas.Series - A single series - arrow_type : pyarrow.DataType, optional - If None, pyarrow's inferred type will be used - spark_type : DataType, optional - If None, spark type converted from arrow_type will be used - arrow_cast: bool, optional - Whether to apply Arrow casting when the user-specified return type mismatches the - actual return values. - - Returns - ------- - pyarrow.Array - """ - import pyarrow as pa - import pandas as pd - - if isinstance(series.dtype, pd.CategoricalDtype): - series = series.astype(series.dtypes.categories.dtype) - - if arrow_type is not None: - dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) - conv = self._get_or_create_converter_from_pandas(dt) - series = conv(series) - - if hasattr(series.array, "__arrow_array__"): - mask = None - else: - mask = series.isnull() - - try: - try: - return pa.Array.from_pandas( - series, mask=mask, type=arrow_type, safe=self._safecheck - ) - except pa.lib.ArrowException: - if arrow_cast: - return pa.Array.from_pandas(series, mask=mask).cast( - target_type=arrow_type, safe=self._safecheck - ) - else: - raise - except pa.lib.ArrowException: - # Display the most user-friendly error messages instead of showing - # arrow's error message. This also works better with Spark Connect - # where the exception messages are by default truncated. - raise PySparkRuntimeError( - errorClass="UDTF_ARROW_TYPE_CAST_ERROR", - messageParameters={ - "col_name": series.name, - "col_type": str(series.dtype), - "arrow_type": arrow_type, - }, - ) from None - - def __repr__(self): - return "ArrowStreamPandasUDTFSerializer" - - class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): """ Serializes pyarrow.RecordBatch data with Arrow streaming format. diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index c727671dd59e5..6727f2951d4b7 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -3011,7 +3011,7 @@ def eval(self): ("x: smallint", err), ("x: int", err), ("x: bigint", err), - ("x: string", err), + ("x: string", [Row(x="[0, 1.1, 2]")]), ("x: date", err), ("x: timestamp", err), ("x: byte", err), @@ -3020,7 +3020,7 @@ def eval(self): ("x: double", err), ("x: decimal(10, 0)", err), ("x: array", [Row(x=["0", "1.1", "2"])]), - ("x: array", [Row(x=[False, True, True])]), + ("x: array", err), ("x: array", [Row(x=[0, 1, 2])]), ("x: array", [Row(x=[0, 1.1, 2])]), ("x: array>", err), diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 29dfd65c0e2b8..2e2e2adf3cf34 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -49,10 +49,10 @@ CPickleSerializer, BatchedSerializer, ) +from pyspark.sql.conversion import LocalDataToArrowConversion, ArrowTableToRowsConversion from pyspark.sql.functions import SkipRestOfInputTableException from pyspark.sql.pandas.serializers import ( ArrowStreamPandasUDFSerializer, - ArrowStreamPandasUDTFSerializer, CogroupArrowUDFSerializer, CogroupPandasUDFSerializer, ArrowStreamUDFSerializer, @@ -976,6 +976,8 @@ def use_large_var_types(runner_conf): # ensure the UDTF is valid. This function also prepares a mapper function for applying # the UDTF logic to input rows. def read_udtf(pickleSer, infile, eval_type): + prefers_large_var_types = False + if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: runner_conf = {} # Load conf used for arrow evaluation. @@ -984,14 +986,9 @@ def read_udtf(pickleSer, infile, eval_type): k = utf8_deserializer.loads(infile) v = utf8_deserializer.loads(infile) runner_conf[k] = v + prefers_large_var_types = use_large_var_types(runner_conf) - # NOTE: if timezone is set here, that implies respectSessionTimeZone is True - timezone = runner_conf.get("spark.sql.session.timeZone", None) - safecheck = ( - runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() - == "true" - ) - ser = ArrowStreamPandasUDTFSerializer(timezone, safecheck) + ser = ArrowStreamUDFSerializer() else: # Each row is a group so do not batch but send one by one. ser = BatchedSerializer(CPickleSerializer(), 1) @@ -1301,7 +1298,7 @@ def check(row): if eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF: def wrap_arrow_udtf(f, return_type): - import pandas as pd + import pyarrow as pa arrow_return_type = to_arrow_type( return_type, prefers_large_types=use_large_var_types(runner_conf) @@ -1309,7 +1306,7 @@ def wrap_arrow_udtf(f, return_type): return_type_size = len(return_type) def verify_result(result): - if not isinstance(result, pd.DataFrame): + if not isinstance(result, pa.RecordBatch): raise PySparkTypeError( errorClass="INVALID_ARROW_UDTF_RETURN_TYPE", messageParameters={ @@ -1335,8 +1332,12 @@ def verify_result(result): ) # Verify the type and the schema of the result. - verify_pandas_result( - result, return_type, assign_cols_by_name=False, truncate_return_schema=False + verify_arrow_result( + pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))), + assign_cols_by_name=False, + expected_cols_and_types=[ + (col.name, to_arrow_type(col.dataType)) for col in return_type.fields + ], ) return result @@ -1372,19 +1373,39 @@ def check_return_value(res): else: yield from res - def evaluate(*args: pd.Series): + def convert_to_arrow(data: Iterable): + data = list(check_return_value(data)) + if len(data) == 0: + return [ + pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type))) + ] + try: + return LocalDataToArrowConversion.convert( + data, return_type, prefers_large_var_types + ).to_batches() + except Exception as e: + raise PySparkRuntimeError( + errorClass="UDTF_ARROW_TYPE_CAST_ERROR", + messageParameters={ + "col_name": return_type.names, + "col_type": return_type.simpleString(), + "arrow_type": arrow_return_type, + }, + ) from e + + def evaluate(*args: pa.ChunkedArray): if len(args) == 0: - res = func() - yield verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type + for batch in convert_to_arrow(func()): + yield verify_result(batch), arrow_return_type + else: - # Create tuples from the input pandas Series, each tuple - # represents a row across all Series. - row_tuples = zip(*args) - for row in row_tuples: - res = func(*row) - yield verify_result( - pd.DataFrame(check_return_value(res)) - ), arrow_return_type + rows = ArrowTableToRowsConversion.convert( + pa.Table.from_arrays(list(args), names=["_0"]), + schema=return_type, + ) + for row in rows: + for batch in convert_to_arrow(func(*row)): + yield verify_result(batch), arrow_return_type return evaluate @@ -1404,7 +1425,7 @@ def mapper(_, it): try: for a in it: # The eval function yields an iterator. Each element produced by this - # iterator is a tuple in the form of (pandas.DataFrame, arrow_return_type). + # iterator is a tuple in the form of (pyarrow.RecordBatch, arrow_return_type). yield from eval(*[a[o] for o in args_kwargs_offsets]) if terminate is not None: yield from terminate()