Skip to content

Commit

Permalink
[SPARK-22340][PYTHON] Add a mode to pin Python thread into JVM's
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR proposes to add **Single threading model design (pinned thread model)** mode which is an experimental mode to sync threads on PVM and JVM. See https://www.py4j.org/advanced_topics.html#using-single-threading-model-pinned-thread

### Multi threading model

Currently, PySpark uses this model. Threads on PVM and JVM are independent. For instance, in a different Python thread, callbacks are received and relevant Python codes are executed. JVM threads are reused when possible.

Py4J will create a new thread every time a command is received and there is no thread available. See the current model we're using - https://www.py4j.org/advanced_topics.html#the-multi-threading-model

One problem in this model is that we can't sync threads on PVM and JVM out of the box. This leads to some problems in particular at some codes related to threading in JVM side. See:
https://github.com/apache/spark/blob/7056e004ee566fabbb9b22ddee2de55ef03260db/core/src/main/scala/org/apache/spark/SparkContext.scala#L334
Due to reusing JVM threads, seems the job groups in Python threads cannot be set in each thread as described in the JIRA.

### Single threading model design (pinned thread model)

This mode pins and syncs the threads on PVM and JVM to work around the problem above. For instance, in the same Python thread, callbacks are received and relevant Python codes are executed. See https://www.py4j.org/advanced_topics.html#the-single-threading-model

Even though this mode can sync threads on PVM and JVM for other thread related code paths,
 this might cause another problem: seems unable to inherit properties as below (assuming multi-thread mode still creates new threads when existing threads are busy, I suspect this issue already exists when multiple jobs are submitted in multi-thread mode; however, it can be always seen in single threading mode):

```bash
$ PYSPARK_PIN_THREAD=true ./bin/pyspark
```

```python
import threading

spark.sparkContext.setLocalProperty("a", "hi")
def print_prop():
    print(spark.sparkContext.getLocalProperty("a"))

threading.Thread(target=print_prop).start()
```

```
None
```

Unlike Scala side:

```scala
spark.sparkContext.setLocalProperty("a", "hi")
new Thread(new Runnable {
  def run() = println(spark.sparkContext.getLocalProperty("a"))
}).start()
```

```
hi
```

This behaviour potentially could cause weird issues but this PR currently does not target this fix this for now since this mode is experimental.

### How does this PR fix?

Basically there are two types of Py4J servers `GatewayServer` and `ClientServer`.  The former is for multi threading and the latter is for single threading. This PR adds a switch to use the latter.

In Scala side:
The logic to select a server is encapsulated in `Py4JServer` and use `Py4JServer` at `PythonRunner` for Spark summit and `PythonGatewayServer` for Spark shell. Each uses `ClientServer` when `PYSPARK_PIN_THREAD` is `true` and `GatewayServer` otherwise.

In Python side:
Simply do an if-else to switch the server to talk. It uses `ClientServer` when `PYSPARK_PIN_THREAD` is `true` and `GatewayServer` otherwise.

This is disabled by default for now.

## How was this patch tested?

Manually tested. This can be tested via:

```python
PYSPARK_PIN_THREAD=true ./bin/pyspark
```

and/or

```bash
cd python
./run-tests --python-executables=python --testnames "pyspark.tests.test_pin_thread"
```

Also, ran the Jenkins tests with `PYSPARK_PIN_THREAD` enabled.

Closes #24898 from HyukjinKwon/pinned-thread.

Authored-by: HyukjinKwon <gurwls223@apache.org>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Nov 7, 2019
1 parent da848b1 commit 4ec04e5
Show file tree
Hide file tree
Showing 11 changed files with 418 additions and 96 deletions.
69 changes: 69 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/Py4JServer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.python

import java.net.InetAddress
import java.util.Locale

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

