diff --git a/core/pom.xml b/core/pom.xml
index bab50f5ce2888..6cb58dbd291c4 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -67,6 +67,10 @@
org.apache.commons
commons-lang3
+
+ org.apache.commons
+ commons-math3
+
com.google.code.findbugs
jsr305
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index aa03e9276fb34..2fdf45a0c8b8e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -379,8 +379,17 @@ abstract class RDD[T: ClassTag](
}.toArray
}
- def takeSample(withReplacement: Boolean, num: Int, seed: Long = Utils.random.nextLong): Array[T] =
- {
+ /**
+ * Return a fixed-size sampled subset of this RDD in an array
+ *
+ * @param withReplacement whether sampling is done with replacement
+ * @param num size of the returned sample
+ * @param seed seed for the random number generator
+ * @return sample of specified size in an array
+ */
+ def takeSample(withReplacement: Boolean,
+ num: Int,
+ seed: Long = Utils.random.nextLong): Array[T] = {
var fraction = 0.0
var total = 0
val multiplier = 3.0
@@ -402,10 +411,11 @@ abstract class RDD[T: ClassTag](
}
if (num > initialCount && !withReplacement) {
+ // special case not covered in computeFraction
total = maxSelected
fraction = multiplier * (maxSelected + 1) / initialCount
} else {
- fraction = multiplier * (num + 1) / initialCount
+ fraction = computeFraction(num, initialCount, withReplacement)
total = num
}
@@ -421,6 +431,22 @@ abstract class RDD[T: ClassTag](
Utils.randomizeInPlace(samples, rand).take(total)
}
+ private[spark] def computeFraction(num: Int, total: Long, withReplacement: Boolean) : Double = {
+ val fraction = num.toDouble / total
+ if (withReplacement) {
+ var numStDev = 5
+ if (num < 12) {
+ // special case to guarantee sample size for small s
+ numStDev = 9
+ }
+ fraction + numStDev * math.sqrt(fraction / total)
+ } else {
+ val delta = 0.00005
+ val gamma = - math.log(delta)/total
+ math.min(1, fraction + gamma + math.sqrt(gamma * gamma + 2 * gamma * fraction))
+ }
+ }
+
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 4dc8ada00a3e8..e53103755b279 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -70,7 +70,7 @@ class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
}
/**
- * Return a sampler with is the complement of the range specified of the current sampler.
+ * Return a sampler which is the complement of the range specified of the current sampler.
*/
def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index e686068f7a99a..5bdcb9bef6d62 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -22,6 +22,7 @@ import scala.reflect.ClassTag
import org.scalatest.FunSuite
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd._
@@ -494,56 +495,84 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sortedTopK === nums.sorted(ord).take(5))
}
+ test("computeFraction") {
+ // test that the computed fraction guarantees enough datapoints in the sample with a failure rate <= 0.0001
+ val data = new EmptyRDD[Int](sc)
+ val n = 100000
+
+ for (s <- 1 to 15) {
+ val frac = data.computeFraction(s, n, true)
+ val qpois = new PoissonDistribution(frac * n)
+ assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ for (s <- 1 to 15) {
+ val frac = data.computeFraction(s, n, false)
+ val qpois = new PoissonDistribution(frac * n)
+ assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ for (s <- List(1, 10, 100, 1000)) {
+ val frac = data.computeFraction(s, n, true)
+ val qpois = new PoissonDistribution(frac * n)
+ assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ for (s <- List(1, 10, 100, 1000)) {
+ val frac = data.computeFraction(s, n, false)
+ val qpois = new PoissonDistribution(frac * n)
+ assert(qpois.inverseCumulativeProbability(0.0001) >= s, "Computed fraction is too low")
+ }
+ }
+
test("takeSample") {
- val data = sc.parallelize(1 to 100, 2)
+ val n = 1000000
+ val data = sc.parallelize(1 to n, 2)
for (num <- List(5, 20, 100)) {
val sample = data.takeSample(withReplacement=false, num=num)
assert(sample.size === num) // Got exactly num elements
assert(sample.toSet.size === num) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=false, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
assert(sample.toSet.size === 20) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=false, 200, seed)
+ val sample = data.takeSample(withReplacement=false, 100, seed)
assert(sample.size === 100) // Got only 100 elements
assert(sample.toSet.size === 100) // Elements are distinct
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
val sample = data.takeSample(withReplacement=true, 20, seed)
assert(sample.size === 20) // Got exactly 20 elements
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
{
val sample = data.takeSample(withReplacement=true, num=20)
assert(sample.size === 20) // Got exactly 100 elements
assert(sample.toSet.size <= 20, "sampling with replacement returned all distinct elements")
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
{
- val sample = data.takeSample(withReplacement=true, num=100)
- assert(sample.size === 100) // Got exactly 100 elements
+ val sample = data.takeSample(withReplacement=true, num=n)
+ assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
- assert(sample.forall(x => 1 <= x && x <= 100), "elements not in [1, 100]")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
+ assert(sample.forall(x => 1 <= x && x <= n), "elements not in [1, 100]")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=true, 100, seed)
- assert(sample.size === 100) // Got exactly 100 elements
+ val sample = data.takeSample(withReplacement=true, n, seed)
+ assert(sample.size === n) // Got exactly 100 elements
// Chance of getting all distinct elements is astronomically low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
for (seed <- 1 to 5) {
- val sample = data.takeSample(withReplacement=true, 200, seed)
- assert(sample.size === 200) // Got exactly 200 elements
+ val sample = data.takeSample(withReplacement=true, 2*n, seed)
+ assert(sample.size === 2*n) // Got exactly 200 elements
// Chance of getting all distinct elements is still quite low, so test we got < 100
- assert(sample.toSet.size < 100, "sampling with replacement returned all distinct elements")
+ assert(sample.toSet.size < n, "sampling with replacement returned all distinct elements")
}
}
diff --git a/pom.xml b/pom.xml
index 7bf9f135fd340..01d6eef32be63 100644
--- a/pom.xml
+++ b/pom.xml
@@ -245,6 +245,11 @@
commons-codec
1.5
+
+ org.apache.commons
+ commons-math3
+ 3.2
+
com.google.code.findbugs
jsr305
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 8ef1e91f609fb..a6b6c26a49395 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -331,6 +331,7 @@ object SparkBuild extends Build {
libraryDependencies ++= Seq(
"com.google.guava" % "guava" % "14.0.1",
"org.apache.commons" % "commons-lang3" % "3.3.2",
+ "org.apache.commons" % "commons-math3" % "3.2",
"com.google.code.findbugs" % "jsr305" % "1.3.9",
"log4j" % "log4j" % "1.2.17",
"org.slf4j" % "slf4j-api" % slf4jVersion,
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 07578b8d937fc..b400404ad97c7 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -31,6 +31,7 @@
import warnings
import heapq
from random import Random
+from math import sqrt, log, min
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, pack_long
@@ -374,7 +375,7 @@ def takeSample(self, withReplacement, num, seed=None):
total = maxSelected
fraction = multiplier * (maxSelected + 1) / initialCount
else:
- fraction = multiplier * (num + 1) / initialCount
+ fraction = self._computeFraction(num, initialCount, withReplacement)
total = num
samples = self.sample(withReplacement, fraction, seed).collect()
@@ -390,6 +391,18 @@ def takeSample(self, withReplacement, num, seed=None):
sampler.shuffle(samples)
return samples[0:total]
+ def _computeFraction(self, num, total, withReplacement):
+ fraction = float(num)/total
+ if withReplacement:
+ numStDev = 5
+ if (num < 12):
+ numStDev = 9
+ return fraction + numStDev * sqrt(fraction/total)
+ else:
+ delta = 0.00005
+ gamma = - log(delta)/total
+ return min(1, fraction + gamma + sqrt(gamma * gamma + 2* gamma * fraction))
+
def union(self, other):
"""
Return the union of this RDD and another one.