Skip to content

Commit

Permalink
show spilled bytes in Python in web ui
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 9, 2014
1 parent f0f1ba0 commit fbe9029
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ private[spark] class PythonRDD(
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
t.taskMetrics)

// Overwrite task metrics
t.taskMetrics = Some(taskMetrics)
// FIXME: deepcopy the metrics, or they will be the same object in local mode
t.taskMetrics = Some(scala.util.Marshal.load[TaskMetrics](scala.util.Marshal.dump(taskMetrics)))
}
}
}
Expand Down
19 changes: 16 additions & 3 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def _get_local_dirs(sub):
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]


# global stats
MemoryBytesSpilled = 0L
DiskBytesSpilled = 0L


class Aggregator(object):

"""
Expand Down Expand Up @@ -313,10 +318,12 @@ def _spill(self):
It will dump the data in batch for better performance.
"""
global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)

used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
Expand All @@ -334,6 +341,7 @@ def _spill(self):
self.serializer.dump_stream([(k, v)], streams[h])

for s in streams:
DiskBytesSpilled += s.tell()
s.close()

self.data.clear()
Expand All @@ -346,9 +354,11 @@ def _spill(self):
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)

self.spills += 1
gc.collect() # release the memory as much as possible
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20

def iteritems(self):
""" Return all merged items as iterator """
Expand Down Expand Up @@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
self._spilled_bytes = 0

def _get_path(self, n):
""" Choose one directory for spill by number n """
Expand All @@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False):
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
Expand All @@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False):
if len(chunk) < batch:
break

if get_used_memory() > self.memory_limit:
used_memory = get_used_memory()
if used_memory > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
self._spilled_bytes += os.path.getsize(path)
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []
gc.collect()
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
DiskBytesSpilled += os.path.getsize(path)

elif not chunks:
batch = min(batch * 2, 10000)
Expand Down
15 changes: 8 additions & 7 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType
from pyspark import shuffle

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -136,17 +137,17 @@ def test_external_sort(self):
random.shuffle(l)
sorter = ExternalSorter(1)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
self.assertGreater(sorter._spilled_bytes, 0)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, 0)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertGreater(sorter._spilled_bytes, last)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertGreater(sorter._spilled_bytes, last)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
self.assertGreater(sorter._spilled_bytes, last)
self.assertGreater(shuffle.DiskBytesSpilled, last)

def test_external_sort_in_rdd(self):
conf = SparkConf().set("spark.python.worker.memory", "1m")
Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,11 @@
# copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
CompressedSerializer

from pyspark import shuffle

pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()
Expand All @@ -52,6 +51,11 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
return

# initialize global state
shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
_accumulatorRegistry.clear()

# fetch name of workdir
spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
Expand Down Expand Up @@ -92,6 +96,9 @@ def main(infile, outfile):
exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
write_long(shuffle.MemoryBytesSpilled, outfile)
write_long(shuffle.DiskBytesSpilled, outfile)

# Mark the beginning of the accumulators section of the output
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
write_int(len(_accumulatorRegistry), outfile)
Expand Down

0 comments on commit fbe9029

Please sign in to comment.