diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 7879d1b7679d9..ce8aec613d08b 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -20,6 +20,7 @@ from pyspark.context import SparkContext from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream +from pyspark.streaming.util import RDDFunction from py4j.java_collections import ListConverter from py4j.java_gateway import java_import @@ -212,11 +213,20 @@ def queueStream(self, queue, oneAtATime=True, default=None): def transform(self, dstreams, transformFunc): """ - Create a new DStream in which each RDD is generated by applying a function on RDDs of - the DStreams. The order of the JavaRDDs in the transform function parameter will be the - same as the order of corresponding DStreams in the list. + Create a new DStream in which each RDD is generated by applying + a function on RDDs of the DStreams. The order of the JavaRDDs in + the transform function parameter will be the same as the order + of corresponding DStreams in the list. """ - # TODO + jdstreams = ListConverter().convert([d._jdstream for d in dstreams], + SparkContext._gateway._gateway_client) + # change the final serializer to sc.serializer + jfunc = RDDFunction(self._sc, + lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + *[d._jrdd_deserializer for d in dstreams]) + + jdstream = self._jvm.PythonDStream.callTransform(self._jssc, jdstreams, jfunc) + return DStream(jdstream, self, self._sc.serializer) def union(self, *dstreams): """ diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 2653e75ccbc54..ae5be72952c76 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -132,7 +132,7 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash): return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc)) def foreach(self, func): - return self.foreachRDD(lambda rdd, _: rdd.foreach(func)) + return self.foreachRDD(lambda _, rdd: rdd.foreach(func)) def foreachRDD(self, func): """ @@ -142,7 +142,7 @@ def foreachRDD(self, func): This is an output operator, so this DStream will be registered as an output stream and there materialized. """ - jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, func, self._jrdd_deserializer) api = self._ssc._jvm.PythonDStream api.callForeachRDD(self._jdstream, jfunc) @@ -151,10 +151,10 @@ def pprint(self): Print the first ten elements of each RDD generated in this DStream. This is an output operator, so this DStream will be registered as an output stream and there materialized. """ - def takeAndPrint(rdd, time): + def takeAndPrint(timestamp, rdd): taken = rdd.take(11) print "-------------------------------------------" - print "Time: %s" % datetime.fromtimestamp(time / 1000.0) + print "Time: %s" % datetime.fromtimestamp(timestamp / 1000.0) print "-------------------------------------------" for record in taken[:10]: print record @@ -176,15 +176,15 @@ def take(self, n): """ rdds = [] - def take(rdd, _): - if rdd: + def take(_, rdd): + if rdd and len(rdds) < n: rdds.append(rdd) - if len(rdds) == n: - # FIXME: NPE in JVM - self._ssc.stop(False) self.foreachRDD(take) + self._ssc.start() - self._ssc.awaitTermination() + while len(rdds) < n: + time.sleep(0.01) + self._ssc.stop(False, True) return rdds def collect(self): @@ -195,7 +195,7 @@ def collect(self): """ result = [] - def get_output(rdd, time): + def get_output(_, rdd): r = rdd.collect() result.append(r) self.foreachRDD(get_output) @@ -317,7 +317,7 @@ def transform(self, func): Return a new DStream in which each RDD is generated by applying a function on each RDD of 'this' DStream. """ - return TransformedDStream(self, lambda a, t: func(a), True) + return TransformedDStream(self, lambda t, a: func(a), True) def transformWithTime(self, func): """ @@ -331,7 +331,7 @@ def transformWith(self, func, other, keepSerializer=False): Return a new DStream in which each RDD is generated by applying a function on each RDD of 'this' DStream and 'other' DStream. """ - jfunc = RDDFunction(self.ctx, lambda a, b, t: func(a, b), self._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, lambda t, a, b: func(a, b), self._jrdd_deserializer) dstream = self.ctx._jvm.PythonTransformed2DStream(self._jdstream.dstream(), other._jdstream.dstream(), jfunc) jrdd_serializer = self._jrdd_deserializer if keepSerializer else self.ctx.serializer @@ -549,14 +549,14 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None self._check_window(windowDuration, slideDuration) reduced = self.reduceByKey(func) - def reduceFunc(a, b, t): + def reduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) r = a.union(b).reduceByKey(func, numPartitions) if a else b if filterFunc: r = r.filter(filterFunc) return r - def invReduceFunc(a, b, t): + def invReduceFunc(t, a, b): b = b.reduceByKey(func, numPartitions) joined = a.leftOuterJoin(b, numPartitions) return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1) @@ -582,7 +582,7 @@ def updateStateByKey(self, updateFunc, numPartitions=None): @param updateFunc State update function ([(k, vs, s)] -> [(k, s)]). If `s` is None, then `k` will be eliminated. """ - def reduceFunc(a, b, t): + def reduceFunc(t, a, b): if a is None: g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) else: @@ -610,7 +610,7 @@ def __init__(self, prev, func, reuse=False): not prev.is_cached and not prev.is_checkpointed): prev_func = prev.func old_func = func - func = lambda rdd, t: old_func(prev_func(rdd, t), t) + func = lambda t, rdd: old_func(t, prev_func(t, rdd)) reuse = reuse and prev.reuse prev = prev.prev @@ -625,7 +625,7 @@ def _jdstream(self): return self._jdstream_val func = self.func - jfunc = RDDFunction(self.ctx, lambda a, _, t: func(a, t), self.prev._jrdd_deserializer) + jfunc = RDDFunction(self.ctx, func, self.prev._jrdd_deserializer) jdstream = self.ctx._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc, self.reuse).asJavaDStream() self._jdstream_val = jdstream diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index c547971cd7741..ecf88cce47beb 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -374,6 +374,19 @@ def test_union(self): expected = [i * 2 for i in input] self.assertEqual(expected, result[:3]) + def test_transform(self): + dstream1 = self.ssc.queueStream([[1]]) + dstream2 = self.ssc.queueStream([[2]]) + dstream3 = self.ssc.queueStream([[3]]) + + def func(rdds): + rdd1, rdd2, rdd3 = rdds + return rdd2.union(rdd3).union(rdd1) + + dstream = self.ssc.transform([dstream1, dstream2, dstream3], func) + + self.assertEqual([2, 3, 1], dstream.first().collect()) + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 885411ed63936..57791805e8f9f 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -22,21 +22,25 @@ class RDDFunction(object): """ This class is for py4j callback. """ - def __init__(self, ctx, func, deserializer, deserializer2=None): + def __init__(self, ctx, func, *deserializers): self.ctx = ctx self.func = func - self.deserializer = deserializer - self.deserializer2 = deserializer2 or deserializer + self.deserializers = deserializers + emptyRDD = getattr(self.ctx, "_emptyRDD", None) + if emptyRDD is None: + self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() + self.emptyRDD = emptyRDD - def call(self, jrdd, jrdd2, milliseconds): + def call(self, milliseconds, jrdds): try: - emptyRDD = getattr(self.ctx, "_emptyRDD", None) - if emptyRDD is None: - self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() + # extend deserializers with the first one + sers = self.deserializers + if len(sers) < len(jrdds): + sers += (sers[0],) * (len(jrdds) - len(sers)) - rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD - other = RDD(jrdd2, self.ctx, self.deserializer2) if jrdd2 else emptyRDD - r = self.func(rdd, other, milliseconds) + rdds = [RDD(jrdd, self.ctx, ser) if jrdd else self.emptyRDD + for jrdd, ser in zip(jrdds, sers)] + r = self.func(milliseconds, *rdds) if r: return r._jrdd except Exception: @@ -44,7 +48,7 @@ def call(self, jrdd, jrdd2, milliseconds): traceback.print_exc() def __repr__(self): - return "RDDFunction2(%s)" % (str(self.func)) + return "RDDFunction(%s)" % (str(self.func)) class Java: implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction'] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 5a8eef1372e23..ab6a6de074a80 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -413,7 +413,7 @@ class StreamingContext private[streaming] ( dstreams: Seq[DStream[_]], transformFunc: (Seq[RDD[_]], Time) => RDD[T] ): DStream[T] = { - new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) + new TransformedDStream[T](dstreams, (transformFunc)) } /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 16ac1b93b5f22..8ba8c0441ef35 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -17,11 +17,12 @@ package org.apache.spark.streaming.api.python -import java.util.{ArrayList => JArrayList} +import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.api.java._ -import org.apache.spark.api.java.function.{Function2 => JFunction2} import org.apache.spark.api.python._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -29,18 +30,19 @@ import org.apache.spark.streaming.{Interval, Duration, Time} import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.api.java._ + /** * Interface for Python callback function with three arguments */ trait PythonRDDFunction { - def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] + def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]] } -class RDDFunction(pfunc: PythonRDDFunction) extends Serializable { - - def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - apply(rdd, None, time) - } +/** + * Wrapper for PythonRDDFunction + */ +class RDDFunction(pfunc: PythonRDDFunction) + extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] with Serializable { def wrapRDD(rdd: Option[RDD[_]]): JavaRDD[_] = { if (rdd.isDefined) { @@ -50,14 +52,25 @@ class RDDFunction(pfunc: PythonRDDFunction) extends Serializable { } } - def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { - val r = pfunc.call(wrapRDD(rdd), wrapRDD(rdd2), time.milliseconds) - if (r != null) { - Some(r.rdd) + def some(jrdd: JavaRDD[Array[Byte]]): Option[RDD[Array[Byte]]] = { + if (jrdd != null) { + Some(jrdd.rdd) } else { None } } + + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + some(pfunc.call(time.milliseconds, List(wrapRDD(rdd)).asJava)) + } + + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + some(pfunc.call(time.milliseconds, List(wrapRDD(rdd), wrapRDD(rdd2)).asJava)) + } + + def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = { + pfunc.call(time.milliseconds, rdds) + } } private[python] @@ -74,8 +87,16 @@ private[spark] object PythonDStream { // helper function for DStream.foreachRDD(), // cannot be `foreachRDD`, it will confusing py4j - def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction): Unit = { - jdstream.dstream.foreachRDD((rdd, time) => pyfunc.call(rdd, null, time.milliseconds)) + def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pyfunc: PythonRDDFunction){ + val func = new RDDFunction(pyfunc) + jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time)) + } + + // helper function for ssc.transform() + def callTransform(ssc: JavaStreamingContext, jdsteams: JList[JavaDStream[_]], pyfunc: PythonRDDFunction) + :JavaDStream[Array[Byte]] = { + val func = new RDDFunction(pyfunc) + ssc.transform(jdsteams, func) } // convert list of RDD into queue of RDDs, for ssc.queueStream()