Skip to content

Commit

Permalink
refactor, minor turning
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Aug 27, 2014
1 parent b48cda5 commit 2c1d05b
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 122 deletions.
12 changes: 10 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,14 +1660,22 @@ def _memory_limit(self):
def groupByKey(self, numPartitions=None):
"""
Group the values for each key in the RDD into a single sequence.
Hash-partitions the resulting RDD with into numPartitions partitions.
Hash-partitions the resulting RDD with into numPartitions
partitions.
The values in the resulting RDD is iterable object L{ResultIterable},
they can be iterated only once. The `len(values)` will result in
iterating values, so they can not be iterable after calling
`len(values)`.
Note: If you are grouping in order to perform an aggregation (such as a
sum or average) over each key, using reduceByKey will provide much
better performance.
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
>>> sorted(x.groupByKey().mapValues(len).collect())
[('a', 2), ('b', 1)]
>>> sorted(x.groupByKey().mapValues(list).collect())
[('a', [1, 1]), ('b', [1])]
"""
def createCombiner(x):
Expand Down
181 changes: 61 additions & 120 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,19 @@
try:
import psutil

process = None

def get_used_memory():
""" Return the used memory in MB """
process = psutil.Process(os.getpid())
global process
if process is None or process._pid != os.getpid():
process = psutil.Process(os.getpid())
if hasattr(process, "memory_info"):
info = process.memory_info()
else:
info = process.get_memory_info()
return info.rss >> 20

except ImportError:

def get_used_memory():
Expand All @@ -49,6 +54,7 @@ def get_used_memory():
for line in open('/proc/self/status'):
if line.startswith('VmRSS:'):
return int(line.split()[1]) >> 10

else:
warnings.warn("Please install psutil to have better "
"support with spilling")
Expand All @@ -57,6 +63,7 @@ def get_used_memory():
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
return rss >> 20
# TODO: support windows

return 0


Expand Down Expand Up @@ -146,7 +153,7 @@ def mergeCombiners(self, iterator):
d[k] = comb(d[k], v) if k in d else v

def iteritems(self):
""" Return the merged items ad iterator """
""" Return the merged items as iterator """
return self.data.iteritems()


Expand Down Expand Up @@ -210,18 +217,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
localdirs=None, scale=1, partitions=59, batch=1000):
Merger.__init__(self, aggregator)
self.memory_limit = memory_limit
# default serializer is only used for tests
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
# add compression
if isinstance(self.serializer, BatchedSerializer):
if not isinstance(self.serializer.serializer, CompressedSerializer):
self.serializer = BatchedSerializer(
CompressedSerializer(self.serializer.serializer),
self.serializer.batchSize)
else:
if not isinstance(self.serializer, CompressedSerializer):
self.serializer = CompressedSerializer(self.serializer)

self.serializer = self._compressed_serializer(serializer)
self.localdirs = localdirs or _get_local_dirs(str(id(self)))
# number of partitions when spill data into disks
self.partitions = partitions
Expand All @@ -238,6 +234,18 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
# randomize the hash of key, id(o) is the address of o (aligned by 8)
self._seed = id(self) + 7

def _compressed_serializer(self, serializer=None):
# default serializer is only used for tests
ser = serializer or PickleSerializer()
# add compression
if isinstance(ser, BatchedSerializer):
if not isinstance(ser.serializer, CompressedSerializer):
ser = BatchedSerializer(CompressedSerializer(ser.serializer), ser.batchSize)
else:
if not isinstance(ser, CompressedSerializer):
ser = BatchedSerializer(CompressedSerializer(ser), 1024)
return ser

def _get_spill_dir(self, n):
""" Choose one directory for spill by number n """
return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))
Expand Down Expand Up @@ -276,6 +284,9 @@ def _partition(self, key):
return hash((key, self._seed)) % self.partitions

def _object_size(self, obj):
""" How much of memory for this obj, assume that all the objects
consume similar bytes of memory
"""
return 1

