Skip to content

Commit

Permalink
[SPARK-3463] [PySpark] aggregate and show spilled bytes in Python
Browse files Browse the repository at this point in the history
Aggregate the number of bytes spilled into disks during aggregation or sorting, show them in Web UI.

![spilled](https://cloud.githubusercontent.com/assets/40902/4209758/4b995562-386d-11e4-97c1-8e838ee1d4e3.png)

This patch is blocked by SPARK-3465. (It includes a fix for that).

Author: Davies Liu <davies.liu@gmail.com>

Closes #2336 from davies/metrics and squashes the following commits:

e37df38 [Davies Liu] remove outdated comments
1245eb7 [Davies Liu] remove the temporary fix
ebd2f43 [Davies Liu] Merge branch 'master' into metrics
7e4ad04 [Davies Liu] Merge branch 'master' into metrics
fbe9029 [Davies Liu] show spilled bytes in Python in web ui
  • Loading branch information
davies authored and JoshRosen committed Sep 14, 2014
1 parent 2aea0da commit 4e3fbe8
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ private[spark] class PythonRDD(
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
Expand Down
19 changes: 16 additions & 3 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def _get_local_dirs(sub):
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]


# global stats
MemoryBytesSpilled = 0L
DiskBytesSpilled = 0L


class Aggregator(object):

"""
Expand Down Expand Up @@ -313,10 +318,12 @@ def _spill(self):
It will dump the data in batch for better performance.
"""
global MemoryBytesSpilled, DiskBytesSpilled
path = self._get_spill_dir(self.spills)
if not os.path.exists(path):
os.makedirs(path)

used_memory = get_used_memory()
if not self.pdata:
# The data has not been partitioned, it will iterator the
# dataset once, write them into different files, has no
Expand All @@ -334,6 +341,7 @@ def _spill(self):
self.serializer.dump_stream([(k, v)], streams[h])

for s in streams:
DiskBytesSpilled += s.tell()
s.close()

self.data.clear()
Expand All @@ -346,9 +354,11 @@ def _spill(self):
# dump items in batch
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
self.pdata[i].clear()
DiskBytesSpilled += os.path.getsize(p)

self.spills += 1
gc.collect() # release the memory as much as possible
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20

def iteritems(self):
""" Return all merged items as iterator """
Expand Down Expand Up @@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
self._spilled_bytes = 0

def _get_path(self, n):
""" Choose one directory for spill by number n """
Expand All @@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False):
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
Expand All @@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False):
if len(chunk) < batch:
break

if get_used_memory() > self.memory_limit:
used_memory = get_used_memory()
if used_memory > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
self._spilled_bytes += os.path.getsize(path)
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []
gc.collect()
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
DiskBytesSpilled += os.path.getsize(path)

elif not chunks:
batch = min(batch * 2, 10000)
Expand Down
15 changes: 8 additions & 7 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType
from pyspark import shuffle

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -138,17 +139,17 @@ def test_external_sort(self):
random.shuffle(l)
sorter = ExternalSorter(1)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
self.assertGreater(sorter._spilled_bytes, 0)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, 0)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertGreater(sorter._spilled_bytes, last)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertGreater(sorter._spilled_bytes, last)
last = sorter._spilled_bytes
self.assertGreater(shuffle.DiskBytesSpilled, last)
last = shuffle.DiskBytesSpilled
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
self.assertGreater(sorter._spilled_bytes, last)
self.assertGreater(shuffle.DiskBytesSpilled, last)

def test_external_sort_in_rdd(self):
conf = SparkConf().set("spark.python.worker.memory", "1m")
Expand Down
14 changes: 10 additions & 4 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@
import time
import socket
import traceback
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.

from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
CompressedSerializer

from pyspark import shuffle

pickleSer = PickleSerializer()
utf8_deserializer = UTF8Deserializer()
Expand All @@ -52,6 +50,11 @@ def main(infile, outfile):
if split_index == -1: # for unit tests
return

# initialize global state
shuffle.MemoryBytesSpilled = 0
shuffle.DiskBytesSpilled = 0
_accumulatorRegistry.clear()

# fetch name of workdir
spark_files_dir = utf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
Expand Down Expand Up @@ -97,6 +100,9 @@ def main(infile, outfile):
exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
write_long(shuffle.MemoryBytesSpilled, outfile)
write_long(shuffle.DiskBytesSpilled, outfile)

# Mark the beginning of the accumulators section of the output
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
write_int(len(_accumulatorRegistry), outfile)
Expand Down

0 comments on commit 4e3fbe8

Please sign in to comment.