Skip to content

Commit

Permalink
[SPARK-46229][PYTHON][CONNECT] Add applyInArrow to groupBy and cogrou…
Browse files Browse the repository at this point in the history
…p in Spark Connect

### What changes were proposed in this pull request?

This PR implements Spark Connect version of #38624.

### Why are the changes needed?

For feature parity.

### Does this PR introduce _any_ user-facing change?

Yes, it adds a new API for Python Spark Connect client.

### How was this patch tested?

Reused unittest and doctests.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #44146 from HyukjinKwon/connect-arrow-api.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Dec 4, 2023
1 parent b23ae15 commit fc56bb5
Show file tree
Hide file tree
Showing 8 changed files with 223 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -595,12 +595,21 @@ class SparkConnectPlanner(
val cols =
rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
Column(transformExpression(expr)))

Dataset
val group = Dataset
.ofRows(session, transformRelation(rel.getInput))
.groupBy(cols: _*)
.flatMapGroupsInPandas(pythonUdf)
.logicalPlan

pythonUdf.evalType match {
case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF =>
group.flatMapGroupsInPandas(pythonUdf).logicalPlan

case PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF =>
group.flatMapGroupsInArrow(pythonUdf).logicalPlan

case _ =>
throw InvalidPlanInput(
s"Function with EvalType: ${pythonUdf.evalType} is not supported")
}

case _ =>
throw InvalidPlanInput(
Expand Down Expand Up @@ -718,7 +727,17 @@ class SparkConnectPlanner(
.ofRows(session, transformRelation(rel.getOther))
.groupBy(otherCols: _*)

input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
pythonUdf.evalType match {
case PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF =>
input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan

case PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF =>
input.flatMapCoGroupsInArrow(other, pythonUdf).logicalPlan

case _ =>
throw InvalidPlanInput(
s"Function with EvalType: ${pythonUdf.evalType} is not supported")
}

case _ =>
throw InvalidPlanInput(
Expand Down
6 changes: 4 additions & 2 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,9 +491,9 @@ def __hash__(self):
"pyspark.sql.pandas.utils",
"pyspark.sql.observation",
# unittests
"pyspark.sql.tests.arrow.test_arrow_cogrouped_map",
"pyspark.sql.tests.arrow.test_arrow_grouped_map",
"pyspark.sql.tests.test_arrow",
"pyspark.sql.tests.test_arrow_cogrouped_map",
"pyspark.sql.tests.test_arrow_grouped_map",
"pyspark.sql.tests.test_arrow_python_udf",
"pyspark.sql.tests.test_catalog",
"pyspark.sql.tests.test_column",
Expand Down Expand Up @@ -894,6 +894,8 @@ def __hash__(self):
"pyspark.sql.tests.connect.test_parity_arrow_map",
"pyspark.sql.tests.connect.test_parity_pandas_grouped_map",
"pyspark.sql.tests.connect.test_parity_pandas_cogrouped_map",
"pyspark.sql.tests.connect.test_parity_arrow_grouped_map",
"pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map",
"pyspark.sql.tests.connect.test_utils",
"pyspark.sql.tests.connect.client.test_artifact",
"pyspark.sql.tests.connect.client.test_client",
Expand Down
12 changes: 10 additions & 2 deletions python/pyspark/sql/connect/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import sys

if sys.version_info >= (3, 8):
from typing import Protocol
from typing import Protocol, Tuple
else:
from typing_extensions import Protocol

from typing import Tuple
from types import FunctionType
from typing import Any, Callable, Iterable, Union, Optional, NewType
import datetime
Expand Down Expand Up @@ -69,6 +69,14 @@
PandasGroupedMapFunctionWithState = Callable[
[Any, Iterable[DataFrameLike], GroupState], Iterable[DataFrameLike]
]
ArrowGroupedMapFunction = Union[
Callable[[pyarrow.Table], pyarrow.Table],
Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table], pyarrow.Table],
]
ArrowCogroupedMapFunction = Union[
Callable[[pyarrow.Table, pyarrow.Table], pyarrow.Table],
Callable[[Tuple[pyarrow.Scalar, ...], pyarrow.Table, pyarrow.Table], pyarrow.Table],
]


class UserDefinedFunctionLike(Protocol):
Expand Down
53 changes: 52 additions & 1 deletion python/pyspark/sql/connect/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import warnings

from pyspark.sql.connect.utils import check_dependencies

check_dependencies(__name__)

import warnings
from typing import (
Any,
Dict,
Expand Down Expand Up @@ -49,6 +49,8 @@
PandasGroupedMapFunction,
GroupedMapPandasUserDefinedFunction,
PandasCogroupedMapFunction,
ArrowCogroupedMapFunction,
ArrowGroupedMapFunction,
PandasGroupedMapFunctionWithState,
)
from pyspark.sql.connect.dataframe import DataFrame
Expand Down Expand Up @@ -353,6 +355,30 @@ def applyInPandasWithState(

applyInPandasWithState.__doc__ = PySparkGroupedData.applyInPandasWithState.__doc__

def applyInArrow(
self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame

udf_obj = UserDefinedFunction(
func,
returnType=schema,
evalType=PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
)

return DataFrame(
plan.GroupMap(
child=self._df._plan,
grouping_cols=self._grouping_cols,
function=udf_obj,
cols=self._df.columns,
),
session=self._df._session,
)

applyInArrow.__doc__ = PySparkGroupedData.applyInArrow.__doc__

def cogroup(self, other: "GroupedData") -> "PandasCogroupedOps":
return PandasCogroupedOps(self, other)

Expand Down Expand Up @@ -393,6 +419,31 @@ def applyInPandas(

applyInPandas.__doc__ = PySparkPandasCogroupedOps.applyInPandas.__doc__

def applyInArrow(
self, func: "ArrowCogroupedMapFunction", schema: Union[StructType, str]
) -> "DataFrame":
from pyspark.sql.connect.udf import UserDefinedFunction
from pyspark.sql.connect.dataframe import DataFrame

udf_obj = UserDefinedFunction(
func,
returnType=schema,
evalType=PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
)

return DataFrame(
plan.CoGroupMap(
input=self._gd1._df._plan,
input_grouping_cols=self._gd1._grouping_cols,
other=self._gd2._df._plan,
other_grouping_cols=self._gd2._grouping_cols,
function=udf_obj,
),
session=self._gd1._df._session,
)

applyInArrow.__doc__ = PySparkPandasCogroupedOps.applyInArrow.__doc__

@staticmethod
def _extract_cols(gd: "GroupedData") -> List[Column]:
df = gd._df
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
import unittest

from pyspark.sql.tests.test_arrow_cogrouped_map import CogroupedMapInArrowTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class CogroupedMapInArrowParityTests(CogroupedMapInArrowTestsMixin, ReusedConnectTestCase):
def quiet_test(self):
# No-op
return contextlib.nullcontext()


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow_cogrouped_map import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
import unittest

from pyspark.sql.tests.test_arrow_grouped_map import GroupedMapInArrowTestsMixin
from pyspark.testing.connectutils import ReusedConnectTestCase


class GroupedApplyInArrowParityTests(GroupedMapInArrowTestsMixin, ReusedConnectTestCase):
def quiet_test(self):
# No-op
return contextlib.nullcontext()


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow_grouped_map import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Loading

0 comments on commit fc56bb5

Please sign in to comment.