From 2a93f2810452b6c4b9ded2aa9bb2c183f930025b Mon Sep 17 00:00:00 2001 From: Tom Wiseman Date: Mon, 13 Nov 2023 13:19:39 -0500 Subject: [PATCH] [WX-1260] Acquire sas token from task runner (#7241) Co-authored-by: Adam Nichols --- .../StandardAsyncExecutionActor.scala | 9 +- .../blob/BlobFileSystemManager.scala | 70 ++++++-- .../filesystems/blob/BlobPathBuilder.scala | 5 +- .../WorkspaceManagerApiClientProvider.scala | 2 + ...cpBatchAsyncBackendJobExecutionActor.scala | 6 +- ...inesApiAsyncBackendJobExecutionActor.scala | 4 +- .../TesAsyncBackendJobExecutionActor.scala | 161 ++++++++++++++++-- .../impl/tes/TesRuntimeAttributes.scala | 39 ++++- .../cromwell/backend/impl/tes/TesTask.scala | 42 +++-- ...TesAsyncBackendJobExecutionActorSpec.scala | 155 +++++++++++++++++ .../impl/tes/TesInitializationActorSpec.scala | 4 +- .../impl/tes/TesRuntimeAttributesSpec.scala | 12 ++ .../backend/impl/tes/TesTaskSpec.scala | 1 + 13 files changed, 455 insertions(+), 55 deletions(-) create mode 100644 supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesAsyncBackendJobExecutionActorSpec.scala diff --git a/backend/src/main/scala/cromwell/backend/standard/StandardAsyncExecutionActor.scala b/backend/src/main/scala/cromwell/backend/standard/StandardAsyncExecutionActor.scala index c98e429c63d..59eb7f08269 100644 --- a/backend/src/main/scala/cromwell/backend/standard/StandardAsyncExecutionActor.scala +++ b/backend/src/main/scala/cromwell/backend/standard/StandardAsyncExecutionActor.scala @@ -1,7 +1,6 @@ package cromwell.backend.standard import java.io.IOException - import akka.actor.{Actor, ActorLogging, ActorRef} import akka.event.LoggingReceive import cats.implicits._ @@ -329,7 +328,7 @@ trait StandardAsyncExecutionActor } /** Any custom code that should be run within commandScriptContents before the instantiated command. */ - def scriptPreamble: String = "" + def scriptPreamble: ErrorOr[String] = "".valid def cwd: Path = commandDirectory def rcPath: Path = cwd./(jobPaths.returnCodeFilename) @@ -427,10 +426,12 @@ trait StandardAsyncExecutionActor |find . -type d -exec sh -c '[ -z "$$(ls -A '"'"'{}'"'"')" ] && touch '"'"'{}'"'"'/.file' \\; |)""".stripMargin) + val errorOrPreamble: ErrorOr[String] = scriptPreamble + // The `tee` trickery below is to be able to redirect to known filenames for CWL while also streaming // stdout and stderr for PAPI to periodically upload to cloud storage. // https://stackoverflow.com/questions/692000/how-do-i-write-stderr-to-a-file-while-using-tee-with-a-pipe - (errorOrDirectoryOutputs, errorOrGlobFiles).mapN((directoryOutputs, globFiles) => + (errorOrDirectoryOutputs, errorOrGlobFiles, errorOrPreamble).mapN((directoryOutputs, globFiles, preamble) => s"""|#!$jobShell |DOCKER_OUTPUT_DIR_LINK |cd ${cwd.pathAsString} @@ -464,7 +465,7 @@ trait StandardAsyncExecutionActor |) |mv $rcTmpPath $rcPath |""".stripMargin - .replace("SCRIPT_PREAMBLE", scriptPreamble) + .replace("SCRIPT_PREAMBLE", preamble) .replace("ENVIRONMENT_VARIABLES", environmentVariables) .replace("INSTANTIATED_COMMAND", commandString) .replace("SCRIPT_EPILOGUE", scriptEpilogue) diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala index e3de6783d85..8f03dbe7e33 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobFileSystemManager.scala @@ -15,6 +15,7 @@ import java.nio.file.spi.FileSystemProvider import java.time.temporal.ChronoUnit import java.time.{Duration, OffsetDateTime} import java.util.UUID +import scala.collection.mutable import scala.jdk.CollectionConverters._ import scala.util.{Failure, Success, Try} @@ -160,12 +161,14 @@ object BlobSasTokenGenerator { */ def createBlobTokenGenerator(workspaceManagerClient: WorkspaceManagerApiClientProvider, overrideWsmAuthToken: Option[String]): BlobSasTokenGenerator = { - WSMBlobSasTokenGenerator(workspaceManagerClient, overrideWsmAuthToken) + new WSMBlobSasTokenGenerator(workspaceManagerClient, overrideWsmAuthToken) } } -case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClientProvider, +case class WSMTerraCoordinates(wsmEndpoint: String, workspaceId: UUID, containerResourceId: UUID) + +class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClientProvider, overrideWsmAuthToken: Option[String]) extends BlobSasTokenGenerator { /** @@ -178,17 +181,14 @@ case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClient * @return an AzureSasCredential for accessing a blob container */ def generateBlobSasToken(endpoint: EndpointURL, container: BlobContainerName): Try[AzureSasCredential] = { - val wsmAuthToken: Try[String] = overrideWsmAuthToken match { - case Some(t) => Success(t) - case None => AzureCredentials.getAccessToken(None).toTry - } + val wsmAuthToken: Try[String] = getWsmAuth container.workspaceId match { // If this is a Terra workspace, request a token from WSM case Success(workspaceId) => { (for { wsmAuth <- wsmAuthToken wsmAzureResourceClient = wsmClientProvider.getControlledAzureResourceApi(wsmAuth) - resourceId <- getContainerResourceId(workspaceId, container, wsmAuth) + resourceId <- getContainerResourceId(workspaceId, container, Option(wsmAuth)) sasToken <- wsmAzureResourceClient.createAzureStorageContainerSasToken(workspaceId, resourceId) } yield sasToken).recoverWith { // If the storage account was still not found in WSM, this may be a public filesystem @@ -201,9 +201,59 @@ case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClient } } - def getContainerResourceId(workspaceId: UUID, container: BlobContainerName, wsmAuth : String): Try[UUID] = { - val wsmResourceClient = wsmClientProvider.getResourceApi(wsmAuth) - wsmResourceClient.findContainerResourceId(workspaceId, container) + private val cachedContainerResourceIds = new mutable.HashMap[BlobContainerName, UUID]() + + // Optionally provide wsmAuth to avoid acquiring it twice in generateBlobSasToken. + // In the case that the resourceId is not cached and no auth is provided, this function will acquire a new auth as necessary. + private def getContainerResourceId(workspaceId: UUID, container: BlobContainerName, precomputedWsmAuth: Option[String]): Try[UUID] = { + cachedContainerResourceIds.get(container) match { + case Some(id) => Try(id) //cache hit + case _ => { //cache miss + val auth: Try[String] = precomputedWsmAuth.map(auth => Try(auth)).getOrElse(getWsmAuth) + val resourceId = for { + wsmAuth <- auth + wsmResourceApi = wsmClientProvider.getResourceApi(wsmAuth) + resourceId <- wsmResourceApi.findContainerResourceId(workspaceId, container) + } yield resourceId + resourceId.map(id => cachedContainerResourceIds.put(container, id)) //NB: Modifying cache state here. + cachedContainerResourceIds.get(container) match { + case Some(uuid) => Try(uuid) + case _ => Failure(new NoSuchElementException("Could not retrieve container resource ID from WSM")) + } + } + } + } + + private def getWsmAuth: Try[String] = { + overrideWsmAuthToken match { + case Some(t) => Success(t) + case None => AzureCredentials.getAccessToken(None).toTry + } + } + + private def parseTerraWorkspaceIdFromPath(blobPath: BlobPath): Try[UUID] = { + if (blobPath.container.value.startsWith("sc-")) Try(UUID.fromString(blobPath.container.value.substring(3))) + else Failure(new Exception("Could not parse workspace ID from storage container. Are you sure this is a file in a Terra Workspace?")) + } + + /** + * Return a REST endpoint that will reply with a sas token for the blob storage container associated with the provided blob path. + * @param blobPath A blob path of a file living in a blob container that WSM knows about (likely a workspace container). + * @param tokenDuration How long will the token last after being generated. Default is 8 hours. Sas tokens won't last longer than 24h. + * NOTE: If a blobPath is provided for a file in a container other than what this token generator was constructed for, + * this function will make two REST requests. Otherwise, the relevant data is already cached locally. + */ + def getWSMSasFetchEndpoint(blobPath: BlobPath, tokenDuration: Option[Duration] = None): Try[String] = { + val wsmEndpoint = wsmClientProvider.getBaseWorkspaceManagerUrl + val lifetimeQueryParameters: String = tokenDuration.map(d => s"?sasExpirationDuration=${d.toSeconds.intValue}").getOrElse("") + val terraInfo: Try[WSMTerraCoordinates] = for { + workspaceId <- parseTerraWorkspaceIdFromPath(blobPath) + containerResourceId <- getContainerResourceId(workspaceId, blobPath.container, None) + coordinates = WSMTerraCoordinates(wsmEndpoint, workspaceId, containerResourceId) + } yield coordinates + terraInfo.map{terraCoordinates => + s"${terraCoordinates.wsmEndpoint}/api/workspaces/v1/${terraCoordinates.workspaceId.toString}/resources/controlled/azure/storageContainer/${terraCoordinates.containerResourceId.toString}/getSasToken${lifetimeQueryParameters}" + } } } diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala index 3aa26eb3c11..3acb99857e0 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/BlobPathBuilder.scala @@ -185,6 +185,9 @@ case class BlobPath private[blob](pathString: String, endpoint: EndpointURL, con * @return Path string relative to the container root. */ def pathWithoutContainer : String = pathString - + + def getFilesystemManager: BlobFileSystemManager = fsm + override def getSymlinkSafePath(options: LinkOption*): Path = toAbsolutePath + } diff --git a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/WorkspaceManagerApiClientProvider.scala b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/WorkspaceManagerApiClientProvider.scala index 276738c98b6..490d0fcc704 100644 --- a/filesystems/blob/src/main/scala/cromwell/filesystems/blob/WorkspaceManagerApiClientProvider.scala +++ b/filesystems/blob/src/main/scala/cromwell/filesystems/blob/WorkspaceManagerApiClientProvider.scala @@ -20,6 +20,7 @@ import scala.util.Try trait WorkspaceManagerApiClientProvider { def getControlledAzureResourceApi(token: String): WsmControlledAzureResourceApi def getResourceApi(token: String): WsmResourceApi + def getBaseWorkspaceManagerUrl: String } class HttpWorkspaceManagerClientProvider(baseWorkspaceManagerUrl: WorkspaceManagerURL) extends WorkspaceManagerApiClientProvider { @@ -40,6 +41,7 @@ class HttpWorkspaceManagerClientProvider(baseWorkspaceManagerUrl: WorkspaceManag apiClient.setAccessToken(token) WsmControlledAzureResourceApi(new ControlledAzureResourceApi(apiClient)) } + def getBaseWorkspaceManagerUrl: String = baseWorkspaceManagerUrl.value } case class WsmResourceApi(resourcesApi : ResourceApi) { diff --git a/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/GcpBatchAsyncBackendJobExecutionActor.scala b/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/GcpBatchAsyncBackendJobExecutionActor.scala index ae228ad503b..766a8f2552f 100644 --- a/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/GcpBatchAsyncBackendJobExecutionActor.scala +++ b/supportedBackends/google/batch/src/main/scala/cromwell/backend/google/batch/actors/GcpBatchAsyncBackendJobExecutionActor.scala @@ -663,12 +663,12 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar private val DockerMonitoringLogPath: Path = GcpBatchWorkingDisk.MountPoint.resolve(gcpBatchCallPaths.batchMonitoringLogFilename) private val DockerMonitoringScriptPath: Path = GcpBatchWorkingDisk.MountPoint.resolve(gcpBatchCallPaths.batchMonitoringScriptFilename) - override def scriptPreamble: String = { + override def scriptPreamble: ErrorOr[String] = { if (monitoringOutput.isDefined) { s"""|touch $DockerMonitoringLogPath |chmod u+x $DockerMonitoringScriptPath - |$DockerMonitoringScriptPath > $DockerMonitoringLogPath &""".stripMargin - } else "" + |$DockerMonitoringScriptPath > $DockerMonitoringLogPath &""".stripMargin.valid + } else "".valid } private[actors] def generateInputs(jobDescriptor: BackendJobDescriptor): Set[GcpBatchInput] = { diff --git a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PipelinesApiAsyncBackendJobExecutionActor.scala b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PipelinesApiAsyncBackendJobExecutionActor.scala index 745e11bee35..942838f8125 100644 --- a/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PipelinesApiAsyncBackendJobExecutionActor.scala +++ b/supportedBackends/google/pipelines/common/src/main/scala/cromwell/backend/google/pipelines/common/PipelinesApiAsyncBackendJobExecutionActor.scala @@ -380,12 +380,12 @@ class PipelinesApiAsyncBackendJobExecutionActor(override val standardParams: Sta private lazy val isDockerImageCacheUsageRequested = runtimeAttributes.useDockerImageCache.getOrElse(useDockerImageCache(jobDescriptor.workflowDescriptor)) - override def scriptPreamble: String = { + override def scriptPreamble: ErrorOr[String] = { if (monitoringOutput.isDefined) { s"""|touch $DockerMonitoringLogPath |chmod u+x $DockerMonitoringScriptPath |$DockerMonitoringScriptPath > $DockerMonitoringLogPath &""".stripMargin - } else "" + }.valid else "".valid } override def globParentDirectory(womGlobFile: WomGlobFile): Path = { diff --git a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesAsyncBackendJobExecutionActor.scala b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesAsyncBackendJobExecutionActor.scala index ad8daca6ec7..100ed6137e9 100644 --- a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesAsyncBackendJobExecutionActor.scala +++ b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesAsyncBackendJobExecutionActor.scala @@ -1,10 +1,5 @@ package cromwell.backend.impl.tes -import common.exception.AggregatedMessageException - -import java.io.FileNotFoundException -import java.nio.file.FileAlreadyExistsException -import cats.syntax.apply._ import akka.http.scaladsl.Http import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ import akka.http.scaladsl.marshalling.Marshal @@ -13,23 +8,31 @@ import akka.http.scaladsl.model._ import akka.http.scaladsl.unmarshalling.{Unmarshal, Unmarshaller} import akka.stream.ActorMaterializer import akka.util.ByteString +import cats.implicits._ +import common.collections.EnhancedCollections._ +import common.exception.AggregatedMessageException import common.validation.ErrorOr.ErrorOr import common.validation.Validation._ import cromwell.backend.BackendJobLifecycleActor import cromwell.backend.async.{AbortedExecutionHandle, ExecutionHandle, FailedNonRetryableExecutionHandle, PendingExecutionHandle} +import cromwell.backend.impl.tes.TesAsyncBackendJobExecutionActor.{determineWSMSasEndpointFromInputs, generateLocalizedSasScriptPreamble} import cromwell.backend.impl.tes.TesResponseJsonFormatter._ import cromwell.backend.standard.{StandardAsyncExecutionActor, StandardAsyncExecutionActorParams, StandardAsyncJob} +import cromwell.core.logging.JobLogger import cromwell.core.path.{DefaultPathBuilder, Path} -import cromwell.core.retry.SimpleExponentialBackoff import cromwell.core.retry.Retry._ -import cromwell.filesystems.blob.BlobPath +import cromwell.core.retry.SimpleExponentialBackoff +import cromwell.filesystems.blob.{BlobPath, WSMBlobSasTokenGenerator} import cromwell.filesystems.drs.{DrsPath, DrsResolver} -import wom.values.WomFile import net.ceedubs.ficus.Ficus._ +import wom.values.WomFile +import java.io.FileNotFoundException +import java.nio.file.FileAlreadyExistsException +import java.time.Duration +import java.time.temporal.ChronoUnit import scala.concurrent.Future -import scala.util.{Failure, Success} - +import scala.util.{Failure, Success, Try} sealed trait TesRunStatus { def isTerminal: Boolean def sysLogs: Seq[String] = Seq.empty[String] @@ -59,6 +62,110 @@ case object Cancelled extends TesRunStatus { object TesAsyncBackendJobExecutionActor { val JobIdKey = "tes_job_id" + + def generateLocalizedSasScriptPreamble(environmentVariableName: String, getSasWsmEndpoint: String) : String = { + // BEARER_TOKEN: https://learn.microsoft.com/en-us/azure/active-directory/managed-identities-azure-resources/how-to-use-vm-token#get-a-token-using-http + // NB: Scala string interpolation and bash variable substitution use similar syntax. $$ is an escaped $, much like \\ is an escaped \. + s""" + |### BEGIN ACQUIRE LOCAL SAS TOKEN ### + |# Function to check if a command exists on this machine + |command_exists() { + | command -v "$$1" > /dev/null 2>&1 + |} + | + |# Check if curl exists; install if not + |if ! command_exists curl; then + | if command_exists apt-get; then + | apt-get -y update && apt-get -y install curl + | if [ $$? -ne 0 ]; then + | echo "Error: Failed to install curl via apt-get." + | exit 1 + | fi + | else + | echo "Error: apt-get is not available, and curl is not installed." + | exit 1 + | fi + |fi + | + |# Check if jq exists; install if not + |if ! command_exists jq; then + | if command_exists apt-get; then + | apt-get -y update && apt-get -y install jq + | if [ $$? -ne 0 ]; then + | echo "Error: Failed to install jq via apt-get." + | exit 1 + | fi + | else + | echo "Error: apt-get is not available, and jq is not installed." + | exit 1 + | fi + |fi + | + |# Acquire bearer token, relying on the User Assigned Managed Identity of this VM. + |echo Acquiring Bearer Token using User Assigned Managed Identity... + |BEARER_TOKEN=$$(curl 'http://169.254.169.254/metadata/identity/oauth2/token?api-version=2018-02-01&resource=https%3A%2F%2Fmanagement.azure.com%2F' -H Metadata:true -s | jq .access_token) + | + |# Remove the leading and trailing quotes + |BEARER_TOKEN="$${BEARER_TOKEN#\\"}" + |BEARER_TOKEN="$${BEARER_TOKEN%\\"}" + | + |# Use the precomputed endpoint from cromwell + WSM to acquire a sas token + |echo Requesting sas token from WSM... + |sas_response_json=$$(curl -s \\ + | --retry 3 \\ + | --retry-delay 2 \\ + | -X POST "$getSasWsmEndpoint" \\ + | -H "Content-Type: application/json" \\ + | -H "accept: */*" \\ + | -H "Authorization: Bearer $${BEARER_TOKEN}") + | + |# Store token as environment variable + |export $environmentVariableName=$$(echo "$${sas_response_json}" | jq -r '.token') + | + |# Echo the first characters for logging/debugging purposes. "null" indicates something went wrong. + |echo Saving sas token: $${$environmentVariableName:0:4}**** to environment variable $environmentVariableName... + |### END ACQUIRE LOCAL SAS TOKEN ### + |""".stripMargin + } + + private def maybeConvertToBlob(pathToTest: Try[Path]): Try[BlobPath] = { + pathToTest.collect { case blob: BlobPath => blob } + } + + /** + * Computes an endpoint that can be used to retrieve a sas token for a particular blob storage container. + * This assumes that some of the task inputs are blob files, all blob files are in the same container, and we can get a sas + * token for this container from WSM. + * The task VM will use the user assigned managed identity that it is running as in order to authenticate. + * @param taskInputs The inputs to this particular TesTask. If any are blob files, the first will be used to + * determine the storage container to retrieve the sas token for. + * @param pathGetter A function to convert string filepath into a cromwell Path object. + * @param blobConverter A function to convert a Path into a Blob path, if possible. Provided for testing purposes. + * @return A URL endpoint that, when called with proper authentication, will return a sas token. + * Returns 'None' if one should not be used for this task. + */ + def determineWSMSasEndpointFromInputs(taskInputs: List[Input], + pathGetter: String => Try[Path], + logger: JobLogger, + blobConverter: Try[Path] => Try[BlobPath] = maybeConvertToBlob): Try[String] = { + // Collect all of the inputs that are valid blob paths + val blobFiles = taskInputs + .collect{ case Input(_, _, Some(url), _, _, _) => blobConverter(pathGetter(url)) } + .collect{ case Success(blob) => blob } + + // Log if not all input files live in the same container. + if (blobFiles.map(_.container).distinct.size > 1) { + logger.info(s"While parsing blob inputs, found more than one container. Generating SAS token based on first file in the list.") + } + + // We use the first blob file in the list to determine the correct blob container. + blobFiles.headOption.map{blobPath => + blobPath.getFilesystemManager.blobTokenGenerator match { + case wsmGenerator: WSMBlobSasTokenGenerator => wsmGenerator.getWSMSasFetchEndpoint(blobPath, Some(Duration.of(24, ChronoUnit.HOURS))) + case _ => Failure(new UnsupportedOperationException("Blob file does not have an associated WSMBlobSasTokenGenerator")) + } + }.getOrElse(Failure(new NoSuchElementException("Could not infer blob storage container from task inputs: No valid blob files provided."))) + } } class TesAsyncBackendJobExecutionActor(override val standardParams: StandardAsyncExecutionActorParams) @@ -71,7 +178,6 @@ class TesAsyncBackendJobExecutionActor(override val standardParams: StandardAsyn override type StandardAsyncRunState = TesRunStatus def statusEquivalentTo(thiz: StandardAsyncRunState)(that: StandardAsyncRunState): Boolean = thiz == that - override lazy val pollBackOff: SimpleExponentialBackoff = tesConfiguration.pollBackoff override lazy val executeOrRecoverBackOff: SimpleExponentialBackoff = tesConfiguration.executeOrRecoverBackoff @@ -90,6 +196,38 @@ class TesAsyncBackendJobExecutionActor(override val standardParams: StandardAsyn ) } + /** + * This script preamble is bash code that is executed at the start of a task inside the user's container. + * It is executed directly before the user's instantiated command is, which gives cromwell a chance to adjust the + * container environment before the actual task runs. See commandScriptContents in StandardAsyncExecutionActor for more context. + * + * For TES tasks, we sometimes want to acquire and save an azure sas token to an environment variable. + * If the user provides a value for runtimeAttributes.localizedSasEnvVar, we will add the relevant bash code to the preamble + * that acquires/exports the sas token to an environment variable. Once there, it will be visible to the user's task code. + * + * If runtimeAttributes.localizedSasEnvVar is provided in the WDL (and determineWSMSasEndpointFromInputs is successful), + * we will export the sas token to an environment variable named to be the value of runtimeAttributes.localizedSasEnvVar. + * Otherwise, we won't alter the preamble. + * + * See determineWSMSasEndpointFromInputs to see how we use taskInputs to infer *which* container to get a sas token for. + * + * @return Bash code to run at the start of a task. + */ + override def scriptPreamble: ErrorOr[String] = { + runtimeAttributes.localizedSasEnvVar match { + case Some(environmentVariableName) => { // Case: user wants a sas token. Return the computed preamble or die trying. + val workflowName = workflowDescriptor.callable.name + val callInputFiles = jobDescriptor.fullyQualifiedInputs.safeMapValues { + _.collectAsSeq { case w: WomFile => w } + } + val taskInputs: List[Input] = TesTask.buildTaskInputs(callInputFiles, workflowName, mapCommandLineWomFile) + val computedEndpoint = determineWSMSasEndpointFromInputs(taskInputs, getPath, jobLogger) + computedEndpoint.map(endpoint => generateLocalizedSasScriptPreamble(environmentVariableName, endpoint)) + }.toErrorOr + case _ => "".valid // Case: user doesn't want a sas token. Empty preamble is the correct preamble. + } + } + override def mapCommandLineWomFile(womFile: WomFile): WomFile = { womFile.mapFile(value => (getPath(value), asAdHocFile(womFile)) match { @@ -173,7 +311,6 @@ class TesAsyncBackendJobExecutionActor(override val standardParams: StandardAsyn } override def executeAsync(): Future[ExecutionHandle] = { - // create call exec dir tesJobPaths.callExecutionRoot.createPermissionedDirectories() val taskMessageFuture = createTaskMessage().fold( diff --git a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesRuntimeAttributes.scala b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesRuntimeAttributes.scala index c5b3c4df66d..48ade7b234a 100644 --- a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesRuntimeAttributes.scala +++ b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesRuntimeAttributes.scala @@ -1,5 +1,6 @@ package cromwell.backend.impl.tes +import cats.data.Validated import cats.syntax.validated._ import com.typesafe.config.Config import common.validation.ErrorOr.ErrorOr @@ -15,6 +16,8 @@ import wom.format.MemorySize import wom.types.{WomIntegerType, WomStringType} import wom.values._ +import java.util.regex.Pattern + case class TesRuntimeAttributes(continueOnReturnCode: ContinueOnReturnCode, dockerImage: String, dockerWorkingDir: Option[String], @@ -23,13 +26,14 @@ case class TesRuntimeAttributes(continueOnReturnCode: ContinueOnReturnCode, memory: Option[MemorySize], disk: Option[MemorySize], preemptible: Boolean, + localizedSasEnvVar: Option[String], backendParameters: Map[String, Option[String]]) object TesRuntimeAttributes { - val DockerWorkingDirKey = "dockerWorkingDir" val DiskSizeKey = "disk" val PreemptibleKey = "preemptible" + val LocalizedSasKey = "azureSasEnvironmentVariable" private def cpuValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[Int Refined Positive] = CpuValidation.optional @@ -47,8 +51,8 @@ object TesRuntimeAttributes { private val dockerValidation: RuntimeAttributesValidation[String] = DockerValidation.instance private val dockerWorkingDirValidation: OptionalRuntimeAttributesValidation[String] = DockerWorkingDirValidation.optional - private def preemptibleValidation(runtimeConfig: Option[Config]) = PreemptibleValidation.default(runtimeConfig) + private def localizedSasValidation: OptionalRuntimeAttributesValidation[String] = LocalizedSasValidation.optional def runtimeAttributesBuilder(backendRuntimeConfig: Option[Config]): StandardValidatedRuntimeAttributesBuilder = // !! NOTE !! If new validated attributes are added to TesRuntimeAttributes, be sure to include @@ -62,6 +66,7 @@ object TesRuntimeAttributes { dockerValidation, dockerWorkingDirValidation, preemptibleValidation(backendRuntimeConfig), + localizedSasValidation ) def makeBackendParameters(runtimeAttributes: Map[String, WomValue], @@ -124,8 +129,10 @@ object TesRuntimeAttributes { RuntimeAttributesValidation.extract(failOnStderrValidation(backendRuntimeConfig), validatedRuntimeAttributes) val continueOnReturnCode: ContinueOnReturnCode = RuntimeAttributesValidation.extract(continueOnReturnCodeValidation(backendRuntimeConfig), validatedRuntimeAttributes) - val preemptible: Boolean = + val preemptible: Boolean = { RuntimeAttributesValidation.extract(preemptibleValidation(backendRuntimeConfig), validatedRuntimeAttributes) + } + val localizedSas: Option[String] = RuntimeAttributesValidation.extractOption(localizedSasValidation.key, validatedRuntimeAttributes) // !! NOTE !! If new validated attributes are added to TesRuntimeAttributes, be sure to include // their validations here so that they will be handled correctly with backendParameters. @@ -139,7 +146,8 @@ object TesRuntimeAttributes { diskSizeCompatValidation(backendRuntimeConfig), failOnStderrValidation(backendRuntimeConfig), continueOnReturnCodeValidation(backendRuntimeConfig), - preemptibleValidation(backendRuntimeConfig) + preemptibleValidation(backendRuntimeConfig), + localizedSasValidation ) // BT-458 any strings included in runtime attributes that aren't otherwise used should be @@ -156,6 +164,7 @@ object TesRuntimeAttributes { memory, disk, preemptible, + localizedSas, backendParameters ) } @@ -218,3 +227,25 @@ class PreemptibleValidation extends BooleanRuntimeAttributesValidation(TesRuntim override protected def missingValueMessage: String = s"Expecting $key runtime attribute to be an Integer, Boolean, or a String with values of 'true' or 'false'" } + +object LocalizedSasValidation { + lazy val instance: RuntimeAttributesValidation[String] = new LocalizedSasValidation + lazy val optional: OptionalRuntimeAttributesValidation[String] = instance.optional +} + +class LocalizedSasValidation extends StringRuntimeAttributesValidation(TesRuntimeAttributes.LocalizedSasKey) { + private def isValidBashVariableName(str: String): Boolean = { + // require string be only letters, numbers, and underscores + val pattern = Pattern.compile("^[a-zA-Z0-9_]+$", Pattern.CASE_INSENSITIVE) + val matcher = pattern.matcher(str) + matcher.find + } + + override protected def invalidValueMessage(value: WomValue): String = { + s"Invalid Runtime Attribute value for ${TesRuntimeAttributes.LocalizedSasKey}. Value must be a string containing only letters, numbers, and underscores." + } + + override protected def validateValue: PartialFunction[WomValue, ErrorOr[String]] = { + case WomString(value) => if(isValidBashVariableName(value)) value.validNel else Validated.invalidNel(invalidValueMessage(WomString(value))) + } +} diff --git a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala index 4a6e77641e3..d775367ac74 100644 --- a/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala +++ b/supportedBackends/tes/src/main/scala/cromwell/backend/impl/tes/TesTask.scala @@ -16,6 +16,8 @@ import wom.callable.Callable.OutputDefinition import wom.expression.NoIoFunctionSet import wom.values._ +import scala.collection.immutable.Map + final case class WorkflowExecutionIdentityConfig(value: String) {override def toString: String = value.toString} final case class WorkflowExecutionIdentityOption(value: String) {override def toString: String = value} final case class TesTask(jobDescriptor: BackendJobDescriptor, @@ -79,24 +81,7 @@ final case class TesTask(jobDescriptor: BackendJobDescriptor, } lazy val inputs: Seq[Input] = { - val result = (callInputFiles ++ writeFunctionFiles).flatMap { - case (fullyQualifiedName, files) => files.flatMap(_.flattenFiles).zipWithIndex.map { - case (f, index) => - val inputType = f match { - case _: WomUnlistedDirectory => "DIRECTORY" - case _: WomSingleFile => "FILE" - case _: WomGlobFile => "FILE" - } - Input( - name = Option(fullyQualifiedName + "." + index), - description = Option(workflowName + "." + fullyQualifiedName + "." + index), - url = Option(f.value), - path = mapCommandLineWomFile(f).value, - `type` = Option(inputType), - content = None - ) - } - }.toList ++ Seq(commandScript) + val result = TesTask.buildTaskInputs(callInputFiles ++ writeFunctionFiles, workflowName, mapCommandLineWomFile) ++ Seq(commandScript) jobLogger.info(s"Calculated TES inputs (found ${result.size}): " + result.mkString(System.lineSeparator(),System.lineSeparator(),System.lineSeparator())) result } @@ -288,6 +273,27 @@ object TesTask { ) } + def buildTaskInputs(taskFiles: Map[FullyQualifiedName, Seq[WomFile]], workflowName: String, womMapFn: WomFile => WomFile): List[Input] = { + taskFiles.flatMap { + case (fullyQualifiedName, files) => files.flatMap(_.flattenFiles).zipWithIndex.map { + case (f, index) => + val inputType = f match { + case _: WomUnlistedDirectory => "DIRECTORY" + case _: WomSingleFile => "FILE" + case _: WomGlobFile => "FILE" + } + Input( + name = Option(fullyQualifiedName + "." + index), + description = Option(workflowName + "." + fullyQualifiedName + "." + index), + url = Option(f.value), + path = womMapFn(f).value, + `type` = Option(inputType), + content = None + ) + } + }.toList + } + def makeTags(workflowDescriptor: BackendWorkflowDescriptor): Map[String, Option[String]] = { // In addition to passing through any workflow labels, include relevant workflow ids as tags. val baseTags = workflowDescriptor.customLabels.asMap.map { case (k, v) => (k, Option(v)) } diff --git a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesAsyncBackendJobExecutionActorSpec.scala b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesAsyncBackendJobExecutionActorSpec.scala new file mode 100644 index 00000000000..a28fce3d445 --- /dev/null +++ b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesAsyncBackendJobExecutionActorSpec.scala @@ -0,0 +1,155 @@ +package cromwell.backend.impl.tes + +import common.mock.MockSugar +import cromwell.core.logging.JobLogger +import cromwell.core.path.NioPath +import cromwell.filesystems.blob.{BlobFileSystemManager, BlobPath, WSMBlobSasTokenGenerator} +import org.mockito.ArgumentMatchers.any +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.time.Duration +import java.time.temporal.ChronoUnit +import scala.util.{Failure, Try} + +class TesAsyncBackendJobExecutionActorSpec extends AnyFlatSpec with Matchers with MockSugar { + behavior of "TesAsyncBackendJobExecutionActor" + + val fullyQualifiedName = "this.name.is.more.than.qualified" + val workflowName = "mockWorkflow" + val someBlobUrl = "https://lz813a3d637adefec2c6e88f.blob.core.windows.net/sc-d8143fd8-aa07-446d-9ba0-af72203f1794/nyxp6c/tes-internal/configuration/supported-vm-sizes" + val someNotBlobUrl = "https://www.google.com/path/to/exile" + var index = 0 + + val blobInput_0 = Input( + name = Option(fullyQualifiedName + "." + index), + description = Option(workflowName + "." + fullyQualifiedName + "." + index), + url = Option(someBlobUrl), + path = someBlobUrl, + `type` = Option("FILE"), + content = None + ) + index = index+1 + + val blobInput_1 = Input( + name = Option(fullyQualifiedName + "." + index), + description = Option(workflowName + "." + fullyQualifiedName + "." + index), + url = Option(someBlobUrl), + path = someBlobUrl, + `type` = Option("FILE"), + content = None + ) + index = index+1 + + val notBlobInput_1 = Input( + name = Option(fullyQualifiedName + "." + index), + description = Option(workflowName + "." + fullyQualifiedName + "." + index), + url = Option(someNotBlobUrl + index), + path = someNotBlobUrl + index, + `type` = Option("FILE"), + content = None + ) + index = index+1 + + val notBlobInput_2 = Input( + name = Option(fullyQualifiedName + "." + index), + description = Option(workflowName + "." + fullyQualifiedName + "." + index), + url = Option(someNotBlobUrl + index), + path = someNotBlobUrl + index, + `type` = Option("FILE"), + content = None + ) + + // Mock blob path functionality. + val testWsmEndpoint = "https://wsm.mock.com/endpoint" + val testWorkspaceId = "e58ed763-928c-4155-0000-fdbaaadc15f3" + val testContainerResourceId = "e58ed763-928c-4155-1111-fdbaaadc15f3" + + def generateMockWsmTokenGenerator: WSMBlobSasTokenGenerator = { + val mockTokenGenerator = mock[WSMBlobSasTokenGenerator] + val expectedTokenDuration: Duration = Duration.of(24, ChronoUnit.HOURS) + mockTokenGenerator.getWSMSasFetchEndpoint(any[BlobPath], any[Option[Duration]]) returns Try(s"$testWsmEndpoint/api/workspaces/v1/$testWorkspaceId/resources/controlled/azure/storageContainer/$testContainerResourceId/getSasToken?sasExpirationDuration=${expectedTokenDuration.getSeconds.toInt}") + mockTokenGenerator + } + def generateMockFsm: BlobFileSystemManager = { + val mockFsm: BlobFileSystemManager = mock[BlobFileSystemManager] + val mockGenerator: WSMBlobSasTokenGenerator = generateMockWsmTokenGenerator + mockFsm.blobTokenGenerator returns mockGenerator + mockFsm + } + //path to a blob file + def generateMockBlobPath: BlobPath = { + val mockBlobPath = mock[BlobPath] + mockBlobPath.pathAsString returns someBlobUrl + + val mockFsm = generateMockFsm + mockBlobPath.getFilesystemManager returns mockFsm + + val mockNioPath: NioPath = mock[NioPath] + mockBlobPath.nioPath returns mockNioPath + mockBlobPath + } + + //Path to a file that isn't a blob file + def generateMockDefaultPath: cromwell.core.path.Path = { + val mockDefaultPath: cromwell.core.path.Path = mock[cromwell.core.path.Path] + mockDefaultPath.pathAsString returns someNotBlobUrl + mockDefaultPath + } + def pathGetter(pathString: String): Try[cromwell.core.path.Path] = { + val mockBlob: BlobPath = generateMockBlobPath + val mockDefault: cromwell.core.path.Path = generateMockDefaultPath + if(pathString.contains(someBlobUrl)) Try(mockBlob) else Try(mockDefault) + } + + def blobConverter(pathToConvert: Try[cromwell.core.path.Path]): Try[BlobPath] = { + val mockBlob: BlobPath = generateMockBlobPath + if(pathToConvert.get.pathAsString.contains(someBlobUrl)) Try(mockBlob) else Failure(new Exception("failed")) + } + + it should "not return sas endpoint when no blob paths are provided" in { + val mockLogger: JobLogger = mock[JobLogger] + val emptyInputs: List[Input] = List() + val bloblessInputs: List[Input] = List(notBlobInput_1, notBlobInput_2) + TesAsyncBackendJobExecutionActor.determineWSMSasEndpointFromInputs(emptyInputs, pathGetter, mockLogger, blobConverter).isFailure shouldBe true + TesAsyncBackendJobExecutionActor.determineWSMSasEndpointFromInputs(bloblessInputs, pathGetter, mockLogger, blobConverter).isFailure shouldBe true + } + + it should "return a sas endpoint based on inputs when blob paths are provided" in { + val mockLogger: JobLogger = mock[JobLogger] + val expectedTokenLifetimeSeconds = 24 * 60 * 60 //assert that cromwell asks for 24h token duration. + val expected = s"$testWsmEndpoint/api/workspaces/v1/$testWorkspaceId/resources/controlled/azure/storageContainer/$testContainerResourceId/getSasToken?sasExpirationDuration=${expectedTokenLifetimeSeconds}" + val blobInput: List[Input] = List(blobInput_0) + val blobInputs: List[Input] = List(blobInput_0, blobInput_1) + val mixedInputs: List[Input] = List(notBlobInput_1, blobInput_0, blobInput_1) + TesAsyncBackendJobExecutionActor.determineWSMSasEndpointFromInputs(blobInput, pathGetter, mockLogger, blobConverter).get shouldEqual expected + TesAsyncBackendJobExecutionActor.determineWSMSasEndpointFromInputs(blobInputs, pathGetter, mockLogger, blobConverter).get shouldEqual expected + TesAsyncBackendJobExecutionActor.determineWSMSasEndpointFromInputs(mixedInputs, pathGetter, mockLogger, blobConverter).get shouldEqual expected + } + + it should "contain expected strings in the bash script" in { + val mockEnvironmentVariableNameFromWom = "mock_env_var_for_storing_sas_token" + val expectedEndpoint = s"$testWsmEndpoint/api/workspaces/v1/$testWorkspaceId/resources/controlled/azure/storageContainer/$testContainerResourceId/getSasToken" + + val beginSubstring = "### BEGIN ACQUIRE LOCAL SAS TOKEN ###" + val endSubstring = "### END ACQUIRE LOCAL SAS TOKEN ###" + val curlCommandSubstring = + s""" + |sas_response_json=$$(curl -s \\ + | --retry 3 \\ + | --retry-delay 2 \\ + | -X POST "$expectedEndpoint" \\ + | -H "Content-Type: application/json" \\ + | -H "accept: */*" \\ + | -H "Authorization: Bearer $${BEARER_TOKEN}") + |""".stripMargin + val exportCommandSubstring = s"""export $mockEnvironmentVariableNameFromWom=$$(echo "$${sas_response_json}" | jq -r '.token')""" + + val generatedBashScript = TesAsyncBackendJobExecutionActor.generateLocalizedSasScriptPreamble(mockEnvironmentVariableNameFromWom, expectedEndpoint) + + generatedBashScript should include (beginSubstring) + generatedBashScript should include (endSubstring) + generatedBashScript should include (curlCommandSubstring) + generatedBashScript should include (exportCommandSubstring) + } +} diff --git a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesInitializationActorSpec.scala b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesInitializationActorSpec.scala index 731dd3c6c70..a081f26c910 100644 --- a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesInitializationActorSpec.scala +++ b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesInitializationActorSpec.scala @@ -63,6 +63,7 @@ class TesInitializationActorSpec extends TestKitSuite | # The keys below have been commented out as they are optional runtime attributes. | # dockerWorkingDir | # docker + | # azureSasEnvironmentVariable |} |""".stripMargin @@ -107,6 +108,7 @@ class TesInitializationActorSpec extends TestKitSuite } def nonStringErrorMessage(key: String) = s"Workflow option $key must be a string" + val bothRequiredErrorMessage = s"Workflow options ${TesWorkflowOptionKeys.WorkflowExecutionIdentity} and ${TesWorkflowOptionKeys.DataAccessIdentity} are both required if one is provided" "fail when WorkflowExecutionIdentity is not a string and DataAccessIdentity is missing" in { @@ -120,7 +122,7 @@ class TesInitializationActorSpec extends TestKitSuite case InitializationFailed(failure) => val expectedMsg = nonStringErrorMessage(TesWorkflowOptionKeys.WorkflowExecutionIdentity) if (!(failure.getMessage.contains(expectedMsg) && - failure.getMessage.contains(bothRequiredErrorMessage))) { + failure.getMessage.contains(bothRequiredErrorMessage))) { fail(s"Exception message did not contain both '$expectedMsg' and '$bothRequiredErrorMessage'. Was '$failure'") } } diff --git a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesRuntimeAttributesSpec.scala b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesRuntimeAttributesSpec.scala index e1984fb65dc..830e0cbe70c 100644 --- a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesRuntimeAttributesSpec.scala +++ b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesRuntimeAttributesSpec.scala @@ -25,6 +25,7 @@ class TesRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeoutSpec None, None, false, + None, Map.empty ) @@ -71,6 +72,17 @@ class TesRuntimeAttributesSpec extends AnyWordSpecLike with CromwellTimeoutSpec assertSuccess(runtimeAttributes, expectedRuntimeAttributes) } + "validate a valid azureSasEnvironmentVariable entry" in { + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), TesRuntimeAttributes.LocalizedSasKey -> WomString("THIS_IS_VALID")) + val expectedRuntimeAttributes = expectedDefaultsPlusUbuntuDocker.copy(localizedSasEnvVar = Some("THIS_IS_VALID")) + assertSuccess(runtimeAttributes, expectedRuntimeAttributes) + } + + "fail to validate an invalid azureSasEnvironmentVariable entry" in { + val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), TesRuntimeAttributes.LocalizedSasKey -> WomString("THIS IS INVALID")) + assertFailure(runtimeAttributes, "Value must be a string containing only letters, numbers, and underscores.") + } + "convert a positive integer preemptible entry to true boolean" in { val runtimeAttributes = Map("docker" -> WomString("ubuntu:latest"), "preemptible" -> WomInteger(3)) val expectedRuntimeAttributes = expectedDefaultsPlusUbuntuDocker.copy(preemptible = true) diff --git a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala index b7887b29944..a5fd3a3a7e2 100644 --- a/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala +++ b/supportedBackends/tes/src/test/scala/cromwell/backend/impl/tes/TesTaskSpec.scala @@ -31,6 +31,7 @@ class TesTaskSpec None, None, false, + None, Map.empty ) val internalPathPrefix = Option("mock/path/to/tes/task")