diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 361163cab4827..43b91459b56ea 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -25,7 +25,8 @@ import operator import pyspark.heapq3 as heapq -from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattedValuesSerializer +from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattedValuesSerializer, \ + CompressedSerializer try: import psutil @@ -204,8 +205,16 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, Merger.__init__(self, aggregator) self.memory_limit = memory_limit # default serializer is only used for tests - self.serializer = serializer or \ - BatchedSerializer(PickleSerializer(), 1024) + self.serializer = serializer or PickleSerializer() + # add compression + if isinstance(self.serializer, BatchedSerializer): + if not isinstance(self.serializer.serializer, CompressedSerializer): + self.serializer = BatchedSerializer( + CompressedSerializer(self.serializer.serializer), + self.serializer.batchSize) + else: + if not isinstance(self.serializer, CompressedSerializer): + self.serializer = CompressedSerializer(self.serializer) self.localdirs = localdirs or _get_local_dirs(str(id(self))) # number of partitions when spill data into disks self.partitions = partitions