Skip to content

Commit

Permalink
[SPARK-43440][PYTHON][CONNECT] Support registration of an Arrow-optim…
Browse files Browse the repository at this point in the history
…ized Python UDF

### What changes were proposed in this pull request?
The PR proposes to provide support for the registration of an Arrow-optimized Python UDF in both vanilla PySpark and Spark Connect.

### Why are the changes needed?
Currently, when users register an Arrow-optimized Python UDF, it will be registered as a pickled Python UDF and thus, executed without Arrow optimization.

We should support Arrow-optimized Python UDFs registration and execute them with Arrow optimization.

### Does this PR introduce _any_ user-facing change?
Yes. No API changes, but result differences are expected in some cases.

Previously, a registered Arrow-optimized Python UDF will be executed without Arrow optimization.
Now, it will be executed with Arrow optimization, as shown below.

```sh
>>> df = spark.range(2)
>>> df.createOrReplaceTempView("df")
>>> from pyspark.sql.functions import udf
>>> udf(useArrow=True)
... def f(x):
...     return str(x)
...

>>> spark.udf.register('str_f', f)
<pyspark.sql.udf.UserDefinedFunction object at 0x7fa1980c16a0>

>>> spark.sql("select str_f(id) from df").explain()  # Executed with Arrow optimization
== Physical Plan ==
*(2) Project [pythonUDF0#32 AS f(id)#30]
+- ArrowEvalPython [f(id#27L)#29], [pythonUDF0#32], 101
   +- *(1) Range (0, 2, step=1, splits=16)
```

Enabling or disabling Arrow optimization can produce result differences in some cases - we are working on minimizing the result differences though.

### How was this patch tested?
Unit test.

Closes #41125 from xinrong-meng/registerArrowPythonUDF.

Authored-by: Xinrong Meng <xinrong@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
xinrong-meng authored and HyukjinKwon committed May 15, 2023
1 parent dd4db21 commit 7cd8f90
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 14 deletions.
15 changes: 8 additions & 7 deletions python/pyspark/sql/connect/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,31 +252,32 @@ def register(
f = cast("UserDefinedFunctionLike", f)
if f.evalType not in [
PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
]:
raise PySparkTypeError(
error_class="INVALID_UDF_EVAL_TYPE",
message_parameters={
"eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF"
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
"SQL_GROUPED_AGG_PANDAS_UDF"
},
)
return_udf = f
self.sparkSession._client.register_udf(
f.func, f.returnType, name, f.evalType, f.deterministic
)
return f
else:
if returnType is None:
returnType = StringType()
return_udf = _create_udf(
py_udf = _create_udf(
f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name
)

self.sparkSession._client.register_udf(f, returnType, name)

return return_udf
self.sparkSession._client.register_udf(py_udf.func, returnType, name)
return py_udf

register.__doc__ = PySparkUDFRegistration.register.__doc__

Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def test_register_grouped_map_udf(self):
exception=pe.exception,
error_class="INVALID_UDF_EVAL_TYPE",
message_parameters={
"eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF"
},
)
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests/test_arrow_python_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,24 @@ def test_eval_type(self):
udf(lambda x: str(x), useArrow=False).evalType, PythonEvalType.SQL_BATCHED_UDF
)

def test_register(self):
df = self.spark.range(1).selectExpr(
"array(1, 2, 3) as array",
)
str_repr_func = self.spark.udf.register("str_repr", udf(lambda x: str(x), useArrow=True))

# To verify that Arrow optimization is on
self.assertEquals(
df.selectExpr("str_repr(array) AS str_id").first()[0],
"[1 2 3]", # The input is a NumPy array when the Arrow optimization is on
)

# To verify that a UserDefinedFunction is returned
self.assertListEqual(
df.selectExpr("str_repr(array) AS str_id").collect(),
df.select(str_repr_func("array").alias("str_id")).collect(),
)


class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase):
@classmethod
Expand Down
18 changes: 12 additions & 6 deletions python/pyspark/sql/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,32 +623,38 @@ def register(
f = cast("UserDefinedFunctionLike", f)
if f.evalType not in [
PythonEvalType.SQL_BATCHED_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_UDF,
PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF,
]:
raise PySparkTypeError(
error_class="INVALID_UDF_EVAL_TYPE",
message_parameters={
"eval_type": "SQL_BATCHED_UDF, SQL_SCALAR_PANDAS_UDF, "
"SQL_SCALAR_PANDAS_ITER_UDF or SQL_GROUPED_AGG_PANDAS_UDF"
"eval_type": "SQL_BATCHED_UDF, SQL_ARROW_BATCHED_UDF, "
"SQL_SCALAR_PANDAS_UDF, SQL_SCALAR_PANDAS_ITER_UDF or "
"SQL_GROUPED_AGG_PANDAS_UDF"
},
)
register_udf = _create_udf(
source_udf = _create_udf(
f.func,
returnType=f.returnType,
name=name,
evalType=f.evalType,
deterministic=f.deterministic,
)._unwrapped # type: ignore[attr-defined]
return_udf = f
)
if f.evalType == PythonEvalType.SQL_ARROW_BATCHED_UDF:
register_udf = _create_arrow_py_udf(source_udf)._unwrapped
else:
register_udf = source_udf._unwrapped # type: ignore[attr-defined]
return_udf = register_udf
else:
if returnType is None:
returnType = StringType()
return_udf = _create_udf(
f, returnType=returnType, evalType=PythonEvalType.SQL_BATCHED_UDF, name=name
)
register_udf = return_udf._unwrapped # type: ignore[attr-defined]
register_udf = return_udf._unwrapped
self.sparkSession._jsparkSession.udf().registerPython(name, register_udf._judf)
return return_udf

Expand Down

0 comments on commit 7cd8f90

Please sign in to comment.