Skip to content

Commit

Permalink
Add s at the end and a couple other fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sryza committed Sep 8, 2014
1 parent 9b0ba99 commit 4c25a54
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -765,9 +765,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartition(partitioner: Partitioner): JavaPairRDD[K, V] = {
def repartitionAndSortWithinPartitions(partitioner: Partitioner): JavaPairRDD[K, V] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
repartitionAndSortWithinPartition(partitioner, comp)
repartitionAndSortWithinPartitions(partitioner, comp)
}

/**
Expand All @@ -777,11 +777,11 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartition(partitioner: Partitioner, comp: Comparator[K])
def repartitionAndSortWithinPartitions(partitioner: Partitioner, comp: Comparator[K])
: JavaPairRDD[K, V] = {
implicit val ordering = comp // Allow implicit conversion of Comparator to Ordering.
fromRDD(
new OrderedRDDFunctions[K, V, (K, V)](rdd).repartitionAndSortWithinPartition(partitioner))
new OrderedRDDFunctions[K, V, (K, V)](rdd).repartitionAndSortWithinPartitions(partitioner))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
* This is more efficient than calling `repartition` and then sorting within each partition
* because it can push the sorting down into the shuffle machinery.
*/
def repartitionAndSortWithinPartition(partitioner: Partitioner)
: RDD[(K, V)] = {
def repartitionAndSortWithinPartitions(partitioner: Partitioner): RDD[(K, V)] = {
new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
def getPartition(key: Any): Int = key.asInstanceOf[Int] % 2
}

val repartitioned = data.repartitionAndSortWithinPartition(partitioner)
val repartitioned = data.repartitionAndSortWithinPartitions(partitioner)
val partitions = repartitioned.glom().collect()
assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6)))
assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8)))
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,14 +520,14 @@ def __add__(self, other):
raise TypeError
return self.union(other)

def repartitionAndSortWithinPartition(self, ascending=True, numPartitions=None,
partitionFunc=portable_hash, keyfunc=lambda x: x):
def repartitionAndSortWithinPartitions(self, numPartitions=None, partitionFunc=portable_hash,
ascending=True, keyfunc=lambda x: x):
"""
Repartition the RDD according to the given partitioner and, within each resulting partition,
sort records by their keys.
>>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
>>> rdd2 = rdd.repartitionAndSortWithinPartition(True, lambda x: x % 2, 2)
>>> rdd2 = rdd.repartitionAndSortWithinPartitions(True, lambda x: x % 2, 2)
>>> rdd2.glom().collect()
[[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
"""
Expand Down
42 changes: 2 additions & 40 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -425,22 +424,6 @@ def test_zip_with_different_number_of_items(self):
self.assertEquals(a.count(), b.count())
self.assertRaises(Exception, lambda: a.zip(b).count())

def test_count_approx_distinct(self):
rdd = self.sc.parallelize(range(1000))
self.assertTrue(950 < rdd.countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.map(float).countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.map(str).countApproxDistinct(0.04) < 1050)
self.assertTrue(950 < rdd.map(lambda x: (x, -x)).countApproxDistinct(0.04) < 1050)

rdd = self.sc.parallelize([i % 20 for i in range(1000)], 7)
self.assertTrue(18 < rdd.countApproxDistinct() < 22)
self.assertTrue(18 < rdd.map(float).countApproxDistinct() < 22)
self.assertTrue(18 < rdd.map(str).countApproxDistinct() < 22)
self.assertTrue(18 < rdd.map(lambda x: (x, -x)).countApproxDistinct() < 22)

self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.00000001))
self.assertRaises(ValueError, lambda: rdd.countApproxDistinct(0.5))

def test_histogram(self):
# empty
rdd = self.sc.parallelize([])
Expand Down Expand Up @@ -545,36 +528,15 @@ def test_histogram(self):
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
self.assertRaises(TypeError, lambda: rdd.histogram(2))

def test_repartitionAndSortWithinPartition(self):
def test_repartitionAndSortWithinPartitions(self):
rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)

repartitioned = rdd.repartitionAndSortWithinPartition(True, 2, lambda key: key % 2)
repartitioned = rdd.repartitionAndSortWithinPartitions(2, lambda key: key % 2)
partitions = repartitioned.glom().collect()
self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)])
self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])


class TestSQL(PySparkTestCase):

def setUp(self):
PySparkTestCase.setUp(self)
self.sqlCtx = SQLContext(self.sc)

def test_udf(self):
self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType())
[row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect()
self.assertEqual(row[0], 5)

def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
self.sqlCtx.registerFunction("MYUDF", lambda x: foo.value[x] if x else '')
[res] = self.sqlCtx.sql("SELECT MYUDF('c')").collect()
self.assertEqual("abc", res[0])
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])


class TestIO(PySparkTestCase):

def test_stdout_redirection(self):
Expand Down

0 comments on commit 4c25a54

Please sign in to comment.