Skip to content

Commit

Permalink
Merge pull request alteryx#109 from pwendell/master
Browse files Browse the repository at this point in the history
Adding Java/Java Streaming versions of `repartition` with associated tests
  • Loading branch information
rxin committed Oct 25, 2013
2 parents 99ad4a6 + ad5f579 commit 4f2c943
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 6 deletions.
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD =
fromRDD(srdd.coalesce(numPartitions, shuffle))

/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))

/**
* Return an RDD with the elements from `this` that are not in `other`.
*
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] =
fromRDD(rdd.coalesce(numPartitions, shuffle))

/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))

/**
* Return a sampled subset of this RDD.
*/
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ JavaRDDLike[T, JavaRDD[T]] {
def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] =
rdd.coalesce(numPartitions, shuffle)

/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)

/**
* Return a sampled subset of this RDD.
*/
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Used to increase or decrease the level of parallelism in this RDD. This will use
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
Expand Down
21 changes: 21 additions & 0 deletions core/src/test/scala/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,27 @@ public Iterable<Integer> call(Iterator<Integer> iter) {
Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
}

@Test
public void repartition() {
// Shrinking number of partitions
JavaRDD<Integer> in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2);
JavaRDD<Integer> repartitioned1 = in1.repartition(4);
List<List<Integer>> result1 = repartitioned1.glom().collect();
Assert.assertEquals(4, result1.size());
for (List<Integer> l: result1) {
Assert.assertTrue(l.size() > 0);
}

// Growing number of partitions
JavaRDD<Integer> in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4);
JavaRDD<Integer> repartitioned2 = in2.repartition(2);
List<List<Integer>> result2 = repartitioned2.glom().collect();
Assert.assertEquals(2, result2.size());
for (List<Integer> l: result2) {
Assert.assertTrue(l.size() > 0);
}
}

@Test
public void persist() {
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ class JavaDStream[T](val dstream: DStream[T])(implicit val classManifest: ClassM
*/
def union(that: JavaDStream[T]): JavaDStream[T] =
dstream.union(that.dstream)

/**
* Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
* returned DStream has exactly numPartitions partitions.
*/
def repartition(numPartitions: Int): JavaDStream[T] = dstream.repartition(numPartitions)
}

object JavaDStream {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/** Persist the RDDs of this DStream with the given storage level */
def persist(storageLevel: StorageLevel): JavaPairDStream[K, V] = dstream.persist(storageLevel)

/**
* Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
* returned DStream has exactly numPartitions partitions.
*/
def repartition(numPartitions: Int): JavaPairDStream[K, V] = dstream.repartition(numPartitions)

/** Method that generates a RDD for the given Duration */
def compute(validTime: Time): JavaPairRDD[K, V] = {
dstream.compute(validTime) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,39 @@ public Boolean call(String s) throws Exception {
assertOrderInvariantEquals(expected, result);
}

@Test
public void testRepartitionMorePartitions() {
List<List<Integer>> inputData = Arrays.asList(
Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 2);
JavaDStream repartitioned = stream.repartition(4);
JavaTestUtils.attachTestOutputStream(repartitioned);
List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
Assert.assertEquals(2, result.size());
for (List<List<Integer>> rdd : result) {
Assert.assertEquals(4, rdd.size());
Assert.assertEquals(
10, rdd.get(0).size() + rdd.get(1).size() + rdd.get(2).size() + rdd.get(3).size());
}
}

@Test
public void testRepartitionFewerPartitions() {
List<List<Integer>> inputData = Arrays.asList(
Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));
JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 4);
JavaDStream repartitioned = stream.repartition(2);
JavaTestUtils.attachTestOutputStream(repartitioned);
List<List<List<Integer>>> result = JavaTestUtils.runStreamsWithPartitions(ssc, 2, 2);
Assert.assertEquals(2, result.size());
for (List<List<Integer>> rdd : result) {
Assert.assertEquals(2, rdd.size());
Assert.assertEquals(10, rdd.get(0).size() + rdd.get(1).size());
}
}

@Test
public void testGlom() {
List<List<String>> inputData = Arrays.asList(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ trait JavaTestBase extends TestSuiteBase {
* The stream will be derived from the supplied lists of Java objects.
**/
def attachTestInputStream[T](
ssc: JavaStreamingContext,
data: JList[JList[T]],
numPartitions: Int) = {
ssc: JavaStreamingContext,
data: JList[JList[T]],
numPartitions: Int) = {
val seqData = data.map(Seq(_:_*))

implicit val cm: ClassManifest[T] =
Expand All @@ -50,7 +50,7 @@ trait JavaTestBase extends TestSuiteBase {
* [[org.apache.spark.streaming.TestOutputStream]].
**/
def attachTestOutputStream[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T, R]](
dstream: JavaDStreamLike[T, This, R]) =
dstream: JavaDStreamLike[T, This, R]) =
{
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
Expand All @@ -62,16 +62,39 @@ trait JavaTestBase extends TestSuiteBase {
* Process all registered streams for a numBatches batches, failing if
* numExpectedOutput RDD's are not generated. Generated RDD's are collected
* and returned, represented as a list for each batch interval.
*
* Returns a list of items for each RDD.
*/
def runStreams[V](
ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
ssc: JavaStreamingContext, numBatches: Int, numExpectedOutput: Int): JList[JList[V]] = {
implicit val cm: ClassManifest[V] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
val res = runStreams[V](ssc.ssc, numBatches, numExpectedOutput)
val out = new ArrayList[JList[V]]()
res.map(entry => out.append(new ArrayList[V](entry)))
out
}

/**
* Process all registered streams for a numBatches batches, failing if
* numExpectedOutput RDD's are not generated. Generated RDD's are collected
* and returned, represented as a list for each batch interval.
*
* Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each
* representing one partition.
*/
def runStreamsWithPartitions[V](ssc: JavaStreamingContext, numBatches: Int,
numExpectedOutput: Int): JList[JList[JList[V]]] = {
implicit val cm: ClassManifest[V] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
val res = runStreamsWithPartitions[V](ssc.ssc, numBatches, numExpectedOutput)
val out = new ArrayList[JList[JList[V]]]()
res.map{entry =>
val lists = entry.map(new ArrayList[V](_))
out.append(new ArrayList[JList[V]](lists))
}
out
}
}

object JavaTestUtils extends JavaTestBase {
Expand Down

0 comments on commit 4f2c943

Please sign in to comment.