Skip to content

Commit

Permalink
[SPARK-22465][Core][WIP] Add a safety-check to RDD defaultPartitioner
Browse files Browse the repository at this point in the history
 that ignores existing Partitioners, if they are more than a single order of magnitude smaller than the max number of upstream partitions
  • Loading branch information
sujithjay committed Dec 16, 2017
1 parent 1e44dd0 commit 176270b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 2 deletions.
17 changes: 15 additions & 2 deletions core/src/main/scala/org/apache/spark/Partitioner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream}

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.math.log10
import scala.reflect.ClassTag
import scala.util.hashing.byteswap32

Expand All @@ -42,7 +43,9 @@ object Partitioner {
/**
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
*
* If any of the RDDs already has a partitioner, choose that one.
* If any of the RDDs already has a partitioner, and the number of partitions of the
* partitioner is either greater than or is less than and within a single order of
* magnitude of the max number of upstream partitions, choose that one.
*
* Otherwise, we use a default HashPartitioner. For the number of partitions, if
* spark.default.parallelism is set, then we'll use the value from SparkContext
Expand All @@ -57,7 +60,7 @@ object Partitioner {
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
val rdds = (Seq(rdd) ++ others)
val hasPartitioner = rdds.filter(_.partitioner.exists(_.numPartitions > 0))
if (hasPartitioner.nonEmpty) {
if (hasPartitioner.nonEmpty && isEligiblePartitioner(hasPartitioner.maxBy(_.partitions.length), rdds)) {
hasPartitioner.maxBy(_.partitions.length).partitioner.get
} else {
if (rdd.context.conf.contains("spark.default.parallelism")) {
Expand All @@ -67,6 +70,16 @@ object Partitioner {
}
}
}

/**
* Returns true if the number of partitions of the RDD is either greater than or is
* less than and within a single order of magnitude of the max number of upstream partitions;
* otherwise, returns false
*/
private def isEligiblePartitioner(hasMaxPartitioner: RDD[_], rdds: Seq[RDD[_]]): Boolean = {
val maxPartitions = rdds.map(_.partitions.length).max
log10(maxPartitions).floor - log10(hasMaxPartitioner.getNumPartitions).floor < 1
}
}

/**
Expand Down
21 changes: 21 additions & 0 deletions core/src/test/scala/org/apache/spark/PartitioningSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,27 @@ class PartitioningSuite extends SparkFunSuite with SharedSparkContext with Priva
val partitioner = new RangePartitioner(22, rdd)
assert(partitioner.numPartitions === 3)
}

test("defaultPartitioner") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 150)
val rdd2 = sc
.parallelize(Array((1, 2), (2, 3), (2, 4), (3, 4)))
.partitionBy(new HashPartitioner(10))
val rdd3 = sc
.parallelize(Array((1, 6), (7, 8), (3, 10), (5, 12), (13, 14)))
.partitionBy(new HashPartitioner(100))

val partitioner1 = Partitioner.defaultPartitioner(rdd1, rdd2)
val partitioner2 = Partitioner.defaultPartitioner(rdd2, rdd3)
val partitioner3 = Partitioner.defaultPartitioner(rdd3, rdd1)
val partitioner4 = Partitioner.defaultPartitioner(rdd1, rdd2, rdd3)

assert(partitioner1.numPartitions == rdd1.getNumPartitions)
assert(partitioner2.numPartitions == rdd3.getNumPartitions)
assert(partitioner3.numPartitions == rdd3.getNumPartitions)
assert(partitioner4.numPartitions == rdd3.getNumPartitions)

}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,28 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
assert(joined.size > 0)
}

// See SPARK-22465
test("cogroup between multiple RDD " +
"with an order of magnitude difference in number of partitions") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 1000)
val rdd2 = sc
.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
.partitionBy(new HashPartitioner(10))
val joined = rdd1.cogroup(rdd2)
assert(joined.getNumPartitions == rdd1.getNumPartitions)
}

// See SPARK-22465
test("cogroup between multiple RDD" +
" with number of partitions similar in order of magnitude") {
val rdd1 = sc.parallelize((1 to 1000).map(x => (x, x)), 20)
val rdd2 = sc
.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
.partitionBy(new HashPartitioner(10))
val joined = rdd1.cogroup(rdd2)
assert(joined.getNumPartitions == rdd2.getNumPartitions)
}

test("rightOuterJoin") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
Expand Down

0 comments on commit 176270b

Please sign in to comment.