Skip to content

Commit

Permalink
[SPARK-40559][PYTHON] Add applyInArrow to groupBy and cogroup
Browse files Browse the repository at this point in the history
Add `applyInArrow` method to PySpark `groupBy` and `groupBy.cogroup` to allow for user functions that work on Arrow. Similar to existing `mapInArrow`.
PySpark allows to transform a `DataFrame` via Pandas and Arrow API:
```
df.mapInArrow(map_arrow, schema="...")
df.mapInPandas(map_pandas, schema="...")
```

For `df.groupBy(...)` and `df.groupBy(...).cogroup(...)`, there is only a Pandas interface, no Arrow interface:
```
df.groupBy("id").applyInPandas(apply_pandas, schema="...")
```

Providing a pure Arrow interface allows user code to use **any** Arrow-based data framework, not only Pandas, e.g. Polars:
```
def apply_polars(df: polars.DataFrame) -> polars.DataFrame:
  return df

def apply_arrow(table: pyarrow.Table) -> pyarrow.Table:
  df = polars.from_arrow(table)
  return apply_polars(df).to_arrow()

df.groupBy("id").applyInArrow(apply_arrow, schema="...")
```
This adds method `applyInPandas` to PySpark `groupBy` and `groupBy.cogroup`.
Tested with unit tests.

Closes apache#38624 from EnricoMi/branch-pyspark-grouped-apply-in-arrow.

Authored-by: Enrico Minack <github@enrico.minack.dev>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
EnricoMi authored and Kimahriman committed Apr 18, 2024
1 parent 406181f commit 2f8ffc6
Show file tree
Hide file tree
Showing 22 changed files with 1,670 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ private[spark] object PythonEvalType {
val SQL_COGROUPED_MAP_PANDAS_UDF = 206
val SQL_MAP_ARROW_ITER_UDF = 207
val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
val SQL_GROUPED_MAP_ARROW_UDF = 209
val SQL_COGROUPED_MAP_ARROW_UDF = 210

val SQL_TABLE_UDF = 300
val SQL_ARROW_TABLE_UDF = 301
Expand All @@ -72,6 +74,8 @@ private[spark] object PythonEvalType {
case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
case SQL_GROUPED_MAP_ARROW_UDF => "SQL_GROUPED_MAP_ARROW_UDF"
case SQL_COGROUPED_MAP_ARROW_UDF => "SQL_COGROUPED_MAP_ARROW_UDF"
case SQL_TABLE_UDF => "SQL_TABLE_UDF"
case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF"
}
Expand Down
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ def __hash__(self):
"pyspark.sql.pandas.utils",
"pyspark.sql.observation",
# unittests
"pyspark.sql.tests.arrow.test_arrow_cogrouped_map",
"pyspark.sql.tests.arrow.test_arrow_grouped_map",
"pyspark.sql.tests.test_arrow",
"pyspark.sql.tests.test_arrow_python_udf",
"pyspark.sql.tests.test_catalog",
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,11 @@
"`<arg_name>` should be one the values from PandasUDFType, got <arg_type>"
]
},
"INVALID_RETURN_TYPE_FOR_ARROW_UDF": {
"message": [
"Grouped and Cogrouped map Arrow UDF should return StructType for <eval_type>, got <return_type>."
]
},
"INVALID_RETURN_TYPE_FOR_PANDAS_UDF": {
"message": [
"Pandas UDF should return StructType for <eval_type>, got <return_type>."
Expand Down Expand Up @@ -648,6 +653,11 @@
"transformation. For more information, see SPARK-5063."
]
},
"RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDF" : {
"message" : [
"Column names of the returned pyarrow.Table do not match specified schema.<missing><extra>"
]
},
"RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF" : {
"message" : [
"Column names of the returned pandas.DataFrame do not match specified schema.<missing><extra>"
Expand All @@ -663,6 +673,11 @@
"The length of output in Scalar iterator pandas UDF should be the same with the input's; however, the length of output was <output_length> and the length of input was <input_length>."
]
},
"RESULT_TYPE_MISMATCH_FOR_ARROW_UDF" : {
"message" : [
"Columns do not match in their data type: <mismatch>."
]
},
"SCHEMA_MISMATCH_FOR_PANDAS_UDF" : {
"message" : [
"Result vector from pandas_udf was not the required length: expected <expected>, got <actual>."
Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@
PandasCogroupedMapUDFType,
ArrowMapIterUDFType,
PandasGroupedMapUDFWithStateType,
ArrowGroupedMapUDFType,
ArrowCogroupedMapUDFType,
)
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import AtomicType, StructType
Expand Down Expand Up @@ -158,6 +160,8 @@ class PythonEvalType:
SQL_COGROUPED_MAP_PANDAS_UDF: "PandasCogroupedMapUDFType" = 206
SQL_MAP_ARROW_ITER_UDF: "ArrowMapIterUDFType" = 207
SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: "PandasGroupedMapUDFWithStateType" = 208
SQL_GROUPED_MAP_ARROW_UDF: "ArrowGroupedMapUDFType" = 209
SQL_COGROUPED_MAP_ARROW_UDF: "ArrowCogroupedMapUDFType" = 210

SQL_TABLE_UDF: "SQLTableUDFType" = 300
SQL_ARROW_TABLE_UDF: "SQLArrowTableUDFType" = 301
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ PandasMapIterUDFType = Literal[205]
PandasCogroupedMapUDFType = Literal[206]
ArrowMapIterUDFType = Literal[207]
PandasGroupedMapUDFWithStateType = Literal[208]
ArrowGroupedMapUDFType = Literal[209]
ArrowCogroupedMapUDFType = Literal[210]

class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...
Expand Down Expand Up @@ -341,4 +343,13 @@ PandasCogroupedMapFunction = Union[
Callable[[Any, DataFrameLike, DataFrameLike], DataFrameLike],
]

ArrowGroupedMapFunction = Union[
Callable[[pyarrow.Table], pyarrow.Table],
Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table],
]
ArrowCogroupedMapFunction = Union[
Callable[[pyarrow.Table, pyarrow.Table], pyarrow.Table],
Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table, pyarrow.Table], pyarrow.Table],
]