/**
* A wrapper for both GatewayServer, and ClientServer to pin Python thread to JVM thread.
*/
private[spark] class Py4JServer(sparkConf: SparkConf) extends Logging {
private[spark] val secret: String = Utils.createSecret(sparkConf)

// Launch a Py4J gateway or client server for the process to connect to; this will let it see our
// Java system properties and such
private val localhost = InetAddress.getLoopbackAddress()
private[spark] val server = if (sys.env.getOrElse(
"PYSPARK_PIN_THREAD", "false").toLowerCase(Locale.ROOT) == "true") {
new py4j.ClientServer.ClientServerBuilder()
.authToken(secret)
.javaPort(0)
.javaAddress(localhost)
.build()
} else {
new py4j.GatewayServer.GatewayServerBuilder()
.authToken(secret)
.javaPort(0)
.javaAddress(localhost)
.callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()
}

def start(): Unit = server match {
case clientServer: py4j.ClientServer => clientServer.startServer()
case gatewayServer: py4j.GatewayServer => gatewayServer.start()
case other => throw new RuntimeException(s"Unexpected Py4J server ${other.getClass}")
}

def getListeningPort: Int = server match {
case clientServer: py4j.ClientServer => clientServer.getJavaServer.getListeningPort
case gatewayServer: py4j.GatewayServer => gatewayServer.getListeningPort
case other => throw new RuntimeException(s"Unexpected Py4J server ${other.getClass}")
}

def shutdown(): Unit = server match {
case clientServer: py4j.ClientServer => clientServer.shutdown()
case gatewayServer: py4j.GatewayServer => gatewayServer.shutdown()
case other => throw new RuntimeException(s"Unexpected Py4J server ${other.getClass}")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,42 +18,28 @@
package org.apache.spark.api.python

import java.io.{DataOutputStream, File, FileOutputStream}
import java.net.InetAddress
import java.nio.charset.StandardCharsets.UTF_8
import java.nio.file.Files

import py4j.GatewayServer

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils

/**
* Process that starts a Py4J GatewayServer on an ephemeral port.
* Process that starts a Py4J server on an ephemeral port.
*
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
*/
private[spark] object PythonGatewayServer extends Logging {
initializeLogIfNecessary(true)

def main(args: Array[String]): Unit = {
val secret = Utils.createSecret(new SparkConf())

// Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
// with the same secret, in case the app needs callbacks from the JVM to the underlying
// python processes.
val localhost = InetAddress.getLoopbackAddress()
val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
.authToken(secret)
.javaPort(0)
.javaAddress(localhost)
.callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()
val sparkConf = new SparkConf()
val gatewayServer: Py4JServer = new Py4JServer(sparkConf)

gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
if (boundPort == -1) {
logError("GatewayServer failed to bind; exiting")
logError(s"${gatewayServer.server.getClass} failed to bind; exiting")
System.exit(1)
} else {
logDebug(s"Started PythonGatewayServer on port $boundPort")
Expand All @@ -68,7 +54,7 @@ private[spark] object PythonGatewayServer extends Logging {
val dos = new DataOutputStream(new FileOutputStream(tmpPath))
dos.writeInt(boundPort)

val secretBytes = secret.getBytes(UTF_8)
val secretBytes = gatewayServer.secret.getBytes(UTF_8)
dos.writeInt(secretBytes.length)
dos.write(secretBytes, 0, secretBytes.length)
dos.close()
Expand Down
18 changes: 5 additions & 13 deletions core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
package org.apache.spark.deploy

import java.io.File
import java.net.{InetAddress, URI}
import java.net.URI
import java.nio.file.Files

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.util.Try

import org.apache.spark.{SparkConf, SparkUserAppException}
import org.apache.spark.api.python.PythonUtils
import org.apache.spark.api.python.{Py4JServer, PythonUtils}
import org.apache.spark.internal.config._
import org.apache.spark.util.{RedirectThread, Utils}

Expand All @@ -40,7 +40,6 @@ object PythonRunner {
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
val sparkConf = new SparkConf()
val secret = Utils.createSecret(sparkConf)
val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON)
.orElse(sparkConf.get(PYSPARK_PYTHON))
.orElse(sys.env.get("PYSPARK_DRIVER_PYTHON"))
Expand All @@ -51,15 +50,8 @@ object PythonRunner {
val formattedPythonFile = formatPath(pythonFile)
val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles))

// Launch a Py4J gateway server for the process to connect to; this will let it see our
// Java system properties and such
val localhost = InetAddress.getLoopbackAddress()
val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder()
.authToken(secret)
.javaPort(0)
.javaAddress(localhost)
.callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
.build()
val gatewayServer = new Py4JServer(sparkConf)

val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() })
thread.setName("py4j-gateway-init")
thread.setDaemon(true)
Expand All @@ -86,7 +78,7 @@ object PythonRunner {
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
env.put("PYSPARK_GATEWAY_SECRET", secret)
env.put("PYSPARK_GATEWAY_SECRET", gatewayServer.secret)
// pass conf spark.pyspark.python to python process, the only way to pass info to
// python process is through environment variable.
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
Expand Down
18 changes: 18 additions & 0 deletions docs/job-scheduling.md
Original file line number Diff line number Diff line change
Expand Up @@ -287,3 +287,21 @@ users can set the `spark.sql.thriftserver.scheduler.pool` variable:
{% highlight SQL %}
SET spark.sql.thriftserver.scheduler.pool=accounting;
{% endhighlight %}

## Concurrent Jobs in PySpark

PySpark, by default, does not support to synchronize PVM threads with JVM threads and
launching multiple jobs in multiple PVM threads does not guarantee to launch each job
in each corresponding JVM thread. Due to this limitation, it is unable to set a different job group
via `sc.setJobGroup` in a separate PVM thread, which also disallows to cancel the job via `sc.cancelJobGroup`
later.

In order to synchronize PVM threads with JVM threads, you should set `PYSPARK_PIN_THREAD` environment variable
to `true`. This pinned thread mode allows one PVM thread has one corresponding JVM thread.

However, currently it cannot inherit the local properties from the parent thread although it isolates
each thread with its own local properties. To work around this, you should manually copy and set the
local properties from the parent thread to the child thread when you create another thread in PVM.

Note that `PYSPARK_PIN_THREAD` is currently experimental and not recommended for use in production.

77 changes: 74 additions & 3 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,14 +1009,61 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False):
in Thread.interrupt() being called on the job's executor threads. This is useful to help
ensure that the tasks are actually stopped in a timely manner, but is off by default due
to HDFS-1208, where HDFS may respond to Thread.interrupt() by marking nodes as dead.
"""
.. note:: Currently, setting a group ID (set to local properties) with a thread does
not properly work. Internally threads on PVM and JVM are not synced, and JVM thread
can be reused for multiple threads on PVM, which fails to isolate local properties
for each thread on PVM. To work around this, you can set `PYSPARK_PIN_THREAD` to
`'true'` (see SPARK-22340). However, note that it cannot inherit the local properties
from the parent thread although it isolates each thread on PVM and JVM with its own
local properties. To work around this, you should manually copy and set the local
properties from the parent thread to the child thread when you create another thread.
"""
warnings.warn(
"Currently, setting a group ID (set to local properties) with a thread does "
"not properly work. "
"\n"
"Internally threads on PVM and JVM are not synced, and JVM thread can be reused "
"for multiple threads on PVM, which fails to isolate local properties for each "
"thread on PVM. "
"\n"
"To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). "
"However, note that it cannot inherit the local properties from the parent thread "
"although it isolates each thread on PVM and JVM with its own local properties. "
"\n"
"To work around this, you should manually copy and set the local properties from "
"the parent thread to the child thread when you create another thread.",
UserWarning)
self._jsc.setJobGroup(groupId, description, interruptOnCancel)

def setLocalProperty(self, key, value):
"""
Set a local property that affects jobs submitted from this thread, such as the
Spark fair scheduler pool.
"""
.. note:: Currently, setting a local property with a thread does
not properly work. Internally threads on PVM and JVM are not synced, and JVM thread
can be reused for multiple threads on PVM, which fails to isolate local properties
for each thread on PVM. To work around this, you can set `PYSPARK_PIN_THREAD` to
`'true'` (see SPARK-22340). However, note that it cannot inherit the local properties
from the parent thread although it isolates each thread on PVM and JVM with its own
local properties. To work around this, you should manually copy and set the local
properties from the parent thread to the child thread when you create another thread.
"""
warnings.warn(
"Currently, setting a local property with a thread does not properly work. "
"\n"
"Internally threads on PVM and JVM are not synced, and JVM thread can be reused "
"for multiple threads on PVM, which fails to isolate local properties for each "
"thread on PVM. "
"\n"
"To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). "
"However, note that it cannot inherit the local properties from the parent thread "
"although it isolates each thread on PVM and JVM with its own local properties. "
"\n"
"To work around this, you should manually copy and set the local properties from "
"the parent thread to the child thread when you create another thread.",
UserWarning)
self._jsc.setLocalProperty(key, value)

def getLocalProperty(self, key):
Expand All @@ -1029,7 +1076,31 @@ def getLocalProperty(self, key):
def setJobDescription(self, value):
"""
Set a human readable description of the current job.
"""
.. note:: Currently, setting a job description (set to local properties) with a thread does
not properly work. Internally threads on PVM and JVM are not synced, and JVM thread
can be reused for multiple threads on PVM, which fails to isolate local properties
for each thread on PVM. To work around this, you can set `PYSPARK_PIN_THREAD` to
`'true'` (see SPARK-22340). However, note that it cannot inherit the local properties
from the parent thread although it isolates each thread on PVM and JVM with its own
local properties. To work around this, you should manually copy and set the local
properties from the parent thread to the child thread when you create another thread.
"""
warnings.warn(
"Currently, setting a job description (set to local properties) with a thread does "
"not properly work. "
"\n"
"Internally threads on PVM and JVM are not synced, and JVM thread can be reused "
"for multiple threads on PVM, which fails to isolate local properties for each "
"thread on PVM. "
"\n"
"To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). "
"However, note that it cannot inherit the local properties from the parent thread "
"although it isolates each thread on PVM and JVM with its own local properties. "
"\n"
"To work around this, you should manually copy and set the local properties from "
"the parent thread to the child thread when you create another thread.",
UserWarning)
self._jsc.setJobDescription(value)

def sparkUser(self):
Expand Down
22 changes: 18 additions & 4 deletions python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
xrange = range

from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters
from py4j.clientserver import ClientServer, JavaParameters, PythonParameters
from pyspark.find_spark_home import _find_spark_home
from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
from pyspark.util import _exception_message
Expand Down Expand Up @@ -125,10 +126,23 @@ def killChild():
Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
atexit.register(killChild)

# Connect to the gateway
gateway = JavaGateway(
gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
auto_convert=True))
# Connect to the gateway (or client server to pin the thread between JVM and Python)
if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true":
gateway = ClientServer(
java_parameters=JavaParameters(
port=gateway_port,
auth_token=gateway_secret,
auto_convert=True),
python_parameters=PythonParameters(
port=0,
eager_load=False))
else:
gateway = JavaGateway(
gateway_parameters=GatewayParameters(
port=gateway_port,
auth_token=gateway_secret,
auto_convert=True))

# Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr)
gateway.proc = proc

Expand Down
23 changes: 16 additions & 7 deletions python/pyspark/ml/tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pyspark.ml.wrapper import _java2py, _py2java, JavaParams, JavaWrapper
from pyspark.testing.mllibutils import MLlibTestCase
from pyspark.testing.mlutils import SparkSessionTestCase
from pyspark.testing.utils import eventually


class JavaWrapperMemoryTests(SparkSessionTestCase):
Expand All @@ -50,19 +51,27 @@ def test_java_object_gets_detached(self):

model.__del__()

with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
def condition():
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString())
return True

eventually(condition, timeout=10, catch_assertions=True)

try:
summary.__del__()
except:
pass

with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
summary._java_obj.toString()
def condition():
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
model._java_obj.toString()
with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object):
summary._java_obj.toString()
return True

eventually(condition, timeout=10, catch_assertions=True)


class WrapperTests(MLlibTestCase):
Expand Down
Loading

0 comments on commit 4ec04e5

Please sign in to comment.