diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala index 51402a835aef5..d1d49153dafc5 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -604,13 +604,20 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { .get("spark.remote") .orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit .orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE)) + .orElse { + if (isAPIModeConnect) { + Option(System.getProperty("spark.master")).orElse(sys.env.get("MASTER")) + } else { + None + } + } val maybeConnectScript = Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh")) if (server.isEmpty && - maybeConnectScript.exists(Files.exists(_)) && - (remoteString.exists(_.startsWith("local")) || isAPIModeConnect)) { + maybeConnectScript.exists(Files.exists(_)) && + remoteString.isDefined) { server = Some { val args = Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions @@ -658,12 +665,19 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { // Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE // by default, if it exists. The connection string can be overridden using // the remote() function, as it takes precedence over the SPARK_REMOTE environment variable. - private val builder = SparkConnectClient.builder().loadFromEnvironment() + private var connectionString: Option[String] = None + private var interceptor: Option[ClientInterceptor] = None private var client: SparkConnectClient = _ + private lazy val builder = { + val b = SparkConnectClient.builder() + connectionString.foreach(b.connectionString) + interceptor.foreach(b.interceptor) + b.loadFromEnvironment() + } /** @inheritdoc */ def remote(connectionString: String): this.type = { - builder.connectionString(connectionString) + this.connectionString = Some(connectionString) this } @@ -675,7 +689,7 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging { * @since 3.5.0 */ def interceptor(interceptor: ClientInterceptor): this.type = { - builder.interceptor(interceptor) + this.interceptor = Some(interceptor) this } diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index cf781a9ff7dc0..b96fa1523829a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -527,7 +527,9 @@ protected boolean handle(String opt, String value) { checkArgument(value != null, "Missing argument to %s", CONF); String[] setConf = value.split("=", 2); checkArgument(setConf.length == 2, "Invalid argument to %s: %s", CONF, value); - if (setConf[0].equals("spark.remote")) { + if (setConf[0].equals("spark.remote") || + (setConf[0].equals("spark.api.mode") && + setConf[1].toLowerCase(Locale.ROOT).equals("connect"))) { isRemote = true; } conf.put(setConf[0], setConf[1]); diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index e0dfac62847ea..043c183e62bef 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -190,6 +190,10 @@ abstract class BaseYarnClusterSuite extends SparkFunSuite with Matchers { .setPropertiesFile(propsFile) .addAppArgs(appArgs.toArray: _*) + extraConf.get("spark.api.mode").foreach { v => + launcher.setConf("spark.api.mode", v) + } + sparkArgs.foreach { case (name, value) => if (value != null) { launcher.addSparkArg(name, value) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 71843b7f90b1f..b74ff43fc22dc 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -86,6 +86,34 @@ class YarnClusterSuite extends BaseYarnClusterSuite { | sc.stop() """.stripMargin + private val TEST_CONNECT_PYFILE = """ + |import mod1, mod2 + |import sys + |from operator import add + | + |from pyspark.sql import SparkSession + |from pyspark.sql.functions import udf + |if __name__ == "__main__": + | if len(sys.argv) != 2: + | print >> sys.stderr, "Usage: test.py [result file]" + | exit(-1) + | spark = SparkSession.builder.config( + | "spark.api.mode", "connect").master("yarn").getOrCreate() + | assert "connect" in str(spark) + | status = open(sys.argv[1],'w') + | result = "failure" + | @udf + | def test(): + | return mod1.func() * mod2.func() + | df = spark.range(10).select(test()) + | cnt = df.count() + | if cnt == 10: + | result = "success" + | status.write(result) + | status.close() + | spark.stop() + """.stripMargin + private val TEST_PYMODULE = """ |def func(): | return 42 @@ -162,6 +190,32 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testWithAddJar(clientMode = false, s"local:$jarPath") } + test("run Scala application with Spark Connect in yarn-client mode") { + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark( + true, + mainClassName(YarnConnectTest.getClass()), + appArgs = Seq(result.getAbsolutePath()), + extraConf = Map("spark.api.mode" -> "connect")) + val output = new Object() { + override def toString: String = Files.asCharSource(result, StandardCharsets.UTF_8).read() + } + assert(finalState == SparkAppHandle.State.FINISHED, output) + } + + test("run Scala application with Spark Connect in yarn-cluster mode") { + val result = File.createTempFile("result", null, tempDir) + val finalState = runSpark( + false, + mainClassName(YarnConnectTest.getClass()), + appArgs = Seq(result.getAbsolutePath()), + extraConf = Map("spark.api.mode" -> "connect")) + val output = new Object() { + override def toString: String = Files.asCharSource(result, StandardCharsets.UTF_8).read() + } + assert(finalState == SparkAppHandle.State.FINISHED, output) + } + test("SPARK-35672: run Spark in yarn-client mode with additional jar using URI scheme 'local' " + "and gateway-replacement path") { // Use the original jar URL, but set up the gateway/replacement configs such that if @@ -237,6 +291,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testPySpark(false) } + test("run Python application with Spark Connect in yarn-client mode") { + testPySpark(true, extraConf = Map("spark.api.mode" -> "connect"), script = TEST_CONNECT_PYFILE) + } + + test("run Python application with Spark Connect in yarn-cluster mode") { + testPySpark(false, extraConf = Map("spark.api.mode" -> "connect"), script = TEST_CONNECT_PYFILE) + } + test("run Python application in yarn-cluster mode using " + "spark.yarn.appMasterEnv to override local envvar") { testPySpark( @@ -370,10 +432,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite { private def testPySpark( clientMode: Boolean, extraConf: Map[String, String] = Map(), - extraEnv: Map[String, String] = Map()): Unit = { + extraEnv: Map[String, String] = Map(), + script: String = TEST_PYFILE): Unit = { assume(isPythonAvailable) val primaryPyFile = new File(tempDir, "test.py") - Files.asCharSink(primaryPyFile, StandardCharsets.UTF_8).write(TEST_PYFILE) + Files.asCharSink(primaryPyFile, StandardCharsets.UTF_8).write(script) // When running tests, let's not assume the user has built the assembly module, which also // creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the @@ -712,6 +775,34 @@ private object YarnClasspathTest extends Logging { } +private object YarnConnectTest extends Logging { + def main(args: Array[String]): Unit = { + val output = new java.io.PrintStream(new File(args(0))) + val clz = Utils.classForName("org.apache.spark.sql.SparkSession$") + val moduleField = clz.getDeclaredField("MODULE$") + val obj = moduleField.get(null) + var builder = clz.getMethod("builder").invoke(obj) + builder = builder.getClass().getMethod( + "config", classOf[String], classOf[String]).invoke(builder, "spark.api.mode", "connect") + builder = builder.getClass().getMethod("master", classOf[String]).invoke(builder, "yarn") + val session = builder.getClass().getMethod("getOrCreate").invoke(builder) + + try { + // Check if the current session is a Spark Connect session. + session.getClass().getDeclaredField("client") + val df = session.getClass().getMethod("range", classOf[Long]).invoke(session, 10) + assert(df.getClass().getMethod("count").invoke(df) == 10) + } catch { + case e: Throwable => + e.printStackTrace(new java.io.PrintStream(output)) + throw e + } finally { + session.getClass().getMethod("stop").invoke(session) + output.close() + } + } +} + private object YarnAddJarTest extends Logging { def main(args: Array[String]): Unit = { if (args.length != 3) {