Skip to content

Commit

Permalink
refactor, add spark.shuffle.sort=False
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Aug 16, 2014
1 parent 250be4e commit efa23df
Showing 1 changed file with 34 additions and 17 deletions.
51 changes: 34 additions & 17 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,14 +652,13 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()

spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
spill = self._can_spill()
memory = self._memory_limit()
serializer = self._jrdd_deserializer

def sortPartition(iterator):
if spill:
sorted = ExternalSorter(memory * 0.9, serializer).sorted
return sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
return sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))

if numPartitions == 1:
if self.getNumPartitions() > 1:
Expand Down Expand Up @@ -1505,10 +1504,8 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numPartitions = self._defaultReducePartitions()

serializer = self.ctx.serializer
spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
== 'true')
memory = _parse_memory(self.ctx._conf.get(
"spark.python.worker.memory", "512m"))
spill = self._can_spill()
memory = self._memory_limit()
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)

def combineLocally(iterator):
Expand Down Expand Up @@ -1562,7 +1559,10 @@ def createZero():
return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)

def _can_spill(self):
return (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"

def _sort_based(self):
return self.ctx._conf.get("spark.shuffle.sort", "False").lower() == "true"

def _memory_limit(self):
return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
Expand All @@ -1577,6 +1577,14 @@ def groupByKey(self, numPartitions=None):
sum or average) over each key, using reduceByKey will provide much
better performance.
By default, it will use hash based aggregation, it can spill the items
into disks when the memory can not hold all the items, but it still
need to hold all the values for single key in memory.
When spark.shuffle.sort is True, it will switch to sort based approach,
then it can support single key with large number of values under small
amount of memory. But it is slower than hash based approach.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
[('a', [1, 1]), ('b', [1])]
Expand All @@ -1592,9 +1600,13 @@ def mergeCombiners(a, b):
a.extend(b)
return a

serializer = self._jrdd_deserializer
spill = self._can_spill()
sort_based = self._sort_based()
if sort_based and not spill:
raise ValueError("can not use sort based group when"
" spark.executor.spill is false")
memory = self._memory_limit()
serializer = self._jrdd_deserializer
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)

def combineLocally(iterator):
Expand All @@ -1608,16 +1620,21 @@ def combineLocally(iterator):
shuffled = locally_combined.partitionBy(numPartitions)

def groupByKey(it):
if spill:
if sort_based:
# Flatten the combined values, so it will not consume huge
# memory during merging sort.
serializer = FlattedValuesSerializer(
ser = FlattedValuesSerializer(
BatchedSerializer(PickleSerializer(), 1024), 10)
sorted = ExternalSorter(memory * 0.9, serializer).sorted
sorter = ExternalSorter(memory * 0.9, ser)
it = sorter.sorted(it, key=operator.itemgetter(0))
return imap(lambda (k, v): ResultIterable(v), GroupByKey(it))

it = sorted(it, key=operator.itemgetter(0))
for k, v in GroupByKey(it):
yield k, ResultIterable(v)
else:
# this is faster than sort based
merger = ExternalMerger(agg, memory * 0.9, serializer) \
if spill else InMemoryMerger(agg)
merger.mergeCombiners(it)
return merger.iteritems()

return shuffled.mapPartitions(groupByKey)

Expand Down

0 comments on commit efa23df

Please sign in to comment.