Skip to content

Commit

Permalink
Revert "[SPARK-45302][PYTHON] Remove PID communication between Python…
Browse files Browse the repository at this point in the history
…workers when no demon is used"

### What changes were proposed in this pull request?

This PR reverts #43087.

### Why are the changes needed?

To clean up those workers. I will make a refactoring PR soon. I will bring them back again with a refactoring PR.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

CI

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46195 from HyukjinKwon/SPARK-45302-revert.

Authored-by: Hyukjin Kwon <gurwls223@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
HyukjinKwon committed Apr 25, 2024
1 parent ea37c86 commit c6aaa18
Show file tree
Hide file tree
Showing 17 changed files with 44 additions and 19 deletions.
6 changes: 3 additions & 3 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class SparkEnv (
workerModule: String,
daemonModule: String,
envVars: Map[String, String],
useDaemon: Boolean): (PythonWorker, Option[Long]) = {
useDaemon: Boolean): (PythonWorker, Option[Int]) = {
synchronized {
val key = PythonWorkersKey(pythonExec, workerModule, daemonModule, envVars)
val workerFactory = pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(
Expand All @@ -161,7 +161,7 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
envVars: Map[String, String],
useDaemon: Boolean): (PythonWorker, Option[Long]) = {
useDaemon: Boolean): (PythonWorker, Option[Int]) = {
createPythonWorker(
pythonExec, workerModule, PythonWorkerFactory.defaultDaemonModule, envVars, useDaemon)
}
Expand All @@ -170,7 +170,7 @@ class SparkEnv (
pythonExec: String,
workerModule: String,
daemonModule: String,
envVars: Map[String, String]): (PythonWorker, Option[Long]) = {
envVars: Map[String, String]): (PythonWorker, Option[Int]) = {
val useDaemon = conf.get(Python.PYTHON_USE_DAEMON)
createPythonWorker(
pythonExec, workerModule, daemonModule, envVars, useDaemon)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ private object BasePythonRunner {

private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler")

private def faultHandlerLogPath(pid: Long): Path = {
private def faultHandlerLogPath(pid: Int): Path = {
new File(faultHandlerLogDir, pid.toString).toPath
}
}
Expand Down Expand Up @@ -204,7 +204,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](

envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default"))

val (worker: PythonWorker, pid: Option[Long]) = env.createPythonWorker(
val (worker: PythonWorker, pid: Option[Int]) = env.createPythonWorker(
pythonExec, workerModule, daemonModule, envVars.asScala.toMap)
// Whether is the worker released into idle pool or closed. When any codes try to release or
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
Expand Down Expand Up @@ -257,7 +257,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Long],
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT]

Expand Down Expand Up @@ -465,7 +465,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Long],
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext)
extends Iterator[OUT] {
Expand Down Expand Up @@ -842,7 +842,7 @@ private[spark] class PythonRunner(
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Long],
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private[spark] class PythonWorkerFactory(
envVars.getOrElse("PYTHONPATH", ""),
sys.env.getOrElse("PYTHONPATH", ""))

def create(): (PythonWorker, Option[Long]) = {
def create(): (PythonWorker, Option[Int]) = {
if (useDaemon) {
self.synchronized {
// Pull from idle workers until we one that is alive, otherwise create a new one.
Expand All @@ -102,7 +102,7 @@ private[spark] class PythonWorkerFactory(
if (workerHandle.isAlive()) {
try {
worker.selectionKey.interestOps(SelectionKey.OP_READ | SelectionKey.OP_WRITE)
return (worker, Some(workerHandle.pid()))
return (worker, Some(workerHandle.pid().toInt))
} catch {
case c: CancelledKeyException => /* pass */
}
Expand All @@ -122,9 +122,9 @@ private[spark] class PythonWorkerFactory(
* processes itself to avoid the high cost of forking from Java. This currently only works
* on UNIX-based systems.
*/
private def createThroughDaemon(): (PythonWorker, Option[Long]) = {
private def createThroughDaemon(): (PythonWorker, Option[Int]) = {

def createWorker(): (PythonWorker, Option[Long]) = {
def createWorker(): (PythonWorker, Option[Int]) = {
val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
// These calls are blocking.
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
Expand Down Expand Up @@ -165,7 +165,7 @@ private[spark] class PythonWorkerFactory(
/**
* Launch a worker by executing worker.py (by default) directly and telling it to connect to us.
*/
private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Long]) = {
private[spark] def createSimpleWorker(blockingMode: Boolean): (PythonWorker, Option[Int]) = {
var serverSocketChannel: ServerSocketChannel = null
try {
serverSocketChannel = ServerSocketChannel.open()
Expand Down Expand Up @@ -209,7 +209,8 @@ private[spark] class PythonWorkerFactory(
"Timed out while waiting for the Python worker to connect back")
}
authHelper.authClient(socketChannel.socket())
val pid = workerProcess.toHandle.pid()
// TODO: When we drop JDK 8, we can just use workerProcess.pid()
val pid = new DataInputStream(Channels.newInputStream(socketChannel)).readInt()
if (pid < 0) {
throw new IllegalStateException("Python failed to launch worker with code " + pid)
}
Expand Down Expand Up @@ -405,7 +406,7 @@ private[spark] class PythonWorkerFactory(
daemonWorkers.get(worker).foreach { processHandle =>
// tell daemon to kill worker by pid
val output = new DataOutputStream(daemon.getOutputStream)
output.writeLong(processHandle.pid())
output.writeInt(processHandle.pid().toInt)
output.flush()
daemon.getOutputStream.flush()
}
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from socket import AF_INET, AF_INET6, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT

from pyspark.serializers import read_long, write_int, write_with_length, UTF8Deserializer
from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer

if len(sys.argv) > 1:
import importlib
Expand Down Expand Up @@ -139,7 +139,7 @@ def handle_sigterm(*args):

if 0 in ready_fds:
try:
worker_pid = read_long(stdin_bin)
worker_pid = read_int(stdin_bin)
except EOFError:
# Spark told us to exit by closing stdin
shutdown(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,6 @@ def process(df_id, batch_id): # type: ignore[no-untyped-def]
(sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
# There could be a long time between each micro batch.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,6 @@ def process(listener_event_str, listener_event_type): # type: ignore[no-untyped
(sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
# There could be a long time between each listener event.
sock.settimeout(None)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,6 @@ def main(infile: IO, outfile: IO) -> None:
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
3 changes: 3 additions & 0 deletions python/pyspark/sql/worker/analyze_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,7 @@ def invalid_analyze_result_field(field_name: str, expected_field: str) -> PySpar
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
2 changes: 2 additions & 0 deletions python/pyspark/sql/worker/commit_data_source_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,6 @@ def main(infile: IO, outfile: IO) -> None:
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
2 changes: 2 additions & 0 deletions python/pyspark/sql/worker/create_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,4 +187,6 @@ def main(infile: IO, outfile: IO) -> None:
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
2 changes: 2 additions & 0 deletions python/pyspark/sql/worker/lookup_data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,6 @@ def main(infile: IO, outfile: IO) -> None:
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
2 changes: 2 additions & 0 deletions python/pyspark/sql/worker/plan_data_source_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,4 +299,6 @@ def batched(iterator: Iterator, n: int) -> Iterator:
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
2 changes: 2 additions & 0 deletions python/pyspark/sql/worker/python_streaming_sink_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,4 +137,6 @@ def main(infile: IO, outfile: IO) -> None:
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
2 changes: 2 additions & 0 deletions python/pyspark/sql/worker/write_into_data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,6 @@ def batch_to_rows() -> Iterator[Row]:
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
3 changes: 3 additions & 0 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,4 +1868,7 @@ def process():
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
# TODO: Remove the following two lines and use `Process.pid()` when we drop JDK 8.
write_int(os.getpid(), sock_file)
sock_file.flush()
main(sock_file, sock_file)
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Long],
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT] = {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ abstract class BasePythonUDFRunner(
startTime: Long,
env: SparkEnv,
worker: PythonWorker,
pid: Option[Long],
pid: Option[Int],
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(
Expand Down

0 comments on commit c6aaa18

Please sign in to comment.