Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WX-1318 gcp batch: Add GPU driver install #7235

Merged
merged 3 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -829,12 +829,7 @@ class GcpBatchAsyncBackendJobExecutionActor(override val standardParams: Standar
_ <- evaluateRuntimeAttributes
_ <- uploadScriptFile()
customLabels <- Future.fromTry(GcpLabel.fromWorkflowOptions(workflowDescriptor.workflowOptions))
_ = customLabels.foreach(x => println(s"ZZZ Custom Labels - $x"))
batchParameters <- generateInputOutputParameters
_ = batchParameters.fileInputParameters.foreach(x => println(s"ZZZ File InputParameters - $x"))
_ = batchParameters.jobInputParameters.foreach(x => println(s"ZZZ InputParameters - $x"))
_ = batchParameters.fileOutputParameters.foreach(x => println(s"ZZZ File OutputParameters - $x"))
_ = batchParameters.jobOutputParameters.foreach(x => println(s"ZZZ OutputParameters - $x"))
createParameters = createBatchParameters(batchParameters, customLabels)
drsLocalizationManifestCloudPath = jobPaths.callExecutionRoot / GcpBatchJobPaths.DrsLocalizationManifestName
_ <- uploadDrsLocalizationManifest(createParameters, drsLocalizationManifestCloudPath)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
package cromwell.backend.google.batch.api

import com.google.cloud.batch.v1.AllocationPolicy.Accelerator
import com.google.cloud.batch.v1.{DeleteJobRequest, GetJobRequest, JobName}
import cromwell.backend.google.batch.models.GcpBatchConfigurationAttributes.GcsTransferConfiguration
import cromwell.backend.google.batch.models.GcpBatchRequest
import cromwell.backend.google.batch.runnable._
import cromwell.backend.google.batch.util.BatchUtilityConversions
import com.google.cloud.batch.v1.AllocationPolicy.{AttachedDisk, InstancePolicy, InstancePolicyOrTemplate, LocationPolicy, NetworkInterface, NetworkPolicy, ProvisioningModel}
import com.google.cloud.batch.v1.AllocationPolicy._
import com.google.cloud.batch.v1.LogsPolicy.Destination
import com.google.cloud.batch.v1.{AllocationPolicy, ComputeResource, CreateJobRequest, Job, LogsPolicy, Runnable, ServiceAccount, TaskGroup, TaskSpec, Volume}
import com.google.cloud.batch.v1.{AllocationPolicy, ComputeResource, CreateJobRequest, DeleteJobRequest, GetJobRequest, Job, JobName, LogsPolicy, Runnable, ServiceAccount, TaskGroup, TaskSpec, Volume}
import com.google.protobuf.Duration
import cromwell.backend.google.batch.io.GcpBatchAttachedDisk
import cromwell.backend.google.batch.models.VpcAndSubnetworkProjectLabelValues
import cromwell.backend.google.batch.models.GcpBatchConfigurationAttributes.GcsTransferConfiguration
import cromwell.backend.google.batch.models.{GcpBatchRequest, VpcAndSubnetworkProjectLabelValues}
import cromwell.backend.google.batch.runnable._
import cromwell.backend.google.batch.util.BatchUtilityConversions

import scala.jdk.CollectionConverters._

Expand Down Expand Up @@ -61,10 +58,11 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe
.build
}

