diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 22b1ffc900751..c3cc7426b177b 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -324,6 +324,12 @@ def test_chained_udf(self): [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() self.assertEqual(row[0], 6) + def test_single_udf_with_repeated_argument(self): + # regression test for SPARK-20685 + self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) + row = self.spark.sql("SELECT add(1, 1)").first() + self.assertEqual(tuple(row), (2, )) + def test_multiple_udfs(self): self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) [row] = self.spark.sql("SELECT double(1), double(2)").collect() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 09182829538fc..5eeaac70b80ed 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -86,22 +86,19 @@ def read_single_udf(pickleSer, infile): def read_udfs(pickleSer, infile): num_udfs = read_int(infile) - if num_udfs == 1: - # fast path for single UDF - _, udf = read_single_udf(pickleSer, infile) - mapper = lambda a: udf(*a) - else: - udfs = {} - call_udf = [] - for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile) - udfs['f%d' % i] = udf - args = ["a[%d]" % o for o in arg_offsets] - call_udf.append("f%d(%s)" % (i, ", ".join(args))) - # Create function like this: - # lambda a: (f0(a0), f1(a1, a2), f2(a3)) - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) - mapper = eval(mapper_str, udfs) + udfs = {} + call_udf = [] + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + # Create function like this: + # lambda a: (f0(a0), f1(a1, a2), f2(a3)) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) ser = BatchedSerializer(PickleSerializer(), 100)