From 10ba6e1111c834dd8bb0f5a6630cd9315e9f9efb Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Sat, 13 Sep 2014 21:13:22 -0700 Subject: [PATCH 1/2] [SPARK-1087] Move python traceback utilities into new traceback_utils.py file. --- python/pyspark/context.py | 5 +- python/pyspark/rdd.py | 58 ++-------------------- python/pyspark/traceback_utils.py | 80 +++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 57 deletions(-) create mode 100644 python/pyspark/traceback_utils.py diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 3ab98e262df31..277564761addd 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -33,6 +33,7 @@ from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD +from pyspark.traceback_utils import extract_concise_traceback from py4j.java_collections import ListConverter @@ -99,8 +100,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ... ValueError:... """ - if rdd._extract_concise_traceback() is not None: - self._callsite = rdd._extract_concise_traceback() + if extract_concise_traceback() is not None: + self._callsite = extract_concise_traceback() else: tempNamedTuple = namedtuple("Callsite", "function file linenum") self._callsite = tempNamedTuple(function=None, file=None, linenum=None) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6ad5ab2a2d1ae..625c7108e5a69 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,13 +18,11 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from collections import namedtuple from itertools import chain, ifilter, imap import operator import os import sys import shlex -import traceback from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread @@ -45,6 +43,7 @@ from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ get_used_memory, ExternalSorter +from pyspark.traceback_utils import JavaStackTrace from py4j.java_collections import ListConverter, MapConverter @@ -81,57 +80,6 @@ def portable_hash(x): return hash(x) -def _extract_concise_traceback(): - """ - This function returns the traceback info for a callsite, returns a dict - with function name, file name and line number - """ - tb = traceback.extract_stack() - callsite = namedtuple("Callsite", "function file linenum") - if len(tb) == 0: - return None - file, line, module, what = tb[len(tb) - 1] - sparkpath = os.path.dirname(file) - first_spark_frame = len(tb) - 1 - for i in range(0, len(tb)): - file, line, fun, what = tb[i] - if file.startswith(sparkpath): - first_spark_frame = i - break - if first_spark_frame == 0: - file, line, fun, what = tb[0] - return callsite(function=fun, file=file, linenum=line) - sfile, sline, sfun, swhat = tb[first_spark_frame] - ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] - return callsite(function=sfun, file=ufile, linenum=uline) - -_spark_stack_depth = 0 - - -class _JavaStackTrace(object): - - def __init__(self, sc): - tb = _extract_concise_traceback() - if tb is not None: - self._traceback = "%s at %s:%s" % ( - tb.function, tb.file, tb.linenum) - else: - self._traceback = "Error! Could not extract traceback info" - self._context = sc - - def __enter__(self): - global _spark_stack_depth - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(self._traceback) - _spark_stack_depth += 1 - - def __exit__(self, type, value, tb): - global _spark_stack_depth - _spark_stack_depth -= 1 - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(None) - - class BoundedFloat(float): """ Bounded value is generated by approximate job, with confidence and low @@ -704,7 +652,7 @@ def collect(self): """ Return a list that contains all of the elements in this RDD. """ - with _JavaStackTrace(self.context) as st: + with JavaStackTrace(self.context) as st: bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) @@ -1515,7 +1463,7 @@ def add_shuffle_key(split, iterator): keyed = self.mapPartitionsWithIndex(add_shuffle_key) keyed._bypass_serializer = True - with _JavaStackTrace(self.context) as st: + with JavaStackTrace(self.context) as st: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, diff --git a/python/pyspark/traceback_utils.py b/python/pyspark/traceback_utils.py new file mode 100644 index 0000000000000..5ee27e3e50bfa --- /dev/null +++ b/python/pyspark/traceback_utils.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple +import os +import traceback + + +__all__ = ["extract_concise_traceback", "SparkContext"] + + +def extract_concise_traceback(): + """ + This function returns the traceback info for a callsite, returns a dict + with function name, file name and line number + """ + tb = traceback.extract_stack() + callsite = namedtuple("Callsite", "function file linenum") + if len(tb) == 0: + return None + file, line, module, what = tb[len(tb) - 1] + sparkpath = os.path.dirname(file) + first_spark_frame = len(tb) - 1 + for i in range(0, len(tb)): + file, line, fun, what = tb[i] + if file.startswith(sparkpath): + first_spark_frame = i + break + if first_spark_frame == 0: + file, line, fun, what = tb[0] + return callsite(function=fun, file=file, linenum=line) + sfile, sline, sfun, swhat = tb[first_spark_frame] + ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] + return callsite(function=sfun, file=ufile, linenum=uline) + + +class JavaStackTrace(object): + """ + Helper for setting the spark context call site. + + Example usage: + from pyspark.context import JavaStackTrace + with JavaStackTrace() as st: + + """ + + _spark_stack_depth = 0 + + def __init__(self, sc): + tb = extract_concise_traceback() + if tb is not None: + self._traceback = "%s at %s:%s" % ( + tb.function, tb.file, tb.linenum) + else: + self._traceback = "Error! Could not extract traceback info" + self._context = sc + + def __enter__(self): + if JavaStackTrace._spark_stack_depth == 0: + self._context._jsc.setCallSite(self._traceback) + JavaStackTrace._spark_stack_depth += 1 + + def __exit__(self, type, value, tb): + JavaStackTrace._spark_stack_depth -= 1 + if JavaStackTrace._spark_stack_depth == 0: + self._context._jsc.setCallSite(None) From 7b3bb13976371cac60abe170db808389e9ba9cbd Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Mon, 15 Sep 2014 09:56:57 -0700 Subject: [PATCH 2/2] Address review comments, cosmetic cleanups. --- python/pyspark/context.py | 9 ++------ python/pyspark/rdd.py | 6 ++--- python/pyspark/traceback_utils.py | 38 +++++++++++++++---------------- 3 files changed, 23 insertions(+), 30 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 277564761addd..ddd452b49cac6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -from collections import namedtuple from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -33,7 +32,7 @@ from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD -from pyspark.traceback_utils import extract_concise_traceback +from pyspark.traceback_utils import CallSite, first_spark_call from py4j.java_collections import ListConverter @@ -100,11 +99,7 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ... ValueError:... """ - if extract_concise_traceback() is not None: - self._callsite = extract_concise_traceback() - else: - tempNamedTuple = namedtuple("Callsite", "function file linenum") - self._callsite = tempNamedTuple(function=None, file=None, linenum=None) + self._callsite = first_spark_call() or CallSite(None, None, None) SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 625c7108e5a69..21f182b0ff137 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -43,7 +43,7 @@ from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ get_used_memory, ExternalSorter -from pyspark.traceback_utils import JavaStackTrace +from pyspark.traceback_utils import SCCallSiteSync from py4j.java_collections import ListConverter, MapConverter @@ -652,7 +652,7 @@ def collect(self): """ Return a list that contains all of the elements in this RDD. """ - with JavaStackTrace(self.context) as st: + with SCCallSiteSync(self.context) as css: bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) @@ -1463,7 +1463,7 @@ def add_shuffle_key(split, iterator): keyed = self.mapPartitionsWithIndex(add_shuffle_key) keyed._bypass_serializer = True - with JavaStackTrace(self.context) as st: + with SCCallSiteSync(self.context) as css: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, diff --git a/python/pyspark/traceback_utils.py b/python/pyspark/traceback_utils.py index 5ee27e3e50bfa..bb8646df2b0bf 100644 --- a/python/pyspark/traceback_utils.py +++ b/python/pyspark/traceback_utils.py @@ -20,16 +20,14 @@ import traceback -__all__ = ["extract_concise_traceback", "SparkContext"] +CallSite = namedtuple("CallSite", "function file linenum") -def extract_concise_traceback(): +def first_spark_call(): """ - This function returns the traceback info for a callsite, returns a dict - with function name, file name and line number + Return a CallSite representing the first Spark call in the current call stack. """ tb = traceback.extract_stack() - callsite = namedtuple("Callsite", "function file linenum") if len(tb) == 0: return None file, line, module, what = tb[len(tb) - 1] @@ -42,39 +40,39 @@ def extract_concise_traceback(): break if first_spark_frame == 0: file, line, fun, what = tb[0] - return callsite(function=fun, file=file, linenum=line) + return CallSite(function=fun, file=file, linenum=line) sfile, sline, sfun, swhat = tb[first_spark_frame] ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] - return callsite(function=sfun, file=ufile, linenum=uline) + return CallSite(function=sfun, file=ufile, linenum=uline) -class JavaStackTrace(object): +class SCCallSiteSync(object): """ Helper for setting the spark context call site. Example usage: - from pyspark.context import JavaStackTrace - with JavaStackTrace() as st: + from pyspark.context import SCCallSiteSync + with SCCallSiteSync() as css: """ _spark_stack_depth = 0 def __init__(self, sc): - tb = extract_concise_traceback() - if tb is not None: - self._traceback = "%s at %s:%s" % ( - tb.function, tb.file, tb.linenum) + call_site = first_spark_call() + if call_site is not None: + self._call_site = "%s at %s:%s" % ( + call_site.function, call_site.file, call_site.linenum) else: - self._traceback = "Error! Could not extract traceback info" + self._call_site = "Error! Could not extract traceback info" self._context = sc def __enter__(self): - if JavaStackTrace._spark_stack_depth == 0: - self._context._jsc.setCallSite(self._traceback) - JavaStackTrace._spark_stack_depth += 1 + if SCCallSiteSync._spark_stack_depth == 0: + self._context._jsc.setCallSite(self._call_site) + SCCallSiteSync._spark_stack_depth += 1 def __exit__(self, type, value, tb): - JavaStackTrace._spark_stack_depth -= 1 - if JavaStackTrace._spark_stack_depth == 0: + SCCallSiteSync._spark_stack_depth -= 1 + if SCCallSiteSync._spark_stack_depth == 0: self._context._jsc.setCallSite(None)