Skip to content

Commit

Permalink
use external sort in sortBy() and sortByKey()
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Aug 15, 2014
1 parent fd9fcd2 commit 55602ee
Show file tree
Hide file tree
Showing 6 changed files with 997 additions and 20 deletions.
1 change: 1 addition & 0 deletions .rat-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ sorttable.js
.*data
.*log
cloudpickle.py
heapq3.py
join.py
SparkExprTyper.scala
SparkILoop.scala
Expand Down
890 changes: 890 additions & 0 deletions python/pyspark/heapq3.py

Large diffs are not rendered by default.

23 changes: 13 additions & 10 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
get_used_memory
get_used_memory, ExternalSorter

from py4j.java_collections import ListConverter, MapConverter

Expand Down Expand Up @@ -587,14 +587,19 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()

spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
serializer = self._jrdd_deserializer

def sortPartition(iterator):
if spill:
sorted = ExternalSorter(memory * 0.9, serializer).sorted
return sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))

if numPartitions == 1:
if self.getNumPartitions() > 1:
self = self.coalesce(1)

def sort(iterator):
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))

return self.mapPartitions(sort)
return self.mapPartitions(sortPartition)

# first compute the boundary of each part via sampling: we want to partition
# the key-space into bins such that the bins have roughly the same
Expand All @@ -617,10 +622,8 @@ def rangePartitionFunc(k):
else:
return numPartitions - 1 - p

def mapFunc(iterator):
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))

return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
return (self.partitionBy(numPartitions, rangePartitionFunc)
.mapPartitions(sortPartition, True))

def sortBy(self, keyfunc, ascending=True, numPartitions=None):
"""
Expand Down
76 changes: 68 additions & 8 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import shutil
import warnings
import gc
import itertools
import operator

import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer

try:
Expand Down Expand Up @@ -54,6 +57,13 @@ def get_used_memory():
return 0


def _get_local_dirs(sub):
""" Get all the directories """
path = os.environ.get("SPARK_LOCAL_DIR", "/tmp")
dirs = path.split(",")
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]


class Aggregator(object):

"""
Expand Down Expand Up @@ -196,7 +206,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
# default serializer is only used for tests
self.serializer = serializer or \
BatchedSerializer(PickleSerializer(), 1024)
self.localdirs = localdirs or self._get_dirs()
self.localdirs = localdirs or _get_local_dirs(str(id(self)))
# number of partitions when spill data into disks
self.partitions = partitions
# check the memory after # of items merged
Expand All @@ -212,13 +222,6 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
# randomize the hash of key, id(o) is the address of o (aligned by 8)
self._seed = id(self) + 7

def _get_dirs(self):
""" Get all the directories """
path = os.environ.get("SPARK_LOCAL_DIR", "/tmp")
dirs = path.split(",")
return [os.path.join(d, "python", str(os.getpid()), str(id(self)))
for d in dirs]

def _get_spill_dir(self, n):
""" Choose one directory for spill by number n """
return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))
Expand Down Expand Up @@ -434,6 +437,63 @@ def _recursive_merged_items(self, start):
os.remove(os.path.join(path, str(i)))


class ExternalSorter(object):
"""
ExtenalSorter will divide the elements into chunks, sort them in
memory and dump them into disks, finally merge them back.
The spilling will only happen when the used memory goes above
the limit.
"""
def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)

def _get_path(self, n):
""" Choose one directory for spill by number n """
d = self.local_dirs[n % len(self.local_dirs)]
if not os.path.exists(d):
os.makedirs(d)
return os.path.join(d, str(n))

def sorted(self, iterator, key=None, reverse=False):
"""
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
# pick elements in batch
chunk = list(itertools.islice(iterator, batch))
current_chunk.extend(chunk)
if len(chunk) < batch:
break

if get_used_memory() > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []

elif not chunks:
batch = min(batch * 2, 10000)

current_chunk.sort(key=key, reverse=reverse)
if not chunks:
return current_chunk

if current_chunk:
chunks.append(iter(current_chunk))

return heapq.merge(chunks, key=key, reverse=reverse)


if __name__ == "__main__":
import doctest
doctest.testmod()
25 changes: 24 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import tempfile
import time
import zipfile
import random

if sys.version_info[:2] <= (2, 6):
import unittest2 as unittest
Expand All @@ -40,7 +41,7 @@
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -117,6 +118,28 @@ def test_huge_dataset(self):
m._cleanup()


class TestSorter(unittest.TestCase):
def test_in_memory_sort(self):
l = range(1024)
random.shuffle(l)
sorter = ExternalSorter(1024)
self.assertEquals(sorted(l), sorter.sorted(l))
self.assertEquals(sorted(l, reverse=True), sorter.sorted(l, reverse=True))
self.assertEquals(sorted(l, key=lambda x: -x), sorter.sorted(l, key=lambda x: -x))
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
sorter.sorted(l, key=lambda x: -x, reverse=True))

def test_external_sort(self):
l = range(100)
random.shuffle(l)
sorter = ExternalSorter(1)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))


class SerializationTestCase(unittest.TestCase):

def test_namedtuple(self):
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@

[pep8]
max-line-length=100
exclude=cloudpickle.py
exclude=cloudpickle.py,heapq3.py

0 comments on commit 55602ee

Please sign in to comment.