Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-1087] Move python traceback utilities into new traceback_utils.py file. #2385

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark import rdd
from pyspark.rdd import RDD
from pyspark.traceback_utils import extract_concise_traceback

from py4j.java_collections import ListConverter

Expand Down Expand Up @@ -99,8 +100,8 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
...
ValueError:...
"""
if rdd._extract_concise_traceback() is not None:
self._callsite = rdd._extract_concise_traceback()
if extract_concise_traceback() is not None:
self._callsite = extract_concise_traceback()
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's better to only call extract_concise_traceback() once, such as:

self._callsite = extract_concise_traceback()
if self._callsite is None:
   xxxx

tempNamedTuple = namedtuple("Callsite", "function file linenum")
self._callsite = tempNamedTuple(function=None, file=None, linenum=None)
Expand Down
58 changes: 3 additions & 55 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
from collections import namedtuple
from itertools import chain, ifilter, imap
import operator
import os
import sys
import shlex
import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
Expand All @@ -45,6 +43,7 @@
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
get_used_memory, ExternalSorter
from pyspark.traceback_utils import JavaStackTrace

from py4j.java_collections import ListConverter, MapConverter

Expand Down Expand Up @@ -81,57 +80,6 @@ def portable_hash(x):
return hash(x)


def _extract_concise_traceback():
"""
This function returns the traceback info for a callsite, returns a dict
with function name, file name and line number
"""
tb = traceback.extract_stack()
callsite = namedtuple("Callsite", "function file linenum")
if len(tb) == 0:
return None
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return callsite(function=fun, file=file, linenum=line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
return callsite(function=sfun, file=ufile, linenum=uline)

_spark_stack_depth = 0


class _JavaStackTrace(object):

def __init__(self, sc):
tb = _extract_concise_traceback()
if tb is not None:
self._traceback = "%s at %s:%s" % (
tb.function, tb.file, tb.linenum)
else:
self._traceback = "Error! Could not extract traceback info"
self._context = sc

def __enter__(self):
global _spark_stack_depth
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(self._traceback)
_spark_stack_depth += 1

def __exit__(self, type, value, tb):
global _spark_stack_depth
_spark_stack_depth -= 1
if _spark_stack_depth == 0:
self._context._jsc.setCallSite(None)


class BoundedFloat(float):
"""
Bounded value is generated by approximate job, with confidence and low
Expand Down Expand Up @@ -704,7 +652,7 @@ def collect(self):
"""
Return a list that contains all of the elements in this RDD.
"""
with _JavaStackTrace(self.context) as st:
with JavaStackTrace(self.context) as st:
bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))

Expand Down Expand Up @@ -1515,7 +1463,7 @@ def add_shuffle_key(split, iterator):

keyed = self.mapPartitionsWithIndex(add_shuffle_key)
keyed._bypass_serializer = True
with _JavaStackTrace(self.context) as st:
with JavaStackTrace(self.context) as st:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
Expand Down
80 changes: 80 additions & 0 deletions python/pyspark/traceback_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from collections import namedtuple
import os
import traceback


__all__ = ["extract_concise_traceback", "SparkContext"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I also need to put JavaStackTrace here instead of SparkContext.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are just internal interfaces, so it's fine to not have all here. If having, it should be JavaStackTrace



def extract_concise_traceback():
"""
This function returns the traceback info for a callsite, returns a dict
with function name, file name and line number
"""
tb = traceback.extract_stack()
callsite = namedtuple("Callsite", "function file linenum")
if len(tb) == 0:
return None
file, line, module, what = tb[len(tb) - 1]
sparkpath = os.path.dirname(file)
first_spark_frame = len(tb) - 1
for i in range(0, len(tb)):
file, line, fun, what = tb[i]
if file.startswith(sparkpath):
first_spark_frame = i
break
if first_spark_frame == 0:
file, line, fun, what = tb[0]
return callsite(function=fun, file=file, linenum=line)
sfile, sline, sfun, swhat = tb[first_spark_frame]
ufile, uline, ufun, uwhat = tb[first_spark_frame - 1]
return callsite(function=sfun, file=ufile, linenum=uline)


class JavaStackTrace(object):
"""
Helper for setting the spark context call site.

Example usage:
from pyspark.context import JavaStackTrace
with JavaStackTrace(<relevant SparkContext>) as st:
<a Spark call>
"""

_spark_stack_depth = 0

def __init__(self, sc):
tb = extract_concise_traceback()
if tb is not None:
self._traceback = "%s at %s:%s" % (
tb.function, tb.file, tb.linenum)
else:
self._traceback = "Error! Could not extract traceback info"
self._context = sc

def __enter__(self):
if JavaStackTrace._spark_stack_depth == 0:
self._context._jsc.setCallSite(self._traceback)
JavaStackTrace._spark_stack_depth += 1

def __exit__(self, type, value, tb):
JavaStackTrace._spark_stack_depth -= 1
if JavaStackTrace._spark_stack_depth == 0:
self._context._jsc.setCallSite(None)