Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32010][PYTHON][CORE] Add InheritableThread for local properties and fixing a thread leak issue in pinned thread mode #28968

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions docs/job-scheduling.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,9 @@ via `sc.setJobGroup` in a separate PVM thread, which also disallows to cancel th
later.

In order to synchronize PVM threads with JVM threads, you should set `PYSPARK_PIN_THREAD` environment variable
to `true`. This pinned thread mode allows one PVM thread has one corresponding JVM thread.

However, currently it cannot inherit the local properties from the parent thread although it isolates
each thread with its own local properties. To work around this, you should manually copy and set the
local properties from the parent thread to the child thread when you create another thread in PVM.
to `true`. This pinned thread mode allows one PVM thread has one corresponding JVM thread. With this mode,
`pyspark.InheritableThread` is recommanded to use together for a PVM thread to inherit the interitable attributes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: interitable -> inheritable

such as local properties in a JVM thread.

Note that `PYSPARK_PIN_THREAD` is currently experimental and not recommended for use in production.

5 changes: 4 additions & 1 deletion python/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
A :class:`TaskContext` that provides extra info and tooling for barrier execution.
- :class:`BarrierTaskInfo`:
Information about a barrier task.
- :class:`InheritableThread`:
A inheritable thread to use in Spark when the pinned thread mode is on.
"""

from functools import wraps
Expand All @@ -51,6 +53,7 @@
from pyspark.context import SparkContext
from pyspark.rdd import RDD, RDDBarrier
from pyspark.files import SparkFiles
from pyspark.util import InheritableThread
from pyspark.storagelevel import StorageLevel
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
Expand Down Expand Up @@ -118,5 +121,5 @@ def wrapper(self, *args, **kwargs):
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
"StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler", "TaskContext",
"RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo",
"RDDBarrier", "BarrierTaskContext", "BarrierTaskInfo", "InheritableThread",
]
18 changes: 12 additions & 6 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,8 +1013,10 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False):
.. note:: Currently, setting a group ID (set to local properties) with multiple threads
does not properly work. Internally threads on PVM and JVM are not synced, and JVM
thread can be reused for multiple threads on PVM, which fails to isolate local
properties for each thread on PVM. To work around this, You can use
:meth:`RDD.collectWithJobGroup` for now.
properties for each thread on PVM.

To avoid this, enable the pinned thread mode by setting ``PYSPARK_PIN_THREAD``
environment variable to ``true`` and uses :class:`pyspark.InheritableThread`.
"""
self._jsc.setJobGroup(groupId, description, interruptOnCancel)

Expand All @@ -1026,8 +1028,10 @@ def setLocalProperty(self, key, value):
.. note:: Currently, setting a local property with multiple threads does not properly work.
Internally threads on PVM and JVM are not synced, and JVM thread
can be reused for multiple threads on PVM, which fails to isolate local properties
for each thread on PVM. To work around this, You can use
:meth:`RDD.collectWithJobGroup`.
for each thread on PVM.

To avoid this, enable the pinned thread mode by setting ``PYSPARK_PIN_THREAD``
environment variable to ``true`` and uses :class:`pyspark.InheritableThread`.
"""
self._jsc.setLocalProperty(key, value)

Expand All @@ -1045,8 +1049,10 @@ def setJobDescription(self, value):
.. note:: Currently, setting a job description (set to local properties) with multiple
threads does not properly work. Internally threads on PVM and JVM are not synced,
and JVM thread can be reused for multiple threads on PVM, which fails to isolate
local properties for each thread on PVM. To work around this, You can use
:meth:`RDD.collectWithJobGroup` for now.
local properties for each thread on PVM.

To avoid this, enable the pinned thread mode by setting ``PYSPARK_PIN_THREAD``
environment variable to ``true`` and uses :class:`pyspark.InheritableThread`.
"""
self._jsc.setJobDescription(value)

Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,12 +859,18 @@ def collect(self):

