From 007298bd1916f1cede3a026313bf56ab90a10a8d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 16 Aug 2014 20:18:10 -0700 Subject: [PATCH] Allow environment variables to be mocked in tests. --- .../main/scala/org/apache/spark/SparkConf.scala | 6 ++++++ .../scala/org/apache/spark/executor/Executor.scala | 14 +++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 13f0bff7ee507..a7f3810911bbd 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -210,6 +210,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { new SparkConf(false).setAll(settings) } + /** + * By using this instead of System.getenv(), environment variables can be mocked + * in unit tests. + */ + private[spark] def getenv(name: String): String = System.getenv(name) + /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ private[spark] def validateSettings() { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index fb3f7bd54bbfa..f25c8167fd3b6 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -66,10 +66,10 @@ private[spark] class Executor( // to what Yarn on this system said was available. This will be used later when SparkEnv // created. if (java.lang.Boolean.valueOf( - System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))) { - conf.set("spark.local.dir", getYarnLocalDirs()) - } else if (sys.env.contains("SPARK_LOCAL_DIRS")) { - conf.set("spark.local.dir", sys.env("SPARK_LOCAL_DIRS")) + System.getProperty("SPARK_YARN_MODE", conf.getenv("SPARK_YARN_MODE")))) { + conf.set("spark.local.dir", getYarnLocalDirs(conf)) + } else if (conf.getenv("SPARK_LOCAL_DIRS") != null) { + conf.set("spark.local.dir", conf.getenv("SPARK_LOCAL_DIRS")) } if (!isLocal) { @@ -135,12 +135,12 @@ private[spark] class Executor( } /** Get the Yarn approved local directories. */ - private def getYarnLocalDirs(): String = { + private def getYarnLocalDirs(conf: SparkConf): String = { // Hadoop 0.23 and 2.x have different Environment variable names for the // local dirs, so lets check both. We assume one of the 2 is set. // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(System.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(System.getenv("LOCAL_DIRS")) + val localDirs = Option(conf.getenv("YARN_LOCAL_DIRS")) + .getOrElse(Option(conf.getenv("LOCAL_DIRS")) .getOrElse("")) if (localDirs.isEmpty) {