Skip to content

Commit

Permalink
Reviving callable objects support in UDF in PySpark
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Jul 12, 2017
1 parent 780586a commit 6d3ef48
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
5 changes: 4 additions & 1 deletion python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2087,10 +2087,13 @@ def _wrapped(self):
"""
Wrap this udf with a function and attach docstring from func
"""
@functools.wraps(self.func)
assignments = tuple(a for a in functools.WRAPPER_ASSIGNMENTS if a != "__name__")

@functools.wraps(self.func, assigned=assignments)
def wrapper(*args):
return self(*args)

wrapper.__name__ = self._name
wrapper.func = self.func
wrapper.returnType = self.returnType

Expand Down
13 changes: 13 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,19 @@ def f(x):
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

class F(object):
"""Identity"""
def __call__(self, x):
return x

f = F()
return_type = IntegerType()
f_ = udf(f, return_type)

self.assertTrue(f.__doc__ in f_.__doc__)
self.assertEqual(f, f_.func)
self.assertEqual(return_type, f_.returnType)

def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
Expand Down

0 comments on commit 6d3ef48

Please sign in to comment.