From 10ba6e1111c834dd8bb0f5a6630cd9315e9f9efb Mon Sep 17 00:00:00 2001 From: Aaron Staple Date: Sat, 13 Sep 2014 21:13:22 -0700 Subject: [PATCH] [SPARK-1087] Move python traceback utilities into new traceback_utils.py file. --- python/pyspark/context.py | 5 +- python/pyspark/rdd.py | 58 ++-------------------- python/pyspark/traceback_utils.py | 80 +++++++++++++++++++++++++++++++ 3 files changed, 86 insertions(+), 57 deletions(-) create mode 100644 python/pyspark/traceback_utils.py diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 3ab98e262df31..277564761addd 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -33,6 +33,7 @@ from pyspark.storagelevel import StorageLevel from pyspark import rdd from pyspark.rdd import RDD +from pyspark.traceback_utils import extract_concise_traceback from py4j.java_collections import ListConverter @@ -99,8 +100,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, ... ValueError:... """ - if rdd._extract_concise_traceback() is not None: - self._callsite = rdd._extract_concise_traceback() + if extract_concise_traceback() is not None: + self._callsite = extract_concise_traceback() else: tempNamedTuple = namedtuple("Callsite", "function file linenum") self._callsite = tempNamedTuple(function=None, file=None, linenum=None) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6ad5ab2a2d1ae..625c7108e5a69 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,13 +18,11 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from collections import namedtuple from itertools import chain, ifilter, imap import operator import os import sys import shlex -import traceback from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread @@ -45,6 +43,7 @@ from pyspark.resultiterable import ResultIterable from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \ get_used_memory, ExternalSorter +from pyspark.traceback_utils import JavaStackTrace from py4j.java_collections import ListConverter, MapConverter @@ -81,57 +80,6 @@ def portable_hash(x): return hash(x) -def _extract_concise_traceback(): - """ - This function returns the traceback info for a callsite, returns a dict - with function name, file name and line number - """ - tb = traceback.extract_stack() - callsite = namedtuple("Callsite", "function file linenum") - if len(tb) == 0: - return None - file, line, module, what = tb[len(tb) - 1] - sparkpath = os.path.dirname(file) - first_spark_frame = len(tb) - 1 - for i in range(0, len(tb)): - file, line, fun, what = tb[i] - if file.startswith(sparkpath): - first_spark_frame = i - break - if first_spark_frame == 0: - file, line, fun, what = tb[0] - return callsite(function=fun, file=file, linenum=line) - sfile, sline, sfun, swhat = tb[first_spark_frame] - ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] - return callsite(function=sfun, file=ufile, linenum=uline) - -_spark_stack_depth = 0 - - -class _JavaStackTrace(object): - - def __init__(self, sc): - tb = _extract_concise_traceback() - if tb is not None: - self._traceback = "%s at %s:%s" % ( - tb.function, tb.file, tb.linenum) - else: - self._traceback = "Error! Could not extract traceback info" - self._context = sc - - def __enter__(self): - global _spark_stack_depth - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(self._traceback) - _spark_stack_depth += 1 - - def __exit__(self, type, value, tb): - global _spark_stack_depth - _spark_stack_depth -= 1 - if _spark_stack_depth == 0: - self._context._jsc.setCallSite(None) - - class BoundedFloat(float): """ Bounded value is generated by approximate job, with confidence and low @@ -704,7 +652,7 @@ def collect(self): """ Return a list that contains all of the elements in this RDD. """ - with _JavaStackTrace(self.context) as st: + with JavaStackTrace(self.context) as st: bytesInJava = self._jrdd.collect().iterator() return list(self._collect_iterator_through_file(bytesInJava)) @@ -1515,7 +1463,7 @@ def add_shuffle_key(split, iterator): keyed = self.mapPartitionsWithIndex(add_shuffle_key) keyed._bypass_serializer = True - with _JavaStackTrace(self.context) as st: + with JavaStackTrace(self.context) as st: pairRDD = self.ctx._jvm.PairwiseRDD( keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, diff --git a/python/pyspark/traceback_utils.py b/python/pyspark/traceback_utils.py new file mode 100644 index 0000000000000..5ee27e3e50bfa --- /dev/null +++ b/python/pyspark/traceback_utils.py @@ -0,0 +1,80 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from collections import namedtuple +import os +import traceback + + +__all__ = ["extract_concise_traceback", "SparkContext"] + + +def extract_concise_traceback(): + """ + This function returns the traceback info for a callsite, returns a dict + with function name, file name and line number + """ + tb = traceback.extract_stack() + callsite = namedtuple("Callsite", "function file linenum") + if len(tb) == 0: + return None + file, line, module, what = tb[len(tb) - 1] + sparkpath = os.path.dirname(file) + first_spark_frame = len(tb) - 1 + for i in range(0, len(tb)): + file, line, fun, what = tb[i] + if file.startswith(sparkpath): + first_spark_frame = i + break + if first_spark_frame == 0: + file, line, fun, what = tb[0] + return callsite(function=fun, file=file, linenum=line) + sfile, sline, sfun, swhat = tb[first_spark_frame] + ufile, uline, ufun, uwhat = tb[first_spark_frame - 1] + return callsite(function=sfun, file=ufile, linenum=uline) + + +class JavaStackTrace(object): + """ + Helper for setting the spark context call site. + + Example usage: + from pyspark.context import JavaStackTrace + with JavaStackTrace() as st: + + """ + + _spark_stack_depth = 0 + + def __init__(self, sc): + tb = extract_concise_traceback() + if tb is not None: + self._traceback = "%s at %s:%s" % ( + tb.function, tb.file, tb.linenum) + else: + self._traceback = "Error! Could not extract traceback info" + self._context = sc + + def __enter__(self): + if JavaStackTrace._spark_stack_depth == 0: + self._context._jsc.setCallSite(self._traceback) + JavaStackTrace._spark_stack_depth += 1 + + def __exit__(self, type, value, tb): + JavaStackTrace._spark_stack_depth -= 1 + if JavaStackTrace._spark_stack_depth == 0: + self._context._jsc.setCallSite(None)