Skip to content

Commit

Permalink
Merge branch 'master' into groupby
Browse files Browse the repository at this point in the history
Conflicts:
	python/pyspark/shuffle.py
  • Loading branch information
davies committed Sep 14, 2014
2 parents fbc504a + 4e3fbe8 commit 8ef965e
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ private[spark] class PythonRDD(
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
Expand Down
28 changes: 25 additions & 3 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ def _get_local_dirs(sub):
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]


# global stats
MemoryBytesSpilled = 0L
DiskBytesSpilled = 0L


class Aggregator(object):

"""
Expand Down Expand Up @@ -318,10 +323,12 @@ def _spill(self):
It will dump the data in batch for better performance.
"""
global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)

used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
Expand All @@ -339,6 +346,7 @@ def _spill(self):
self.serializer.dump_stream([(k, v)], streams[h])

for s in streams:
DiskBytesSpilled += s.tell()
s.close()

self.data.clear()
Expand All @@ -351,9 +359,11 @@ def _spill(self):
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)

self.spills += 1
gc.collect() # release the memory as much as possible
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20

def iteritems(self):
""" Return all merged items as iterator """
Expand Down Expand Up @@ -454,7 +464,6 @@ def __init__(self, memory_limit, serializer=None):
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(
CompressedSerializer(PickleSerializer()), 1024)
self._spilled_bytes = 0

def _get_path(self, n):
""" Choose one directory for spill by number n """
Expand All @@ -468,6 +477,7 @@ def sorted(self, iterator, key=None, reverse=False):
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
Expand All @@ -478,17 +488,19 @@ def sorted(self, iterator, key=None, reverse=False):
if len(chunk) < batch:
break

if get_used_memory() > self.memory_limit:
used_memory = get_used_memory()
if used_memory > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
self._spilled_bytes += os.path.getsize(path)
chunks.append(self.serializer.load_stream(open(path)))
os.unlink(path) # data will be deleted after close
current_chunk = []
gc.collect()
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
DiskBytesSpilled += os.path.getsize(path)

elif not chunks:
batch = min(batch * 2, 10000)
Expand Down Expand Up @@ -569,6 +581,7 @@ def append(self, value):

def _spill(self):
""" dump the values into disk """
global MemoryBytesSpilled, DiskBytesSpilled
if self._file is None:
dirs = _get_local_dirs("objects")
d = dirs[id(self) % len(dirs)]
Expand All @@ -578,9 +591,13 @@ def _spill(self):
self._file = open(p, "w+", 65536)
self._ser = CompressedSerializer(PickleSerializer())

used_memory = get_used_memory()
pos = self._file.tell()
self._ser.dump_stream([self.values], self._file)
DiskBytesSpilled += self._file.tell() - pos
self.values = []
gc.collect()
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20


class GroupByKey(object):
Expand Down Expand Up @@ -641,10 +658,12 @@ def _spill(self):
"""
dump already partitioned data into disks.
"""
global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)

used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# data once, write them into different files, has no
Expand All @@ -669,6 +688,7 @@ def _spill(self):
self.serializer.dump_stream([(k, v)], streams[h])

for s in streams:
DiskBytesSpilled += s.tell()
s.close()

self.data.clear()
Expand All @@ -687,9 +707,11 @@ def _spill(self):
else:
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)

self.spills += 1
gc.collect() # release the memory as much as possible
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20

def _merged_items(self, index, limit=0):
size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
Expand Down
15 changes: 8 additions & 7 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType
from pyspark import shuffle

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -138,17 +139,17 @@ def test_external_sort(self):
random.shuffle(l)
sorter = ExternalSorter(1)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
self.assertGreater(sorter._spilled_bytes, 0)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, 0)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertGreater(sorter._spilled_bytes, last)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertGreater(sorter._spilled_bytes, last)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
self.assertGreater(sorter._spilled_bytes, last)
self.assertGreater(shuffle.DiskBytesSpilled, last)

def test_external_sort_in_rdd(self):
conf = SparkConf().set("spark.python.worker.memory", "1m")
Expand Down
14 changes: 10 additions & 4 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@
import time
import socket
import traceback
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.

from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
CompressedSerializer

from pyspark import shuffle

pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()
Expand All @@ -52,6 +50,11 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
return

# initialize global state
shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
_accumulatorRegistry.clear()

# fetch name of workdir
spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
Expand Down Expand Up @@ -97,6 +100,9 @@ def main(infile, outfile):
exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
write_long(shuffle.MemoryBytesSpilled, outfile)
write_long(shuffle.DiskBytesSpilled, outfile)

# Mark the beginning of the accumulators section of the output
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
write_int(len(_accumulatorRegistry), outfile)
Expand Down

0 comments on commit 8ef965e

Please sign in to comment.