private def createInstancePolicy(cpuPlatform: String, spotModel: ProvisioningModel, accelerators: Option[Accelerator.Builder], attachedDisks: List[AttachedDisk]) = {
private def createInstancePolicy(cpuPlatform: String, spotModel: ProvisioningModel, accelerators: Option[Accelerator.Builder], attachedDisks: List[AttachedDisk]): InstancePolicy.Builder = {

//set GPU count to 0 if not included in workflow
val gpuAccelerators = accelerators.getOrElse(Accelerator.newBuilder.setCount(0).setType(""))
val gpuAccelerators = accelerators.getOrElse(Accelerator.newBuilder.setCount(0).setType("")) // TODO: Driver version

val instancePolicy = InstancePolicy
.newBuilder
.setProvisioningModel(spotModel)
Expand All @@ -83,7 +81,6 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe

}


private def createNetworkPolicy(networkInterface: NetworkInterface): NetworkPolicy = {
NetworkPolicy
.newBuilder
Expand Down Expand Up @@ -113,22 +110,29 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe

}

private def createAllocationPolicy(data: GcpBatchRequest, locationPolicy: LocationPolicy, instancePolicy: InstancePolicy, networkPolicy: NetworkPolicy, serviceAccount: ServiceAccount) = {
AllocationPolicy
private def createAllocationPolicy(data: GcpBatchRequest, locationPolicy: LocationPolicy, instancePolicy: InstancePolicy, networkPolicy: NetworkPolicy, serviceAccount: ServiceAccount, accelerators: Option[Accelerator.Builder]) = {

val allocationPolicy = AllocationPolicy
.newBuilder
.setLocation(locationPolicy)
.setNetwork(networkPolicy)
.putLabels("cromwell-workflow-id", toLabel(data.workflowId.toString)) //label for workflow from WDL
.putLabels("goog-batch-worker", "true")
.putAllLabels((data.createParameters.googleLabels.map(label => label.key -> label.value).toMap.asJava))
.setServiceAccount(serviceAccount)
.addInstances(InstancePolicyOrTemplate
.newBuilder
.setPolicy(instancePolicy)
.build)
.build
.buildPartial()

val gpuAccelerators = accelerators.getOrElse(Accelerator.newBuilder.setCount(0).setType(""))

//add GPUs if GPU count is greater than or equal to 1
if (gpuAccelerators.getCount >= 1) {
allocationPolicy.toBuilder.addInstances(InstancePolicyOrTemplate.newBuilder.setPolicy(instancePolicy).setInstallGpuDrivers(true).build)
} else {
allocationPolicy.toBuilder.addInstances(InstancePolicyOrTemplate.newBuilder.setPolicy(instancePolicy).build)
}
}


override def submitRequest(data: GcpBatchRequest): CreateJobRequest = {

val batchAttributes = data.gcpBatchParameters.batchAttributes
Expand Down Expand Up @@ -160,10 +164,6 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe
// Batch defaults to 1 task
val taskCount: Long = 1

println(f"command script container path ${data.createParameters.commandScriptContainerPath}")
println(f"cloud workflow root ${data.createParameters.cloudWorkflowRoot}")
println(f"all parameters:\n ${data.createParameters.allParameters.mkString("\n")}")

// parse preemption value and set value for Spot. Spot is replacement for preemptible
val spotModel = toProvisioningModel(runtimeAttributes.preemptible)

Expand Down Expand Up @@ -205,11 +205,11 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe
val taskGroup: TaskGroup = createTaskGroup(taskCount, taskSpec)
val instancePolicy = createInstancePolicy(cpuPlatform, spotModel, accelerators, allDisks)
val locationPolicy = LocationPolicy.newBuilder.addAllowedLocations(zones).build
val allocationPolicy = createAllocationPolicy(data, locationPolicy, instancePolicy.build, networkPolicy, gcpSa)
val allocationPolicy = createAllocationPolicy(data, locationPolicy, instancePolicy.build, networkPolicy, gcpSa, accelerators)
val job = Job
.newBuilder
.addTaskGroups(taskGroup)
.setAllocationPolicy(allocationPolicy)
.setAllocationPolicy(allocationPolicy.build())
.putLabels("submitter", "cromwell") // label to signify job submitted by cromwell for larger tracking purposes within GCP batch
.putLabels("goog-batch-worker", "true")
.putAllLabels((data.createParameters.googleLabels.map(label => label.key -> label.value).toMap.asJava))
Expand All @@ -218,9 +218,6 @@ class GcpBatchRequestFactoryImpl()(implicit gcsTransferConfiguration: GcsTransfe
.setDestination(Destination.CLOUD_LOGGING)
.build)

println(f"job shell ${data.createParameters.jobShell}")
println(f"script container path ${data.createParameters.commandScriptContainerPath}")
println(f"labels ${data.createParameters.googleLabels}")

CreateJobRequest
.newBuilder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import wom.values.{WomArray, WomBoolean, WomInteger, WomString, WomValue}

object GpuResource {

val DefaultNvidiaDriverVersion = "418.87.00"

final case class GpuType(name: String) {
override def toString: String = name
}
Expand Down Expand Up @@ -99,6 +101,9 @@ object GcpBatchRuntimeAttributes {
private def cpuPlatformValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[String] = cpuPlatformValidationInstance
private def gpuTypeValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[GpuType] = GpuTypeValidation.optional

val GpuDriverVersionKey = "nvidiaDriverVersion"
private def gpuDriverValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[String] = new StringRuntimeAttributesValidation(GpuDriverVersionKey).optional

private def gpuCountValidation(runtimeConfig: Option[Config]): OptionalRuntimeAttributesValidation[Int Refined Positive] = GpuValidation.optional
private def gpuMinValidation(runtimeConfig: Option[Config]):OptionalRuntimeAttributesValidation[Int Refined Positive] = GpuValidation.optionalMin

Expand Down Expand Up @@ -159,6 +164,7 @@ object GcpBatchRuntimeAttributes {
StandardValidatedRuntimeAttributesBuilder.default(runtimeConfig).withValidation(
gpuCountValidation(runtimeConfig),
gpuTypeValidation(runtimeConfig),
gpuDriverValidation(runtimeConfig),
cpuValidation(runtimeConfig),
cpuPlatformValidation(runtimeConfig),
cpuMinValidation(runtimeConfig),
Expand Down Expand Up @@ -189,8 +195,9 @@ object GcpBatchRuntimeAttributes {
.extractOption(gpuTypeValidation(runtimeAttrsConfig).key, validatedRuntimeAttributes)
lazy val gpuCount: Option[Int Refined Positive] = RuntimeAttributesValidation
.extractOption(gpuCountValidation(runtimeAttrsConfig).key, validatedRuntimeAttributes)
lazy val gpuDriver: Option[String] = RuntimeAttributesValidation.extractOption(gpuDriverValidation(runtimeAttrsConfig).key, validatedRuntimeAttributes)

val gpuResource: Option[GpuResource] = if (gpuType.isDefined || gpuCount.isDefined) {
val gpuResource: Option[GpuResource] = if (gpuType.isDefined || gpuCount.isDefined || gpuDriver.isDefined) {
Option(GpuResource(gpuType.getOrElse(GpuType.DefaultGpuType), gpuCount
.getOrElse(GpuType.DefaultGpuCount)))
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ trait UserRunnable {

def userRunnables(createParameters: CreateBatchJobParameters, volumes: List[Volume]): List[Runnable] = {

println(f"job shell ${createParameters.jobShell}")
println(f"script container path ${createParameters.commandScriptContainerPath}")

val userRunnable = RunnableBuilder.userRunnable(
docker = createParameters.dockerImage,
scriptContainerPath = createParameters.commandScriptContainerPath.pathAsString,
Expand Down