Skip to content

Commit

Permalink
Added simpler version of updateStateByKey API with initialRDD and test.
Browse files Browse the repository at this point in the history
  • Loading branch information
soumitrak committed Nov 11, 2014
1 parent 9781135 commit 304f636
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -443,23 +443,6 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
scalaFunc
}

private def convertUpdateStateFunctionWithIterator[S]
(in: JFunction2[JList[V], Optional[S], Optional[S]]):
(Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)] = {
val scalaFunc: (Iterator[(K, Seq[V], Option[S])]) => Iterator[(K, S)] = (iterator) => {
iterator.flatMap { t =>
val list: JList[V] = t._2
val scalaState: Optional[S] = JavaUtils.optionToOptional(t._3)
val result: Optional[S] = in.apply(list, scalaState)
result.isPresent match {
case true => Some((t._1, result.get()))
case _ => None
}
}
}
scalaFunc
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
Expand Down Expand Up @@ -526,8 +509,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
initialRDD: JavaPairRDD[K, S]
): JavaPairDStream[K, S] = {
implicit val cm: ClassTag[S] = fakeClassTag
dstream.updateStateByKey(convertUpdateStateFunctionWithIterator(updateFunc),
partitioner, true, initialRDD)
dstream.updateStateByKey(convertUpdateStateFunction(updateFunc), partitioner, initialRDD)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,28 @@ class PairDStreamFunctions[K, V](self: DStream[(K,V)])
new StateDStream(self, ssc.sc.clean(updateFunc), partitioner, rememberPartitioner, None)
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of the key.
* org.apache.spark.Partitioner is used to control the partitioning of each RDD.
* @param updateFunc State update function. If `this` function returns None, then
* corresponding state key-value pair will be eliminated.
* @param partitioner Partitioner for controlling the partitioning of each RDD in the new
* DStream.
* @param initialRDD initial state value of each key.
* @tparam S State type
*/
def updateStateByKey[S: ClassTag](
updateFunc: (Seq[V], Option[S]) => Option[S],
partitioner: Partitioner,
initialRDD: RDD[(K, S)]
): DStream[(K, S)] = {
val newUpdateFunc = (iterator: Iterator[(K, Seq[V], Option[S])]) => {
iterator.flatMap(t => updateFunc(t._2, t._3).map(s => (t._1, s)))
}
updateStateByKey(newUpdateFunc, partitioner, true, initialRDD)
}

/**
* Return a new "state" DStream where the state for each key is updated by applying
* the given function on the previous state of the key and the new values of each key.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag](
val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => {
val i = iterator.map(t => {
val itr = t._2._2.iterator
val headOption = itr.hasNext match {
case true => Some(itr.next())
case false => None
}
val headOption = if(itr.hasNext) Some(itr.next) else None
(t._1, t._2._1.toSeq, headOption)
})
updateFuncLocal(i)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -806,15 +806,17 @@ public void testUnion() {
* Performs an order-invariant comparison of lists representing two RDD streams. This allows
* us to account for ordering variation within individual RDD's which occurs during windowing.
*/
public static <T extends Comparable<T>> void assertOrderInvariantEquals(
public static <T> void assertOrderInvariantEquals(
List<List<T>> expected, List<List<T>> actual) {
List<Set<T>> expectedSets = new ArrayList<Set<T>>();
for (List<T> list: expected) {
Collections.sort(list);
expectedSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
}
List<Set<T>> actualSets = new ArrayList<Set<T>>();
for (List<T> list: actual) {
Collections.sort(list);
actualSets.add(Collections.unmodifiableSet(new HashSet<T>(list)));
}
Assert.assertEquals(expected, actual);
Assert.assertEquals(expectedSets, actualSets);
}


Expand Down Expand Up @@ -1252,12 +1254,12 @@ public void testUpdateStateByKeyWithInitial() {
JavaPairRDD<String, Integer> initialRDD = JavaPairRDD.fromJavaRDD (tmpRDD);

List<List<Tuple2<String, Integer>>> expected = Arrays.asList(
Arrays.asList(new Tuple2<String, Integer>("new york", 7),
new Tuple2<String, Integer>("california", 5)),
Arrays.asList(new Tuple2<String, Integer>("new york", 11),
new Tuple2<String, Integer>("california", 15)),
Arrays.asList(new Tuple2<String, Integer>("new york", 11),
new Tuple2<String, Integer>("california", 15)));
Arrays.asList(new Tuple2<String, Integer>("california", 4),
new Tuple2<String, Integer>("new york", 5)),
Arrays.asList(new Tuple2<String, Integer>("california", 14),
new Tuple2<String, Integer>("new york", 9)),
Arrays.asList(new Tuple2<String, Integer>("california", 14),
new Tuple2<String, Integer>("new york", 9)));

JavaDStream<Tuple2<String, Integer>> stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
JavaPairDStream<String, Integer> pairStream = JavaPairDStream.fromJavaDStream(stream);
Expand All @@ -1279,7 +1281,7 @@ public Optional<Integer> call(List<Integer> values, Optional<Integer> state) {
JavaTestUtils.attachTestOutputStream(updated);
List<List<Tuple2<String, Integer>>> result = JavaTestUtils.runStreams(ssc, 3, 3);

Assert.assertEquals(expected, result);
assertOrderInvariantEquals(expected, result);
}

@SuppressWarnings("unchecked")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,41 @@ class BasicOperationsSuite extends TestSuiteBase {
testOperation(inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - simple with initial value RDD") {
val initial = Seq(("a", 1), ("c", 2))

val inputData =
Seq(
Seq("a"),
Seq("a", "b"),
Seq("a", "b", "c"),
Seq("a", "b"),
Seq("a"),
Seq()
)

val outputData =
Seq(
Seq(("a", 2), ("c", 2)),
Seq(("a", 3), ("b", 1), ("c", 2)),
Seq(("a", 4), ("b", 2), ("c", 3)),
Seq(("a", 5), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3)),
Seq(("a", 6), ("b", 3), ("c", 3))
)

val updateStateOperation = (s: DStream[String]) => {
val initialRDD = s.context.sparkContext.makeRDD(initial)
val updateFunc = (values: Seq[Int], state: Option[Int]) => {
Some(values.sum + state.getOrElse(0))
}
s.map(x => (x, 1)).updateStateByKey[Int](updateFunc,
new HashPartitioner (numInputPartitions), initialRDD)
}

testOperation(inputData, updateStateOperation, outputData, true)
}

test("updateStateByKey - with initial value RDD") {
val initial = Seq(("a", 1), ("c", 2))

Expand Down

0 comments on commit 304f636

Please sign in to comment.