Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Apr 9, 2015
1 parent 0b0fde8 commit e78c15c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 38 deletions.
62 changes: 25 additions & 37 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,14 +533,14 @@ class ExternalList(object):
>>> l.append(10)
>>> len(l)
101
>>> for i in range(10240):
>>> for i in range(20240):
... l.append(i)
>>> len(l)
10341
20341
>>> import pickle
>>> l2 = pickle.loads(pickle.dumps(l))
>>> len(l2)
10341
20341
>>> list(l2)[100]
10
"""
Expand Down Expand Up @@ -577,9 +577,8 @@ def __iter__(self):
# read all items from disks first
with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
f.seek(0)
for values in self._ser.load_stream(f):
for v in values:
yield v
for v in self._ser.load_stream(f):
yield v

for v in self.values:
yield v
Expand All @@ -601,7 +600,7 @@ def _open_file(self):
os.makedirs(d)
p = os.path.join(d, str(id))
self._file = open(p, "w+", 65536)
self._ser = CompressedSerializer(PickleSerializer())
self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
os.unlink(p)

def _spill(self):
Expand All @@ -612,7 +611,7 @@ def _spill(self):

used_memory = get_used_memory()
pos = self._file.tell()
self._ser.dump_stream([self.values], self._file)
self._ser.dump_stream(self.values, self._file)
self.values = []
gc.collect()
DiskBytesSpilled += self._file.tell() - pos
Expand All @@ -622,7 +621,17 @@ def _spill(self):
class ExternalListOfList(ExternalList):
"""
An external list for list.
>>> l = ExternalListOfList([[i, i] for i in range(100)])
>>> len(l)
200
>>> l.append(range(10))
>>> len(l)
210
>>> len(list(l))
210
"""

def __init__(self, values):
ExternalList.__init__(self, values)
self.count = sum(len(i) for i in values)
Expand All @@ -632,20 +641,23 @@ def append(self, value):
# already counted 1 in ExternalList.append
self.count += len(value) - 1

def __iter__(self):
for values in ExternalList.__iter__(self):
for v in values:
yield v


class GroupByKey(object):
"""
Group a sorted iterator as [(k1, it1), (k2, it2), ...]
>>> k = [i/3 for i in range(6)]
>>> v = [i for i in range(6)]
>>> v = [[i] for i in range(6)]
>>> g = GroupByKey(iter(zip(k, v)))
>>> [(k, list(it)) for k, it in g]
[(0, [0, 1, 2]), (1, [3, 4, 5])]
"""

external_class = ExternalList

def __init__(self, iterator):
self.iterator = iterator
self.next_item = None
Expand All @@ -655,7 +667,7 @@ def __iter__(self):

def next(self):
key, value = self.next_item if self.next_item else next(self.iterator)
values = self.external_class([value])
values = ExternalListOfList([value])
try:
while True:
k, v = next(self.iterator)
Expand All @@ -668,30 +680,6 @@ def next(self):
return key, values


class GroupListsByKey(GroupByKey):
"""
Group a sorted iterator of list as [(k1, it1), (k2, it2), ...]
"""
external_class = ExternalListOfList


class ChainedIterable(object):
"""
Picklable chained iterator, similar to itertools.chain.from_iterable()
"""
def __init__(self, iterators):
self.iterators = iterators

def __len__(self):
try:
return len(self.iterators)
except TypeError:
return sum(len(i) for i in self.iterators)

def __iter__(self):
return itertools.chain.from_iterable(self.iterators)


class ExternalGroupBy(ExternalMerger):

"""
Expand Down Expand Up @@ -835,7 +823,7 @@ def load_partition(j):
sorter = ExternalSorter(self.memory_limit, ser)
sorted_items = sorter.sorted(itertools.chain(*disk_items),
key=operator.itemgetter(0))
return ((k, ChainedIterable(vs)) for k, vs in GroupListsByKey(sorted_items))
return ((k, vs) for k, vs in GroupByKey(sorted_items))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ def test_external_group_by_key(self):
filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
result = filtered.collect()[0][1]
self.assertEqual(N/3, len(result))
self.assertTrue(isinstance(result.data, shuffle.ChainedIterable))
self.assertTrue(isinstance(result.data, shuffle.ExternalList))

def test_sort_on_empty_rdd(self):
self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
Expand Down

0 comments on commit e78c15c

Please sign in to comment.