Skip to content

Commit

Permalink
minor fixes and Java API.
Browse files Browse the repository at this point in the history
punting on python for now. moved aggregateWithContext out of RDD
  • Loading branch information
dorx committed Jul 3, 2014
1 parent 4ad516b commit 254e03c
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 80 deletions.
34 changes: 33 additions & 1 deletion core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.api.java

import java.util.{Comparator, List => JList}
import java.util.{Comparator, List => JList, Map => JMap}
import java.lang.{Iterable => JIterable}

import scala.collection.JavaConversions._
Expand Down Expand Up @@ -129,6 +129,38 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
def sample(withReplacement: Boolean, fraction: Double, seed: Long): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))

/**
* Return a subset of this RDD sampled by key (via stratified sampling).
*/
def sampleByKey(withReplacement: Boolean,
fractions: JMap[K, Double],
exact: Boolean,
seed: Long): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions, exact, seed))


/**
* Return a subset of this RDD sampled by key (via stratified sampling).
*/
def sampleByKey(withReplacement: Boolean,
fractions: JMap[K, Double],
exact: Boolean): JavaPairRDD[K, V] =
sampleByKey(withReplacement, fractions, exact, Utils.random.nextLong)

/**
* Return a subset of this RDD sampled by key (via stratified sampling).
*/
def sampleByKey(withReplacement: Boolean,
fractions: JMap[K, Double],
seed: Long): JavaPairRDD[K, V] =
sampleByKey(withReplacement, fractions, true, seed)

/**
* Return a subset of this RDD sampled by key (via stratified sampling).
*/
def sampleByKey(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] =
sampleByKey(withReplacement, fractions, true, Utils.random.nextLong)

/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
Expand Down
30 changes: 14 additions & 16 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,10 @@ package org.apache.spark.rdd

import java.nio.ByteBuffer
import java.text.SimpleDateFormat
import java.util.Date
import java.util.{HashMap => JHashMap}
import java.util.{Date, HashMap => JHashMap}

import scala.collection.{Map, mutable}
import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag

Expand All @@ -34,16 +32,14 @@ import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob,
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}

import org.apache.spark._
import org.apache.spark.annotation.Experimental
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.SparkHadoopWriter
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -216,24 +212,26 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* need two additional passes.
*
* @param withReplacement whether to sample with or without replacement
* @param fractionByKey function mapping key to sampling rate
* @param fractions map of specific keys to sampling rates
* @param seed seed for the random number generator
* @param exact whether sample size needs to be exactly math.ceil(fraction * size) per stratum
* @return RDD containing the sampled subset
*/
def sampleByKey(withReplacement: Boolean,
fractionByKey: Map[K, Double],
seed: Long = Utils.random.nextLong,
exact: Boolean = true): RDD[(K, V)]= {
require(fractionByKey.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.")
fractions: Map[K, Double],
exact: Boolean = true,
seed: Long = Utils.random.nextLong): RDD[(K, V)]= {

require(fractions.forall({case(k, v) => v >= 0.0}), "Invalid sampling rates.")

if (withReplacement) {
val counts = if (exact) Some(this.countByKey()) else None
val samplingFunc =
StratifiedSampler.getPoissonSamplingFunction(self, fractionByKey, exact, counts, seed)
StratifiedSampler.getPoissonSamplingFunction(self, fractions, exact, counts, seed)
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
} else {
val samplingFunc =
StratifiedSampler.getBernoulliSamplingFunction(self, fractionByKey, exact, seed)
StratifiedSampler.getBernoulliSamplingFunction(self, fractions, exact, seed)
self.mapPartitionsWithIndex(samplingFunc, preservesPartitioning = true)
}
}
Expand Down
21 changes: 0 additions & 21 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -875,27 +875,6 @@ abstract class RDD[T: ClassTag](
jobResult
}

/**
* A version of {@link #aggregate()} that passes the TaskContext to the function that does
* aggregation for each partition.
*/
def aggregateWithContext[U: ClassTag](zeroValue: U)(seqOp: ((TaskContext, U), T) => U,
combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
// pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce
val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item))
val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) =>
(arg1._1, combOp(arg1._2, arg1._2))
val cleanSeqOp = sc.clean(paddedSeqOp)
val cleanCombOp = sc.clean(paddedcombOp)
val aggregatePartition = (tc: TaskContext, it: Iterator[T]) =>
(it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
sc.runJob(this, aggregatePartition, mergeResult)
jobResult
}

