From aabd7fa7b79d35be9281fb6750220fb59031197f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 28 Jul 2014 18:06:29 -0700 Subject: [PATCH] fix pickle itemgetter with cloudpickle --- python/pyspark/cloudpickle.py | 5 +++-- python/pyspark/tests.py | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/python/pyspark/cloudpickle.py b/python/pyspark/cloudpickle.py index 4fda2a9b950b8..68062483dedaa 100644 --- a/python/pyspark/cloudpickle.py +++ b/python/pyspark/cloudpickle.py @@ -560,8 +560,9 @@ class ItemGetterType(ctypes.Structure): ] - itemgetter_obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents - return self.save_reduce(operator.itemgetter, (itemgetter_obj.item,)) + obj = ctypes.cast(ctypes.c_void_p(id(obj)), ctypes.POINTER(ItemGetterType)).contents + return self.save_reduce(operator.itemgetter, + obj.item if obj.nitems > 1 else (obj.item,)) if PyObject_HEAD: dispatch[operator.itemgetter] = save_itemgetter diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 63cc5e9ad96fa..621ddcd3ed8e3 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -278,6 +278,12 @@ def combOp(x, y): self.assertEqual(set([2]), sets[3]) self.assertEqual(set([1, 3]), sets[5]) + def test_itemgetter(self): + rdd = self.sc.parallelize([range(10)]) + from operator import itemgetter + self.assertEqual([1], rdd.map(itemgetter(1)).collect()) + self.assertEqual([(2, 3)], rdd.map(itemgetter(2, 3)).collect()) + class TestIO(PySparkTestCase):