Skip to content

Commit

Permalink
group the same key before shuffle, reduce the comparison during sorting
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Aug 16, 2014
1 parent 083d842 commit d05060d
Showing 1 changed file with 36 additions and 17 deletions.
53 changes: 36 additions & 17 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,19 +218,19 @@ def __iter__(self):
return self

def next(self):
if self._index < len(self.values):
self._index += 1
return self.values[self._index - 1]
if self._index >= len(self.values):
if self.it is None:
raise StopIteration

if self.it is None:
raise StopIteration
key, values = self.it.next()
if key != self.key:
self.groupBy._next_item = (key, values)
raise StopIteration
self.values = values
self._index = 0

key, value = self.it.next()
if key == self.key:
return value

self.groupBy._next_item = (key, value)
raise StopIteration
self._index += 1
return self.values[self._index - 1]


class GroupByKey(object):
Expand All @@ -248,20 +248,20 @@ def __iter__(self):
def next(self):
if self._next_item is None:
while True:
key, value = self.it.next()
key, values = self.it.next()
if self.current is None:
break
if key != self.current.key:
break
self.current.values.append(value)
self.current.values.extend(values)

else:
key, value = self._next_item
key, values = self._next_item
self._next_item = None

if self.current is not None:
self.current.it = None
self.current = SameKey(key, [value], self.it, self)
self.current = SameKey(key, values, self.it, self)

return key, self.current

Expand Down Expand Up @@ -1581,9 +1581,30 @@ def groupByKey(self, numPartitions=None):
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
[('a', [1, 1]), ('b', [1])]
"""
def createCombiner(x):
return [x]

def mergeValue(xs, x):
xs.append(x)
return xs

def mergeCombiners(a, b):
a.extend(b)
return a

serializer = self.ctx.serializer
spill = self._can_spill()
memory = self._memory_limit()
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)

def combineLocally(iterator):
merger = ExternalMerger(agg, memory * 0.9, serializer) \
if spill else InMemoryMerger(agg)
merger.mergeValues(iterator)
return merger.iteritems()

locally_combined = self.mapPartitions(combineLocally)
shuffled = locally_combined.partitionBy(numPartitions)

def groupByKey(it):
if spill:
Expand All @@ -1592,8 +1613,6 @@ def groupByKey(it):
for k, v in GroupByKey(it):
yield k, ResultIterable(v)

# TODO: combine before shuffle ?
shuffled = self.partitionBy(numPartitions)
return shuffled.mapPartitions(groupByKey)

# TODO: add tests
Expand Down

0 comments on commit d05060d

Please sign in to comment.