From efa23df9a82c02169bd526e7f99d30c7f6a95de2 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 16 Aug 2014 00:11:15 -0700 Subject: [PATCH] refactor, add spark.shuffle.sort=False --- python/pyspark/rdd.py | 51 ++++++++++++++++++++++++++++--------------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 12baf6072255c..05a3570a9b8ba 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -652,14 +652,13 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x): if numPartitions is None: numPartitions = self._defaultReducePartitions() - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') - memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() serializer = self._jrdd_deserializer def sortPartition(iterator): - if spill: - sorted = ExternalSorter(memory * 0.9, serializer).sorted - return sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)) + sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted + return sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)) if numPartitions == 1: if self.getNumPartitions() > 1: @@ -1505,10 +1504,8 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners, numPartitions = self._defaultReducePartitions() serializer = self.ctx.serializer - spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() - == 'true') - memory = _parse_memory(self.ctx._conf.get( - "spark.python.worker.memory", "512m")) + spill = self._can_spill() + memory = self._memory_limit() agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): @@ -1562,7 +1559,10 @@ def createZero(): return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions) def _can_spill(self): - return (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true') + return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true" + + def _sort_based(self): + return self.ctx._conf.get("spark.shuffle.sort", "False").lower() == "true" def _memory_limit(self): return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m")) @@ -1577,6 +1577,14 @@ def groupByKey(self, numPartitions=None): sum or average) over each key, using reduceByKey will provide much better performance. + By default, it will use hash based aggregation, it can spill the items + into disks when the memory can not hold all the items, but it still + need to hold all the values for single key in memory. + + When spark.shuffle.sort is True, it will switch to sort based approach, + then it can support single key with large number of values under small + amount of memory. But it is slower than hash based approach. + >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) [('a', [1, 1]), ('b', [1])] @@ -1592,9 +1600,13 @@ def mergeCombiners(a, b): a.extend(b) return a - serializer = self._jrdd_deserializer spill = self._can_spill() + sort_based = self._sort_based() + if sort_based and not spill: + raise ValueError("can not use sort based group when" + " spark.executor.spill is false") memory = self._memory_limit() + serializer = self._jrdd_deserializer agg = Aggregator(createCombiner, mergeValue, mergeCombiners) def combineLocally(iterator): @@ -1608,16 +1620,21 @@ def combineLocally(iterator): shuffled = locally_combined.partitionBy(numPartitions) def groupByKey(it): - if spill: + if sort_based: # Flatten the combined values, so it will not consume huge # memory during merging sort. - serializer = FlattedValuesSerializer( + ser = FlattedValuesSerializer( BatchedSerializer(PickleSerializer(), 1024), 10) - sorted = ExternalSorter(memory * 0.9, serializer).sorted + sorter = ExternalSorter(memory * 0.9, ser) + it = sorter.sorted(it, key=operator.itemgetter(0)) + return imap(lambda (k, v): ResultIterable(v), GroupByKey(it)) - it = sorted(it, key=operator.itemgetter(0)) - for k, v in GroupByKey(it): - yield k, ResultIterable(v) + else: + # this is faster than sort based + merger = ExternalMerger(agg, memory * 0.9, serializer) \ + if spill else InMemoryMerger(agg) + merger.mergeCombiners(it) + return merger.iteritems() return shuffled.mapPartitions(groupByKey)