From e78c15c41ff0c1cf393bb6e65eb513ebd8357c75 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 8 Apr 2015 20:44:06 -0700 Subject: [PATCH] address comments --- python/pyspark/shuffle.py | 62 ++++++++++++++++----------------------- python/pyspark/tests.py | 2 +- 2 files changed, 26 insertions(+), 38 deletions(-) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 7f2defad17842..e2bef5fe0683a 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -533,14 +533,14 @@ class ExternalList(object): >>> l.append(10) >>> len(l) 101 - >>> for i in range(10240): + >>> for i in range(20240): ... l.append(i) >>> len(l) - 10341 + 20341 >>> import pickle >>> l2 = pickle.loads(pickle.dumps(l)) >>> len(l2) - 10341 + 20341 >>> list(l2)[100] 10 """ @@ -577,9 +577,8 @@ def __iter__(self): # read all items from disks first with os.fdopen(os.dup(self._file.fileno()), 'r') as f: f.seek(0) - for values in self._ser.load_stream(f): - for v in values: - yield v + for v in self._ser.load_stream(f): + yield v for v in self.values: yield v @@ -601,7 +600,7 @@ def _open_file(self): os.makedirs(d) p = os.path.join(d, str(id)) self._file = open(p, "w+", 65536) - self._ser = CompressedSerializer(PickleSerializer()) + self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024) os.unlink(p) def _spill(self): @@ -612,7 +611,7 @@ def _spill(self): used_memory = get_used_memory() pos = self._file.tell() - self._ser.dump_stream([self.values], self._file) + self._ser.dump_stream(self.values, self._file) self.values = [] gc.collect() DiskBytesSpilled += self._file.tell() - pos @@ -622,7 +621,17 @@ def _spill(self): class ExternalListOfList(ExternalList): """ An external list for list. + + >>> l = ExternalListOfList([[i, i] for i in range(100)]) + >>> len(l) + 200 + >>> l.append(range(10)) + >>> len(l) + 210 + >>> len(list(l)) + 210 """ + def __init__(self, values): ExternalList.__init__(self, values) self.count = sum(len(i) for i in values) @@ -632,20 +641,23 @@ def append(self, value): # already counted 1 in ExternalList.append self.count += len(value) - 1 + def __iter__(self): + for values in ExternalList.__iter__(self): + for v in values: + yield v + class GroupByKey(object): """ Group a sorted iterator as [(k1, it1), (k2, it2), ...] >>> k = [i/3 for i in range(6)] - >>> v = [i for i in range(6)] + >>> v = [[i] for i in range(6)] >>> g = GroupByKey(iter(zip(k, v))) >>> [(k, list(it)) for k, it in g] [(0, [0, 1, 2]), (1, [3, 4, 5])] """ - external_class = ExternalList - def __init__(self, iterator): self.iterator = iterator self.next_item = None @@ -655,7 +667,7 @@ def __iter__(self): def next(self): key, value = self.next_item if self.next_item else next(self.iterator) - values = self.external_class([value]) + values = ExternalListOfList([value]) try: while True: k, v = next(self.iterator) @@ -668,30 +680,6 @@ def next(self): return key, values -class GroupListsByKey(GroupByKey): - """ - Group a sorted iterator of list as [(k1, it1), (k2, it2), ...] - """ - external_class = ExternalListOfList - - -class ChainedIterable(object): - """ - Picklable chained iterator, similar to itertools.chain.from_iterable() - """ - def __init__(self, iterators): - self.iterators = iterators - - def __len__(self): - try: - return len(self.iterators) - except TypeError: - return sum(len(i) for i in self.iterators) - - def __iter__(self): - return itertools.chain.from_iterable(self.iterators) - - class ExternalGroupBy(ExternalMerger): """ @@ -835,7 +823,7 @@ def load_partition(j): sorter = ExternalSorter(self.memory_limit, ser) sorted_items = sorter.sorted(itertools.chain(*disk_items), key=operator.itemgetter(0)) - return ((k, ChainedIterable(vs)) for k, vs in GroupListsByKey(sorted_items)) + return ((k, vs) for k, vs in GroupByKey(sorted_items)) if __name__ == "__main__": diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 03fdebaf21291..15cb4685e18a1 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -745,7 +745,7 @@ def test_external_group_by_key(self): filtered.values().map(lambda x: (len(x), len(list(x)))).collect()) result = filtered.collect()[0][1] self.assertEqual(N/3, len(result)) - self.assertTrue(isinstance(result.data, shuffle.ChainedIterable)) + self.assertTrue(isinstance(result.data, shuffle.ExternalList)) def test_sort_on_empty_rdd(self): self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())