def collectWithJobGroup(self, groupId, description, interruptOnCancel=False):
"""
.. note:: Experimental

When collect rdd, use this method to specify job group.

.. note:: Deprecated in 3.1.0. Use :class:`pyspark.InheritableThread` with
the pinned thread mode enabled.

.. versionadded:: 3.0.0
"""
warnings.warn(
"Deprecated in 3.1, Use pyspark.InheritableThread with "
"the pinned thread mode enabled.",
DeprecationWarning)

with SCCallSiteSync(self.context) as css:
sock_info = self.ctx._jvm.PythonRDD.collectAndServeWithJobGroup(
self._jrdd.rdd(), groupId, description, interruptOnCancel)
Expand Down
23 changes: 22 additions & 1 deletion python/pyspark/tests/test_pin_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import threading
import unittest

from pyspark import SparkContext, SparkConf
from pyspark import SparkContext, SparkConf, InheritableThread


class PinThreadTests(unittest.TestCase):
Expand Down Expand Up @@ -143,6 +143,27 @@ def run_job(job_group, index):
is_job_cancelled[i],
"Thread {i}: Job in group B did not succeeded.".format(i=i))

def test_inheritable_local_property(self):
self.sc.setLocalProperty("a", "hi")
expected = []

def get_inner_local_prop():
expected.append(self.sc.getLocalProperty("b"))

def get_outer_local_prop():
expected.append(self.sc.getLocalProperty("a"))
self.sc.setLocalProperty("b", "hello")
t2 = InheritableThread(target=get_inner_local_prop)
t2.start()
t2.join()

t1 = InheritableThread(target=get_outer_local_prop)
t1.start()
t1.join()

self.assertEqual(self.sc.getLocalProperty("b"), None)
self.assertEqual(expected, ["hi", "hello"])


if __name__ == "__main__":
import unittest
Expand Down
61 changes: 61 additions & 0 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
# limitations under the License.
#

import threading
import re
import sys
import traceback

from py4j.clientserver import ClientServer

__all__ = []


Expand Down Expand Up @@ -114,6 +117,64 @@ def _parse_memory(s):
raise ValueError("invalid format: " + s)
return int(float(s[:-1]) * units[s[-1].lower()])


class InheritableThread(threading.Thread):
"""
Thread that is recommended to be used in PySpark instead of :class:`threading.Thread`
when the pinned thread mode is enabled. The usage of this class is exactly same as
:class:`threading.Thread` but correctly inherits the inheritable properties specific
to JVM thread such as ``InheritableThreadLocal``.

Also, note that pinned thread mode does not close the connection from Python
to JVM when the thread is finished in the Python side. With this class, Python
garbage-collects the Python thread instance and also closes the connection
which finishes JVM thread correctly.

When the pinned thread mode is off, this works as :class:`threading.Thread`.

.. note:: Experimental

.. versionadded:: 3.1.0
"""
def __init__(self, target, *args, **kwargs):
from pyspark import SparkContext

sc = SparkContext._active_spark_context

if isinstance(sc._gateway, ClientServer):
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
properties = sc._jsc.sc().getLocalProperties().clone()
Copy link
Member

@viirya viirya Aug 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need to clone? Doesn't sc.localProperties get clone in childValue already?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we're mimicking that behaviour here because the thread in JVM does not respect the inheritance here since the thread is always sepearately created via the JVM gateway whereas Scala Java side we can keep the inheritance by creating a thread within a thread.

self._sc = sc

def copy_local_properties(*a, **k):
sc._jsc.sc().setLocalProperties(properties)
return target(*a, **k)

super(InheritableThread, self).__init__(
target=copy_local_properties, *args, **kwargs)
else:
super(InheritableThread, self).__init__(target=target, *args, **kwargs)

def __del__(self):
from pyspark import SparkContext

if isinstance(SparkContext._gateway, ClientServer):
thread_connection = self._sc._jvm._gateway_client.thread_connection.connection()
if thread_connection is not None:
connections = self._sc._jvm._gateway_client.deque

# Reuse the lock for Py4J in PySpark
with SparkContext._lock:
for i in range(len(connections)):
if connections[i] is thread_connection:
connections[i].close()
del connections[i]
break
else:
# Just in case the connection was not closed but removed from the queue.
thread_connection.close()


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down