Skip to content

Commit

Permalink
Add a YARN test
Browse files Browse the repository at this point in the history
  • Loading branch information
HyukjinKwon committed Dec 11, 2024
1 parent 0fabc3a commit bc35f57
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -604,13 +604,20 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging {
.get("spark.remote")
.orElse(Option(System.getProperty("spark.remote"))) // Set from Spark Submit
.orElse(sys.env.get(SparkConnectClient.SPARK_REMOTE))
.orElse {
if (isAPIModeConnect) {
Option(System.getProperty("spark.master")).orElse(sys.env.get("MASTER"))
} else {
None
}
}

val maybeConnectScript =
Option(System.getenv("SPARK_HOME")).map(Paths.get(_, "sbin", "start-connect-server.sh"))

if (server.isEmpty &&
maybeConnectScript.exists(Files.exists(_)) &&
(remoteString.exists(_.startsWith("local")) || isAPIModeConnect)) {
maybeConnectScript.exists(Files.exists(_)) &&
remoteString.isDefined) {
server = Some {
val args =
Seq(maybeConnectScript.get.toString, "--master", remoteString.get) ++ sparkOptions
Expand Down Expand Up @@ -658,12 +665,19 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging {
// Initialize the connection string of the Spark Connect client builder from SPARK_REMOTE
// by default, if it exists. The connection string can be overridden using
// the remote() function, as it takes precedence over the SPARK_REMOTE environment variable.
private val builder = SparkConnectClient.builder().loadFromEnvironment()
private var connectionString: Option[String] = None
private var interceptor: Option[ClientInterceptor] = None
private var client: SparkConnectClient = _
private lazy val builder = {
val b = SparkConnectClient.builder()
connectionString.foreach(b.connectionString)
interceptor.foreach(b.interceptor)
b.loadFromEnvironment()
}

/** @inheritdoc */
def remote(connectionString: String): this.type = {
builder.connectionString(connectionString)
this.connectionString = Some(connectionString)
this
}

Expand All @@ -675,7 +689,7 @@ object SparkSession extends api.BaseSparkSessionCompanion with Logging {
* @since 3.5.0
*/
def interceptor(interceptor: ClientInterceptor): this.type = {
builder.interceptor(interceptor)
this.interceptor = Some(interceptor)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,9 @@ protected boolean handle(String opt, String value) {
checkArgument(value != null, "Missing argument to %s", CONF);
String[] setConf = value.split("=", 2);
checkArgument(setConf.length == 2, "Invalid argument to %s: %s", CONF, value);
if (setConf[0].equals("spark.remote")) {
if (setConf[0].equals("spark.remote") ||
(setConf[0].equals("spark.api.mode") &&
setConf[1].toLowerCase(Locale.ROOT).equals("connect"))) {
isRemote = true;
}
conf.put(setConf[0], setConf[1]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,10 @@ abstract class BaseYarnClusterSuite extends SparkFunSuite with Matchers {
.setPropertiesFile(propsFile)
.addAppArgs(appArgs.toArray: _*)

extraConf.get("spark.api.mode").foreach { v =>
launcher.setConf("spark.api.mode", v)
}

sparkArgs.foreach { case (name, value) =>
if (value != null) {
launcher.addSparkArg(name, value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,34 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
| sc.stop()
""".stripMargin

private val TEST_CONNECT_PYFILE = """
|import mod1, mod2
|import sys
|from operator import add
|
|from pyspark.sql import SparkSession
|from pyspark.sql.functions import udf
|if __name__ == "__main__":
| if len(sys.argv) != 2:
| print >> sys.stderr, "Usage: test.py [result file]"
| exit(-1)
| spark = SparkSession.builder.config(
| "spark.api.mode", "connect").master("yarn").getOrCreate()
| assert "connect" in str(spark)
| status = open(sys.argv[1],'w')
| result = "failure"
| @udf
| def test():
| return mod1.func() * mod2.func()
| df = spark.range(10).select(test())
| cnt = df.count()
| if cnt == 10:
| result = "success"
| status.write(result)
| status.close()
| spark.stop()
""".stripMargin

private val TEST_PYMODULE = """
|def func():
| return 42
Expand Down Expand Up @@ -162,6 +190,32 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
testWithAddJar(clientMode = false, s"local:$jarPath")
}

test("run Scala application with Spark Connect in yarn-client mode") {
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(
true,
mainClassName(YarnConnectTest.getClass()),
appArgs = Seq(result.getAbsolutePath()),
extraConf = Map("spark.api.mode" -> "connect"))
val output = new Object() {
override def toString: String = Files.asCharSource(result, StandardCharsets.UTF_8).read()
}
assert(finalState == SparkAppHandle.State.FINISHED, output)
}

test("run Scala application with Spark Connect in yarn-cluster mode") {
val result = File.createTempFile("result", null, tempDir)
val finalState = runSpark(
false,
mainClassName(YarnConnectTest.getClass()),
appArgs = Seq(result.getAbsolutePath()),
extraConf = Map("spark.api.mode" -> "connect"))
val output = new Object() {
override def toString: String = Files.asCharSource(result, StandardCharsets.UTF_8).read()
}
assert(finalState == SparkAppHandle.State.FINISHED, output)
}

test("SPARK-35672: run Spark in yarn-client mode with additional jar using URI scheme 'local' " +
"and gateway-replacement path") {
// Use the original jar URL, but set up the gateway/replacement configs such that if
Expand Down Expand Up @@ -237,6 +291,14 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
testPySpark(false)
}

test("run Python application with Spark Connect in yarn-client mode") {
testPySpark(true, extraConf = Map("spark.api.mode" -> "connect"), script = TEST_CONNECT_PYFILE)
}

test("run Python application with Spark Connect in yarn-cluster mode") {
testPySpark(false, extraConf = Map("spark.api.mode" -> "connect"), script = TEST_CONNECT_PYFILE)
}

test("run Python application in yarn-cluster mode using " +
"spark.yarn.appMasterEnv to override local envvar") {
testPySpark(
Expand Down Expand Up @@ -370,10 +432,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
private def testPySpark(
clientMode: Boolean,
extraConf: Map[String, String] = Map(),
extraEnv: Map[String, String] = Map()): Unit = {
extraEnv: Map[String, String] = Map(),
script: String = TEST_PYFILE): Unit = {
assume(isPythonAvailable)
val primaryPyFile = new File(tempDir, "test.py")
Files.asCharSink(primaryPyFile, StandardCharsets.UTF_8).write(TEST_PYFILE)
Files.asCharSink(primaryPyFile, StandardCharsets.UTF_8).write(script)

// When running tests, let's not assume the user has built the assembly module, which also
// creates the pyspark archive. Instead, let's use PYSPARK_ARCHIVES_PATH to point at the
Expand Down Expand Up @@ -712,6 +775,34 @@ private object YarnClasspathTest extends Logging {

}

private object YarnConnectTest extends Logging {
def main(args: Array[String]): Unit = {
val output = new java.io.PrintStream(new File(args(0)))
val clz = Utils.classForName("org.apache.spark.sql.SparkSession$")
val moduleField = clz.getDeclaredField("MODULE$")
val obj = moduleField.get(null)
var builder = clz.getMethod("builder").invoke(obj)
builder = builder.getClass().getMethod(
"config", classOf[String], classOf[String]).invoke(builder, "spark.api.mode", "connect")
builder = builder.getClass().getMethod("master", classOf[String]).invoke(builder, "yarn")
val session = builder.getClass().getMethod("getOrCreate").invoke(builder)

try {
// Check if the current session is a Spark Connect session.
session.getClass().getDeclaredField("client")
val df = session.getClass().getMethod("range", classOf[Long]).invoke(session, 10)
assert(df.getClass().getMethod("count").invoke(df) == 10)
} catch {
case e: Throwable =>
e.printStackTrace(new java.io.PrintStream(output))
throw e
} finally {
session.getClass().getMethod("stop").invoke(session)
output.close()
}
}
}

private object YarnAddJarTest extends Logging {
def main(args: Array[String]): Unit = {
if (args.length != 3) {
Expand Down

0 comments on commit bc35f57

Please sign in to comment.