Skip to content

Commit

Permalink
fix memory when groupByKey().count()
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Aug 20, 2014
1 parent 905b233 commit acd8e1b
Showing 1 changed file with 68 additions and 17 deletions.
85 changes: 68 additions & 17 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def mergeValues(self, iterator):
""" Combine the items by creator and combiner """
# speedup attribute lookup
creator, comb = self.agg.createCombiner, self.agg.mergeValue
c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, 100
c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
limit = self.memory_limit

for k, v in iterator:
Expand All @@ -259,36 +259,40 @@ def mergeValues(self, iterator):
if get_used_memory() >= limit:
self._spill()
limit = self._next_limit()
batch /= 2
c = 0
else:
batch = min(batch * 2, self.batch)
c = 0
batch *= 1.5

def _partition(self, key):
""" Return the partition for key """
return hash((key, self._seed)) % self.partitions

def _object_size(self, obj):
return 1

def mergeCombiners(self, iterator, limit=None):
""" Merge (K,V) pair by mergeCombiner """
if limit is None:
limit = self.memory_limit
# speedup attribute lookup
comb, hfun = self.agg.mergeCombiners, self._partition
c, data, pdata, batch = 0, self.data, self.pdata, 1
comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
c, data, pdata, batch = 0, self.data, self.pdata, self.batch
for k, v in iterator:
d = pdata[hfun(k)] if pdata else data
d[k] = comb(d[k], v) if k in d else v
if not limit:
continue

c += 1
c += objsize(v)
if c > batch:
if get_used_memory() > limit:
self._spill()
limit = self._next_limit()
batch /= 4
batch /= 2
c = 0
else:
batch = min(batch * 2, self.batch)
c = 0
batch *= 1.5

def _spill(self):
"""
Expand Down Expand Up @@ -476,18 +480,42 @@ class SameKey(object):
>>> l = zip(range(2), range(2))
>>> list(SameKey(0, [1], iter(l), GroupByKey(iter([]))))
[1, 0]
>>> s = SameKey(0, [1], iter(l), GroupByKey(iter([])))
>>> for i in range(2000):
... s.append(i)
>>> len(list(s))
2002
"""
def __init__(self, key, values, it, groupBy):
self.key = key
self.values = values
self.it = it
self.groupBy = groupBy
self._index = 0
self._file = None
self._ser = None
self._index = None

def __iter__(self):
return self

def next(self):
if self._index is None:
# begin of iterator
if self._file is not None:
if self.values:
self._spill()
self._file.flush()
self._file.seek(0)
self._index = 0

if self._index >= len(self.values) and self._file is not None:
try:
self.values = next(self._ser.load_stream(self._file))
self._index = 0
except StopIteration:
self._file.close()
self._file = None

if self._index < len(self.values):
value = self.values[self._index]
self._index += 1
Expand All @@ -503,6 +531,29 @@ def next(self):
raise StopIteration
return value

def append(self, value):
if self._index is not None:
raise ValueError("Can not append value while iterating")

self.values.append(value)
# dump them into disk if the key is huge
if len(self.values) >= 10240:
self._spill()

def _spill(self):
if self._file is None:
dirs = _get_local_dirs("objects")
d = dirs[id(self) % len(dirs)]
if not os.path.exists(d):
os.makedirs(d)
p = os.path.join(d, str(id))
self._file = open(p, "w+")
self._ser = CompressedSerializer(PickleSerializer())

self._ser.dump_stream([self.values], self._file)
self.values = []
gc.collect()


class GroupByKey(object):
"""
Expand All @@ -528,16 +579,12 @@ def next(self):
key, value = self.it.next()
if self.current is None or key != self.current.key:
break
self.current.values.append(value)

self.current.append(value)
else:
key, value = self._next_item
self._next_item = None

if self.current is not None:
self.current.it = None
self.current = SameKey(key, [value], self.it, self)

return key, (v for vs in self.current for v in vs)


Expand All @@ -557,6 +604,9 @@ def _flatted_serializer(self):
ser = FlattedValuesSerializer(ser, 20)
return ser

def _object_size(self, obj):
return len(obj)

def _spill(self):
"""
dump already partitioned data into disks.
Expand Down Expand Up @@ -615,8 +665,9 @@ def _merged_items(self, index, limit=0):
size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
for j in range(self.spills))
# if the memory can not hold all the partition,
# then use sort based merge
if (size >> 20) > self.memory_limit / 2:
# then use sort based merge. Because of compression,
# the data on disks will be much smaller than needed memory
if (size >> 20) > self.memory_limit / 10:
return self._sorted_items(index)

self.data = {}
Expand Down

0 comments on commit acd8e1b

Please sign in to comment.