Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51267][CONNECT] Match local Spark Connect server logic between Python and Scala #50017

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,12 @@ private[spark] class SparkSubmit extends Logging {
val childArgs = new ArrayBuffer[String]()
val childClasspath = new ArrayBuffer[String]()
val sparkConf = args.toSparkConf()
if (sparkConf.contains("spark.local.connect")) sparkConf.remove("spark.remote")
var childMainClass = ""

// Set the cluster manager
val clusterManager: Int = args.maybeMaster match {
case Some(v) =>
assert(args.maybeRemote.isEmpty || sparkConf.contains("spark.local.connect"))
assert(args.maybeRemote.isEmpty)
v match {
case "yarn" => YARN
case m if m.startsWith("spark") => STANDALONE
Expand Down Expand Up @@ -643,14 +642,11 @@ private[spark] class SparkSubmit extends Logging {
// All cluster managers
OptionAssigner(
// If remote is not set, sets the master,
// In local remote mode, starts the default master to to start the server.
if (args.maybeRemote.isEmpty || sparkConf.contains("spark.local.connect")) args.master
if (args.maybeRemote.isEmpty) args.master
else args.maybeMaster.orNull,
ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.master"),
OptionAssigner(
// In local remote mode, do not set remote.
if (sparkConf.contains("spark.local.connect")) null
else args.maybeRemote.orNull, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.remote"),
args.maybeRemote.orNull, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.remote"),
OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
confKey = SUBMIT_DEPLOY_MODE.key),
OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.app.name"),
Expand Down Expand Up @@ -767,8 +763,7 @@ private[spark] class SparkSubmit extends Logging {
// In case of shells, spark.ui.showConsoleProgress can be true by default or by user. Except,
// when Spark Connect is in local mode, because Spark Connect support its own progress
// reporting.
if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS) &&
!sparkConf.contains("spark.local.connect")) {
if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) {
sparkConf.set(UI_SHOW_CONSOLE_PROGRESS, true)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S
if (args.length == 0) {
printUsageAndExit(-1)
}
if (!sparkProperties.contains("spark.local.connect") &&
maybeRemote.isDefined && (maybeMaster.isDefined || deployMode != null)) {
if (maybeRemote.isDefined && (maybeMaster.isDefined || deployMode != null)) {
error("Remote cannot be specified with master and/or deploy mode.")
}
if (primaryResource == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,8 @@ List<String> buildClassPath(String appClassPath) throws IOException {
addToClassPath(cp, f.toString());
}
}
// If we're in 'spark.local.connect', it should create a Spark Classic Spark Context
// that launches Spark Connect server.
if (isRemote && System.getenv("SPARK_LOCAL_CONNECT") == null) {

if (isRemote) {
for (File f: new File(jarsDir).listFiles()) {
// Exclude Spark Classic SQL and Spark Connect server jars
// if we're in Spark Connect Shell. Also exclude Spark SQL API and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ public class SparkLauncher extends AbstractLauncher<SparkLauncher> {

/** The Spark remote. */
public static final String SPARK_REMOTE = "spark.remote";
public static final String SPARK_LOCAL_REMOTE = "spark.local.connect";

/** The Spark API mode. */
public static final String SPARK_API_MODE = "spark.api.mode";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ public List<String> buildCommand(Map<String, String> env)
}

List<String> buildSparkSubmitArgs() {
return buildSparkSubmitArgs(true);
}

List<String> buildSparkSubmitArgs(boolean includeRemote) {
List<String> args = new ArrayList<>();
OptionParser parser = new OptionParser(false);
final boolean isSpecialCommand;
Expand All @@ -210,7 +214,7 @@ List<String> buildSparkSubmitArgs() {
args.add(master);
}

if (remote != null) {
if (includeRemote && remote != null) {
args.add(parser.REMOTE);
args.add(remote);
}
Expand All @@ -226,8 +230,12 @@ List<String> buildSparkSubmitArgs() {
}

for (Map.Entry<String, String> e : conf.entrySet()) {
args.add(parser.CONF);
args.add(String.format("%s=%s", e.getKey(), e.getValue()));
if (includeRemote ||
(!e.getKey().equalsIgnoreCase("spark.api.mode") &&
!e.getKey().equalsIgnoreCase("spark.remote"))) {
args.add(parser.CONF);
args.add(String.format("%s=%s", e.getKey(), e.getValue()));
}
}

if (propertiesFile != null) {
Expand Down Expand Up @@ -368,7 +376,8 @@ private List<String> buildPySparkShellCommand(Map<String, String> env) throws IO
// When launching the pyspark shell, the spark-submit arguments should be stored in the
// PYSPARK_SUBMIT_ARGS env variable.
appResource = PYSPARK_SHELL_RESOURCE;
constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS");
// Do not pass remote configurations to Spark Connect server via Py4J.
constructEnvVarArgs(env, "PYSPARK_SUBMIT_ARGS", false);

// Will pick up the binary executable in the following order
// 1. conf spark.pyspark.driver.python
Expand All @@ -391,8 +400,7 @@ private List<String> buildPySparkShellCommand(Map<String, String> env) throws IO
String masterStr = firstNonEmpty(master, conf.getOrDefault(SparkLauncher.SPARK_MASTER, null));
String deployStr = firstNonEmpty(
deployMode, conf.getOrDefault(SparkLauncher.DEPLOY_MODE, null));
if (!conf.containsKey(SparkLauncher.SPARK_LOCAL_REMOTE) &&
remoteStr != null && (masterStr != null || deployStr != null)) {
if (remoteStr != null && (masterStr != null || deployStr != null)) {
throw new IllegalStateException("Remote cannot be specified with master and/or deploy mode.");
}

Expand Down Expand Up @@ -423,7 +431,7 @@ private List<String> buildSparkRCommand(Map<String, String> env) throws IOExcept
// When launching the SparkR shell, store the spark-submit arguments in the SPARKR_SUBMIT_ARGS
// env variable.
appResource = SPARKR_SHELL_RESOURCE;
constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS");
constructEnvVarArgs(env, "SPARKR_SUBMIT_ARGS", true);

// Set shell.R as R_PROFILE_USER to load the SparkR package when the shell comes up.
String sparkHome = System.getenv("SPARK_HOME");
Expand All @@ -438,12 +446,13 @@ private List<String> buildSparkRCommand(Map<String, String> env) throws IOExcept

private void constructEnvVarArgs(
Map<String, String> env,
String submitArgsEnvVariable) throws IOException {
String submitArgsEnvVariable,
boolean includeRemote) throws IOException {
mergeEnvPathList(env, getLibPathEnvName(),
getEffectiveConfig().get(SparkLauncher.DRIVER_EXTRA_LIBRARY_PATH));

StringBuilder submitArgs = new StringBuilder();
for (String arg : buildSparkSubmitArgs()) {
for (String arg : buildSparkSubmitArgs(includeRemote)) {
if (submitArgs.length() > 0) {
submitArgs.append(" ");
}
Expand Down
3 changes: 3 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ object MimaExcludes {
ProblemFilters.exclude[Problem]("org.sparkproject.spark_protobuf.protobuf.*"),
ProblemFilters.exclude[Problem]("org.apache.spark.sql.protobuf.utils.SchemaConverters.*"),

// SPARK-51267: Match local Spark Connect server logic between Python and Scala
ProblemFilters.exclude[MissingFieldProblem]("org.apache.spark.launcher.SparkLauncher.SPARK_LOCAL_REMOTE"),

(problem: Problem) => problem match {
case MissingClassProblem(cls) => !cls.fullName.startsWith("org.sparkproject.jpmml") &&
!cls.fullName.startsWith("org.sparkproject.dmg.pmml")
Expand Down
3 changes: 1 addition & 2 deletions python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,8 +326,7 @@ def default_port() -> int:
# This is only used in the test/development mode.
session = PySparkSession._instantiatedSession

# 'spark.local.connect' is set when we use the local mode in Spark Connect.
if session is not None and session.conf.get("spark.local.connect", "0") == "1":
if session is not None:
jvm = PySparkSession._instantiatedSession._jvm # type: ignore[union-attr]
return getattr(
getattr(
Expand Down
7 changes: 4 additions & 3 deletions python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,8 +1044,10 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
# Configurations to be overwritten
overwrite_conf = opts
overwrite_conf["spark.master"] = master
overwrite_conf["spark.local.connect"] = "1"
os.environ["SPARK_LOCAL_CONNECT"] = "1"
if "spark.remote" in overwrite_conf:
del overwrite_conf["spark.remote"]
if "spark.api.mode" in overwrite_conf:
del overwrite_conf["spark.api.mode"]

# Configurations to be set if unset.
default_conf = {
Expand Down Expand Up @@ -1083,7 +1085,6 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
finally:
if origin_remote is not None:
os.environ["SPARK_REMOTE"] = origin_remote
del os.environ["SPARK_LOCAL_CONNECT"]
else:
raise PySparkRuntimeError(
errorClass="SESSION_OR_CONTEXT_EXISTS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,17 @@ class SparkSession private[sql] (
}
allocator.close()
SparkSession.onSessionClose(this)
SparkSession.server.synchronized {
if (SparkSession.server.isDefined) {
// When local mode is in use, follow the regular Spark session's
// behavior by terminating the Spark Connect server,
// meaning that you can stop local mode, and restart the Spark Connect
// client with a different remote address.
new ProcessBuilder(SparkSession.maybeConnectStopScript.get.toString)
.start()
SparkSession.server = None
}
}
}

/** @inheritdoc */
Expand Down Expand Up @@ -679,6 +690,10 @@ object SparkSession extends SparkSessionCompanion with Logging {
private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
private val maybeConnectStartScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))
private val maybeConnectStopScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "stop-connect-server.sh"))
private[sql] val sparkOptions = sys.props.filter { p =>
p._1.startsWith("spark.") && p._2.nonEmpty
}.toMap
Expand All @@ -695,34 +710,37 @@ object SparkSession extends SparkSessionCompanion with Logging {
* Create a new Spark Connect server to connect locally.
*/
private[sql] def withLocalConnectServer[T](f: => T): T = {
synchronized {
lazy val isAPIModeConnect =
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
.getOrElse("classic")
.toLowerCase(Locale.ROOT) == "connect"
val remoteString = sparkOptions
.get("spark.remote")
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
.orElse {
if (isAPIModeConnect) {
sparkOptions.get("spark.master").orElse(sys.env.get("MASTER"))
} else {
None
}
lazy val isAPIModeConnect =
Option(System.getProperty(org.apache.spark.sql.SparkSessionBuilder.API_MODE_KEY))
.getOrElse("classic")
.toLowerCase(Locale.ROOT) == "connect"
val remoteString = sparkOptions
.get("spark.remote")
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
.orElse {
if (isAPIModeConnect) {
sparkOptions.get("spark.master").orElse(sys.env.get("MASTER"))
} else {
None
}
}

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

server.synchronized {
if (server.isEmpty &&
(remoteString.exists(_.startsWith("local")) ||
(remoteString.isDefined && isAPIModeConnect)) &&
maybeConnectScript.exists(Files.exists(_))) {
maybeConnectStartScript.exists(Files.exists(_))) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
Seq(
maybeConnectStartScript.get.toString,
"--master",
remoteString.get) ++ (sparkOptions ++ Map(
"spark.sql.artifact.isolation.enabled" -> "true",
"spark.sql.artifact.isolation.alwaysApplyClassloader" -> "true"))
.filter(p => !p._1.startsWith("spark.remote"))
.filter(p => !p._1.startsWith("spark.api.mode"))
.flatMap { case (k, v) => Seq("--conf", s"$k=$v") }
val pb = new ProcessBuilder(args: _*)
// So don't exclude spark-sql jar in classpath
Expand All @@ -737,14 +755,17 @@ object SparkSession extends SparkSessionCompanion with Logging {

// scalastyle:off runtimeaddshutdownhook
Runtime.getRuntime.addShutdownHook(new Thread() {
override def run(): Unit = if (server.isDefined) {
new ProcessBuilder(maybeConnectScript.get.toString)
.start()
override def run(): Unit = server.synchronized {
if (server.isDefined) {
new ProcessBuilder(maybeConnectStopScript.get.toString)
.start()
}
}
})
// scalastyle:on runtimeaddshutdownhook
}
}

f
}

Expand Down