From fc56bb53910ebe8050a2b8e3a47aab77fd3d2cc3 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 4 Dec 2023 16:23:55 +0900 Subject: [PATCH] [SPARK-46229][PYTHON][CONNECT] Add applyInArrow to groupBy and cogroup in Spark Connect ### What changes were proposed in this pull request? This PR implements Spark Connect version of https://github.com/apache/spark/pull/38624. ### Why are the changes needed? For feature parity. ### Does this PR introduce _any_ user-facing change? Yes, it adds a new API for Python Spark Connect client. ### How was this patch tested? Reused unittest and doctests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #44146 from HyukjinKwon/connect-arrow-api. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- .../connect/planner/SparkConnectPlanner.scala | 29 ++++++-- dev/sparktestsupport/modules.py | 6 +- python/pyspark/sql/connect/_typing.py | 12 +++- python/pyspark/sql/connect/group.py | 53 +++++++++++++- .../test_parity_arrow_cogrouped_map.py | 39 ++++++++++ .../test_parity_arrow_grouped_map.py} | 23 ++++++ .../{arrow => }/test_arrow_cogrouped_map.py | 71 ++++++++++--------- .../{arrow => }/test_arrow_grouped_map.py | 61 ++++++++-------- 8 files changed, 223 insertions(+), 71 deletions(-) create mode 100644 python/pyspark/sql/tests/connect/test_parity_arrow_cogrouped_map.py rename python/pyspark/sql/tests/{arrow/__init__.py => connect/test_parity_arrow_grouped_map.py} (52%) rename python/pyspark/sql/tests/{arrow => }/test_arrow_cogrouped_map.py (92%) rename python/pyspark/sql/tests/{arrow => }/test_arrow_grouped_map.py (96%) diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index e2b4a3c782ecf..b64fecafa311a 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -595,12 +595,21 @@ class SparkConnectPlanner( val cols = rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr))) - - Dataset + val group = Dataset .ofRows(session, transformRelation(rel.getInput)) .groupBy(cols: _*) - .flatMapGroupsInPandas(pythonUdf) - .logicalPlan + + pythonUdf.evalType match { + case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF => + group.flatMapGroupsInPandas(pythonUdf).logicalPlan + + case PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF => + group.flatMapGroupsInArrow(pythonUdf).logicalPlan + + case _ => + throw InvalidPlanInput( + s"Function with EvalType: ${pythonUdf.evalType} is not supported") + } case _ => throw InvalidPlanInput( @@ -718,7 +727,17 @@ class SparkConnectPlanner( .ofRows(session, transformRelation(rel.getOther)) .groupBy(otherCols: _*) - input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan + pythonUdf.evalType match { + case PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF => + input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan + + case PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF => + input.flatMapCoGroupsInArrow(other, pythonUdf).logicalPlan + + case _ => + throw InvalidPlanInput( + s"Function with EvalType: ${pythonUdf.evalType} is not supported") + } case _ => throw InvalidPlanInput( diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 718a25097412f..8995b7de0df93 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -491,9 +491,9 @@ def __hash__(self): "pyspark.sql.pandas.utils", "pyspark.sql.observation", # unittests - "pyspark.sql.tests.arrow.test_arrow_cogrouped_map", - "pyspark.sql.tests.arrow.test_arrow_grouped_map", "pyspark.sql.tests.test_arrow", + "pyspark.sql.tests.test_arrow_cogrouped_map", + "pyspark.sql.tests.test_arrow_grouped_map", "pyspark.sql.tests.test_arrow_python_udf", "pyspark.sql.tests.test_catalog", "pyspark.sql.tests.test_column", @@ -894,6 +894,8 @@ def __hash__(self): "pyspark.sql.tests.connect.test_parity_arrow_map", "pyspark.sql.tests.connect.test_parity_pandas_grouped_map", "pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map", + "pyspark.sql.tests.connect.test_parity_arrow_grouped_map", + "pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map", "pyspark.sql.tests.connect.test_utils", "pyspark.sql.tests.connect.client.test_artifact", "pyspark.sql.tests.connect.client.test_client", diff --git a/python/pyspark/sql/connect/_typing.py b/python/pyspark/sql/connect/_typing.py index 471af24f40dc8..392c62bf50d37 100644 --- a/python/pyspark/sql/connect/_typing.py +++ b/python/pyspark/sql/connect/_typing.py @@ -14,14 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # - import sys if sys.version_info >= (3, 8): - from typing import Protocol + from typing import Protocol, Tuple else: from typing_extensions import Protocol +from typing import Tuple from types import FunctionType from typing import Any, Callable, Iterable, Union, Optional, NewType import datetime @@ -69,6 +69,14 @@ PandasGroupedMapFunctionWithState = Callable[ [Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike] ] +ArrowGroupedMapFunction = Union[ + Callable[[pyarrow.Table], pyarrow.Table], + Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table], +] +ArrowCogroupedMapFunction = Union[ + Callable[[pyarrow.Table, pyarrow.Table], pyarrow.Table], + Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table, pyarrow.Table], pyarrow.Table], +] class UserDefinedFunctionLike(Protocol): diff --git a/python/pyspark/sql/connect/group.py b/python/pyspark/sql/connect/group.py index bb963c910e2f5..2ccd7463b9e0d 100644 --- a/python/pyspark/sql/connect/group.py +++ b/python/pyspark/sql/connect/group.py @@ -14,12 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import warnings from pyspark.sql.connect.utils import check_dependencies check_dependencies(__name__) +import warnings from typing import ( Any, Dict, @@ -49,6 +49,8 @@ PandasGroupedMapFunction, GroupedMapPandasUserDefinedFunction, PandasCogroupedMapFunction, + ArrowCogroupedMapFunction, + ArrowGroupedMapFunction, PandasGroupedMapFunctionWithState, ) from pyspark.sql.connect.dataframe import DataFrame @@ -353,6 +355,30 @@ def applyInPandasWithState( applyInPandasWithState.__doc__ = PySparkGroupedData.applyInPandasWithState.__doc__ + def applyInArrow( + self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str] + ) -> "DataFrame": + from pyspark.sql.connect.udf import UserDefinedFunction + from pyspark.sql.connect.dataframe import DataFrame + + udf_obj = UserDefinedFunction( + func, + returnType=schema, + evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF, + ) + + return DataFrame( + plan.GroupMap( + child=self._df._plan, + grouping_cols=self._grouping_cols, + function=udf_obj, + cols=self._df.columns, + ), + session=self._df._session, + ) + + applyInArrow.__doc__ = PySparkGroupedData.applyInArrow.__doc__ + def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps": return PandasCogroupedOps(self, other) @@ -393,6 +419,31 @@ def applyInPandas( applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__ + def applyInArrow( + self, func: "ArrowCogroupedMapFunction", schema: Union[StructType, str] + ) -> "DataFrame": + from pyspark.sql.connect.udf import UserDefinedFunction + from pyspark.sql.connect.dataframe import DataFrame + + udf_obj = UserDefinedFunction( + func, + returnType=schema, + evalType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF, + ) + + return DataFrame( + plan.CoGroupMap( + input=self._gd1._df._plan, + input_grouping_cols=self._gd1._grouping_cols, + other=self._gd2._df._plan, + other_grouping_cols=self._gd2._grouping_cols, + function=udf_obj, + ), + session=self._gd1._df._session, + ) + + applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__ + @staticmethod def _extract_cols(gd: "GroupedData") -> List[Column]: df = gd._df diff --git a/python/pyspark/sql/tests/connect/test_parity_arrow_cogrouped_map.py b/python/pyspark/sql/tests/connect/test_parity_arrow_cogrouped_map.py new file mode 100644 index 0000000000000..fda2efa821e90 --- /dev/null +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_cogrouped_map.py @@ -0,0 +1,39 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import contextlib +import unittest + +from pyspark.sql.tests.test_arrow_cogrouped_map import CogroupedMapInArrowTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class CogroupedMapInArrowParityTests(CogroupedMapInArrowTestsMixin, ReusedConnectTestCase): + def quiet_test(self): + # No-op + return contextlib.nullcontext() + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/arrow/__init__.py b/python/pyspark/sql/tests/connect/test_parity_arrow_grouped_map.py similarity index 52% rename from python/pyspark/sql/tests/arrow/__init__.py rename to python/pyspark/sql/tests/connect/test_parity_arrow_grouped_map.py index cce3acad34a49..3687faeee7c1a 100644 --- a/python/pyspark/sql/tests/arrow/__init__.py +++ b/python/pyspark/sql/tests/connect/test_parity_arrow_grouped_map.py @@ -14,3 +14,26 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import contextlib +import unittest + +from pyspark.sql.tests.test_arrow_grouped_map import GroupedMapInArrowTestsMixin +from pyspark.testing.connectutils import ReusedConnectTestCase + + +class GroupedApplyInArrowParityTests(GroupedMapInArrowTestsMixin, ReusedConnectTestCase): + def quiet_test(self): + # No-op + return contextlib.nullcontext() + + +if __name__ == "__main__": + from pyspark.sql.tests.connect.test_parity_arrow_grouped_map import * # noqa: F401 + + try: + import xmlrunner + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py b/python/pyspark/sql/tests/test_arrow_cogrouped_map.py similarity index 92% rename from python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py rename to python/pyspark/sql/tests/test_arrow_cogrouped_map.py index 0206d4c2c6ded..406ccfc30d87e 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_arrow_cogrouped_map.py @@ -38,7 +38,7 @@ not have_pyarrow, pyarrow_requirement_message, # type: ignore[arg-type] ) -class CogroupedMapInArrowTests(ReusedSQLTestCase): +class CogroupedMapInArrowTestsMixin: @property def left(self): return self.spark.range(0, 10, 2, 3).withColumn("v", col("id") * 10) @@ -53,27 +53,6 @@ def cogrouped(self): grouped_right_df = self.right.groupBy((col("id") / 4).cast("int")) return grouped_left_df.cogroup(grouped_right_df) - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - - # Synchronize default timezone between Python and Java - cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" - os.environ["TZ"] = tz - time.tzset() - - cls.sc.environment["TZ"] = tz - cls.spark.conf.set("spark.sql.session.timeZone", tz) - - @classmethod - def tearDownClass(cls): - del os.environ["TZ"] - if cls.tz_prev is not None: - os.environ["TZ"] = cls.tz_prev - time.tzset() - ReusedSQLTestCase.tearDownClass() - @staticmethod def apply_in_arrow_func(left, right): assert isinstance(left, pa.Table) @@ -101,14 +80,14 @@ def func(key, left, right): for table in [left, right] for k in table.column(key_column) ) - return CogroupedMapInArrowTests.apply_in_arrow_func(left, right) + return CogroupedMapInArrowTestsMixin.apply_in_arrow_func(left, right) return func @staticmethod def apply_in_pandas_with_key_func(key_column): def func(key, left, right): - return CogroupedMapInArrowTests.apply_in_arrow_with_key_func(key_column)( + return CogroupedMapInArrowTestsMixin.apply_in_arrow_with_key_func(key_column)( tuple(pa.scalar(k) for k in key), pa.Table.from_pandas(left), pa.Table.from_pandas(right), @@ -121,18 +100,18 @@ def do_test_apply_in_arrow(self, cogrouped_df, key_column="id"): # compare with result of applyInPandas expected = cogrouped_df.applyInPandas( - CogroupedMapInArrowTests.apply_in_pandas_with_key_func(key_column), schema + CogroupedMapInArrowTestsMixin.apply_in_pandas_with_key_func(key_column), schema ) # apply in arrow without key actual = cogrouped_df.applyInArrow( - CogroupedMapInArrowTests.apply_in_arrow_func, schema + CogroupedMapInArrowTestsMixin.apply_in_arrow_func, schema ).collect() self.assertEqual(actual, expected.collect()) # apply in arrow with key actual2 = cogrouped_df.applyInArrow( - CogroupedMapInArrowTests.apply_in_arrow_with_key_func(key_column), schema + CogroupedMapInArrowTestsMixin.apply_in_arrow_with_key_func(key_column), schema ).collect() self.assertEqual(actual2, expected.collect()) @@ -149,7 +128,7 @@ def test_apply_in_arrow_not_returning_arrow_table(self): def func(key, left, right): return key - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, "Return type of the user-defined function should be pyarrow.Table, but is tuple", @@ -168,7 +147,7 @@ def test_apply_in_arrow_returning_wrong_types(self): ("id long, v string", "column 'v' \\(expected string, actual int64\\)"), ]: with self.subTest(schema=schema): - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, f"Columns do not match in their data type: {expected}", @@ -192,7 +171,7 @@ def test_apply_in_arrow_returning_wrong_types_positional_assignment(self): with self.sql_conf( {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} ): - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, f"Columns do not match in their data type: {expected}", @@ -212,7 +191,7 @@ def stats(key, left, right): } ) - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, "Column names of the returned pyarrow.Table do not match specified schema. " @@ -248,7 +227,7 @@ def odd_means(key, left, _): {"id": [key[0].as_py()], "m": [pc.mean(left.column("v")).as_py()]} ) - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, "Column names of the returned pyarrow.Table do not match specified schema. " @@ -288,8 +267,34 @@ def foo(left, right): self.assertEqual(r.b, 1) +class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + def quiet_test(self): + return QuietTest(self.sc) + + if __name__ == "__main__": - from pyspark.sql.tests.arrow.test_arrow_cogrouped_map import * # noqa: F401 + from pyspark.sql.tests.test_arrow_cogrouped_map import * # noqa: F401 try: import xmlrunner # type: ignore[import] diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py b/python/pyspark/sql/tests/test_arrow_grouped_map.py similarity index 96% rename from python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py rename to python/pyspark/sql/tests/test_arrow_grouped_map.py index fa43648d42dcc..d46154c104c07 100644 --- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py +++ b/python/pyspark/sql/tests/test_arrow_grouped_map.py @@ -39,7 +39,7 @@ not have_pyarrow, pyarrow_requirement_message, # type: ignore[arg-type] ) -class GroupedMapInArrowTests(ReusedSQLTestCase): +class GroupedMapInArrowTestsMixin: @property def data(self): return ( @@ -50,27 +50,6 @@ def data(self): .drop("vs") ) - @classmethod - def setUpClass(cls): - ReusedSQLTestCase.setUpClass() - - # Synchronize default timezone between Python and Java - cls.tz_prev = os.environ.get("TZ", None) # save current tz if set - tz = "America/Los_Angeles" - os.environ["TZ"] = tz - time.tzset() - - cls.sc.environment["TZ"] = tz - cls.spark.conf.set("spark.sql.session.timeZone", tz) - - @classmethod - def tearDownClass(cls): - del os.environ["TZ"] - if cls.tz_prev is not None: - os.environ["TZ"] = cls.tz_prev - time.tzset() - ReusedSQLTestCase.tearDownClass() - def test_apply_in_arrow(self): def func(group): assert isinstance(group, pa.Table) @@ -132,7 +111,7 @@ def test_apply_in_arrow_not_returning_arrow_table(self): def stats(key, _): return key - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, "Return type of the user-defined function should be pyarrow.Table, but is tuple", @@ -153,7 +132,7 @@ def test_apply_in_arrow_returning_wrong_types(self): ("id long, v string", "column 'v' \\(expected string, actual int32\\)"), ]: with self.subTest(schema=schema): - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, f"Columns do not match in their data type: {expected}", @@ -177,7 +156,7 @@ def test_apply_in_arrow_returning_wrong_types_positional_assignment(self): with self.sql_conf( {"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False} ): - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, f"Columns do not match in their data type: {expected}", @@ -199,7 +178,7 @@ def stats(key, table): } ) - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, "Column names of the returned pyarrow.Table do not match specified schema. " @@ -235,7 +214,7 @@ def odd_means(key, table): {"id": [key[0].as_py()], "m": [pc.mean(table.column("v")).as_py()]} ) - with QuietTest(self.sc): + with self.quiet_test(): with self.assertRaisesRegex( PythonException, "Column names of the returned pyarrow.Table do not match specified schema. " @@ -279,8 +258,34 @@ def foo(_): self.assertEqual(r.b, 1) +class GroupedMapInArrowTests(GroupedMapInArrowTestsMixin, ReusedSQLTestCase): + @classmethod + def setUpClass(cls): + ReusedSQLTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + + cls.sc.environment["TZ"] = tz + cls.spark.conf.set("spark.sql.session.timeZone", tz) + + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev + time.tzset() + ReusedSQLTestCase.tearDownClass() + + def quiet_test(self): + return QuietTest(self.sc) + + if __name__ == "__main__": - from pyspark.sql.tests.arrow.test_arrow_grouped_map import * # noqa: F401 + from pyspark.sql.tests.test_arrow_grouped_map import * # noqa: F401 try: import xmlrunner # type: ignore[import]