/**
* Return the number of elements in the RDD.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,43 @@

package org.apache.spark.util.random

import scala.collection.{Map, mutable}
import scala.collection.mutable.ArrayBuffer
import scala.collection.{mutable, Map}
import scala.reflect.ClassTag

import org.apache.commons.math3.random.RandomDataGenerator
import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.util.random.{PoissonBounds => PB}
import scala.Some
import org.apache.spark.{Logging, SparkContext, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
import org.apache.spark.util.random.{PoissonBounds => PB}

private[spark] object StratifiedSampler extends Logging {

/**
* A version of {@link #aggregate()} that passes the TaskContext to the function that does
* aggregation for each partition. This function avoids creating an extra depth in the RDD
* lineage, as opposed to using mapPartitionsWithId, which results in slightly improved run time.
*/
def aggregateWithContext[U: ClassTag, T: ClassTag](zeroValue: U)
(rdd: RDD[T],
seqOp: ((TaskContext, U), T) => U,
combOp: (U, U) => U): U = {
val sc: SparkContext = rdd.sparkContext
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
// pad seqOp and combOp with taskContext to conform to aggregate's signature in TraversableOnce
val paddedSeqOp = (arg1: (TaskContext, U), item: T) => (arg1._1, seqOp(arg1, item))
val paddedcombOp = (arg1: (TaskContext, U), arg2: (TaskContext, U)) =>
(arg1._1, combOp(arg1._2, arg1._2))
val cleanSeqOp = sc.clean(paddedSeqOp)
val cleanCombOp = sc.clean(paddedcombOp)
val aggregatePartition = (tc: TaskContext, it: Iterator[T]) =>
(it.aggregate(tc, zeroValue)(cleanSeqOp, cleanCombOp))._2
val mergeResult = (index: Int, taskResult: U) => jobResult = combOp(jobResult, taskResult)
sc.runJob(rdd, aggregatePartition, mergeResult)
jobResult
}

/**
* Returns the function used by aggregate to collect sampling statistics for each partition.
*/
Expand Down Expand Up @@ -153,7 +181,7 @@ private[spark] object StratifiedSampler extends Logging {
val seqOp = StratifiedSampler.getSeqOp[K,V](false, fractionByKey, None)
val combOp = StratifiedSampler.getCombOp[K]()
val zeroU = new Result[K](Map[K, Stratum](), seed = seed)
val finalResult = rdd.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
val finalResult = aggregateWithContext(zeroU)(rdd, seqOp, combOp).resultMap
samplingRateByKey = StratifiedSampler.computeThresholdByKey(finalResult, fractionByKey)
}
(idx: Int, iter: Iterator[(K, V)]) => {
Expand Down Expand Up @@ -183,7 +211,7 @@ private[spark] object StratifiedSampler extends Logging {
val seqOp = StratifiedSampler.getSeqOp[K,V](true, fractionByKey, counts)
val combOp = StratifiedSampler.getCombOp[K]()
val zeroU = new Result[K](Map[K, Stratum](), seed = seed)
val finalResult = rdd.aggregateWithContext(zeroU)(seqOp, combOp).resultMap
val finalResult = aggregateWithContext(zeroU)(rdd, seqOp, combOp).resultMap
val thresholdByKey = StratifiedSampler.computeThresholdByKey(finalResult, fractionByKey)
(idx: Int, iter: Iterator[(K, V)]) => {
val random = new RandomDataGenerator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
n: Long) = {
val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
math.ceil(count * samplingRate).toInt)
val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate)
val sample = stratifiedData.sampleByKey(false, fractionByKey, seed, exact)
val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
val sample = stratifiedData.sampleByKey(false, fractions, exact, seed)
val sampleCounts = sample.countByKey()
val takeSample = sample.collect()
assert(sampleCounts.forall({case(k,v) =>
Expand All @@ -124,8 +124,8 @@ class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
n: Long) = {
val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
math.ceil(count * samplingRate).toInt)
val fractionByKey = Map("1" -> samplingRate, "0" -> samplingRate)
val sample = stratifiedData.sampleByKey(true, fractionByKey, seed, exact)
val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
val sample = stratifiedData.sampleByKey(true, fractions, exact, seed)
val sampleCounts = sample.countByKey()
val takeSample = sample.collect()
assert(sampleCounts.forall({case(k,v) =>
Expand Down
32 changes: 0 additions & 32 deletions core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,38 +141,6 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}

test("aggregateWithContext") {
val data = Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))
val numPartitions = 2
val pairs = sc.makeRDD(data, numPartitions)
//determine the partitionId for each pair
type StringMap = HashMap[String, Int]
val partitions = pairs.collectPartitions()
val offSets = new StringMap
for (i <- 0 to numPartitions - 1) {
partitions(i).foreach({ case (k, v) => offSets.put(k, offSets.getOrElse(k, 0) + i)})
}
val emptyMap = new StringMap {
override def default(key: String): Int = 0
}
val mergeElement: ((TaskContext, StringMap), (String, Int)) => StringMap = (arg1, pair) => {
val stringMap = arg1._2
val tc = arg1._1
stringMap(pair._1) += pair._2 + tc.partitionId
stringMap
}
val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => {
for ((key, value) <- map2) {
map1(key) += value
}
map1
}
val result = pairs.aggregateWithContext(emptyMap)(mergeElement, mergeMaps)
val expected = Set(("a", 6), ("b", 2), ("c", 5))
.map({ case (k, v) => (k -> (offSets.getOrElse(k, 0) + v))})
assert(result.toSet === expected)
}

test("basic caching") {
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(rdd.collect().toList === List(1, 2, 3, 4))
Expand Down

0 comments on commit 254e03c

Please sign in to comment.