From 3eb0ee06a588da5b9c08a72d178835c6e8bad36b Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Wed, 10 May 2017 16:50:57 -0700 Subject: [PATCH] [SPARK-20685] Fix BatchPythonEvaluation bug in case of single UDF w/ repeated arg. ## What changes were proposed in this pull request? There's a latent corner-case bug in PySpark UDF evaluation where executing a `BatchPythonEvaluation` with a single multi-argument UDF where _at least one argument value is repeated_ will crash at execution with a confusing error. This problem was introduced in #12057: the code there has a fast path for handling a "batch UDF evaluation consisting of a single Python UDF", but that branch incorrectly assumes that a single UDF won't have repeated arguments and therefore skips the code for unpacking arguments from the input row (whose schema may not necessarily match the UDF inputs due to de-duplication of repeated arguments which occurred in the JVM before sending UDF inputs to Python). This fix here is simply to remove this special-casing: it turns out that the code in the "multiple UDFs" branch just so happens to work for the single-UDF case because Python treats `(x)` as equivalent to `x`, not as a single-argument tuple. ## How was this patch tested? New regression test in `pyspark.python.sql.tests` module (tested and confirmed that it fails before my fix). Author: Josh Rosen Closes #17927 from JoshRosen/SPARK-20685. (cherry picked from commit 8ddbc431d8b21d5ee57d3d209a4f25e301f15283) Signed-off-by: Xiao Li --- python/pyspark/sql/tests.py | 6 ++++++ python/pyspark/worker.py | 29 +++++++++++++---------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2aa2d23c6f0dd..e06f62b35bc6f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -324,6 +324,12 @@ def test_chained_udf(self): [row] = self.spark.sql("SELECT double(double(1) + 1)").collect() self.assertEqual(row[0], 6) + def test_single_udf_with_repeated_argument(self): + # regression test for SPARK-20685 + self.spark.catalog.registerFunction("add", lambda x, y: x + y, IntegerType()) + row = self.spark.sql("SELECT add(1, 1)").first() + self.assertEqual(tuple(row), (2, )) + def test_multiple_udfs(self): self.spark.catalog.registerFunction("double", lambda x: x * 2, IntegerType()) [row] = self.spark.sql("SELECT double(1), double(2)").collect() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 25ee475c7f4d9..baaa3fe074e9a 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -87,22 +87,19 @@ def read_single_udf(pickleSer, infile): def read_udfs(pickleSer, infile): num_udfs = read_int(infile) - if num_udfs == 1: - # fast path for single UDF - _, udf = read_single_udf(pickleSer, infile) - mapper = lambda a: udf(*a) - else: - udfs = {} - call_udf = [] - for i in range(num_udfs): - arg_offsets, udf = read_single_udf(pickleSer, infile) - udfs['f%d' % i] = udf - args = ["a[%d]" % o for o in arg_offsets] - call_udf.append("f%d(%s)" % (i, ", ".join(args))) - # Create function like this: - # lambda a: (f0(a0), f1(a1, a2), f2(a3)) - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) - mapper = eval(mapper_str, udfs) + udfs = {} + call_udf = [] + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + # Create function like this: + # lambda a: (f0(a0), f1(a1, a2), f2(a3)) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) ser = BatchedSerializer(PickleSerializer(), 100)