From f0ea311e444655d5e13a72b85004683452919267 Mon Sep 17 00:00:00 2001 From: giwa Date: Wed, 20 Aug 2014 17:09:34 -0700 Subject: [PATCH] clean up code --- python/pyspark/streaming/context.py | 11 ++--- python/pyspark/streaming/dstream.py | 32 ++++++++----- python/pyspark/streaming/pyprint.py | 54 --------------------- python/pyspark/streaming_tests.py | 74 ++++++++++++++++------------- 4 files changed, 63 insertions(+), 108 deletions(-) delete mode 100644 python/pyspark/streaming/pyprint.py diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index f7e356319ecac..dbb6fdf1694ad 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -72,7 +72,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, # Callback sever is need only by SparkStreming; therefore the callback sever # is started in StreamingContext. SparkContext._gateway.restart_callback_server() - self._clean_up_trigger() + self._set_clean_up_trigger() self._jvm = self._sc._jvm self._jssc = self._initialize_context(self._sc._jsc, duration._jduration) @@ -80,13 +80,11 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, def _initialize_context(self, jspark_context, jduration): return self._jvm.JavaStreamingContext(jspark_context, jduration) - def _clean_up_trigger(self): + def _set_clean_up_trigger(self): """Kill py4j callback server properly using signal lib""" def clean_up_handler(*args): # Make sure stop callback server. - # This need improvement how to terminate callback sever properly. - SparkContext._gateway._shutdown_callback_server() SparkContext._gateway.shutdown() sys.exit(0) @@ -132,18 +130,15 @@ def stop(self, stopSparkContext=True, stopGraceFully=False): Stop the execution of the streams immediately (does not wait for all received data to be processed). """ - try: self._jssc.stop(stopSparkContext, stopGraceFully) finally: - # Stop Callback server - SparkContext._gateway._shutdown_callback_server() SparkContext._gateway.shutdown() def _testInputStream(self, test_inputs, numSlices=None): """ This function is only for unittest. - It requires a sequence as input, and returns the i_th element at the i_th batch + It requires a list as input, and returns the i_th element at the i_th batch under manual clock. """ test_rdds = list() diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index caf4378a9b1b9..fc15309679c2a 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -201,7 +201,7 @@ def _defaultReducePartitions(self): """ Returns the default number of partitions to use during reduce tasks (e.g., groupBy). If spark.default.parallelism is set, then we'll use the value from SparkContext - defaultParallelism, otherwise we'll use the number of partitions in this RDD. + defaultParallelism, otherwise we'll use the number of partitions in this RDD This mirrors the behavior of the Scala Partitioner#defaultPartitioner, intended to reduce the likelihood of OOMs. Once PySpark adopts Partitioner-based APIs, this behavior will @@ -216,7 +216,8 @@ def getNumPartitions(self): """ Return the number of partitions in RDD """ - # TODO: remove hardcoding. RDD has NumPartitions but DStream does not have. + # TODO: remove hardcoding. RDD has NumPartitions. How do we get the number of partition + # through DStream? return 2 def foreachRDD(self, func): @@ -236,6 +237,10 @@ def pyprint(self): operator, so this DStream will be registered as an output stream and there materialized. """ def takeAndPrint(rdd, time): + """ + Closure to take element from RDD and print first 10 elements. + This closure is called by py4j callback server. + """ taken = rdd.take(11) print "-------------------------------------------" print "Time: %s" % (str(time)) @@ -300,17 +305,11 @@ def checkpoint(self, interval): Mark this DStream for checkpointing. It will be saved to a file inside the checkpoint directory set with L{SparkContext.setCheckpointDir()} - I am not sure this part in DStream - and - all references to its parent RDDs will be removed. This function must - be called before any job has been executed on this RDD. It is strongly - recommended that this RDD is persisted in memory, otherwise saving it - on a file will require recomputation. - - interval must be pysprak.streaming.duration + @param interval: Time interval after which generated RDD will be checkpointed + interval has to be pyspark.streaming.duration.Duration """ self.is_checkpointed = True - self._jdstream.checkpoint(interval) + self._jdstream.checkpoint(interval._jduration) return self def groupByKey(self, numPartitions=None): @@ -363,6 +362,10 @@ def saveAsTextFiles(self, prefix, suffix=None): """ def saveAsTextFile(rdd, time): + """ + Closure to save element in RDD in DStream as Pickled data in file. + This closure is called by py4j callback server. + """ path = rddToFileName(prefix, suffix, time) rdd.saveAsTextFile(path) @@ -376,6 +379,10 @@ def saveAsPickleFiles(self, prefix, suffix=None): """ def saveAsPickleFile(rdd, time): + """ + Closure to save element in RDD in the DStream as Pickled data in file. + This closure is called by py4j callback server. + """ path = rddToFileName(prefix, suffix, time) rdd.saveAsPickleFile(path) @@ -404,9 +411,10 @@ def get_output(rdd, time): # TODO: implement countByWindow # TODO: implement reduceByWindow -# Following operation has dependency to transform +# transform Operation # TODO: implement transform # TODO: implement transformWith +# Following operation has dependency with transform # TODO: implement union # TODO: implement repertitions # TODO: implement cogroup diff --git a/python/pyspark/streaming/pyprint.py b/python/pyspark/streaming/pyprint.py deleted file mode 100644 index 49517b3e5c247..0000000000000 --- a/python/pyspark/streaming/pyprint.py +++ /dev/null @@ -1,54 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - - -import sys -from itertools import chain - -from pyspark.serializers import PickleSerializer - - -def collect(binary_file_path): - """ - Read pickled file written by SparkStreaming - """ - dse = PickleSerializer() - with open(binary_file_path, 'rb') as tempFile: - for item in dse.load_stream(tempFile): - yield item - - -def main(): - try: - binary_file_path = sys.argv[1] - except: - print "Missed FilePath in argements" - - if not binary_file_path: - return - - counter = 0 - for rdd in chain.from_iterable(collect(binary_file_path)): - print rdd - counter = counter + 1 - if counter >= 10: - print "..." - break - - -if __name__ =="__main__": - exit(main()) diff --git a/python/pyspark/streaming_tests.py b/python/pyspark/streaming_tests.py index f2ef45ab23ccc..ba6c028f1fb55 100644 --- a/python/pyspark/streaming_tests.py +++ b/python/pyspark/streaming_tests.py @@ -64,7 +64,7 @@ class TestBasicOperationsSuite(PySparkStreamingTestCase): If the number of input element is over 3, that DStream use batach deserializer. If not, that DStream use unbatch deserializer. - All tests input should have list of lists. This list represents stream. + All tests input should have list of lists(3 lists are default). This list represents stream. Every batch interval, the first object of list are chosen to make DStream. e.g The first list in the list is input of the first batch. Please see the BasicTestSuits in Scala which is close to this implementation. @@ -82,7 +82,7 @@ def tearDownClass(cls): PySparkStreamingTestCase.tearDownClass() def test_map_batch(self): - """Basic operation test for DStream.map with batch deserializer""" + """Basic operation test for DStream.map with batch deserializer.""" test_input = [range(1, 5), range(5, 9), range(9, 13)] def test_func(dstream): @@ -92,7 +92,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_map_unbatach(self): - """Basic operation test for DStream.map with unbatch deserializer""" + """Basic operation test for DStream.map with unbatch deserializer.""" test_input = [range(1, 4), range(4, 7), range(7, 10)] def test_func(dstream): @@ -102,7 +102,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_flatMap_batch(self): - """Basic operation test for DStream.faltMap with batch deserializer""" + """Basic operation test for DStream.faltMap with batch deserializer.""" test_input = [range(1, 5), range(5, 9), range(9, 13)] def test_func(dstream): @@ -113,7 +113,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_flatMap_unbatch(self): - """Basic operation test for DStream.faltMap with unbatch deserializer""" + """Basic operation test for DStream.faltMap with unbatch deserializer.""" test_input = [range(1, 4), range(4, 7), range(7, 10)] def test_func(dstream): @@ -124,7 +124,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_filter_batch(self): - """Basic operation test for DStream.filter with batch deserializer""" + """Basic operation test for DStream.filter with batch deserializer.""" test_input = [range(1, 5), range(5, 9), range(9, 13)] def test_func(dstream): @@ -134,7 +134,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_filter_unbatch(self): - """Basic operation test for DStream.filter with unbatch deserializer""" + """Basic operation test for DStream.filter with unbatch deserializer.""" test_input = [range(1, 4), range(4, 7), range(7, 10)] def test_func(dstream): @@ -144,7 +144,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_count_batch(self): - """Basic operation test for DStream.count with batch deserializer""" + """Basic operation test for DStream.count with batch deserializer.""" test_input = [range(1, 5), range(1, 10), range(1, 20)] def test_func(dstream): @@ -154,7 +154,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_count_unbatch(self): - """Basic operation test for DStream.count with unbatch deserializer""" + """Basic operation test for DStream.count with unbatch deserializer.""" test_input = [[], [1], range(1, 3), range(1, 4)] def test_func(dstream): @@ -164,7 +164,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_reduce_batch(self): - """Basic operation test for DStream.reduce with batch deserializer""" + """Basic operation test for DStream.reduce with batch deserializer.""" test_input = [range(1, 5), range(5, 9), range(9, 13)] def test_func(dstream): @@ -174,7 +174,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_reduce_unbatch(self): - """Basic operation test for DStream.reduce with unbatch deserializer""" + """Basic operation test for DStream.reduce with unbatch deserializer.""" test_input = [[1], range(1, 3), range(1, 4)] def test_func(dstream): @@ -184,7 +184,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_reduceByKey_batch(self): - """Basic operation test for DStream.reduceByKey with batch deserializer""" + """Basic operation test for DStream.reduceByKey with batch deserializer.""" test_input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)], [("", 1),("", 1), ("", 1), ("", 1)], [(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]] @@ -198,7 +198,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_reduceByKey_unbatch(self): - """Basic operation test for DStream.reduceByKey with unbatch deserilizer""" + """Basic operation test for DStream.reduceByKey with unbatch deserializer.""" test_input = [[("a", 1), ("a", 1), ("b", 1)], [("", 1), ("", 1)], []] def test_func(dstream): @@ -210,44 +210,49 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_mapValues_batch(self): - """Basic operation test for DStream.mapValues with batch deserializer""" + """Basic operation test for DStream.mapValues with batch deserializer.""" test_input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], - [("", 4), (1, 1), (2, 2), (3, 3)]] + [("", 4), (1, 1), (2, 2), (3, 3)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] def test_func(dstream): return dstream.mapValues(lambda x: x + 10) expected_output = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)], - [("", 14), (1, 11), (2, 12), (3, 13)]] + [("", 14), (1, 11), (2, 12), (3, 13)], + [(1, 11), (2, 11), (3, 11), (4, 11)]] output = self._run_stream(test_input, test_func, expected_output) for result in (output, expected_output): self._sort_result_based_on_key(result) self.assertEqual(expected_output, output) def test_mapValues_unbatch(self): - """Basic operation test for DStream.mapValues with unbatch deserializer""" - test_input = [[("a", 2), ("b", 1)], [("", 2)], []] + """Basic operation test for DStream.mapValues with unbatch deserializer.""" + test_input = [[("a", 2), ("b", 1)], [("", 2)], [], [(1, 1), (2, 2)]] def test_func(dstream): return dstream.mapValues(lambda x: x + 10) - expected_output = [[("a", 12), ("b", 11)], [("", 12)], []] + expected_output = [[("a", 12), ("b", 11)], [("", 12)], [], [(1, 11), (2, 12)]] output = self._run_stream(test_input, test_func, expected_output) + for result in (output, expected_output): + self._sort_result_based_on_key(result) self.assertEqual(expected_output, output) def test_flatMapValues_batch(self): - """Basic operation test for DStream.flatMapValues with batch deserializer""" - test_input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], [("", 4), (1, 1), (2, 1), (3, 1)]] + """Basic operation test for DStream.flatMapValues with batch deserializer.""" + test_input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)], + [("", 4), (1, 1), (2, 1), (3, 1)], + [(1, 1), (2, 1), (3, 1), (4, 1)]] def test_func(dstream): return dstream.flatMapValues(lambda x: (x, x + 10)) - expected_output = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), - ("c", 1), ("c", 11), ("d", 1), ("d", 11)], - [("", 4), ("", 14), (1, 1), (1, 11), - (2, 1), (2, 11), (3, 1), (3, 11)]] + expected_output = [[("a", 2), ("a", 12), ("b", 2), ("b", 12), ("c", 1), ("c", 11), ("d", 1), ("d", 11)], + [("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)], + [(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]] output = self._run_stream(test_input, test_func, expected_output) self.assertEqual(expected_output, output) def test_flatMapValues_unbatch(self): - """Basic operation test for DStream.flatMapValues with unbatch deserializer""" + """Basic operation test for DStream.flatMapValues with unbatch deserializer.""" test_input = [[("a", 2), ("b", 1)], [("", 2)], []] def test_func(dstream): @@ -257,7 +262,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_glom_batch(self): - """Basic operation test for DStream.glom with batch deserializer""" + """Basic operation test for DStream.glom with batch deserializer.""" test_input = [range(1, 5), range(5, 9), range(9, 13)] numSlices = 2 @@ -268,7 +273,7 @@ def test_func(dstream): self.assertEqual(expected_output, output) def test_glom_unbatach(self): - """Basic operation test for DStream.glom with unbatch deserialiser""" + """Basic operation test for DStream.glom with unbatch deserializer.""" test_input = [range(1, 4), range(4, 7), range(7, 10)] numSlices = 2 @@ -385,7 +390,7 @@ def add(a, b): return a + str(b) def test_combineByKey_unbatch(self): """Basic operation test for DStream.combineByKey with unbatch deserializer.""" - test_input = [[(1, 1), (2, 1), (3 ,1)], [(1, 1), (1, 1), ("", 1)], [("a", 1), ("a", 1), ("b", 1)]] + test_input = [[(1, 1), (2, 1), (3, 1)], [(1, 1), (1, 1), ("", 1)], [("a", 1), ("a", 1), ("b", 1)]] def test_func(dstream): def add(a, b): return a + str(b) @@ -414,8 +419,8 @@ def _run_stream(self, test_input, test_func, expected_output, numSlices=None): """ Start stream and return the result. @param test_input: dataset for the test. This should be list of lists. - @param test_func: wrapped test_function. This function should return PythonDstream object. - @param expexted_output: expected output for this testcase. + @param test_func: wrapped test_function. This function should return PythonDStream object. + @param expected_output: expected output for this testcase. @param numSlices: the number of slices in the rdd in the dstream. """ # Generate input stream with user-defined input. @@ -436,21 +441,22 @@ def _run_stream(self, test_input, test_func, expected_output, numSlices=None): if (current_time - start_time) > self.timeout: break # StreamingContext.awaitTermination is not used to wait because - # if py4j server is called every 50 milliseconds, it gets an error + # if py4j server is called every 50 milliseconds, it gets an error. time.sleep(0.05) - # Check if the output is the same length of expexted output. + # Check if the output is the same length of expected output. if len(expected_output) == len(result): break return result +#TODO: add testcase for saveAs* + class TestSaveAsFilesSuite(PySparkStreamingTestCase): def setUp(self): PySparkStreamingTestCase.setUp(self) self.timeout = 10 # seconds self.numInputPartitions = 2 - self.result = list() def tearDown(self): PySparkStreamingTestCase.tearDown(self)