diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 2672da36c1f50..94bebc310bad6 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -114,6 +114,9 @@ def __ne__(self, other): def __repr__(self): return "<%s object>" % self.__class__.__name__ + def __hash__(self): + return hash(str(self)) + class FramedSerializer(Serializer): diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 1c7cb5604e5cc..c4a1014ab9ab0 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -15,16 +15,51 @@ # limitations under the License. # -from pyspark.serializers import UTF8Deserializer +from pyspark import RDD +from pyspark.serializers import UTF8Deserializer, BatchedSerializer from pyspark.context import SparkContext +from pyspark.storagelevel import StorageLevel from pyspark.streaming.dstream import DStream -from pyspark.streaming.duration import Duration, Seconds +from pyspark.streaming.duration import Seconds from py4j.java_collections import ListConverter __all__ = ["StreamingContext"] +def _daemonize_callback_server(): + """ + Hack Py4J to daemonize callback server + """ + # TODO: create a patch for Py4J + import socket + import py4j.java_gateway + logger = py4j.java_gateway.logger + from py4j.java_gateway import Py4JNetworkError + from threading import Thread + + def start(self): + """Starts the CallbackServer. This method should be called by the + client instead of run().""" + self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, + 1) + try: + self.server_socket.bind((self.address, self.port)) + # self.port = self.server_socket.getsockname()[1] + except Exception: + msg = 'An error occurred while trying to start the callback server' + logger.exception(msg) + raise Py4JNetworkError(msg) + + # Maybe thread needs to be cleanup up? + self.thread = Thread(target=self.run) + self.thread.daemon = True + self.thread.start() + + py4j.java_gateway.CallbackServer.start = start + + class StreamingContext(object): """ Main entry point for Spark Streaming functionality. A StreamingContext represents the @@ -53,7 +88,9 @@ def _start_callback_server(self): gw = self._sc._gateway # getattr will fallback to JVM if "_callback_server" not in gw.__dict__: + _daemonize_callback_server() gw._start_callback_server(gw._python_proxy_port) + gw._python_proxy_port = gw._callback_server.port # update port with real port def _initialize_context(self, sc, duration): return self._jvm.JavaStreamingContext(sc._jsc, duration._jduration) @@ -92,26 +129,44 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): def remember(self, duration): """ - Set each DStreams in this context to remember RDDs it generated in the last given duration. - DStreams remember RDDs only for a limited duration of time and releases them for garbage - collection. This method allows the developer to specify how to long to remember the RDDs ( - if the developer wishes to query old data outside the DStream computation). - @param duration pyspark.streaming.duration.Duration object or seconds. - Minimum duration that each DStream should remember its RDDs + Set each DStreams in this context to remember RDDs it generated + in the last given duration. DStreams remember RDDs only for a + limited duration of time and releases them for garbage collection. + This method allows the developer to specify how to long to remember + the RDDs ( if the developer wishes to query old data outside the + DStream computation). + + @param duration Minimum duration (in seconds) that each DStream + should remember its RDDs """ if isinstance(duration, (int, long, float)): duration = Seconds(duration) self._jssc.remember(duration._jduration) - # TODO: add storageLevel - def socketTextStream(self, hostname, port): + def checkpoint(self, directory): + """ + Sets the context to periodically checkpoint the DStream operations for master + fault-tolerance. The graph will be checkpointed every batch interval. + + @param directory HDFS-compatible directory where the checkpoint data + will be reliably stored + """ + self._jssc.checkpoint(directory) + + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): """ Create an input from TCP source hostname:port. Data is received using a TCP socket and receive byte is interpreted as UTF8 encoded '\n' delimited lines. + + @param hostname Hostname to connect to for receiving data + @param port Port to connect to for receiving data + @param storageLevel Storage level to use for storing the received objects """ - return DStream(self._jssc.socketTextStream(hostname, port), self, UTF8Deserializer()) + jlevel = self._sc._getJavaStorageLevel(storageLevel) + return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self, + UTF8Deserializer()) def textFileStream(self, directory): """ @@ -122,14 +177,52 @@ def textFileStream(self, directory): """ return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer()) - def _makeStream(self, inputs, numSlices=None): + def _check_serialzers(self, rdds): + # make sure they have same serializer + if len(set(rdd._jrdd_deserializer for rdd in rdds)): + for i in range(len(rdds)): + # reset them to sc.serializer + rdds[i] = rdds[i].map(lambda x: x, preservesPartitioning=True) + + def queueStream(self, queue, oneAtATime=False, default=None): """ - This function is only for unittest. - It requires a list as input, and returns the i_th element at the i_th batch - under manual clock. + Create an input stream from an queue of RDDs or list. In each batch, + it will process either one or all of the RDDs returned by the queue. + + NOTE: changes to the queue after the stream is created will not be recognized. + @param queue Queue of RDDs + @tparam T Type of objects in the RDD """ - rdds = [self._sc.parallelize(input, numSlices) for input in inputs] + if queue and not isinstance(queue[0], RDD): + rdds = [self._sc.parallelize(input) for input in queue] + else: + rdds = queue + self._check_serialzers(rdds) jrdds = ListConverter().convert([r._jrdd for r in rdds], SparkContext._gateway._gateway_client) - jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds).asJavaDStream() - return DStream(jdstream, self, rdds[0]._jrdd_deserializer) + jdstream = self._jvm.PythonDataInputStream(self._jssc, jrdds, oneAtATime, + default and default._jrdd) + return DStream(jdstream.asJavaDStream(), self, rdds[0]._jrdd_deserializer) + + def transform(self, dstreams, transformFunc): + """ + Create a new DStream in which each RDD is generated by applying a function on RDDs of + the DStreams. The order of the JavaRDDs in the transform function parameter will be the + same as the order of corresponding DStreams in the list. + """ + # TODO + + def union(self, *dstreams): + """ + Create a unified DStream from multiple DStreams of the same + type and same slide duration. + """ + if not dstreams: + raise ValueError("should have at least one DStream to union") + if len(dstreams) == 1: + return dstreams[0] + self._check_serialzers(dstreams) + first = dstreams[0] + jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]], + SparkContext._gateway._gateway_client) + return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 27e1400b8ba0b..9dd3556327477 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -315,16 +315,16 @@ def repartitions(self, numPartitions): return self.transform(lambda rdd: rdd.repartition(numPartitions)) def union(self, other): - return self.transformWith(lambda a, b: a.union(b), other, True) + return self.transformWith(lambda a, b, t: a.union(b), other, True) def cogroup(self, other): - return self.transformWith(lambda a, b: a.cogroup(b), other) + return self.transformWith(lambda a, b, t: a.cogroup(b), other) def leftOuterJoin(self, other): - return self.transformWith(lambda a, b: a.leftOuterJion(b), other) + return self.transformWith(lambda a, b, t: a.leftOuterJion(b), other) def rightOuterJoin(self, other): - return self.transformWith(lambda a, b: a.rightOuterJoin(b), other) + return self.transformWith(lambda a, b, t: a.rightOuterJoin(b), other) def _jtime(self, milliseconds): return self.ctx._jvm.Time(milliseconds) diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 755ea224e56da..a585bbfa06f5b 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -40,27 +40,25 @@ def setUp(self): class_name = self.__class__.__name__ self.sc = SparkContext(appName=class_name) self.sc.setCheckpointDir("/tmp") + # TODO: decrease duration to speed up tests self.ssc = StreamingContext(self.sc, duration=Seconds(1)) def tearDown(self): self.ssc.stop() - self.sc.stop() @classmethod def tearDownClass(cls): # Make sure tp shutdown the callback server SparkContext._gateway._shutdown_callback_server() - def _test_func(self, input, func, expected, numSlices=None, sort=False): + def _test_func(self, input, func, expected, sort=False): """ - Start stream and return the result. @param input: dataset for the test. This should be list of lists. @param func: wrapped function. This function should return PythonDStream object. @param expected: expected output for this testcase. - @param numSlices: the number of slices in the rdd in the dstream. """ # Generate input stream with user-defined input. - input_stream = self.ssc._makeStream(input, numSlices) + input_stream = self.ssc.queueStream(input) # Apply test function to stream. stream = func(input_stream) result = stream.collect() @@ -121,7 +119,7 @@ def func(dstream): def test_count(self): """Basic operation test for DStream.count.""" - input = [range(1, 5), range(1, 10), range(1, 20)] + input = [range(5), range(10), range(20)] def func(dstream): return dstream.count() @@ -178,24 +176,24 @@ def func(dstream): def test_glom(self): """Basic operation test for DStream.glom.""" input = [range(1, 5), range(5, 9), range(9, 13)] - numSlices = 2 + rdds = [self.sc.parallelize(r, 2) for r in input] def func(dstream): return dstream.glom() expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]] - self._test_func(input, func, expected, numSlices) + self._test_func(rdds, func, expected) def test_mapPartitions(self): """Basic operation test for DStream.mapPartitions.""" input = [range(1, 5), range(5, 9), range(9, 13)] - numSlices = 2 + rdds = [self.sc.parallelize(r, 2) for r in input] def func(dstream): def f(iterator): yield sum(iterator) return dstream.mapPartitions(f) expected = [[3, 7], [11, 15], [19, 23]] - self._test_func(input, func, expected, numSlices) + self._test_func(rdds, func, expected) def test_countByValue(self): """Basic operation test for DStream.countByValue.""" @@ -236,14 +234,14 @@ def add(a, b): self._test_func(input, func, expected, sort=True) def test_union(self): - input1 = [range(3), range(5), range(1)] + input1 = [range(3), range(5), range(1), range(6)] input2 = [range(3, 6), range(5, 6), range(1, 6)] - d1 = self.ssc._makeStream(input1) - d2 = self.ssc._makeStream(input2) + d1 = self.ssc.queueStream(input1) + d2 = self.ssc.queueStream(input2) d = d1.union(d2) result = d.collect() - expected = [range(6), range(6), range(6)] + expected = [range(6), range(6), range(6), range(6)] self.ssc.start() start_time = time.time() @@ -317,33 +315,49 @@ def func(dstream): class TestStreamingContext(unittest.TestCase): def setUp(self): self.sc = SparkContext(master="local[2]", appName=self.__class__.__name__) - self.batachDuration = Seconds(1) - self.ssc = None + self.batachDuration = Seconds(0.1) + self.ssc = StreamingContext(self.sc, self.batachDuration) def tearDown(self): - if self.ssc is not None: - self.ssc.stop() + self.ssc.stop() self.sc.stop() def test_stop_only_streaming_context(self): - self.ssc = StreamingContext(self.sc, self.batachDuration) - self._addInputStream(self.ssc) + self._addInputStream() self.ssc.start() self.ssc.stop(False) self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5) def test_stop_multiple_times(self): - self.ssc = StreamingContext(self.sc, self.batachDuration) - self._addInputStream(self.ssc) + self._addInputStream() self.ssc.start() self.ssc.stop() self.ssc.stop() - def _addInputStream(self, s): + def _addInputStream(self): # Make sure each length of input is over 3 inputs = map(lambda x: range(1, x), range(5, 101)) - stream = s._makeStream(inputs) + stream = self.ssc.queueStream(inputs) stream.collect() + def test_queueStream(self): + input = [range(i) for i in range(3)] + dstream = self.ssc.queueStream(input) + result = dstream.collect() + self.ssc.start() + time.sleep(1) + self.assertEqual(input, result) + + def test_union(self): + input = [range(i) for i in range(3)] + dstream = self.ssc.queueStream(input) + dstream2 = self.ssc.union(dstream, dstream) + result = dstream.collect() + self.ssc.start() + time.sleep(1) + expected = [i * 2 for i in input] + self.assertEqual(input, result) + + if __name__ == "__main__": unittest.main() diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index fdbd01ec1766d..feff1b3889c49 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -30,7 +30,10 @@ def __init__(self, ctx, func, jrdd_deserializer): def call(self, jrdd, milliseconds): try: - rdd = RDD(jrdd, self.ctx, self.deserializer) + emptyRDD = getattr(self.ctx, "_emptyRDD", None) + if emptyRDD is None: + self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() + rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else emptyRDD r = self.func(rdd, milliseconds) if r: return r._jrdd @@ -58,8 +61,12 @@ def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None): def call(self, jrdd, jrdd2, milliseconds): try: - rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else None - other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else None + emptyRDD = getattr(self.ctx, "_emptyRDD", None) + if emptyRDD is None: + self.ctx._emptyRDD = emptyRDD = self.ctx.parallelize([]).cache() + + rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else emptyRDD + other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else emptyRDD r = self.func(rdd, other, milliseconds) if r: return r._jrdd diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index b904e273eb438..828a620e4c08f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -39,6 +39,22 @@ trait PythonRDDFunction { def call(rdd: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] } +class RDDFunction(pfunc: PythonRDDFunction) { + def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val jrdd = if (rdd.isDefined) { + JavaRDD.fromRDD(rdd.get) + } else { + null + } + val r = pfunc.call(jrdd, time.milliseconds) + if (r != null) { + Some(r.rdd) + } else { + None + } + } +} + /** * Interface for Python callback function with three arguments */ @@ -46,33 +62,61 @@ trait PythonRDDFunction2 { def call(rdd: JavaRDD[_], rdd2: JavaRDD[_], time: Long): JavaRDD[Array[Byte]] } +class RDDFunction2(pfunc: PythonRDDFunction2) { + def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = { + val jrdd = if (rdd.isDefined) { + JavaRDD.fromRDD(rdd.get) + } else { + null + } + val jrdd2 = if (rdd2.isDefined) { + JavaRDD.fromRDD(rdd2.get) + } else { + null + } + val r = pfunc.call(jrdd, jrdd2, time.milliseconds) + if (r != null) { + Some(r.rdd) + } else { + None + } + } +} + +private[python] +abstract class PythonDStream(parent: DStream[_]) extends DStream[Array[Byte]] (parent.ssc) { + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + val asJavaDStream = JavaDStream.fromDStream(this) +} + /** * Transformed DStream in Python. * * If the result RDD is PythonRDD, then it will cache it as an template for future use, * this can reduce the Python callbacks. */ -private[spark] class PythonTransformedDStream (parent: DStream[_], func: PythonRDDFunction, +private[spark] class PythonTransformedDStream (parent: DStream[_], pfunc: PythonRDDFunction, var reuse: Boolean = false) - extends DStream[Array[Byte]] (parent.ssc) { + extends PythonDStream(parent) { + val func = new RDDFunction(pfunc) var lastResult: PythonRDD = _ - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val rdd1 = parent.getOrCompute(validTime).getOrElse(null) - if (rdd1 == null) { + val rdd1 = parent.getOrCompute(validTime) + if (rdd1.isEmpty) { return None } if (reuse && lastResult != null) { - Some(lastResult.copyTo(rdd1)) + Some(lastResult.copyTo(rdd1.get)) } else { - val r = func.call(JavaRDD.fromRDD(rdd1), validTime.milliseconds).rdd - if (reuse && lastResult == null) { - r match { + val r = func(rdd1, validTime) + if (reuse && r.isDefined && lastResult == null) { + r.get match { case rdd: PythonRDD => if (rdd.parent(0) == rdd1) { // only one PythonRDD @@ -83,46 +127,65 @@ private[spark] class PythonTransformedDStream (parent: DStream[_], func: PythonR } } } - Some(r) + r } } - - val asJavaDStream = JavaDStream.fromDStream(this) } /** * Transformed from two DStreams in Python. */ -private[spark] class PythonTransformed2DStream (parent: DStream[_], parent2: DStream[_], func: PythonRDDFunction2) +private[spark] +class PythonTransformed2DStream(parent: DStream[_], parent2: DStream[_], + pfunc: PythonRDDFunction2) extends DStream[Array[Byte]] (parent.ssc) { - override def dependencies = List(parent, parent2) + val func = new RDDFunction2(pfunc) override def slideDuration: Duration = parent.slideDuration + override def dependencies = List(parent, parent2) + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - def resultRdd(stream: DStream[_]): JavaRDD[_] = stream.getOrCompute(validTime) match { - case Some(rdd) => JavaRDD.fromRDD(rdd) - case None => null - } - Some(func.call(resultRdd(parent), resultRdd(parent2), validTime.milliseconds)) + func(parent.getOrCompute(validTime), parent2.getOrCompute(validTime), validTime) } val asJavaDStream = JavaDStream.fromDStream(this) } +/** + * similar to StateDStream + */ +private[spark] +class PythonStateDStream(parent: DStream[Array[Byte]], preduceFunc: PythonRDDFunction2) + extends PythonDStream(parent) { + + val reduceFunc = new RDDFunction2(preduceFunc) + + super.persist(StorageLevel.MEMORY_ONLY) + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val lastState = getOrCompute(validTime - slideDuration) + val rdd = parent.getOrCompute(validTime) + if (rdd.isDefined) { + reduceFunc(lastState, rdd, validTime) + } else { + lastState + } + } +} /** * Copied from ReducedWindowedDStream */ private[spark] -class PythonReducedWindowedDStream( - parent: DStream[Array[Byte]], - reduceFunc: PythonRDDFunction2, - invReduceFunc: PythonRDDFunction2, - _windowDuration: Duration, - _slideDuration: Duration - ) extends DStream[Array[Byte]](parent.ssc) { +class PythonReducedWindowedDStream(parent: DStream[Array[Byte]], + preduceFunc: PythonRDDFunction2, + pinvReduceFunc: PythonRDDFunction2, + _windowDuration: Duration, + _slideDuration: Duration + ) extends PythonStateDStream(parent, preduceFunc) { assert(_windowDuration.isMultipleOf(parent.slideDuration), "The window duration of ReducedWindowedDStream (" + _windowDuration + ") " + @@ -134,18 +197,10 @@ class PythonReducedWindowedDStream( "must be multiple of the slide duration of parent DStream (" + parent.slideDuration + ")" ) + val invReduceFunc = new RDDFunction2(pinvReduceFunc) - // Persist RDDs to memory by default as these RDDs are going to be reused. - super.persist(StorageLevel.MEMORY_ONLY) - - def windowDuration: Duration = _windowDuration - - override def dependencies = List(parent) - + def windowDuration: Duration = _windowDuration override def slideDuration: Duration = _slideDuration - - override val mustCheckpoint = true - override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { @@ -171,20 +226,17 @@ class PythonReducedWindowedDStream( // old RDDs new RDDs // - // Get the RDD of the reduced value of the previous window - val previousWindowRDD = - getOrCompute(previousWindow.endTime) + val previousWindowRDD = getOrCompute(previousWindow.endTime) if (windowDuration > slideDuration * 5 && previousWindowRDD.isDefined) { // subtle the values from old RDDs val oldRDDs = parent.slice(previousWindow.beginTime, currentWindow.beginTime - parent.slideDuration) val subbed = if (oldRDDs.size > 0) { - invReduceFunc.call(JavaRDD.fromRDD(previousWindowRDD.get), - JavaRDD.fromRDD(ssc.sc.union(oldRDDs)), validTime.milliseconds).rdd + invReduceFunc(previousWindowRDD, Some(ssc.sc.union(oldRDDs)), validTime) } else { - previousWindowRDD.get + previousWindowRDD } // add the RDDs of the reduced values in "new time steps" @@ -192,58 +244,21 @@ class PythonReducedWindowedDStream( parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration) if (newRDDs.size > 0) { - Some(reduceFunc.call(JavaRDD.fromRDD(subbed), JavaRDD.fromRDD(ssc.sc.union(newRDDs)), validTime.milliseconds)) + reduceFunc(subbed, Some(ssc.sc.union(newRDDs)), validTime) } else { - Some(subbed) + subbed } } else { // Get the RDDs of the reduced values in current window val currentRDDs = parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration) if (currentRDDs.size > 0) { - Some(reduceFunc.call(null, JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds)) + reduceFunc(None, Some(ssc.sc.union(currentRDDs)), validTime) } else { None } } } - - val asJavaDStream = JavaDStream.fromDStream(this) -} - - -/** - * Copied from ReducedWindowedDStream - */ -private[spark] -class PythonStateDStream( - parent: DStream[Array[Byte]], - reduceFunc: PythonRDDFunction2 - ) extends DStream[Array[Byte]](parent.ssc) { - - super.persist(StorageLevel.MEMORY_ONLY) - - override def dependencies = List(parent) - - override def slideDuration: Duration = parent.slideDuration - - override val mustCheckpoint = true - - override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val lastState = getOrCompute(validTime - slideDuration) - val newRDD = parent.getOrCompute(validTime) - if (newRDD.isDefined) { - if (lastState.isDefined) { - Some(reduceFunc.call(JavaRDD.fromRDD(lastState.get), JavaRDD.fromRDD(newRDD.get), validTime.milliseconds)) - } else { - Some(reduceFunc.call(null, JavaRDD.fromRDD(newRDD.get), validTime.milliseconds)) - } - } else { - lastState - } - } - - val asJavaDStream = JavaDStream.fromDStream(this) } /** @@ -255,7 +270,9 @@ class PythonForeachDStream( ) extends ForEachDStream[Array[Byte]]( prev, (rdd: RDD[Array[Byte]], time: Time) => { - foreachFunction.call(rdd.toJavaRDD(), time.milliseconds) + if (rdd != null) { + foreachFunction.call(rdd, time.milliseconds) + } } ) { @@ -264,34 +281,42 @@ class PythonForeachDStream( /** - * This is a input stream just for the unitest. This is equivalent to a checkpointable, - * replayable, reliable message queue like Kafka. It requires a JArrayList of JavaRDD, - * and returns the i_th element at the i_th batch under manual clock. + * similar to QueueInputStream */ class PythonDataInputStream( ssc_ : JavaStreamingContext, - inputRDDs: JArrayList[JavaRDD[Array[Byte]]] + inputRDDs: JArrayList[JavaRDD[Array[Byte]]], + oneAtAtime: Boolean, + defaultRDD: JavaRDD[Array[Byte]] ) extends InputDStream[Array[Byte]](JavaStreamingContext.toStreamingContext(ssc_)) { + val emptyRDD = if (defaultRDD != null) { + Some(defaultRDD.rdd) + } else { + None // ssc.sparkContext.emptyRDD[Array[Byte]] + } + def start() {} def stop() {} def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - val emptyRDD = ssc.sparkContext.emptyRDD[Array[Byte]] val index = ((validTime - zeroTime) / slideDuration - 1).toInt - val selectedRDD = { - if (inputRDDs.isEmpty) { + if (oneAtAtime) { + if (index == 0) { + val rdds = inputRDDs.toArray.map(_.asInstanceOf[JavaRDD[Array[Byte]]].rdd).toSeq + Some(ssc.sparkContext.union(rdds)) + } else { emptyRDD - } else if (index < inputRDDs.size()) { - inputRDDs.get(index).rdd + } + } else { + if (index < inputRDDs.size()) { + Some(inputRDDs.get(index).rdd) } else { emptyRDD } } - - Some(selectedRDD) } val asJavaDStream = JavaDStream.fromDStream(this)