From 1ea0669a72ce11e29e78da747f3d3e1d2f28df8a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 18 Aug 2014 22:04:38 -0700 Subject: [PATCH] choose sort based groupByKey() automatically --- python/pyspark/rdd.py | 108 +---------- python/pyspark/resultiterable.py | 5 +- python/pyspark/shuffle.py | 304 ++++++++++++++++++++++--------- python/pyspark/tests.py | 6 +- 4 files changed, 233 insertions(+), 190 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index a843646f6657a..4daf81480528e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -44,7 +44,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ - get_used_memory, ExternalSorter + get_used_memory, ExternalSorter, ExternalGroupBy from py4j.java_collections import ListConverter, MapConverter @@ -201,71 +201,6 @@ def _replaceRoot(self, value): self._sink(1) -class SameKey(object): - """ - take the first few items which has the same expected key - - This is used by GroupByKey. - """ - def __init__(self, key, values, it, groupBy): - self.key = key - self.values = values - self.it = it - self.groupBy = groupBy - self._index = 0 - - def __iter__(self): - return self - - def next(self): - if self._index >= len(self.values): - if self.it is None: - raise StopIteration - - key, values = self.it.next() - if key != self.key: - self.groupBy._next_item = (key, values) - raise StopIteration - self.values = values - self._index = 0 - - self._index += 1 - return self.values[self._index - 1] - - -class GroupByKey(object): - """ - group a sorted iterator into [(k1, it1), (k2, it2), ...] - """ - def __init__(self, it): - self.it = iter(it) - self._next_item = None - self.current = None - - def __iter__(self): - return self - - def next(self): - if self._next_item is None: - while True: - key, values = self.it.next() - if self.current is None: - break - if key != self.current.key: - break - self.current.values.extend(values) - - else: - key, values = self._next_item - self._next_item = None - - if self.current is not None: - self.current.it = None - self.current = SameKey(key, values, self.it, self) - - return key, self.current - - def _parse_memory(s): """ Parse a memory string in the format supported by Java (e.g. 1g, 200m) and @@ -1561,9 +1496,6 @@ def createZero(): def _can_spill(self): return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" - def _sort_based(self): - return self.ctx._conf.get("spark.shuffle.sort", "False").lower() == "true" - def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) @@ -1577,14 +1509,6 @@ def groupByKey(self, numPartitions=None): sum or average) over each key, using reduceByKey will provide much better performance. - By default, it will use hash based aggregation, it can spill the items - into disks when the memory can not hold all the items, but it still - need to hold all the values for single key in memory. - - When spark.shuffle.sort is True, it will switch to sort based approach, - then it can support single key with large number of values under small - amount of memory. But it is slower than hash based approach. - >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) [('a', [1, 1]), ('b', [1])] @@ -1601,42 +1525,26 @@ def mergeCombiners(a, b): return a spill = self._can_spill() - sort_based = self._sort_based() - if sort_based and not spill: - raise ValueError("can not use sort based group when" - " spark.executor.spill is false") memory = self._memory_limit() serializer = self._jrdd_deserializer agg = Aggregator(createCombiner, mergeValue, mergeCombiners) - def combineLocally(iterator): + def combine(iterator): merger = ExternalMerger(agg, memory * 0.9, serializer) \ if spill else InMemoryMerger(agg) merger.mergeValues(iterator) return merger.iteritems() - # combine them before shuffle could reduce the comparison later - locally_combined = self.mapPartitions(combineLocally) + locally_combined = self.mapPartitions(combine) shuffled = locally_combined.partitionBy(numPartitions) def groupByKey(it): - if sort_based: - # Flatten the combined values, so it will not consume huge - # memory during merging sort. - ser = FlattedValuesSerializer( - BatchedSerializer(PickleSerializer(), 1024), 10) - sorter = ExternalSorter(memory * 0.9, ser) - it = sorter.sorted(it, key=operator.itemgetter(0)) - return imap(lambda (k, v): (k, ResultIterable(v)), GroupByKey(it)) - - else: - # this is faster than sort based - merger = ExternalMerger(agg, memory * 0.9, serializer) \ - if spill else InMemoryMerger(agg) - merger.mergeCombiners(it) - return merger.iteritems() + merger = ExternalGroupBy(agg, memory, serializer)\ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(it) + return merger.iteritems() - return shuffled.mapPartitions(groupByKey) + return shuffled.mapPartitions(groupByKey).mapValues(ResultIterable) # TODO: add tests def flatMapValues(self, f): diff --git a/python/pyspark/resultiterable.py b/python/pyspark/resultiterable.py index fa6e0dcfe4b32..a9436cdad55ba 100644 --- a/python/pyspark/resultiterable.py +++ b/python/pyspark/resultiterable.py @@ -33,7 +33,10 @@ def __iter__(self): return iter(self.it) def __len__(self): - return sum(1 for _ in self.it) + try: + return len(self.it) + except TypeError: + return sum(1 for _ in self.it) def __reduce__(self): return (ResultIterable, (list(self.it),)) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 905bcf6fea7b4..3cfe17e03ca91 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -25,7 +25,7 @@ import operator import pyspark.heapq3 as heapq -from pyspark.serializers import BatchedSerializer, PickleSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattedValuesSerializer try: import psutil @@ -236,72 +236,50 @@ def _next_limit(self): def mergeValues(self, iterator): """ Combine the items by creator and combiner """ - iterator = iter(iterator) # speedup attribute lookup creator, comb = self.agg.createCombiner, self.agg.mergeValue - d, c, batch = self.data, 0, self.batch + c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, 100 + limit = self.memory_limit for k, v in iterator: + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else creator(v) c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeValues(iterator, self._next_limit()) - break + if c >= batch: + if get_used_memory() >= limit: + self._spill() + limit = self._next_limit() + else: + batch = min(batch * 2, self.batch) + c = 0 def _partition(self, key): """ Return the partition for key """ return hash((key, self._seed)) % self.partitions - def _partitioned_mergeValues(self, iterator, limit=0): - """ Partition the items by key, then combine them """ - # speedup attribute lookup - creator, comb = self.agg.createCombiner, self.agg.mergeValue - c, pdata, hfun, batch = 0, self.pdata, self._partition, self.batch - - for k, v in iterator: - d = pdata[hfun(k)] - d[k] = comb(d[k], v) if k in d else creator(v) - if not limit: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() - - def mergeCombiners(self, iterator, check=True): + def mergeCombiners(self, iterator, limit=None): """ Merge (K,V) pair by mergeCombiner """ - iterator = iter(iterator) + if limit is None: + limit = self.memory_limit # speedup attribute lookup - d, comb, batch = self.data, self.agg.mergeCombiners, self.batch - c = 0 + comb, hfun = self.agg.mergeCombiners, self._partition + c, data, pdata, batch = 0, self.data, self.pdata, 1 for k, v in iterator: - d[k] = comb(d[k], v) if k in d else v - if not check: - continue - - c += 1 - if c % batch == 0 and get_used_memory() > self.memory_limit: - self._spill() - self._partitioned_mergeCombiners(iterator, self._next_limit()) - break - - def _partitioned_mergeCombiners(self, iterator, limit=0): - """ Partition the items by key, then merge them """ - comb, pdata = self.agg.mergeCombiners, self.pdata - c, hfun = 0, self._partition - for k, v in iterator: - d = pdata[hfun(k)] + d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else v if not limit: continue c += 1 - if c % self.batch == 0 and get_used_memory() > limit: - self._spill() - limit = self._next_limit() + if c > batch: + if get_used_memory() > limit: + self._spill() + limit = self._next_limit() + batch /= 4 + else: + batch = min(batch * 2, self.batch) + c = 0 def _spill(self): """ @@ -333,7 +311,7 @@ def _spill(self): s.close() self.data.clear() - self.pdata = [{} for i in range(self.partitions)] + self.pdata.extend([{} for i in range(self.partitions)]) else: for i in range(self.partitions): @@ -354,9 +332,9 @@ def iteritems(self): def _external_items(self): """ Return all partitioned items as iterator """ - assert not self.data if any(self.pdata): self._spill() + self.pdata = [] hard_limit = self._next_limit() try: @@ -366,8 +344,7 @@ def _external_items(self): path = self._get_spill_dir(j) p = os.path.join(path, str(i)) # do not check memory during merging - self.mergeCombiners(self.serializer.load_stream(open(p)), - False) + self.mergeCombiners(self.serializer.load_stream(open(p)), 0) # limit the total partitions if (self.scale * self.partitions < self.MAX_TOTAL_PARTITIONS @@ -377,12 +354,14 @@ def _external_items(self): gc.collect() # release the memory as much as possible for v in self._recursive_merged_items(i): yield v - return + break + else: + for v in self.data.iteritems(): + yield v + self.data.clear() - for v in self.data.iteritems(): - yield v - self.data.clear() gc.collect() + hard_limit = self._next_limit() # remove the merged partition for j in range(self.spills): @@ -397,44 +376,29 @@ def _cleanup(self): for d in self.localdirs: shutil.rmtree(d, True) - def _recursive_merged_items(self, start): + def _recursive_merged_items(self, index): """ merge the partitioned items and return the as iterator If one partition can not be fit in memory, then them will be partitioned and merged recursively. """ - # make sure all the data are dumps into disks. - assert not self.data - if any(self.pdata): - self._spill() - assert self.spills > 0 - - for i in range(start, self.partitions): - subdirs = [os.path.join(d, "parts", str(i)) - for d in self.localdirs] - m = ExternalMerger(self.agg, self.memory_limit, self.serializer, - subdirs, self.scale * self.partitions) - m.pdata = [{} for _ in range(self.partitions)] - limit = self._next_limit() - - for j in range(self.spills): - path = self._get_spill_dir(j) - p = os.path.join(path, str(i)) - m._partitioned_mergeCombiners( - self.serializer.load_stream(open(p))) - - if get_used_memory() > limit: - m._spill() - limit = self._next_limit() - - for v in m._external_items(): - yield v + subdirs = [os.path.join(d, "parts", str(index)) for d in self.localdirs] + m = ExternalMerger(self.agg, self.memory_limit, self.serializer, subdirs, + self.scale * self.partitions, self.partitions, self.batch) + m.pdata = [{} for _ in range(self.partitions)] + limit = self._next_limit() + + for j in range(self.spills): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + m.mergeCombiners(self.serializer.load_stream(open(p)), 0) + + if get_used_memory() > limit: + m._spill() + limit = self._next_limit() - # remove the merged partition - for j in range(self.spills): - path = self._get_spill_dir(j) - os.remove(os.path.join(path, str(i))) + return m._external_items() class ExternalSorter(object): @@ -495,6 +459,174 @@ def sorted(self, iterator, key=None, reverse=False): return heapq.merge(chunks, key=key, reverse=reverse) +class SameKey(object): + """ + take the first few items which has the same expected key + + This is used by GroupByKey. + + >>> l = zip(range(2), range(2)) + >>> list(SameKey(0, [1], iter(l), GroupByKey(iter([])))) + [1, 0] + """ + def __init__(self, key, values, it, groupBy): + self.key = key + self.values = values + self.it = it + self.groupBy = groupBy + self._index = 0 + + def __iter__(self): + return self + + def next(self): + if self._index < len(self.values): + value = self.values[self._index] + self._index += 1 + return value + + if self.it is None: + raise StopIteration + + key, value = self.it.next() + if key != self.key: + self.groupBy._next_item = (key, value) + self.it = None + raise StopIteration + return value + + +class GroupByKey(object): + """ + group a sorted iterator into [(k1, it1), (k2, it2), ...] + + >>> k = [i/3 for i in range(6)] + >>> v = [[i] for i in range(6)] + >>> g = GroupByKey(iter(zip(k, v))) + >>> [(k, list(it)) for k, it in g] + [(0, [0, 1, 2]), (1, [3, 4, 5])] + """ + def __init__(self, it): + self.it = iter(it) + self._next_item = None + self.current = None + + def __iter__(self): + return self + + def next(self): + if self._next_item is None: + while True: + key, value = self.it.next() + if self.current is None or key != self.current.key: + break + self.current.values.append(value) + + else: + key, value = self._next_item + self._next_item = None + + if self.current is not None: + self.current.it = None + self.current = SameKey(key, [value], self.it, self) + + return key, (v for vs in self.current for v in vs) + + +class ExternalGroupBy(ExternalMerger): + + """ + Group by the items by key. If any partition of them can not been + hold in memory, it will do sort based group by. + """ + SORT_KEY_LIMIT = 1000 + + def _flatted_serializer(self): + ser = self.serializer + if not isinstance(ser, (BatchedSerializer, FlattedValuesSerializer)): + ser = BatchedSerializer(ser, 1024) + if not isinstance(ser, FlattedValuesSerializer): + ser = FlattedValuesSerializer(ser, 20) + return ser + + def _spill(self): + """ + dump already partitioned data into disks. + """ + path = self._get_spill_dir(self.spills) + if not os.path.exists(path): + os.makedirs(path) + + if not self.pdata: + # The data has not been partitioned, it will iterator the + # dataset once, write them into different files, has no + # additional memory. It only called when the memory goes + # above limit at the first time. + + # open all the files for writing + streams = [open(os.path.join(path, str(i)), 'w') + for i in range(self.partitions)] + + # If the number of keys is small, then the overhead of sort is small + # sort them before dumping into disks + self._sorted = len(self.data) < self.SORT_KEY_LIMIT + if self._sorted: + ser = self._flatted_serializer() + for k in sorted(self.data.keys()): + v = self.data[k] + h = self._partition(k) + ser.dump_stream([(k, v)], streams[h]) + self.serializer = ser + else: + for k, v in self.data.iteritems(): + h = self._partition(k) + self.serializer.dump_stream([(k, v)], streams[h]) + + for s in streams: + s.close() + + self.data.clear() + self.pdata.extend([{} for i in range(self.partitions)]) + + else: + for i in range(self.partitions): + p = os.path.join(path, str(i)) + with open(p, "w") as f: + # dump items in batch + if self._sorted: + self.serializer.dump_stream( + sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0)), f) + else: + self.serializer.dump_stream(self.pdata[i].iteritems(), f) + self.pdata[i].clear() + + self.spills += 1 + gc.collect() # release the memory as much as possible + + def _recursive_merged_items(self, index): + """ load a partition from disk, then sort and group by key """ + def load_partition(j): + path = self._get_spill_dir(j) + p = os.path.join(path, str(index)) + return self.serializer.load_stream(open(p, 'r', 65536)) + + disk_items = [load_partition(j) for j in range(self.spills)] + + if self._sorted: + # all the partitions are already sorted + sorted_items = heapq.merge(disk_items, key=operator.itemgetter(0)) + + else: + # Flatten the combined values, so it will not consume huge + # memory during merging sort. + ser = self._flatted_serializer() + sorter = ExternalSorter(self.memory_limit, ser) + sorted_items = sorter.sorted(itertools.chain(*disk_items), + key=operator.itemgetter(0)) + + return GroupByKey(sorted_items) + + if __name__ == "__main__": import doctest doctest.testmod() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index aaad62ce419da..62dcd486c3010 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -97,7 +97,7 @@ def test_small_dataset(self): sum(xrange(self.N))) def test_medium_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 30) m.mergeValues(self.data) self.assertTrue(m.spills >= 1) self.assertEqual(sum(sum(v) for k, v in m.iteritems()), @@ -110,10 +110,10 @@ def test_medium_dataset(self): sum(xrange(self.N)) * 3) def test_huge_dataset(self): - m = ExternalMerger(self.agg, 10) + m = ExternalMerger(self.agg, 10, partitions=3) m.mergeCombiners(map(lambda (k, v): (k, [str(v)]), self.data * 10)) self.assertTrue(m.spills >= 1) - self.assertEqual(sum(len(v) for k, v in m._recursive_merged_items(0)), + self.assertEqual(sum(len(v) for k, v in m.iteritems()), self.N * 10) m._cleanup()