diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 43b91459b56ea..a28f6efde1c4f 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -247,7 +247,7 @@ def mergeValues(self, iterator): """ Combine the items by creator and combiner """ # speedup attribute lookup creator, comb = self.agg.createCombiner, self.agg.mergeValue - c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, 100 + c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch limit = self.memory_limit for k, v in iterator: @@ -259,36 +259,40 @@ def mergeValues(self, iterator): if get_used_memory() >= limit: self._spill() limit = self._next_limit() + batch /= 2 + c = 0 else: - batch = min(batch * 2, self.batch) - c = 0 + batch *= 1.5 def _partition(self, key): """ Return the partition for key """ return hash((key, self._seed)) % self.partitions + def _object_size(self, obj): + return 1 + def mergeCombiners(self, iterator, limit=None): """ Merge (K,V) pair by mergeCombiner """ if limit is None: limit = self.memory_limit # speedup attribute lookup - comb, hfun = self.agg.mergeCombiners, self._partition - c, data, pdata, batch = 0, self.data, self.pdata, 1 + comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size + c, data, pdata, batch = 0, self.data, self.pdata, self.batch for k, v in iterator: d = pdata[hfun(k)] if pdata else data d[k] = comb(d[k], v) if k in d else v if not limit: continue - c += 1 + c += objsize(v) if c > batch: if get_used_memory() > limit: self._spill() limit = self._next_limit() - batch /= 4 + batch /= 2 + c = 0 else: - batch = min(batch * 2, self.batch) - c = 0 + batch *= 1.5 def _spill(self): """ @@ -476,18 +480,42 @@ class SameKey(object): >>> l = zip(range(2), range(2)) >>> list(SameKey(0, [1], iter(l), GroupByKey(iter([])))) [1, 0] + >>> s = SameKey(0, [1], iter(l), GroupByKey(iter([]))) + >>> for i in range(2000): + ... s.append(i) + >>> len(list(s)) + 2002 """ def __init__(self, key, values, it, groupBy): self.key = key self.values = values self.it = it self.groupBy = groupBy - self._index = 0 + self._file = None + self._ser = None + self._index = None def __iter__(self): return self def next(self): + if self._index is None: + # begin of iterator + if self._file is not None: + if self.values: + self._spill() + self._file.flush() + self._file.seek(0) + self._index = 0 + + if self._index >= len(self.values) and self._file is not None: + try: + self.values = next(self._ser.load_stream(self._file)) + self._index = 0 + except StopIteration: + self._file.close() + self._file = None + if self._index < len(self.values): value = self.values[self._index] self._index += 1 @@ -503,6 +531,29 @@ def next(self): raise StopIteration return value + def append(self, value): + if self._index is not None: + raise ValueError("Can not append value while iterating") + + self.values.append(value) + # dump them into disk if the key is huge + if len(self.values) >= 10240: + self._spill() + + def _spill(self): + if self._file is None: + dirs = _get_local_dirs("objects") + d = dirs[id(self) % len(dirs)] + if not os.path.exists(d): + os.makedirs(d) + p = os.path.join(d, str(id)) + self._file = open(p, "w+") + self._ser = CompressedSerializer(PickleSerializer()) + + self._ser.dump_stream([self.values], self._file) + self.values = [] + gc.collect() + class GroupByKey(object): """ @@ -528,16 +579,12 @@ def next(self): key, value = self.it.next() if self.current is None or key != self.current.key: break - self.current.values.append(value) - + self.current.append(value) else: key, value = self._next_item self._next_item = None - if self.current is not None: - self.current.it = None self.current = SameKey(key, [value], self.it, self) - return key, (v for vs in self.current for v in vs) @@ -557,6 +604,9 @@ def _flatted_serializer(self): ser = FlattedValuesSerializer(ser, 20) return ser + def _object_size(self, obj): + return len(obj) + def _spill(self): """ dump already partitioned data into disks. @@ -615,8 +665,9 @@ def _merged_items(self, index, limit=0): size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index))) for j in range(self.spills)) # if the memory can not hold all the partition, - # then use sort based merge - if (size >> 20) > self.memory_limit / 2: + # then use sort based merge. Because of compression, + # the data on disks will be much smaller than needed memory + if (size >> 20) > self.memory_limit / 10: return self._sorted_items(index) self.data = {}