Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Apr 8, 2015
1 parent 0dcf320 commit 0b0fde8
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,7 @@ class ExternalList(object):

def __init__(self, values):
self.values = values
if values and isinstance(values[0], list):
self.count = sum(len(i) for i in values)
else:
self.count = len(values)
self.count = len(values)
self._file = None
self._ser = None

Expand Down Expand Up @@ -592,7 +589,7 @@ def __len__(self):

def append(self, value):
self.values.append(value)
self.count += len(value) if isinstance(value, list) else 1
self.count += 1
# dump them into disk if the key is huge
if len(self.values) >= self.LIMIT:
self._spill()
Expand Down Expand Up @@ -622,16 +619,33 @@ def _spill(self):
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20


class ExternalListOfList(ExternalList):
"""
An external list for list.
"""
def __init__(self, values):
ExternalList.__init__(self, values)
self.count = sum(len(i) for i in values)

def append(self, value):
ExternalList.append(self, value)
# already counted 1 in ExternalList.append
self.count += len(value) - 1


class GroupByKey(object):
"""
group a sorted iterator into [(k1, it1), (k2, it2), ...]
Group a sorted iterator as [(k1, it1), (k2, it2), ...]
>>> k = [i/3 for i in range(6)]
>>> v = [i for i in range(6)]
>>> g = GroupByKey(iter(zip(k, v)))
>>> [(k, list(it)) for k, it in g]
[(0, [0, 1, 2]), (1, [3, 4, 5])]
"""

external_class = ExternalList

def __init__(self, iterator):
self.iterator = iterator
self.next_item = None
Expand All @@ -641,7 +655,7 @@ def __iter__(self):

def next(self):
key, value = self.next_item if self.next_item else next(self.iterator)
values = ExternalList([value])
values = self.external_class([value])
try:
while True:
k, v = next(self.iterator)
Expand All @@ -654,6 +668,13 @@ def next(self):
return key, values


class GroupListsByKey(GroupByKey):
"""
Group a sorted iterator of list as [(k1, it1), (k2, it2), ...]
"""
external_class = ExternalListOfList


class ChainedIterable(object):
"""
Picklable chained iterator, similar to itertools.chain.from_iterable()
Expand All @@ -664,7 +685,7 @@ def __init__(self, iterators):
def __len__(self):
try:
return len(self.iterators)
except:
except TypeError:
return sum(len(i) for i in self.iterators)

def __iter__(self):
Expand Down Expand Up @@ -814,7 +835,7 @@ def load_partition(j):
sorter = ExternalSorter(self.memory_limit, ser)
sorted_items = sorter.sorted(itertools.chain(*disk_items),
key=operator.itemgetter(0))
return ((k, ChainedIterable(vs)) for k, vs in GroupByKey(sorted_items))
return ((k, ChainedIterable(vs)) for k, vs in GroupListsByKey(sorted_items))


if __name__ == "__main__":
Expand Down

0 comments on commit 0b0fde8

Please sign in to comment.