Skip to content

Commit

Permalink
choose sort based groupByKey() automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Aug 19, 2014
1 parent b40bae7 commit 1ea0669
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 190 deletions.
108 changes: 8 additions & 100 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
get_used_memory, ExternalSorter
get_used_memory, ExternalSorter, ExternalGroupBy

from py4j.java_collections import ListConverter, MapConverter

Expand Down Expand Up @@ -201,71 +201,6 @@ def _replaceRoot(self, value):
self._sink(1)


class SameKey(object):
"""
take the first few items which has the same expected key
This is used by GroupByKey.
"""
def __init__(self, key, values, it, groupBy):
self.key = key
self.values = values
self.it = it
self.groupBy = groupBy
self._index = 0

def __iter__(self):
return self

def next(self):
if self._index >= len(self.values):
if self.it is None:
raise StopIteration

key, values = self.it.next()
if key != self.key:
self.groupBy._next_item = (key, values)
raise StopIteration
self.values = values
self._index = 0

self._index += 1
return self.values[self._index - 1]


class GroupByKey(object):
"""
group a sorted iterator into [(k1, it1), (k2, it2), ...]
"""
def __init__(self, it):
self.it = iter(it)
self._next_item = None
self.current = None

def __iter__(self):
return self

def next(self):
if self._next_item is None:
while True:
key, values = self.it.next()
if self.current is None:
break
if key != self.current.key:
break
self.current.values.extend(values)

else:
key, values = self._next_item
self._next_item = None

if self.current is not None:
self.current.it = None
self.current = SameKey(key, values, self.it, self)

return key, self.current


def _parse_memory(s):
"""
Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
Expand Down Expand Up @@ -1561,9 +1496,6 @@ def createZero():
def _can_spill(self):
return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"

def _sort_based(self):
return self.ctx._conf.get("spark.shuffle.sort", "False").lower() == "true"

def _memory_limit(self):
return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))

Expand All @@ -1577,14 +1509,6 @@ def groupByKey(self, numPartitions=None):
sum or average) over each key, using reduceByKey will provide much
better performance.
By default, it will use hash based aggregation, it can spill the items
into disks when the memory can not hold all the items, but it still
need to hold all the values for single key in memory.
When spark.shuffle.sort is True, it will switch to sort based approach,
then it can support single key with large number of values under small
amount of memory. But it is slower than hash based approach.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
[('a', [1, 1]), ('b', [1])]
Expand All @@ -1601,42 +1525,26 @@ def mergeCombiners(a, b):
return a

spill = self._can_spill()
sort_based = self._sort_based()
if sort_based and not spill:
raise ValueError("can not use sort based group when"
" spark.executor.spill is false")
memory = self._memory_limit()
serializer = self._jrdd_deserializer
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)

def combineLocally(iterator):
def combine(iterator):
merger = ExternalMerger(agg, memory * 0.9, serializer) \
if spill else InMemoryMerger(agg)
merger.mergeValues(iterator)
return merger.iteritems()

# combine them before shuffle could reduce the comparison later
locally_combined = self.mapPartitions(combineLocally)
locally_combined = self.mapPartitions(combine)
shuffled = locally_combined.partitionBy(numPartitions)

def groupByKey(it):
if sort_based:
# Flatten the combined values, so it will not consume huge
# memory during merging sort.
ser = FlattedValuesSerializer(
BatchedSerializer(PickleSerializer(), 1024), 10)
sorter = ExternalSorter(memory * 0.9, ser)
it = sorter.sorted(it, key=operator.itemgetter(0))
return imap(lambda (k, v): (k, ResultIterable(v)), GroupByKey(it))

else:
# this is faster than sort based
merger = ExternalMerger(agg, memory * 0.9, serializer) \
if spill else InMemoryMerger(agg)
merger.mergeCombiners(it)
return merger.iteritems()
merger = ExternalGroupBy(agg, memory, serializer)\
if spill else InMemoryMerger(agg)
merger.mergeCombiners(it)
return merger.iteritems()

return shuffled.mapPartitions(groupByKey)
return shuffled.mapPartitions(groupByKey).mapValues(ResultIterable)

# TODO: add tests
def flatMapValues(self, f):
Expand Down
5 changes: 4 additions & 1 deletion python/pyspark/resultiterable.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def __iter__(self):
return iter(self.it)

def __len__(self):
return sum(1 for _ in self.it)
try:
return len(self.it)
except TypeError:
return sum(1 for _ in self.it)

def __reduce__(self):
return (ResultIterable, (list(self.it),))
Loading

0 comments on commit 1ea0669

Please sign in to comment.