Skip to content

Commit

Permalink
fix token flow (apache#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gschiavon authored and Marcos P committed Aug 14, 2017
1 parent 8200918 commit bade4d4
Showing 1 changed file with 74 additions and 51 deletions.
125 changes: 74 additions & 51 deletions core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import java.text.ParseException
import scala.annotation.tailrec
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import scala.util.Properties

import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.fs.Path
import org.apache.hadoop.security.UserGroupInformation
Expand All @@ -41,11 +40,11 @@ import org.apache.ivy.core.settings.IvySettings
import org.apache.ivy.plugins.matcher.GlobPatternMatcher
import org.apache.ivy.plugins.repository.file.FileRepository
import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver}

import org.apache.spark.{SPARK_REVISION, SPARK_VERSION, SparkException, SparkUserAppException}
import org.apache.spark.{SPARK_BRANCH, SPARK_BUILD_DATE, SPARK_BUILD_USER, SPARK_REPO_URL}
import org.apache.spark.api.r.RUtils
import org.apache.spark.deploy.rest._
import org.apache.spark.internal.Logging
import org.apache.spark.launcher.SparkLauncher
import org.apache.spark.scheduler.{KerberosUser, KerberosUtil}
import org.apache.spark.security.{ConfigSecurity, VaultHelper}
Expand Down Expand Up @@ -665,64 +664,88 @@ object SparkSubmit {
}


val mesosRoleEnv = (sys.env.get("VAULT_ROLE_ID"),
sys.env.get("VAULT_SECRET_ID"))

val sparkRoleOpts = (args.sparkProperties.get("spark.secret.roleID"),
args.sparkProperties.get("spark.secret.secretID"))

val tempToken = args.sparkProperties.get("spark.secret.vault.tempToken")
val tempToken = (args.sparkProperties.get("spark.secret.vault.tempToken"),
sys.env.get("VAULT_TEMP_TOKEN")) match {
case (Some(property), env) => Option(property)
case (property, Some(env)) => Option(env)
case _ => None
}

val sysEnvToken = sys.env.get("VAULT_TEMP_TOKEN")
val roleSecret = (args.sparkProperties.get("spark.secret.roleID"),
args.sparkProperties.get("spark.secret.secretID"),
sys.env.get("VAULT_ROLE_ID"),
sys.env.get("VAULT_SECRET_ID")) match {
case (Some(roleProperty), Some(secretProperty), roleEnv, secretEnv) =>
Option(roleProperty, secretProperty)
case (roleProperty, secretProperty, Some(roleEnv), Some(secretEnv)) =>
Option(roleEnv, secretEnv)
case _ => None
}

val vaultProtocol = args.sparkProperties.get("spark.secret.vault.protocol")
val vaultHosts = args.sparkProperties.get("spark.secret.vault.hosts")
val vaultHost = args.sparkProperties.get("spark.secret.vault.hosts")
val vaultPort = args.sparkProperties.get("spark.secret.vault.port")

val (pincipal, keytab) =
(mesosRoleEnv, sparkRoleOpts, tempToken, sysEnvToken,
vaultProtocol, vaultHosts, vaultPort) match {

case ((roleIdEnv, secretIdEnv), (roleIdProp, secretIdProp), _, _,
Some(protocol), Some(hosts), Some(port))
if ((roleIdEnv.isDefined || roleIdProp.isDefined) &&
(secretIdEnv.isDefined || secretIdProp.isDefined)) =>
val vaultUrl = s"$protocol://${hosts.split(",")
.map(host => s"$host:$port").mkString(",")}"

val roleId = roleIdEnv.getOrElse(roleIdProp.get)
val secretId = secretIdEnv.getOrElse(secretIdProp.get)
val vaultToken = VaultHelper.getTokenFromAppRole (vaultUrl, roleId, secretId)
val environment = ConfigSecurity.prepareEnvironment(
Option(vaultToken), Option (vaultUrl) )
val principal = environment.get ("principal").getOrElse (args.principal)
val keytab = environment.get ("keytabPath").getOrElse (args.keytab)

environment.foreach {
case (key, value) => sysProps.put (key, value)
}
(principal, keytab)
val vaultUrlParams = (vaultProtocol, vaultHost, vaultPort)
val vaultUrl = buildVaultUrl(vaultUrlParams)
lazy val vaultToken = getToken(tempToken, roleSecret, vaultUrl)

case (_, _, tempTokenProp, tempTokenEnv, Some(protocol), Some(hosts), Some(port))
if (tempTokenProp.isDefined || tempTokenEnv.isDefined) =>
val vaultUrl = s"$protocol://${hosts.split(",")
.map(host => s"$host:${port}").mkString(",")}"
val tempToken = tempTokenProp.getOrElse(tempTokenEnv.get)
val vaultToken = VaultHelper.getRealToken (vaultUrl, tempToken)
val environment = ConfigSecurity.prepareEnvironment(
Option (vaultToken), Option (vaultUrl))
val principal = environment.get ("principal").getOrElse (args.principal)
val keytab = environment.get ("keytabPath").getOrElse (args.keytab)

environment.foreach {
case (key, value) => sysProps.put (key, value)
}
(principal, keytab)
val (principal, keytab) =
if (vaultUrl.nonEmpty && vaultToken.isDefined) {
val environment = ConfigSecurity.prepareEnvironment(
Option (vaultToken.get), Option(vaultUrl))
val principal = environment.getOrElse("principal", args.principal)
val keytab = environment.getOrElse("keytabPath", args.keytab)

environment.foreach {
case (key, value) => sysProps.put(key, value)
}
(principal, keytab)

case _ => (args.principal, args.keytab)
} else {
(args.principal, args.keytab)
}

(childArgs, childClasspath, sysProps, childMainClass, pincipal, keytab)
(childArgs, childClasspath, sysProps, childMainClass, principal, keytab)
}

/**
*
* @param tempToken Temporal token, either Property one or Environment one
* @param roleSecret Role and Secret ID, either Property one or Environment one
* @param vaultUrl a Vault Url protocol://vaultHost:vaultPort
* @return An option of a token
*/
private def getToken(tempToken: Option[String],
roleSecret: Option[(String, String)],
vaultUrl: String): Option[String] = {

(tempToken, roleSecret) match {
case (Some(tempToken), _) => Some(VaultHelper.getRealToken(vaultUrl, tempToken))
case (_, Some((role, secret))) =>
Some(VaultHelper.getTokenFromAppRole(vaultUrl, role, secret))
case _ => None
}
}

/**
*
* @param vaultUrlParams Is composed of Vault Protocol,
* Vault Host and Vault Port
* @return a Vault Url protocol://vaultHost:vaultPort
*/
private def buildVaultUrl(vaultUrlParams: (Option[String],
Option[String],
Option[String])): String = {

val vaultUrl = vaultUrlParams match {
case (Some(protocol), Some(hosts), Some(port)) =>
s"${protocol}://${
hosts.split(",")
.map(host => s"$host:${port}").mkString(",")}"
case _ => ""
}
vaultUrl
}

/**
Expand Down

0 comments on commit bade4d4

Please sign in to comment.