diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 2831ae74f5606..9f3f88fa5cfec 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -55,6 +55,8 @@ private[spark] object PythonEvalType { val SQL_COGROUPED_MAP_PANDAS_UDF = 206 val SQL_MAP_ARROW_ITER_UDF = 207 val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 + val SQL_GROUPED_MAP_ARROW_UDF = 209 + val SQL_COGROUPED_MAP_ARROW_UDF = 210 val SQL_TABLE_UDF = 300 val SQL_ARROW_TABLE_UDF = 301 @@ -72,6 +74,8 @@ private[spark] object PythonEvalType { case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" + case SQL_GROUPED_MAP_ARROW_UDF => "SQL_GROUPED_MAP_ARROW_UDF" + case SQL_COGROUPED_MAP_ARROW_UDF => "SQL_COGROUPED_MAP_ARROW_UDF" case SQL_TABLE_UDF => "SQL_TABLE_UDF" case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF" } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d29fc8726018d..c25f852232016 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -469,6 +469,8 @@ def __hash__(self): "pyspark.sql.pandas.utils", "pyspark.sql.observation", # unittests + "pyspark.sql.tests.arrow.test_arrow_cogrouped_map", + "pyspark.sql.tests.arrow.test_arrow_grouped_map", "pyspark.sql.tests.test_arrow", "pyspark.sql.tests.test_arrow_python_udf", "pyspark.sql.tests.test_catalog", diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 0fbe489f623c2..8b8d1d4da36de 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -282,6 +282,11 @@ "`` should be one the values from PandasUDFType, got " ] }, + "INVALID_RETURN_TYPE_FOR_ARROW_UDF": { + "message": [ + "Grouped and Cogrouped map Arrow UDF should return StructType for , got ." + ] + }, "INVALID_RETURN_TYPE_FOR_PANDAS_UDF": { "message": [ "Pandas UDF should return StructType for , got ." @@ -648,6 +653,11 @@ "transformation. For more information, see SPARK-5063." ] }, + "RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDF" : { + "message" : [ + "Column names of the returned pyarrow.Table do not match specified schema." + ] + }, "RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF" : { "message" : [ "Column names of the returned pandas.DataFrame do not match specified schema." @@ -663,6 +673,11 @@ "The length of output in Scalar iterator pandas UDF should be the same with the input's; however, the length of output was and the length of input was ." ] }, + "RESULT_TYPE_MISMATCH_FOR_ARROW_UDF" : { + "message" : [ + "Columns do not match in their data type: ." + ] + }, "SCHEMA_MISMATCH_FOR_PANDAS_UDF" : { "message" : [ "Result vector from pandas_udf was not the required length: expected , got ." diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 8ea9a31022298..fb8a0a5b00140 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -107,6 +107,8 @@ PandasCogroupedMapUDFType, ArrowMapIterUDFType, PandasGroupedMapUDFWithStateType, + ArrowGroupedMapUDFType, + ArrowCogroupedMapUDFType, ) from pyspark.sql.dataframe import DataFrame from pyspark.sql.types import AtomicType, StructType @@ -158,6 +160,8 @@ class PythonEvalType: SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206 SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207 SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208 + SQL_GROUPED_MAP_ARROW_UDF: "ArrowGroupedMapUDFType" = 209 + SQL_COGROUPED_MAP_ARROW_UDF: "ArrowCogroupedMapUDFType" = 210 SQL_TABLE_UDF: "SQLTableUDFType" = 300 SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301 diff --git a/python/pyspark/sql/pandas/_typing/__init__.pyi b/python/pyspark/sql/pandas/_typing/__init__.pyi index 69279727ca9c2..0838f446279b9 100644 --- a/python/pyspark/sql/pandas/_typing/__init__.pyi +++ b/python/pyspark/sql/pandas/_typing/__init__.pyi @@ -53,6 +53,8 @@ PandasMapIterUDFType = Literal[205] PandasCogroupedMapUDFType = Literal[206] ArrowMapIterUDFType = Literal[207] PandasGroupedMapUDFWithStateType = Literal[208] +ArrowGroupedMapUDFType = Literal[209] +ArrowCogroupedMapUDFType = Literal[210] class PandasVariadicScalarToScalarFunction(Protocol): def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ... @@ -341,4 +343,13 @@ PandasCogroupedMapFunction = Union[ Callable[[Any, DataFrameLike, DataFrameLike], DataFrameLike], ] +ArrowGroupedMapFunction = Union[ + Callable[[pyarrow.Table], pyarrow.Table], + Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table], +] +ArrowCogroupedMapFunction = Union[ + Callable[[pyarrow.Table, pyarrow.Table], pyarrow.Table], + Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table, pyarrow.Table], pyarrow.Table], +] + GroupedMapPandasUserDefinedFunction = NewType("GroupedMapPandasUserDefinedFunction", FunctionType) diff --git a/python/pyspark/sql/pandas/functions.py b/python/pyspark/sql/pandas/functions.py index b7d381f04c76c..8c039ca1bed4d 100644 --- a/python/pyspark/sql/pandas/functions.py +++ b/python/pyspark/sql/pandas/functions.py @@ -378,6 +378,8 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]: PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, None, ]: # None means it should infer the type from type hints. @@ -416,6 +418,8 @@ def _create_pandas_udf(f, returnType, evalType): PythonEvalType.SQL_MAP_ARROW_ITER_UDF, PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, PythonEvalType.SQL_ARROW_BATCHED_UDF, ]: # In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered @@ -463,6 +467,15 @@ def _create_pandas_udf(f, returnType, evalType): }, ) + if evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF and len(argspec.args) not in (1, 2): + raise PySparkValueError( + error_class="INVALID_PANDAS_UDF", + message_parameters={ + "detail": "the function in groupby.applyInArrow must take either one argument " + "(data) or two arguments (key, data).", + }, + ) + if evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF and len(argspec.args) not in (2, 3): raise PySparkValueError( error_class="INVALID_PANDAS_UDF", @@ -472,6 +485,15 @@ def _create_pandas_udf(f, returnType, evalType): }, ) + if evalType == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF and len(argspec.args) not in (2, 3): + raise PySparkValueError( + error_class="INVALID_PANDAS_UDF", + message_parameters={ + "detail": "the function in cogroup.applyInArrow must take either two arguments " + "(left, right) or three arguments (key, left, right).", + }, + ) + if is_remote(): from pyspark.sql.connect.udf import _create_udf as _create_connect_udf diff --git a/python/pyspark/sql/pandas/group_ops.py b/python/pyspark/sql/pandas/group_ops.py index 56403482b9deb..d7e0a0b86c43c 100644 --- a/python/pyspark/sql/pandas/group_ops.py +++ b/python/pyspark/sql/pandas/group_ops.py @@ -30,6 +30,8 @@ PandasGroupedMapFunction, PandasGroupedMapFunctionWithState, PandasCogroupedMapFunction, + ArrowGroupedMapFunction, + ArrowCogroupedMapFunction, ) from pyspark.sql.group import GroupedData @@ -148,7 +150,7 @@ def applyInPandas( Examples -------- >>> import pandas as pd # doctest: +SKIP - >>> from pyspark.sql.functions import pandas_udf, ceil + >>> from pyspark.sql.functions import ceil >>> df = spark.createDataFrame( ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ... ("id", "v")) # doctest: +SKIP @@ -354,6 +356,133 @@ def applyInPandasWithState( ) return DataFrame(jdf, self.session) + def applyInArrow( + self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str] + ) -> "DataFrame": + """ + Maps each group of the current :class:`DataFrame` using an Arrow udf and returns the result + as a `DataFrame`. + + The function should take a `pyarrow.Table` and return another + `pyarrow.Table`. Alternatively, the user can pass a function that takes + a tuple of `pyarrow.Scalar` grouping key(s) and a `pyarrow.Table`. + For each group, all columns are passed together as a `pyarrow.Table` + to the user-function and the returned `pyarrow.Table` are combined as a + :class:`DataFrame`. + + The `schema` should be a :class:`StructType` describing the schema of the returned + `pyarrow.Table`. The column labels of the returned `pyarrow.Table` must either match + the field names in the defined schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pyarrow.Table` can be arbitrary. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + func : function + a Python native function that takes a `pyarrow.Table` and outputs a + `pyarrow.Table`, or that takes one tuple (grouping keys) and a + `pyarrow.Table` and outputs a `pyarrow.Table`. + schema : :class:`pyspark.sql.types.DataType` or str + the return type of the `func` in PySpark. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + + Examples + -------- + >>> from pyspark.sql.functions import ceil + >>> import pyarrow # doctest: +SKIP + >>> import pyarrow.compute as pc # doctest: +SKIP + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> def normalize(table): + ... v = table.column("v") + ... norm = pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1)) + ... return table.set_column(1, "v", norm) + >>> df.groupby("id").applyInArrow( + ... normalize, schema="id long, v double").show() # doctest: +SKIP + +---+-------------------+ + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + Alternatively, the user can pass a function that takes two arguments. + In this case, the grouping key(s) will be passed as the first argument and the data will + be passed as the second argument. The grouping key(s) will be passed as a tuple of Arrow + scalars types, e.g., `pyarrow.Int32Scalar` and `pyarrow.FloatScalar`. The data will still + be passed in as a `pyarrow.Table` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key(s) in the function. + + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> def mean_func(key, table): + ... # key is a tuple of one pyarrow.Int64Scalar, which is the value + ... # of 'id' for the current group + ... mean = pc.mean(table.column("v")) + ... return pyarrow.Table.from_pydict({"id": [key[0].as_py()], "v": [mean.as_py()]}) + >>> df.groupby('id').applyInArrow( + ... mean_func, schema="id long, v double") # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + + >>> def sum_func(key, table): + ... # key is a tuple of two pyarrow.Int64Scalars, which is the values + ... # of 'id' and 'ceil(df.v / 2)' for the current group + ... sum = pc.sum(table.column("v")) + ... return pyarrow.Table.from_pydict({ + ... "id": [key[0].as_py()], + ... "ceil(v / 2)": [key[1].as_py()], + ... "v": [sum.as_py()] + ... }) + >>> df.groupby(df.id, ceil(df.v / 2)).applyInArrow( + ... sum_func, schema="id long, `ceil(v / 2)` long, v double").show() # doctest: +SKIP + +---+-----------+----+ + | id|ceil(v / 2)| v| + +---+-----------+----+ + | 2| 5|10.0| + | 1| 1| 3.0| + | 2| 3| 5.0| + | 2| 2| 3.0| + +---+-----------+----+ + + Notes + ----- + This function requires a full shuffle. All the data of a group will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + This API is experimental. + + See Also + -------- + pyspark.sql.functions.pandas_udf + """ + from pyspark.sql import GroupedData + from pyspark.sql.functions import pandas_udf + + assert isinstance(self, GroupedData) + + # The usage of the pandas_udf is internal so type checking is disabled. + udf = pandas_udf( + func, returnType=schema, functionType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF + ) # type: ignore[call-overload] + df = self._df + udf_column = udf(*[df[col] for col in df.columns]) + jdf = self._jgd.flatMapGroupsInArrow(udf_column._jc.expr()) + return DataFrame(jdf, self.session) + def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": """ Cogroups this group with another group so that we can run cogrouped operations. @@ -428,7 +557,6 @@ def applyInPandas( Examples -------- - >>> from pyspark.sql.functions import pandas_udf >>> df1 = spark.createDataFrame( ... [(20000101, 1, 1.0), (20000101, 2, 2.0), (20000102, 1, 3.0), (20000102, 2, 4.0)], ... ("time", "id", "v1")) @@ -495,6 +623,104 @@ def applyInPandas( jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) return DataFrame(jdf, self._gd1.session) + def applyInArrow( + self, func: "ArrowCogroupedMapFunction", schema: Union[StructType, str] + ) -> "DataFrame": + """ + Applies a function to each cogroup using Arrow and returns the result + as a `DataFrame`. + + The function should take two `pyarrow.Table`s and return another + `pyarrow.Table`. Alternatively, the user can pass a function that takes + a tuple of `pyarrow.Scalar` grouping key(s) and the two `pyarrow.Table`s. + For each side of the cogroup, all columns are passed together as a + `pyarrow.Table` to the user-function and the returned `pyarrow.Table` are combined as + a :class:`DataFrame`. + + The `schema` should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. The column labels of the returned `pandas.DataFrame` must either match + the field names in the defined schema if specified as strings, or match the + field data types by position if not strings, e.g. integer indices. + The length of the returned `pyarrow.Table` can be arbitrary. + + .. versionadded:: 4.0.0 + + Parameters + ---------- + func : function + a Python native function that takes two `pyarrow.Table`s, and + outputs a `pyarrow.Table`, or that takes one tuple (grouping keys) and two + ``pyarrow.Table``s, and outputs a ``pyarrow.Table``. + schema : :class:`pyspark.sql.types.DataType` or str + the return type of the `func` in PySpark. The value can be either a + :class:`pyspark.sql.types.DataType` object or a DDL-formatted type string. + + Examples + -------- + >>> import pyarrow # doctest: +SKIP + >>> df1 = spark.createDataFrame([(1, 1.0), (2, 2.0), (1, 3.0), (2, 4.0)], ("id", "v1")) + >>> df2 = spark.createDataFrame([(1, "x"), (2, "y")], ("id", "v2")) + >>> def summarize(l, r): + ... return pyarrow.Table.from_pydict({ + ... "left": [l.num_rows], + ... "right": [r.num_rows] + ... }) + >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow( + ... summarize, schema="left long, right long" + ... ).show() # doctest: +SKIP + +----+-----+ + |left|right| + +----+-----+ + | 2| 1| + | 2| 1| + +----+-----+ + + Alternatively, the user can define a function that takes three arguments. In this case, + the grouping key(s) will be passed as the first argument and the data will be passed as the + second and third arguments. The grouping key(s) will be passed as a tuple of Arrow scalars + types, e.g., `pyarrow.Int32Scalar` and `pyarrow.FloatScalar`. The data will still be passed + in as two `pyarrow.Table`s containing all columns from the original Spark DataFrames. + + >>> def summarize(key, l, r): + ... return pyarrow.Table.from_pydict({ + ... "key": [key[0].as_py()], + ... "left": [l.num_rows], + ... "right": [r.num_rows] + ... }) + >>> df1.groupby("id").cogroup(df2.groupby("id")).applyInArrow( + ... summarize, schema="key long, left long, right long" + ... ).show() # doctest: +SKIP + +---+----+-----+ + |key|left|right| + +---+----+-----+ + | 1| 2| 1| + | 2| 2| 1| + +---+----+-----+ + + Notes + ----- + This function requires a full shuffle. All the data of a cogroup will be loaded + into memory, so the user should be aware of the potential OOM risk if data is skewed + and certain groups are too large to fit in memory. + + This API is experimental. + + See Also + -------- + pyspark.sql.functions.pandas_udf + """ + from pyspark.sql.pandas.functions import pandas_udf + + # The usage of the pandas_udf is internal so type checking is disabled. + udf = pandas_udf( + func, returnType=schema, functionType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF + ) # type: ignore[call-overload] + + all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) + udf_column = udf(*all_cols) + jdf = self._gd1._jgd.flatMapCoGroupsInArrow(self._gd2._jgd, udf_column._jc.expr()) + return DataFrame(jdf, self._gd1.session) + @staticmethod def _extract_cols(gd: "GroupedData") -> List[Column]: df = gd._df diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 2cc3db15c9cd5..87b967bf91bd4 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -161,6 +161,49 @@ def wrap_and_init_stream(): return super(ArrowStreamUDFSerializer, self).dump_stream(wrap_and_init_stream(), stream) +class ArrowStreamGroupUDFSerializer(ArrowStreamUDFSerializer): + """ + Serializes pyarrow.RecordBatch data with Arrow streaming format. + + Loads Arrow record batches as ``[[pyarrow.RecordBatch]]`` (one ``[pyarrow.RecordBatch]`` per + group) and serializes ``[([pyarrow.RecordBatch], arrow_type)]``. + + Parameters + ---------- + assign_cols_by_name : bool + If True, then DataFrames will get columns by name + """ + + def __init__(self, assign_cols_by_name): + super(ArrowStreamGroupUDFSerializer, self).__init__() + self._assign_cols_by_name = assign_cols_by_name + + def dump_stream(self, iterator, stream): + import pyarrow as pa + + # flatten inner list [([pa.RecordBatch], arrow_type)] into [(pa.RecordBatch, arrow_type)] + # so strip off inner iterator induced by ArrowStreamUDFSerializer.load_stream + batch_iter = ( + (batch, arrow_type) + for batches, arrow_type in iterator # tuple constructed in wrap_grouped_map_arrow_udf + for batch in batches + ) + + if self._assign_cols_by_name: + batch_iter = ( + ( + pa.RecordBatch.from_arrays( + [batch.column(field.name) for field in arrow_type], + names=[field.name for field in arrow_type], + ), + arrow_type, + ) + for batch, arrow_type in batch_iter + ) + + super(ArrowStreamGroupUDFSerializer, self).dump_stream(batch_iter, stream) + + class ArrowStreamPandasSerializer(ArrowStreamSerializer): """ Serializes pandas.Series as Arrow data with Arrow streaming format. @@ -618,7 +661,43 @@ def __repr__(self): return "ArrowStreamPandasUDTFSerializer" -class CogroupUDFSerializer(ArrowStreamPandasUDFSerializer): +class CogroupArrowUDFSerializer(ArrowStreamGroupUDFSerializer): + """ + Serializes pyarrow.RecordBatch data with Arrow streaming format. + + Loads Arrow record batches as `[([pa.RecordBatch], [pa.RecordBatch])]` (one tuple per group) + and serializes `[([pa.RecordBatch], arrow_type)]`. + + Parameters + ---------- + assign_cols_by_name : bool + If True, then DataFrames will get columns by name + """ + + def __init__(self, assign_cols_by_name): + super(CogroupArrowUDFSerializer, self).__init__(assign_cols_by_name) + + def load_stream(self, stream): + """ + Deserialize Cogrouped ArrowRecordBatches and yield as two `pyarrow.RecordBatch`es. + """ + dataframes_in_group = None + + while dataframes_in_group is None or dataframes_in_group > 0: + dataframes_in_group = read_int(stream) + + if dataframes_in_group == 2: + batches1 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] + batches2 = [batch for batch in ArrowStreamSerializer.load_stream(self, stream)] + yield batches1, batches2 + + elif dataframes_in_group != 0: + raise ValueError( + "Invalid number of dataframes in group {0}".format(dataframes_in_group) + ) + + +class CogroupPandasUDFSerializer(ArrowStreamPandasUDFSerializer): def load_stream(self, stream): """ Deserialize Cogrouped ArrowRecordBatches to a tuple of Arrow tables and yield as two diff --git a/python/pyspark/sql/tests/arrow/__init__.py b/python/pyspark/sql/tests/arrow/__init__.py new file mode 100644 index 0000000000000..cce3acad34a49 --- /dev/null +++ b/python/pyspark/sql/tests/arrow/__init__.py @@ -0,0 +1,16 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py new file mode 100644 index 0000000000000..0206d4c2c6ded --- /dev/null +++ b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py @@ -0,0 +1,300 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import time +import unittest + +from pyspark.errors import PythonException +from pyspark.sql import Row +from pyspark.sql.functions import col +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pyarrow, + pyarrow_requirement_message, +) +from pyspark.testing.utils import QuietTest + + +if have_pyarrow: + import pyarrow as pa + import pyarrow.compute as pc + + +@unittest.skipIf( + not have_pyarrow, + pyarrow_requirement_message, # type: ignore[arg-type] +) +class CogroupedMapInArrowTests(ReusedSQLTestCase): + @property + def left(self): + return self.spark.range(0, 10, 2, 3).withColumn("v", col("id") * 10) + + @property + def right(self): + return self.spark.range(0, 10, 3, 3).withColumn("v", col("id") * 10) + + @property + def cogrouped(self): + grouped_left_df = self.left.groupBy((col("id") / 4).cast("int")) + grouped_right_df = self.right.groupBy((col("id") / 4).cast("int")) + return grouped_left_df.cogroup(grouped_right_df) + + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + @staticmethod + def apply_in_arrow_func(left, right): + assert isinstance(left, pa.Table) + assert isinstance(right, pa.Table) + assert left.schema.names == ["id", "v"] + assert right.schema.names == ["id", "v"] + + left_ids = left.to_pydict()["id"] + right_ids = right.to_pydict()["id"] + result = { + "metric": ["min", "max", "len", "sum"], + "left": [min(left_ids), max(left_ids), len(left_ids), sum(left_ids)], + "right": [min(right_ids), max(right_ids), len(right_ids), sum(right_ids)], + } + return pa.Table.from_pydict(result) + + @staticmethod + def apply_in_arrow_with_key_func(key_column): + def func(key, left, right): + assert isinstance(key, tuple) + assert all(isinstance(scalar, pa.Scalar) for scalar in key) + if key_column: + assert all( + (pc.divide(k, pa.scalar(4)).cast(pa.int32()),) == key + for table in [left, right] + for k in table.column(key_column) + ) + return CogroupedMapInArrowTests.apply_in_arrow_func(left, right) + + return func + + @staticmethod + def apply_in_pandas_with_key_func(key_column): + def func(key, left, right): + return CogroupedMapInArrowTests.apply_in_arrow_with_key_func(key_column)( + tuple(pa.scalar(k) for k in key), + pa.Table.from_pandas(left), + pa.Table.from_pandas(right), + ).to_pandas() + + return func + + def do_test_apply_in_arrow(self, cogrouped_df, key_column="id"): + schema = "metric string, left long, right long" + + # compare with result of applyInPandas + expected = cogrouped_df.applyInPandas( + CogroupedMapInArrowTests.apply_in_pandas_with_key_func(key_column), schema + ) + + # apply in arrow without key + actual = cogrouped_df.applyInArrow( + CogroupedMapInArrowTests.apply_in_arrow_func, schema + ).collect() + self.assertEqual(actual, expected.collect()) + + # apply in arrow with key + actual2 = cogrouped_df.applyInArrow( + CogroupedMapInArrowTests.apply_in_arrow_with_key_func(key_column), schema + ).collect() + self.assertEqual(actual2, expected.collect()) + + def test_apply_in_arrow(self): + self.do_test_apply_in_arrow(self.cogrouped) + + def test_apply_in_arrow_empty_groupby(self): + grouped_left_df = self.left.groupBy() + grouped_right_df = self.right.groupBy() + cogrouped_df = grouped_left_df.cogroup(grouped_right_df) + self.do_test_apply_in_arrow(cogrouped_df, key_column=None) + + def test_apply_in_arrow_not_returning_arrow_table(self): + def func(key, left, right): + return key + + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + "Return type of the user-defined function should be pyarrow.Table, but is tuple", + ): + self.cogrouped.applyInArrow(func, schema="id long").collect() + + def test_apply_in_arrow_returning_wrong_types(self): + for schema, expected in [ + ("id integer, v long", "column 'id' \\(expected int32, actual int64\\)"), + ( + "id integer, v integer", + "column 'id' \\(expected int32, actual int64\\), " + "column 'v' \\(expected int32, actual int64\\)", + ), + ("id long, v integer", "column 'v' \\(expected int32, actual int64\\)"), + ("id long, v string", "column 'v' \\(expected string, actual int64\\)"), + ]: + with self.subTest(schema=schema): + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + f"Columns do not match in their data type: {expected}", + ): + self.cogrouped.applyInArrow( + lambda left, right: left, schema=schema + ).collect() + + def test_apply_in_arrow_returning_wrong_types_positional_assignment(self): + for schema, expected in [ + ("a integer, b long", "column 'a' \\(expected int32, actual int64\\)"), + ( + "a integer, b integer", + "column 'a' \\(expected int32, actual int64\\), " + "column 'b' \\(expected int32, actual int64\\)", + ), + ("a long, b int", "column 'b' \\(expected int32, actual int64\\)"), + ("a long, b string", "column 'b' \\(expected string, actual int64\\)"), + ]: + with self.subTest(schema=schema): + with self.sql_conf( + {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} + ): + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + f"Columns do not match in their data type: {expected}", + ): + self.cogrouped.applyInArrow( + lambda left, right: left, schema=schema + ).collect() + + def test_apply_in_arrow_returning_wrong_column_names(self): + def stats(key, left, right): + # returning three columns + return pa.Table.from_pydict( + { + "id": [key[0].as_py()], + "v": [pc.mean(left.column("v")).as_py()], + "v2": [pc.stddev(right.column("v")).as_py()], + } + ) + + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + "Column names of the returned pyarrow.Table do not match specified schema. " + "Missing: m. Unexpected: v, v2.\n", + ): + # stats returns three columns while here we set schema with two columns + self.cogrouped.applyInArrow(stats, schema="id long, m double").collect() + + def test_apply_in_arrow_returning_empty_dataframe(self): + def odd_means(key, left, right): + if key[0].as_py() == 0: + return pa.table([]) + else: + return pa.Table.from_pydict( + { + "id": [key[0].as_py()], + "m": [pc.mean(left.column("v")).as_py()], + "n": [pc.mean(right.column("v")).as_py()], + } + ) + + schema = "id long, m double, n double" + actual = self.cogrouped.applyInArrow(odd_means, schema=schema).sort("id").collect() + expected = [Row(id=1, m=50.0, n=60.0), Row(id=2, m=80.0, n=90.0)] + self.assertEqual(expected, actual) + + def test_apply_in_arrow_returning_empty_dataframe_and_wrong_column_names(self): + def odd_means(key, left, _): + if key[0].as_py() % 2 == 0: + return pa.table([[]], names=["id"]) + else: + return pa.Table.from_pydict( + {"id": [key[0].as_py()], "m": [pc.mean(left.column("v")).as_py()]} + ) + + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + "Column names of the returned pyarrow.Table do not match specified schema. " + "Missing: m.\n", + ): + # stats returns one column for even keys while here we set schema with two columns + self.cogrouped.applyInArrow(odd_means, schema="id long, m double").collect() + + def test_apply_in_arrow_column_order(self): + df = self.left + expected = df.select(df.id, (df.v * 3).alias("u"), df.v).collect() + + # Function returns a table with required column names but different order + def change_col_order(left, _): + return left.append_column("u", pc.multiply(left.column("v"), 3)) + + # The result should assign columns by name from the table + result = ( + self.cogrouped.applyInArrow(change_col_order, "id long, u long, v long") + .sort("id", "v") + .select("id", "u", "v") + .collect() + ) + self.assertEqual(expected, result) + + def test_positional_assignment_conf(self): + with self.sql_conf( + {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} + ): + + def foo(left, right): + return pa.Table.from_pydict({"x": ["hi"], "y": [1]}) + + result = self.cogrouped.applyInArrow(foo, "a string, b long").select("a", "b").collect() + for r in result: + self.assertEqual(r.a, "hi") + self.assertEqual(r.b, 1) + + +if __name__ == "__main__": + from pyspark.sql.tests.arrow.test_arrow_cogrouped_map import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py new file mode 100644 index 0000000000000..fa43648d42dcc --- /dev/null +++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py @@ -0,0 +1,291 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import time +import unittest + +from pyspark.errors import PythonException +from pyspark.sql import Row +from pyspark.sql.functions import array, col, explode, lit, mean, stddev +from pyspark.sql.window import Window +from pyspark.testing.sqlutils import ( + ReusedSQLTestCase, + have_pyarrow, + pyarrow_requirement_message, +) +from pyspark.testing.utils import QuietTest + + +if have_pyarrow: + import pyarrow as pa + import pyarrow.compute as pc + + +@unittest.skipIf( + not have_pyarrow, + pyarrow_requirement_message, # type: ignore[arg-type] +) +class GroupedMapInArrowTests(ReusedSQLTestCase): + @property + def data(self): + return ( + self.spark.range(10) + .toDF("id") + .withColumn("vs", array([lit(i) for i in range(20, 30)])) + .withColumn("v", explode(col("vs"))) + .drop("vs") + ) + + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + def test_apply_in_arrow(self): + def func(group): + assert isinstance(group, pa.Table) + assert group.schema.names == ["id", "value"] + return group + + df = self.spark.range(10).withColumn("value", col("id") * 10) + grouped_df = df.groupBy((col("id") / 4).cast("int")) + expected = df.collect() + + actual = grouped_df.applyInArrow(func, "id long, value long").collect() + self.assertEqual(actual, expected) + + def test_apply_in_arrow_with_key(self): + def func(key, group): + assert isinstance(key, tuple) + assert all(isinstance(scalar, pa.Scalar) for scalar in key) + assert isinstance(group, pa.Table) + assert group.schema.names == ["id", "value"] + assert all( + (pc.divide(k, pa.scalar(4)).cast(pa.int32()),) == key for k in group.column("id") + ) + return group + + df = self.spark.range(10).withColumn("value", col("id") * 10) + grouped_df = df.groupBy((col("id") / 4).cast("int")) + expected = df.collect() + + actual2 = grouped_df.applyInArrow(func, "id long, value long").collect() + self.assertEqual(actual2, expected) + + def test_apply_in_arrow_empty_groupby(self): + df = self.data + + def normalize(table): + v = table.column("v") + return table.set_column( + 1, "v", pc.divide(pc.subtract(v, pc.mean(v)), pc.stddev(v, ddof=1)) + ) + + # casting doubles to floats to get rid of numerical precision issues + # when comparing Arrow and Spark values + actual = ( + df.groupby() + .applyInArrow(normalize, "id long, v double") + .withColumn("v", col("v").cast("float")) + .sort("id", "v") + ) + windowSpec = Window.partitionBy() + expected = df.withColumn( + "v", + ((df.v - mean(df.v).over(windowSpec)) / stddev(df.v).over(windowSpec)).cast("float"), + ) + self.assertEqual(actual.collect(), expected.collect()) + + def test_apply_in_arrow_not_returning_arrow_table(self): + df = self.data + + def stats(key, _): + return key + + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + "Return type of the user-defined function should be pyarrow.Table, but is tuple", + ): + df.groupby("id").applyInArrow(stats, schema="id long, m double").collect() + + def test_apply_in_arrow_returning_wrong_types(self): + df = self.data + + for schema, expected in [ + ("id integer, v integer", "column 'id' \\(expected int32, actual int64\\)"), + ( + "id integer, v long", + "column 'id' \\(expected int32, actual int64\\), " + "column 'v' \\(expected int64, actual int32\\)", + ), + ("id long, v long", "column 'v' \\(expected int64, actual int32\\)"), + ("id long, v string", "column 'v' \\(expected string, actual int32\\)"), + ]: + with self.subTest(schema=schema): + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + f"Columns do not match in their data type: {expected}", + ): + df.groupby("id").applyInArrow(lambda table: table, schema=schema).collect() + + def test_apply_in_arrow_returning_wrong_types_positional_assignment(self): + df = self.data + + for schema, expected in [ + ("a integer, b integer", "column 'a' \\(expected int32, actual int64\\)"), + ( + "a integer, b long", + "column 'a' \\(expected int32, actual int64\\), " + "column 'b' \\(expected int64, actual int32\\)", + ), + ("a long, b long", "column 'b' \\(expected int64, actual int32\\)"), + ("a long, b string", "column 'b' \\(expected string, actual int32\\)"), + ]: + with self.subTest(schema=schema): + with self.sql_conf( + {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} + ): + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + f"Columns do not match in their data type: {expected}", + ): + df.groupby("id").applyInArrow( + lambda table: table, schema=schema + ).collect() + + def test_apply_in_arrow_returning_wrong_column_names(self): + df = self.data + + def stats(key, table): + # returning three columns + return pa.Table.from_pydict( + { + "id": [key[0].as_py()], + "v": [pc.mean(table.column("v")).as_py()], + "v2": [pc.stddev(table.column("v")).as_py()], + } + ) + + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + "Column names of the returned pyarrow.Table do not match specified schema. " + "Missing: m. Unexpected: v, v2.\n", + ): + # stats returns three columns while here we set schema with two columns + df.groupby("id").applyInArrow(stats, schema="id long, m double").collect() + + def test_apply_in_arrow_returning_empty_dataframe(self): + df = self.data + + def odd_means(key, table): + if key[0].as_py() % 2 == 0: + return pa.table([]) + else: + return pa.Table.from_pydict( + {"id": [key[0].as_py()], "m": [pc.mean(table.column("v")).as_py()]} + ) + + schema = "id long, m double" + actual = df.groupby("id").applyInArrow(odd_means, schema=schema).sort("id").collect() + expected = [Row(id=id, m=24.5) for id in range(1, 10, 2)] + self.assertEqual(expected, actual) + + def test_apply_in_arrow_returning_empty_dataframe_and_wrong_column_names(self): + df = self.data + + def odd_means(key, table): + if key[0].as_py() % 2 == 0: + return pa.table([[]], names=["id"]) + else: + return pa.Table.from_pydict( + {"id": [key[0].as_py()], "m": [pc.mean(table.column("v")).as_py()]} + ) + + with QuietTest(self.sc): + with self.assertRaisesRegex( + PythonException, + "Column names of the returned pyarrow.Table do not match specified schema. " + "Missing: m.\n", + ): + # stats returns one column for even keys while here we set schema with two columns + df.groupby("id").applyInArrow(odd_means, schema="id long, m double").collect() + + def test_apply_in_arrow_column_order(self): + df = self.data + grouped_df = df.groupby("id") + expected = df.select(df.id, (df.v * 3).alias("u"), df.v).collect() + + # Function returns a table with required column names but different order + def change_col_order(table): + return table.append_column("u", pc.multiply(table.column("v"), 3)) + + # The result should assign columns by name from the table + result = ( + grouped_df.applyInArrow(change_col_order, "id long, u long, v int") + .sort("id", "v") + .select("id", "u", "v") + .collect() + ) + self.assertEqual(expected, result) + + def test_positional_assignment_conf(self): + with self.sql_conf( + {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} + ): + + def foo(_): + return pa.Table.from_pydict({"x": ["hi"], "y": [1]}) + + df = self.data + result = ( + df.groupBy("id").applyInArrow(foo, "a string, b long").select("a", "b").collect() + ) + for r in result: + self.assertEqual(r.a, "hi") + self.assertEqual(r.b, 1) + + +if __name__ == "__main__": + from pyspark.sql.tests.arrow.test_arrow_grouped_map import * # noqa: F401 + + try: + import xmlrunner # type: ignore[import] + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index bdd3aba502b89..b915fcae98f09 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -274,6 +274,26 @@ def returnType(self) -> DataType: "return_type": str(self._returnType_placeholder), }, ) + elif self.evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: + if isinstance(self._returnType_placeholder, StructType): + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise PySparkNotImplementedError( + error_class="NOT_IMPLEMENTED", + message_parameters={ + "feature": "Invalid return type with grouped map Arrow UDFs or " + f"at groupby.applyInArrow: {self._returnType_placeholder}" + }, + ) + else: + raise PySparkTypeError( + error_class="INVALID_RETURN_TYPE_FOR_ARROW_UDF", + message_parameters={ + "eval_type": "SQL_GROUPED_MAP_ARROW_UDF", + "return_type": str(self._returnType_placeholder), + }, + ) elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: if isinstance(self._returnType_placeholder, StructType): try: @@ -294,6 +314,26 @@ def returnType(self) -> DataType: "return_type": str(self._returnType_placeholder), }, ) + elif self.evalType == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: + if isinstance(self._returnType_placeholder, StructType): + try: + to_arrow_type(self._returnType_placeholder) + except TypeError: + raise PySparkNotImplementedError( + error_class="NOT_IMPLEMENTED", + message_parameters={ + "feature": "Invalid return type in cogroup.applyInArrow: " + f"{self._returnType_placeholder}" + }, + ) + else: + raise PySparkTypeError( + error_class="INVALID_RETURN_TYPE_FOR_ARROW_UDF", + message_parameters={ + "eval_type": "SQL_COGROUPED_MAP_ARROW_UDF", + "return_type": str(self._returnType_placeholder), + }, + ) elif self.evalType == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: try: # StructType is not yet allowed as a return type, explicitly check here to fail fast diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 90b11d0623166..4becda459bcd9 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -58,8 +58,10 @@ from pyspark.sql.pandas.serializers import ( ArrowStreamPandasUDFSerializer, ArrowStreamPandasUDTFSerializer, - CogroupUDFSerializer, + CogroupArrowUDFSerializer, + CogroupPandasUDFSerializer, ArrowStreamUDFSerializer, + ArrowStreamGroupUDFSerializer, ApplyInPandasWithStateSerializer, ) from pyspark.sql.pandas.types import to_arrow_type @@ -306,6 +308,33 @@ def verify_element(elem): ) +def wrap_cogrouped_map_arrow_udf(f, return_type, argspec, runner_conf): + _assign_cols_by_name = assign_cols_by_name(runner_conf) + + if _assign_cols_by_name: + expected_cols_and_types = { + col.name: to_arrow_type(col.dataType) for col in return_type.fields + } + else: + expected_cols_and_types = [ + (col.name, to_arrow_type(col.dataType)) for col in return_type.fields + ] + + def wrapped(left_key_table, left_value_table, right_key_table, right_value_table): + if len(argspec.args) == 2: + result = f(left_value_table, right_value_table) + elif len(argspec.args) == 3: + key_table = left_key_table if left_key_table.num_rows > 0 else right_key_table + key = tuple(c[0] for c in key_table.columns) + result = f(key, left_value_table, right_value_table) + + verify_arrow_result(result, _assign_cols_by_name, expected_cols_and_types) + + return result.to_batches() + + return lambda kl, vl, kr, vr: (wrapped(kl, vl, kr, vr), to_arrow_type(return_type)) + + def wrap_cogrouped_map_pandas_udf(f, return_type, argspec, runner_conf): _assign_cols_by_name = assign_cols_by_name(runner_conf) @@ -330,6 +359,104 @@ def wrapped(left_key_series, left_value_series, right_key_series, right_value_se return lambda kl, vl, kr, vr: [(wrapped(kl, vl, kr, vr), to_arrow_type(return_type))] +def verify_arrow_result(table, assign_cols_by_name, expected_cols_and_types): + import pyarrow as pa + + if not isinstance(table, pa.Table): + raise PySparkTypeError( + error_class="UDF_RETURN_TYPE", + message_parameters={ + "expected": "pyarrow.Table", + "actual": type(table).__name__, + }, + ) + + # the types of the fields have to be identical to return type + # an empty table can have no columns; if there are columns, they have to match + if table.num_columns != 0 or table.num_rows != 0: + # columns are either mapped by name or position + if assign_cols_by_name: + actual_cols_and_types = { + name: dataType for name, dataType in zip(table.schema.names, table.schema.types) + } + missing = sorted( + list(set(expected_cols_and_types.keys()).difference(actual_cols_and_types.keys())) + ) + extra = sorted( + list(set(actual_cols_and_types.keys()).difference(expected_cols_and_types.keys())) + ) + + if missing or extra: + missing = f" Missing: {', '.join(missing)}." if missing else "" + extra = f" Unexpected: {', '.join(extra)}." if extra else "" + + raise PySparkRuntimeError( + error_class="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDF", + message_parameters={ + "missing": missing, + "extra": extra, + }, + ) + + column_types = [ + (name, expected_cols_and_types[name], actual_cols_and_types[name]) + for name in sorted(expected_cols_and_types.keys()) + ] + else: + actual_cols_and_types = [ + (name, dataType) for name, dataType in zip(table.schema.names, table.schema.types) + ] + column_types = [ + (expected_name, expected_type, actual_type) + for (expected_name, expected_type), (actual_name, actual_type) in zip( + expected_cols_and_types, actual_cols_and_types + ) + ] + + type_mismatch = [ + (name, expected, actual) + for name, expected, actual in column_types + if actual != expected + ] + + if type_mismatch: + raise PySparkRuntimeError( + error_class="RESULT_TYPE_MISMATCH_FOR_ARROW_UDF", + message_parameters={ + "mismatch": ", ".join( + "column '{}' (expected {}, actual {})".format(name, expected, actual) + for name, expected, actual in type_mismatch + ) + }, + ) + + +def wrap_grouped_map_arrow_udf(f, return_type, argspec, runner_conf): + _assign_cols_by_name = assign_cols_by_name(runner_conf) + + if _assign_cols_by_name: + expected_cols_and_types = { + col.name: to_arrow_type(col.dataType) for col in return_type.fields + } + else: + expected_cols_and_types = [ + (col.name, to_arrow_type(col.dataType)) for col in return_type.fields + ] + + def wrapped(key_table, value_table): + if len(argspec.args) == 1: + result = f(value_table) + elif len(argspec.args) == 2: + key = tuple(c[0] for c in key_table.columns) + result = f(key, value_table) + + verify_arrow_result(result, _assign_cols_by_name, expected_cols_and_types) + + return result.to_batches() + + return lambda k, v: (wrapped(k, v), to_arrow_type(return_type)) + + def wrap_grouped_map_pandas_udf(f, return_type, argspec, runner_conf): _assign_cols_by_name = assign_cols_by_name(runner_conf) @@ -555,12 +682,18 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): return arg_offsets, wrap_arrow_batch_iter_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it - return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) + return args_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec, runner_conf) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: + argspec = getfullargspec(chained_func) # signature was lost when wrapping it + return args_offsets, wrap_grouped_map_arrow_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: return arg_offsets, wrap_grouped_map_pandas_udf_with_state(func, return_type) elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: argspec = getfullargspec(chained_func) # signature was lost when wrapping it - return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf) + return args_offsets, wrap_cogrouped_map_pandas_udf(func, return_type, argspec, runner_conf) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: + argspec = getfullargspec(chained_func) # signature was lost when wrapping it + return args_offsets, wrap_cogrouped_map_arrow_udf(func, return_type, argspec, runner_conf) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: @@ -571,7 +704,10 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): raise ValueError("Unknown eval type: {}".format(eval_type)) -# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and SQL_ARROW_BATCHED_UDF when +# Used by SQL_GROUPED_MAP_PANDAS_UDF, SQL_GROUPED_MAP_ARROW_UDF, +# SQL_COGROUPED_MAP_PANDAS_UDF, SQL_COGROUPED_MAP_ARROW_UDF, +# SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, +# SQL_SCALAR_PANDAS_UDF and SQL_ARROW_BATCHED_UDF when # returning StructType def assign_cols_by_name(runner_conf): return ( @@ -831,6 +967,8 @@ def read_udfs(pickleSer, infile, eval_type): PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE, + PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, + PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, ): # Load conf used for pandas_udf evaluation @@ -850,9 +988,12 @@ def read_udfs(pickleSer, infile, eval_type): runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() == "true" ) + _assign_cols_by_name = assign_cols_by_name(runner_conf) - if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: - ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name(runner_conf)) + if eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: + ser = CogroupArrowUDFSerializer(_assign_cols_by_name) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: + ser = CogroupPandasUDFSerializer(timezone, safecheck, _assign_cols_by_name) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: arrow_max_records_per_batch = runner_conf.get( "spark.sql.execution.arrow.maxRecordsPerBatch", 10000 @@ -862,12 +1003,14 @@ def read_udfs(pickleSer, infile, eval_type): ser = ApplyInPandasWithStateSerializer( timezone, safecheck, - assign_cols_by_name(runner_conf), + _assign_cols_by_name, state_object_schema, arrow_max_records_per_batch, ) elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF: ser = ArrowStreamUDFSerializer() + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: + ser = ArrowStreamGroupUDFSerializer(_assign_cols_by_name) else: # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of # pandas Series. See SPARK-27240. @@ -886,7 +1029,7 @@ def read_udfs(pickleSer, infile, eval_type): ser = ArrowStreamPandasUDFSerializer( timezone, safecheck, - assign_cols_by_name(runner_conf), + _assign_cols_by_name, df_for_struct, struct_in_pandas, ndarray_as_list, @@ -1009,6 +1152,32 @@ def mapper(a): vals = [a[o] for o in parsed_offsets[0][1]] return f(keys, vals) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF: + import pyarrow as pa + + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes + arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def batch_from_offset(batch, offsets): + return pa.RecordBatch.from_arrays( + arrays=[batch.columns[o] for o in offsets], + names=[batch.schema.names[o] for o in offsets], + ) + + def table_from_batches(batches, offsets): + return pa.Table.from_batches([batch_from_offset(batch, offsets) for batch in batches]) + + def mapper(a): + keys = table_from_batches(a, parsed_offsets[0][0]) + vals = table_from_batches(a, parsed_offsets[0][1]) + return f(keys, vals) + elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: # We assume there is only one UDF here because grouped map doesn't # support combining multiple UDFs. @@ -1061,6 +1230,32 @@ def mapper(a): df2_vals = [a[1][o] for o in parsed_offsets[1][1]] return f(df1_keys, df1_vals, df2_keys, df2_vals) + elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF: + import pyarrow as pa + + # We assume there is only one UDF here because cogrouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=0) + + parsed_offsets = extract_key_value_indexes(arg_offsets) + + def batch_from_offset(batch, offsets): + return pa.RecordBatch.from_arrays( + arrays=[batch.columns[o] for o in offsets], + names=[batch.schema.names[o] for o in offsets], + ) + + def table_from_batches(batches, offsets): + return pa.Table.from_batches([batch_from_offset(batch, offsets) for batch in batches]) + + def mapper(a): + df1_keys = table_from_batches(a[0], parsed_offsets[0][0]) + df1_vals = table_from_batches(a[0], parsed_offsets[0][1]) + df2_keys = table_from_batches(a[1], parsed_offsets[1][0]) + df2_vals = table_from_batches(a[1], parsed_offsets[1][1]) + return f(df1_keys, df1_vals, df2_keys, df2_vals) + else: udfs = [] for i in range(num_udfs): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala index 593ed619cb32b..f5930c5272a2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -46,6 +46,29 @@ case class FlatMapGroupsInPandas( copy(child = newChild) } +/** + * FlatMap groups using a udf: iter(pyarrow.RecordBatch) -> iter(pyarrow.RecordBatch). + * This is used by DataFrame.groupby().applyInArrow(). + */ +case class FlatMapGroupsInArrow( + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + /** + * This is needed because output attributes are considered `references` when + * passed through the constructor. + * + * Without this, catalyst will complain that output attributes are missing + * from the input. + */ + override val producedAttributes = AttributeSet(output) + + override protected def withNewChildInternal(newChild: LogicalPlan): FlatMapGroupsInArrow = + copy(child = newChild) +} + /** * Map partitions using a udf: iter(pandas.Dataframe) -> iter(pandas.DataFrame). * This is used by DataFrame.mapInPandas() @@ -135,6 +158,31 @@ case class FlatMapGroupsInPandasWithState( newChild: LogicalPlan): FlatMapGroupsInPandasWithState = copy(child = newChild) } +/** + * Flatmap cogroups using a udf: iter(pyarrow.RecordBatch) -> iter(pyarrow.RecordBatch) + * This is used by DataFrame.groupby().cogroup().applyInArrow(). + */ +case class FlatMapCoGroupsInArrow( + leftGroupingLen: Int, + rightGroupingLen: Int, + functionExpr: Expression, + output: Seq[Attribute], + left: LogicalPlan, + right: LogicalPlan) extends BinaryNode { + + override val producedAttributes = AttributeSet(output) + override lazy val references: AttributeSet = + AttributeSet(leftAttributes ++ rightAttributes ++ functionExpr.references) -- producedAttributes + + def leftAttributes: Seq[Attribute] = left.output.take(leftGroupingLen) + + def rightAttributes: Seq[Attribute] = right.output.take(rightGroupingLen) + + override protected def withNewChildrenInternal( + newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapCoGroupsInArrow = + copy(left = newLeft, right = newRight) +} + trait BaseEvalPython extends UnaryNode { def udfs: Seq[PythonUDF] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 11327cdf7d1d3..877e54a4f1bd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -552,7 +552,7 @@ class RelationalGroupedDataset protected[sql]( */ private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - "Must pass a grouped map udf") + "Must pass a grouped map pandas udf") require(expr.dataType.isInstanceOf[StructType], s"The returnType of the udf must be a ${StructType.simpleString}") @@ -570,6 +570,38 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + /** + * Applies a grouped vectorized python user-defined function to each group of data. + * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. + * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results + * for all groups are combined into a new [[DataFrame]]. + * + * This function does not support partial aggregation, and requires shuffling all the data in + * the [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ + private[sql] def flatMapGroupsInArrow(expr: PythonUDF): DataFrame = { + require(expr.evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, + "Must pass a grouped map arrow udf") + require(expr.dataType.isInstanceOf[StructType], + s"The returnType of the udf must be a ${StructType.simpleString}") + + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val child = df.logicalPlan + val project = df.sparkSession.sessionState.executePlan( + Project(groupingNamedExpressions ++ child.output, child)).analyzed + val groupingAttributes = project.output.take(groupingNamedExpressions.length) + val output = toAttributes(expr.dataType.asInstanceOf[StructType]) + val plan = FlatMapGroupsInArrow(groupingAttributes, expr, output, project) + + Dataset.ofRows(df.sparkSession, plan) + } + /** * Applies a vectorized python user-defined function to each cogrouped data. * The user-defined function defines a transformation: @@ -584,7 +616,7 @@ class RelationalGroupedDataset protected[sql]( r: RelationalGroupedDataset, expr: PythonUDF): DataFrame = { require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, - "Must pass a cogrouped map udf") + "Must pass a cogrouped map pandas udf") require(this.groupingExprs.length == r.groupingExprs.length, "Cogroup keys must have same size: " + s"${this.groupingExprs.length} != ${r.groupingExprs.length}") @@ -616,6 +648,52 @@ class RelationalGroupedDataset protected[sql]( Dataset.ofRows(df.sparkSession, plan) } + /** + * Applies a vectorized python user-defined function to each cogrouped data. + * The user-defined function defines a transformation: + * `pandas.DataFrame`, `pandas.DataFrame` -> `pandas.DataFrame`. + * For each group in the cogrouped data, all elements in the group are passed as a + * `pandas.DataFrame` and the results for all cogroups are combined into a new [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ + private[sql] def flatMapCoGroupsInArrow( + r: RelationalGroupedDataset, + expr: PythonUDF): DataFrame = { + require(expr.evalType == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, + "Must pass a cogrouped map arrow udf") + require(this.groupingExprs.length == r.groupingExprs.length, + "Cogroup keys must have same size: " + + s"${this.groupingExprs.length} != ${r.groupingExprs.length}") + require(expr.dataType.isInstanceOf[StructType], + s"The returnType of the udf must be a ${StructType.simpleString}") + + val leftGroupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val rightGroupingNamedExpressions = r.groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + + val leftChild = df.logicalPlan + val rightChild = r.df.logicalPlan + + val left = df.sparkSession.sessionState.executePlan( + Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)).analyzed + val right = r.df.sparkSession.sessionState.executePlan( + Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)).analyzed + + val output = toAttributes(expr.dataType.asInstanceOf[StructType]) + val plan = FlatMapCoGroupsInArrow( + leftGroupingNamedExpressions.length, rightGroupingNamedExpressions.length, + expr, output, left, right) + Dataset.ofRows(df.sparkSession, plan) + } + /** * Applies a grouped vectorized python user-defined function to each group of data. * The user-defined function defines a transformation: iterator of `pandas.DataFrame` -> diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d851eacd5ab92..983c6a653d4dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -825,10 +825,16 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { f, p, b, is, ot, planLater(child)) :: Nil case logical.FlatMapGroupsInPandas(grouping, func, output, child) => execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil + case logical.FlatMapGroupsInArrow(grouping, func, output, child) => + execution.python.FlatMapGroupsInArrowExec(grouping, func, output, planLater(child)) :: Nil case f @ logical.FlatMapCoGroupsInPandas(_, _, func, output, left, right) => execution.python.FlatMapCoGroupsInPandasExec( f.leftAttributes, f.rightAttributes, func, output, planLater(left), planLater(right)) :: Nil + case f @ logical.FlatMapCoGroupsInArrow(_, _, func, output, left, right) => + execution.python.FlatMapCoGroupsInArrowExec( + f.leftAttributes, f.rightAttributes, + func, output, planLater(left), planLater(right)) :: Nil case logical.MapInPandas(func, output, child, isBarrier) => execution.python.MapInPandasExec(func, output, planLater(child), isBarrier) :: Nil case logical.PythonMapInArrow(func, output, child, isBarrier) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala new file mode 100644 index 0000000000000..17c68a86b7592 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInArrowExec.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan + + +/** + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapCoGroupsInPandas]] + * + * The input dataframes are first Cogrouped. Rows from each side of the cogroup are passed to the + * Python worker via Arrow. As each side of the cogroup may have a different schema we send every + * group in its own Arrow stream. + * The Python worker turns the resulting record batches to `pandas.DataFrame`s, invokes the + * user-defined function, and passes the resulting `pandas.DataFrame` + * as an Arrow record batch. Finally, each record batch is turned to + * Iterator[InternalRow] using ColumnarBatch. + * + * Note on memory usage: + * Both the Python worker and the Java executor need to have enough memory to + * hold the largest cogroup. The memory on the Java side is used to construct the + * record batches (off heap memory). The memory on the Python side is used for + * holding the `pandas.DataFrame`. It's possible to further split one group into + * multiple record batches to reduce the memory footprint on the Java side, this + * is left as future work. + */ +case class FlatMapCoGroupsInArrowExec( + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) + extends FlatMapCoGroupsInPythonExec { + + protected val pythonEvalType: Int = PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): FlatMapCoGroupsInArrowExec = + copy(left = newLeft, right = newRight) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala index bbfe97d194778..32d7748bcaac5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala @@ -17,15 +17,9 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.JobArtifactSet -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} -import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan} -import org.apache.spark.sql.execution.python.PandasGroupUtils._ +import org.apache.spark.sql.execution.SparkPlan /** @@ -54,57 +48,9 @@ case class FlatMapCoGroupsInPandasExec( output: Seq[Attribute], left: SparkPlan, right: SparkPlan) - extends SparkPlan with BinaryExecNode with PythonSQLMetrics { + extends FlatMapCoGroupsInPythonExec { - private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - private val pandasFunction = func.asInstanceOf[PythonUDF].func - private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - - override def producedAttributes: AttributeSet = AttributeSet(output) - - override def outputPartitioning: Partitioning = left.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = { - val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup) - val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup) - leftDist :: rightDist :: Nil - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = { - leftGroup - .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil - } - - override protected def doExecute(): RDD[InternalRow] = { - val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) - val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) - val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) - - // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty - left.execute().zipPartitions(right.execute()) { (leftData, rightData) => - if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else { - - val leftGrouped = groupAndProject(leftData, leftGroup, left.output, leftDedup) - val rightGrouped = groupAndProject(rightData, rightGroup, right.output, rightDedup) - val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) - .map { case (_, l, r) => (l, r) } - - val runner = new CoGroupedArrowPythonRunner( - chainedFunc, - PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, - Array(leftArgOffsets ++ rightArgOffsets), - DataTypeUtils.fromAttributes(leftDedup), - DataTypeUtils.fromAttributes(rightDedup), - sessionLocalTimeZone, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID) - - executePython(data, output, runner) - } - } - } + protected val pythonEvalType: Int = PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF override protected def withNewChildrenInternal( newLeft: SparkPlan, newRight: SparkPlan): FlatMapCoGroupsInPandasExec = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPythonExec.scala new file mode 100644 index 0000000000000..f75b0019f1065 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPythonExec.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.JobArtifactSet +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.python.PandasGroupUtils._ + + +/** + * Base class for Python-based FlatMapCoGroupsIn*Exec. + */ +trait FlatMapCoGroupsInPythonExec extends SparkPlan with BinaryExecNode with PythonSQLMetrics { + val leftGroup: Seq[Attribute] + val rightGroup: Seq[Attribute] + val func: Expression + val output: Seq[Attribute] + val left: SparkPlan + val right: SparkPlan + + protected val pythonEvalType: Int + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) + private val pandasFunction = func.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup) + val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup) + leftDist :: rightDist :: Nil + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = { + leftGroup + .map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + } + + override protected def doExecute(): RDD[InternalRow] = { + val (leftDedup, leftArgOffsets) = resolveArgOffsets(left.output, leftGroup) + val (rightDedup, rightArgOffsets) = resolveArgOffsets(right.output, rightGroup) + val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + // Map cogrouped rows to ArrowPythonRunner results, Only execute if partition is not empty + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + if (leftData.isEmpty && rightData.isEmpty) Iterator.empty else { + + val leftGrouped = groupAndProject(leftData, leftGroup, left.output, leftDedup) + val rightGrouped = groupAndProject(rightData, rightGroup, right.output, rightDedup) + val data = new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup) + .map { case (_, l, r) => (l, r) } + + val runner = new CoGroupedArrowPythonRunner( + chainedFunc, + pythonEvalType, + Array(leftArgOffsets ++ rightArgOffsets), + DataTypeUtils.fromAttributes(leftDedup), + DataTypeUtils.fromAttributes(rightDedup), + sessionLocalTimeZone, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID) + + executePython(data, output, runner) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala new file mode 100644 index 0000000000000..b0dd800af8f74 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInArrowExec.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python.PythonEvalType +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{StructField, StructType} + + +/** + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] + * + * Rows in each group are passed to the Python worker as an Arrow record batch. + * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the + * user-defined function, and passes the resulting `pandas.DataFrame` + * as an Arrow record batch. Finally, each record batch is turned to + * Iterator[InternalRow] using ColumnarBatch. + * + * Note on memory usage: + * Both the Python worker and the Java executor need to have enough memory to + * hold the largest group. The memory on the Java side is used to construct the + * record batch (off heap memory). The memory on the Python side is used for + * holding the `pandas.DataFrame`. It's possible to further split one group into + * multiple record batches to reduce the memory footprint on the Java side, this + * is left as future work. + */ +case class FlatMapGroupsInArrowExec( + groupingAttributes: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + child: SparkPlan) + extends FlatMapGroupsInPythonExec { + + protected val pythonEvalType: Int = PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF + + override protected def groupedData(iter: Iterator[InternalRow], attrs: Seq[Attribute]): + Iterator[Iterator[InternalRow]] = + super.groupedData(iter, attrs) + // Here we wrap it via another row so that Python sides understand it as a DataFrame. + .map(_.map(InternalRow(_))) + + override protected def groupedSchema(attrs: Seq[Attribute]): StructType = + StructType(StructField("struct", super.groupedSchema(attrs)) :: Nil) + + override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInArrowExec = + copy(child = newChild) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index f2d21ce8e9646..8874789972048 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -17,15 +17,9 @@ package org.apache.spark.sql.execution.python -import org.apache.spark.JobArtifactSet -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} -import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.python.PandasGroupUtils._ +import org.apache.spark.sql.execution.SparkPlan /** @@ -50,55 +44,9 @@ case class FlatMapGroupsInPandasExec( func: Expression, output: Seq[Attribute], child: SparkPlan) - extends SparkPlan with UnaryExecNode with PythonSQLMetrics { + extends FlatMapGroupsInPythonExec { - private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val largeVarTypes = conf.arrowUseLargeVarTypes - private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - private val pandasFunction = func.asInstanceOf[PythonUDF].func - private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) - - override def producedAttributes: AttributeSet = AttributeSet(output) - - override def outputPartitioning: Partitioning = child.outputPartitioning - - override def requiredChildDistribution: Seq[Distribution] = { - if (groupingAttributes.isEmpty) { - AllTuples :: Nil - } else { - ClusteredDistribution(groupingAttributes) :: Nil - } - } - - override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) - - override protected def doExecute(): RDD[InternalRow] = { - val inputRDD = child.execute() - - val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, groupingAttributes) - - // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty - inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { - - val data = groupAndProject(iter, groupingAttributes, child.output, dedupAttributes) - .map { case (_, x) => x } - - val runner = new ArrowPythonRunner( - chainedFunc, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - Array(argOffsets), - DataTypeUtils.fromAttributes(dedupAttributes), - sessionLocalTimeZone, - largeVarTypes, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID) - - executePython(data, output, runner) - }} - } + protected val pythonEvalType: Int = PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF override protected def withNewChildInternal(newChild: SparkPlan): FlatMapGroupsInPandasExec = copy(child = newChild) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala new file mode 100644 index 0000000000000..0c18206a825aa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import org.apache.spark.JobArtifactSet +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.python.PandasGroupUtils._ +import org.apache.spark.sql.types.StructType + + +/** + * Base class for Python-based FlatMapGroupsIn*Exec. + */ +trait FlatMapGroupsInPythonExec extends SparkPlan with UnaryExecNode with PythonSQLMetrics { + val groupingAttributes: Seq[Attribute] + val func: Expression + val output: Seq[Attribute] + val child: SparkPlan + + protected val pythonEvalType: Int + + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val largeVarTypes = conf.arrowUseLargeVarTypes + private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) + private val pythonFunction = func.asInstanceOf[PythonUDF].func + private val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction))) + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + protected def groupedData(iter: Iterator[InternalRow], attrs: Seq[Attribute]): + Iterator[Iterator[InternalRow]] = + groupAndProject(iter, groupingAttributes, child.output, attrs) + .map { case (_, x) => x } + + protected def groupedSchema(attrs: Seq[Attribute]): StructType = + DataTypeUtils.fromAttributes(attrs) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val (dedupAttributes, argOffsets) = resolveArgOffsets(child.output, groupingAttributes) + + // Map grouped rows to ArrowPythonRunner results, Only execute if partition is not empty + inputRDD.mapPartitionsInternal { iter => if (iter.isEmpty) iter else { + + val data = groupedData(iter, dedupAttributes) + + val runner = new ArrowPythonRunner( + chainedFunc, + pythonEvalType, + Array(argOffsets), + groupedSchema(dedupAttributes), + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID) + + executePython(data, output, runner) + }} + } +}