Skip to content

Commit

Permalink
[WX-1260] Acquire sas token from task runner (#7241)
Browse files Browse the repository at this point in the history
Co-authored-by: Adam Nichols <anichols@broadinstitute.org>
  • Loading branch information
THWiseman and aednichols authored Nov 13, 2023
1 parent 110ca3e commit 2a93f28
Show file tree
Hide file tree
Showing 13 changed files with 455 additions and 55 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package cromwell.backend.standard

import java.io.IOException

import akka.actor.{Actor, ActorLogging, ActorRef}
import akka.event.LoggingReceive
import cats.implicits._
Expand Down Expand Up @@ -329,7 +328,7 @@ trait StandardAsyncExecutionActor
}

/** Any custom code that should be run within commandScriptContents before the instantiated command. */
def scriptPreamble: String = ""
def scriptPreamble: ErrorOr[String] = "".valid

def cwd: Path = commandDirectory
def rcPath: Path = cwd./(jobPaths.returnCodeFilename)
Expand Down Expand Up @@ -427,10 +426,12 @@ trait StandardAsyncExecutionActor
|find . -type d -exec sh -c '[ -z "$$(ls -A '"'"'{}'"'"')" ] && touch '"'"'{}'"'"'/.file' \\;
|)""".stripMargin)

val errorOrPreamble: ErrorOr[String] = scriptPreamble

// The `tee` trickery below is to be able to redirect to known filenames for CWL while also streaming
// stdout and stderr for PAPI to periodically upload to cloud storage.
// https://stackoverflow.com/questions/692000/how-do-i-write-stderr-to-a-file-while-using-tee-with-a-pipe
(errorOrDirectoryOutputs, errorOrGlobFiles).mapN((directoryOutputs, globFiles) =>
(errorOrDirectoryOutputs, errorOrGlobFiles, errorOrPreamble).mapN((directoryOutputs, globFiles, preamble) =>
s"""|#!$jobShell
|DOCKER_OUTPUT_DIR_LINK
|cd ${cwd.pathAsString}
Expand Down Expand Up @@ -464,7 +465,7 @@ trait StandardAsyncExecutionActor
|)
|mv $rcTmpPath $rcPath
|""".stripMargin
.replace("SCRIPT_PREAMBLE", scriptPreamble)
.replace("SCRIPT_PREAMBLE", preamble)
.replace("ENVIRONMENT_VARIABLES", environmentVariables)
.replace("INSTANTIATED_COMMAND", commandString)
.replace("SCRIPT_EPILOGUE", scriptEpilogue)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import java.nio.file.spi.FileSystemProvider
import java.time.temporal.ChronoUnit
import java.time.{Duration, OffsetDateTime}
import java.util.UUID
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -160,12 +161,14 @@ object BlobSasTokenGenerator {
*/
def createBlobTokenGenerator(workspaceManagerClient: WorkspaceManagerApiClientProvider,
overrideWsmAuthToken: Option[String]): BlobSasTokenGenerator = {
WSMBlobSasTokenGenerator(workspaceManagerClient, overrideWsmAuthToken)
new WSMBlobSasTokenGenerator(workspaceManagerClient, overrideWsmAuthToken)
}

}

case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClientProvider,
case class WSMTerraCoordinates(wsmEndpoint: String, workspaceId: UUID, containerResourceId: UUID)

class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClientProvider,
overrideWsmAuthToken: Option[String]) extends BlobSasTokenGenerator {

/**
Expand All @@ -178,17 +181,14 @@ case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClient
* @return an AzureSasCredential for accessing a blob container
*/
def generateBlobSasToken(endpoint: EndpointURL, container: BlobContainerName): Try[AzureSasCredential] = {
val wsmAuthToken: Try[String] = overrideWsmAuthToken match {
case Some(t) => Success(t)
case None => AzureCredentials.getAccessToken(None).toTry
}
val wsmAuthToken: Try[String] = getWsmAuth
container.workspaceId match {
// If this is a Terra workspace, request a token from WSM
case Success(workspaceId) => {
(for {
wsmAuth <- wsmAuthToken
wsmAzureResourceClient = wsmClientProvider.getControlledAzureResourceApi(wsmAuth)
resourceId <- getContainerResourceId(workspaceId, container, wsmAuth)
resourceId <- getContainerResourceId(workspaceId, container, Option(wsmAuth))
sasToken <- wsmAzureResourceClient.createAzureStorageContainerSasToken(workspaceId, resourceId)
} yield sasToken).recoverWith {
// If the storage account was still not found in WSM, this may be a public filesystem
Expand All @@ -201,9 +201,59 @@ case class WSMBlobSasTokenGenerator(wsmClientProvider: WorkspaceManagerApiClient
}
}

def getContainerResourceId(workspaceId: UUID, container: BlobContainerName, wsmAuth : String): Try[UUID] = {
val wsmResourceClient = wsmClientProvider.getResourceApi(wsmAuth)
wsmResourceClient.findContainerResourceId(workspaceId, container)
private val cachedContainerResourceIds = new mutable.HashMap[BlobContainerName, UUID]()

// Optionally provide wsmAuth to avoid acquiring it twice in generateBlobSasToken.
// In the case that the resourceId is not cached and no auth is provided, this function will acquire a new auth as necessary.
private def getContainerResourceId(workspaceId: UUID, container: BlobContainerName, precomputedWsmAuth: Option[String]): Try[UUID] = {
cachedContainerResourceIds.get(container) match {
case Some(id) => Try(id) //cache hit
case _ => { //cache miss
val auth: Try[String] = precomputedWsmAuth.map(auth => Try(auth)).getOrElse(getWsmAuth)
val resourceId = for {
wsmAuth <- auth
wsmResourceApi = wsmClientProvider.getResourceApi(wsmAuth)
resourceId <- wsmResourceApi.findContainerResourceId(workspaceId, container)
} yield resourceId
resourceId.map(id => cachedContainerResourceIds.put(container, id)) //NB: Modifying cache state here.
cachedContainerResourceIds.get(container) match {
case Some(uuid) => Try(uuid)
case _ => Failure(new NoSuchElementException("Could not retrieve container resource ID from WSM"))
}
}
}
}

private def getWsmAuth: Try[String] = {
overrideWsmAuthToken match {
case Some(t) => Success(t)
case None => AzureCredentials.getAccessToken(None).toTry
}
}

private def parseTerraWorkspaceIdFromPath(blobPath: BlobPath): Try[UUID] = {
if (blobPath.container.value.startsWith("sc-")) Try(UUID.fromString(blobPath.container.value.substring(3)))
else Failure(new Exception("Could not parse workspace ID from storage container. Are you sure this is a file in a Terra Workspace?"))
}

/**
* Return a REST endpoint that will reply with a sas token for the blob storage container associated with the provided blob path.
* @param blobPath A blob path of a file living in a blob container that WSM knows about (likely a workspace container).
* @param tokenDuration How long will the token last after being generated. Default is 8 hours. Sas tokens won't last longer than 24h.
* NOTE: If a blobPath is provided for a file in a container other than what this token generator was constructed for,
* this function will make two REST requests. Otherwise, the relevant data is already cached locally.
*/
def getWSMSasFetchEndpoint(blobPath: BlobPath, tokenDuration: Option[Duration] = None): Try[String] = {
val wsmEndpoint = wsmClientProvider.getBaseWorkspaceManagerUrl
val lifetimeQueryParameters: String = tokenDuration.map(d => s"?sasExpirationDuration=${d.toSeconds.intValue}").getOrElse("")
val terraInfo: Try[WSMTerraCoordinates] = for {
workspaceId <- parseTerraWorkspaceIdFromPath(blobPath)
containerResourceId <- getContainerResourceId(workspaceId, blobPath.container, None)
coordinates = WSMTerraCoordinates(wsmEndpoint, workspaceId, containerResourceId)
} yield coordinates
terraInfo.map{terraCoordinates =>
s"${terraCoordinates.wsmEndpoint}/api/workspaces/v1/${terraCoordinates.workspaceId.toString}/resources/controlled/azure/storageContainer/${terraCoordinates.containerResourceId.toString}/getSasToken${lifetimeQueryParameters}"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ case class BlobPath private[blob](pathString: String, endpoint: EndpointURL, con
* @return Path string relative to the container root.
*/
def pathWithoutContainer : String = pathString


def getFilesystemManager: BlobFileSystemManager = fsm

override def getSymlinkSafePath(options: LinkOption*): Path = toAbsolutePath

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import scala.util.Try
trait WorkspaceManagerApiClientProvider {
def getControlledAzureResourceApi(token: String): WsmControlledAzureResourceApi
def getResourceApi(token: String): WsmResourceApi
def getBaseWorkspaceManagerUrl: String
}

class HttpWorkspaceManagerClientProvider(baseWorkspaceManagerUrl: WorkspaceManagerURL) extends WorkspaceManagerApiClientProvider {
Expand All @@ -40,6 +41,7 @@ class HttpWorkspaceManagerClientProvider(baseWorkspaceManagerUrl: WorkspaceManag
apiClient.setAccessToken(token)
WsmControlledAzureResourceApi(new ControlledAzureResourceApi(apiClient))
}
def getBaseWorkspaceManagerUrl: String = baseWorkspaceManagerUrl.value
}

case class WsmResourceApi(resourcesApi : ResourceApi) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -663,12 +663,12 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
private val DockerMonitoringLogPath: Path = GcpBatchWorkingDisk.MountPoint.resolve(gcpBatchCallPaths.batchMonitoringLogFilename)
private val DockerMonitoringScriptPath: Path = GcpBatchWorkingDisk.MountPoint.resolve(gcpBatchCallPaths.batchMonitoringScriptFilename)

override def scriptPreamble: String = {
override def scriptPreamble: ErrorOr[String] = {
if (monitoringOutput.isDefined) {
s"""|touch $DockerMonitoringLogPath
|chmod u+x $DockerMonitoringScriptPath
|$DockerMonitoringScriptPath > $DockerMonitoringLogPath &""".stripMargin
} else ""
|$DockerMonitoringScriptPath > $DockerMonitoringLogPath &""".stripMargin.valid
} else "".valid
}

private[actors] def generateInputs(jobDescriptor: BackendJobDescriptor): Set[GcpBatchInput] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,12 +380,12 @@ class PipelinesApiAsyncBackendJobExecutionActor(override val standardParams: Sta

private lazy val isDockerImageCacheUsageRequested = runtimeAttributes.useDockerImageCache.getOrElse(useDockerImageCache(jobDescriptor.workflowDescriptor))

override def scriptPreamble: String = {
override def scriptPreamble: ErrorOr[String] = {
if (monitoringOutput.isDefined) {
s"""|touch $DockerMonitoringLogPath
|chmod u+x $DockerMonitoringScriptPath
|$DockerMonitoringScriptPath > $DockerMonitoringLogPath &""".stripMargin
} else ""
}.valid else "".valid
}

override def globParentDirectory(womGlobFile: WomGlobFile): Path = {
Expand Down
Loading

0 comments on commit 2a93f28

Please sign in to comment.