Skip to content

Commit

Permalink
[SPARK-21394][SPARK-21432][PYTHON] Reviving callable object/partial f…
Browse files Browse the repository at this point in the history
…unction support in UDF in PySpark

## What changes were proposed in this pull request?

This PR proposes to avoid `__name__` in the tuple naming the attributes assigned directly from the wrapped function to the wrapper function, and use `self._name` (`func.__name__` or `obj.__class__.name__`).

After SPARK-19161, we happened to break callable objects as UDFs in Python as below:

```python
from pyspark.sql import functions

class F(object):
    def __call__(self, x):
        return x

foo = F()
udf = functions.udf(foo)
```

```
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../spark/python/pyspark/sql/functions.py", line 2142, in udf
    return _udf(f=f, returnType=returnType)
  File ".../spark/python/pyspark/sql/functions.py", line 2133, in _udf
    return udf_obj._wrapped()
  File ".../spark/python/pyspark/sql/functions.py", line 2090, in _wrapped
    functools.wraps(self.func)
  File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/functools.py", line 33, in update_wrapper
    setattr(wrapper, attr, getattr(wrapped, attr))
AttributeError: F instance has no attribute '__name__'
```

This worked in Spark 2.1:

```python
from pyspark.sql import functions

class F(object):
    def __call__(self, x):
        return x

foo = F()
udf = functions.udf(foo)
spark.range(1).select(udf("id")).show()
```

```
+-----+
|F(id)|
+-----+
|    0|
+-----+
```

**After**

```python
from pyspark.sql import functions

class F(object):
    def __call__(self, x):
        return x

foo = F()
udf = functions.udf(foo)
spark.range(1).select(udf("id")).show()
```

```
+-----+
|F(id)|
+-----+
|    0|
+-----+
```

_In addition, we also happened to break partial functions as below_:

```python
from pyspark.sql import functions
from functools import partial

partial_func = partial(lambda x: x, x=1)
udf = functions.udf(partial_func)
```

```
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File ".../spark/python/pyspark/sql/functions.py", line 2154, in udf
    return _udf(f=f, returnType=returnType)
  File ".../spark/python/pyspark/sql/functions.py", line 2145, in _udf
    return udf_obj._wrapped()
  File ".../spark/python/pyspark/sql/functions.py", line 2099, in _wrapped
    functools.wraps(self.func, assigned=assignments)
  File "/System/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/functools.py", line 33, in update_wrapper
    setattr(wrapper, attr, getattr(wrapped, attr))
AttributeError: 'functools.partial' object has no attribute '__module__'
```

This worked in Spark 2.1:

```python
from pyspark.sql import functions
from functools import partial

partial_func = partial(lambda x: x, x=1)
udf = functions.udf(partial_func)
spark.range(1).select(udf()).show()
```

```
+---------+
|partial()|
+---------+
|        1|
+---------+
```

**After**

```python
from pyspark.sql import functions
from functools import partial

partial_func = partial(lambda x: x, x=1)
udf = functions.udf(partial_func)
spark.range(1).select(udf()).show()
```

```
+---------+
|partial()|
+---------+
|        1|
+---------+
```

## How was this patch tested?

Unit tests in `python/pyspark/sql/tests.py` and manual tests.

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #18615 from HyukjinKwon/callable-object.
  • Loading branch information
HyukjinKwon authored and holdenk committed Jul 17, 2017
1 parent e398c28 commit 4ce735e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
14 changes: 13 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,10 +2087,22 @@ def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
"""
@functools.wraps(self.func)

# It is possible for a callable instance without __name__ attribute or/and
# __module__ attribute to be wrapped here. For example, functools.partial. In this case,
# we should avoid wrapping the attributes from the wrapped function to the wrapper
# function. So, we take out these attribute names from the default names to set and
# then manually assign it after being wrapped.
assignments = tuple(
a for a in functools.WRAPPER_ASSIGNMENTS if a != '__name__' and a != '__module__')

@functools.wraps(self.func, assigned=assignments)
def wrapper(*args):
return self(*args)

wrapper.__name__ = self._name
wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__')
else self.func.__class__.__module__)
wrapper.func = self.func
wrapper.returnType = self.returnType

Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,27 @@ def f(x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

class F(object):
"""Identity"""
def __call__(self, x):
return x

f = F()
return_type = IntegerType()
f_ = udf(f, return_type)

self.assertTrue(f.__doc__ in f_.__doc__)
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

f = functools.partial(f, x=1)
return_type = IntegerType()
f_ = udf(f, return_type)

self.assertTrue(f.__doc__ in f_.__doc__)
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
Expand Down

0 comments on commit 4ce735e

Please sign in to comment.