GroupedMapPandasUserDefinedFunction = NewType("GroupedMapPandasUserDefinedFunction", FunctionType)
22 changes: 22 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
None,
]: # None means it should infer the type from type hints.

Expand Down Expand Up @@ -416,6 +418,8 @@ def _create_pandas_udf(f, returnType, evalType):
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
]:
# In case of 'SQL_GROUPED_MAP_PANDAS_UDF', deprecation warning is being triggered
Expand Down Expand Up @@ -463,6 +467,15 @@ def _create_pandas_udf(f, returnType, evalType):
},
)

if evalType == PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF and len(argspec.args) not in (1, 2):
raise PySparkValueError(
error_class="INVALID_PANDAS_UDF",
message_parameters={
"detail": "the function in groupby.applyInArrow must take either one argument "
"(data) or two arguments (key, data).",
},
)

if evalType == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF and len(argspec.args) not in (2, 3):
raise PySparkValueError(
error_class="INVALID_PANDAS_UDF",
Expand All @@ -472,6 +485,15 @@ def _create_pandas_udf(f, returnType, evalType):
},
)

if evalType == PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF and len(argspec.args) not in (2, 3):
raise PySparkValueError(
error_class="INVALID_PANDAS_UDF",
message_parameters={
"detail": "the function in cogroup.applyInArrow must take either two arguments "
"(left, right) or three arguments (key, left, right).",
},
)

if is_remote():
from pyspark.sql.connect.udf import _create_udf as _create_connect_udf

Expand Down
Loading

0 comments on commit 2f8ffc6

Please sign in to comment.