def mergeCombiners(self, iterator, limit=None):
Expand Down Expand Up @@ -485,18 +496,18 @@ class SameKey(object):
This is used by GroupByKey.
>>> l = zip(range(2), range(2))
>>> list(SameKey(0, [1], iter(l), GroupByKey(iter([]))))
>>> list(SameKey(0, 1, iter(l), GroupByKey(iter([]))))
[1, 0]
>>> s = SameKey(0, [1], iter(l), GroupByKey(iter([])))
>>> s = SameKey(0, 1, iter(l), GroupByKey(iter([])))
>>> for i in range(2000):
... s.append(i)
>>> len(list(s))
2002
"""
def __init__(self, key, values, it, groupBy):
def __init__(self, key, value, iterator, groupBy):
self.key = key
self.values = values
self.it = it
self.values = [value]
self.iterator = iterator
self.groupBy = groupBy
self._file = None
self._ser = None
Expand All @@ -516,27 +527,22 @@ def next(self):
self._index = 0

if self._index >= len(self.values) and self._file is not None:
try:
self.values = next(self._ser.load_stream(self._file))
self._index = 0
except StopIteration:
self._file.close()
self._file = None
# load next chunk of values from disk
self.values = next(self._ser.load_stream(self._file))
self._index = 0

if self._index < len(self.values):
value = self.values[self._index]
self._index += 1
return value

if self.it is None:
raise StopIteration
key, value = next(self.iterator)
if key == self.key:
return value

key, value = self.it.next()
if key != self.key:
self.groupBy._next_item = (key, value)
self.it = None
raise StopIteration
return value
# push them back into groupBy
self.groupBy.next_item = (key, value)
raise StopIteration

def append(self, value):
if self._index is not None:
Expand All @@ -548,13 +554,14 @@ def append(self, value):
self._spill()

def _spill(self):
""" dump the values into disk """
if self._file is None:
dirs = _get_local_dirs("objects")
d = dirs[id(self) % len(dirs)]
if not os.path.exists(d):
os.makedirs(d)
p = os.path.join(d, str(id))
self._file = open(p, "w+")
self._file = open(p, "w+", 65536)
self._ser = CompressedSerializer(PickleSerializer())

self._ser.dump_stream([self.values], self._file)
Expand All @@ -567,32 +574,34 @@ class GroupByKey(object):
group a sorted iterator into [(k1, it1), (k2, it2), ...]
>>> k = [i/3 for i in range(6)]
>>> v = [[i] for i in range(6)]
>>> v = [i for i in range(6)]
>>> g = GroupByKey(iter(zip(k, v)))
>>> [(k, list(it)) for k, it in g]
[(0, [0, 1, 2]), (1, [3, 4, 5])]
"""
def __init__(self, it):
self.it = iter(it)
self._next_item = None
def __init__(self, iterator):
self.iterator = iterator
self.next_item = None
self.current = None

def __iter__(self):
return self

def next(self):
if self._next_item is None:
if self.next_item is None:
while True:
key, value = self.it.next()
key, value = next(self.iterator)
if self.current is None or key != self.current.key:
break
# the current key has not been visited.
self.current.append(value)
else:
key, value = self._next_item
self._next_item = None
# next key was popped while visiting current key
key, value = self.next_item
self.next_item = None

self.current = SameKey(key, [value], self.it, self)
return key, (v for vs in self.current for v in vs)
self.current = SameKey(key, value, self.iterator, self)
return key, self.current


class ExternalGroupBy(ExternalMerger):
Expand Down Expand Up @@ -624,7 +633,7 @@ def _spill(self):

if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
# data once, write them into different files, has no
# additional memory. It only called when the memory goes
# above limit at the first time.

Expand All @@ -636,12 +645,10 @@ def _spill(self):
# sort them before dumping into disks
self._sorted = len(self.data) < self.SORT_KEY_LIMIT
if self._sorted:
ser = self._flatted_serializer()
self.serializer = self._flatted_serializer()
for k in sorted(self.data.keys()):
v = self.data[k]
h = self._partition(k)
ser.dump_stream([(k, v)], streams[h])
self.serializer = ser
self.serializer.dump_stream([(k, self.data[k])], streams[h])
else:
for k, v in self.data.iteritems():
h = self._partition(k)
Expand All @@ -651,6 +658,7 @@ def _spill(self):
s.close()

self.data.clear()
# self.pdata is cached in `mergeValues` and `mergeCombiners`
self.pdata.extend([{} for i in range(self.partitions)])

else:
Expand All @@ -659,8 +667,9 @@ def _spill(self):
with open(p, "w") as f:
# dump items in batch
if self._sorted:
self.serializer.dump_stream(
sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0)), f)
# sort by key only (stable)
sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
self.serializer.dump_stream(sorted_items, f)
else:
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
Expand Down Expand Up @@ -706,75 +715,7 @@ def load_partition(j):
sorted_items = sorter.sorted(itertools.chain(*disk_items),
key=operator.itemgetter(0))

return GroupByKey(sorted_items)


class ExternalSorter(object):
"""
ExtenalSorter will divide the elements into chunks, sort them in
memory and dump them into disks, finally merge them back.
The spilling will only happen when the used memory goes above
the limit.
>>> sorter = ExternalSorter(1) # 1M
>>> import random
>>> l = range(1024)
>>> random.shuffle(l)
>>> sorted(l) == list(sorter.sorted(l))
True
>>> sorted(l) == list(sorter.sorted(l, key=lambda x: -x, reverse=True))
True
"""
def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
self._spilled_bytes = 0

def _get_path(self, n):
""" Choose one directory for spill by number n """
d = self.local_dirs[n % len(self.local_dirs)]
if not os.path.exists(d):
os.makedirs(d)
return os.path.join(d, str(n))

def sorted(self, iterator, key=None, reverse=False):
"""
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
# pick elements in batch
chunk = list(itertools.islice(iterator, batch))
current_chunk.extend(chunk)
if len(chunk) < batch:
break

if get_used_memory() > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
self._spilled_bytes += os.path.getsize(path)
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []

elif not chunks:
batch = min(batch * 2, 10000)

current_chunk.sort(key=key, reverse=reverse)
if not chunks:
return current_chunk

if current_chunk:
chunks.append(iter(current_chunk))

return heapq.merge(chunks, key=key, reverse=reverse)
return ((k, itertools.chain.from_iterable(vs)) for k, vs in GroupByKey(sorted_items))


if __name__ == "__main__":
Expand Down

0 comments on commit 2c1d05b

Please sign in to comment.