Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into aen_wx_1878
Browse files Browse the repository at this point in the history
  • Loading branch information
aednichols committed Oct 16, 2024
2 parents 22f6851 + 358a156 commit eafe5cf
Show file tree
Hide file tree
Showing 13 changed files with 642 additions and 172 deletions.
4 changes: 2 additions & 2 deletions CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence,
# @broadinstitute/dsp-batch will be requested for
# @broadinstitute/dsp-analysis will be requested for
# review when someone opens a pull request.
* @broadinstitute/dsp-batch
* @broadinstitute/dsp-analysis
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
package cromwell.backend.standard.pollmonitoring

import akka.actor.{Actor, ActorRef}
import cromwell.backend.{BackendJobDescriptor, BackendWorkflowDescriptor, Platform}
import cromwell.backend.validation.{
Expand All @@ -9,6 +10,7 @@ import cromwell.backend.validation.{
ValidatedRuntimeAttributes
}
import cromwell.core.logging.JobLogger
import cromwell.services.cost.InstantiatedVmInfo
import cromwell.services.metadata.CallMetadataKeys
import cromwell.services.metrics.bard.BardEventing.BardEventRequest
import cromwell.services.metrics.bard.model.TaskSummaryEvent
Expand All @@ -26,7 +28,7 @@ case class PollMonitorParameters(
jobDescriptor: BackendJobDescriptor,
validatedRuntimeAttributes: ValidatedRuntimeAttributes,
platform: Option[Platform],
logger: Option[JobLogger]
logger: JobLogger
)

/**
Expand All @@ -42,6 +44,9 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
// Time that the user VM started spending money.
def extractStartTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime]

// Used to kick off a cost calculation
def extractVmInfoFromRunState(pollStatus: PollResultType): Option[InstantiatedVmInfo]

// Time that the user VM stopped spending money.
def extractEndTimeFromRunState(pollStatus: PollResultType): Option[OffsetDateTime]

Expand Down Expand Up @@ -99,6 +104,7 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
Option.empty
private var vmStartTime: Option[OffsetDateTime] = Option.empty
private var vmEndTime: Option[OffsetDateTime] = Option.empty
protected var vmCostPerHour: Option[BigDecimal] = Option.empty

def processPollResult(pollStatus: PollResultType): Unit = {
// Make sure jobStartTime remains the earliest event time ever seen
Expand All @@ -122,8 +128,16 @@ trait PollResultMonitorActor[PollResultType] extends Actor {
tellMetadata(Map(CallMetadataKeys.VmEndTime -> end))
}
}
// If we don't yet have a cost per hour and we can extract VM info, send a cost request to the catalog service.
// We expect it to reply with an answer, which is handled in receive.
// NB: Due to the nature of async code, we may send a few cost requests before we get a response back.
if (vmCostPerHour.isEmpty) {
extractVmInfoFromRunState(pollStatus).foreach(handleVmCostLookup)
}
}

def handleVmCostLookup(vmInfo: InstantiatedVmInfo): Unit

// When a job finishes, the bard actor needs to know about the timing in order to record metrics.
// Cost related metadata should already have been handled in processPollResult.
def handleAsyncJobFinish(terminalStateName: String): Unit =
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,11 @@ services {
}
}

CostCatalogService {
// When enabled, Cromwell will store vmCostPerHour metadata for GCP tasks
GcpCostCatalogService {
class = "cromwell.services.cost.GcpCostCatalogService"
config {
enabled = false
catalogExpirySeconds = 86400
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package cromwell.services.cost
import com.typesafe.config.Config
import net.ceedubs.ficus.Ficus._

final case class CostCatalogConfig(catalogExpirySeconds: Int)
final case class CostCatalogConfig(enabled: Boolean, catalogExpirySeconds: Int)

object CostCatalogConfig {
def apply(config: Config): CostCatalogConfig = CostCatalogConfig(config.as[Int]("catalogExpirySeconds"))
def apply(config: Config): CostCatalogConfig =
CostCatalogConfig(config.as[Boolean]("enabled"), config.as[Int]("catalogExpirySeconds"))
}
Original file line number Diff line number Diff line change
@@ -1,31 +1,121 @@
package cromwell.services.cost

import akka.actor.{Actor, ActorRef}
import cats.implicits.catsSyntaxValidatedId
import com.google.`type`.Money
import com.google.cloud.billing.v1._
import com.typesafe.config.Config
import com.typesafe.scalalogging.LazyLogging
import common.util.StringUtil.EnhancedToStringable
import common.validation.ErrorOr._
import common.validation.ErrorOr.ErrorOr
import cromwell.services.ServiceRegistryActor.ServiceRegistryMessage
import cromwell.services.cost.GcpCostCatalogService.{COMPUTE_ENGINE_SERVICE_NAME, DEFAULT_CURRENCY_CODE}
import cromwell.util.GracefulShutdownHelper.ShutdownCommand

import java.time.{Duration, Instant}
import scala.jdk.CollectionConverters.IterableHasAsScala
import java.time.temporal.ChronoUnit.SECONDS
import scala.util.Using

case class CostCatalogKey(machineType: Option[MachineType],
usageType: Option[UsageType],
machineCustomization: Option[MachineCustomization],
resourceGroup: Option[ResourceGroup]
case class CostCatalogKey(machineType: MachineType,
usageType: UsageType,
machineCustomization: MachineCustomization,
resourceType: ResourceType,
region: String
)

object CostCatalogKey {

// Specifically support only the SKUs that we know we can use. This is brittle and I hate it, but the more structured
// fields available in the SKU don't give us enough information without relying on the human-readable descriptions.
//
// N1: We usually use custom machines but SKUs are only available for predefined; we'll fall back to these SKUs.
// N2 and N2D: We only use custom machines.

// Use this regex to filter down to just the SKUs we are interested in.
// NB: This should be updated if we add new machine types or the cost catalog descriptions change
final val expectedSku =
(".*?N1 Predefined Instance (Core|Ram) .*|" +
".*?N2 Custom Instance (Core|Ram) .*|" +
".*?N2D AMD Custom Instance (Core|Ram) .*").r

def apply(sku: Sku): List[CostCatalogKey] =
for {
_ <- expectedSku.findFirstIn(sku.getDescription).toList
machineType <- MachineType.fromSku(sku).toList
resourceType <- ResourceType.fromSku(sku).toList
usageType <- UsageType.fromSku(sku).toList
machineCustomization <- MachineCustomization.fromSku(sku).toList
region <- sku.getServiceRegionsList.asScala.toList
} yield CostCatalogKey(machineType, usageType, machineCustomization, resourceType, region)

def apply(instantiatedVmInfo: InstantiatedVmInfo, resourceType: ResourceType): ErrorOr[CostCatalogKey] =
MachineType.fromGoogleMachineTypeString(instantiatedVmInfo.machineType).map { mType =>
CostCatalogKey(
mType,
UsageType.fromBoolean(instantiatedVmInfo.preemptible),
MachineCustomization.fromMachineTypeString(instantiatedVmInfo.machineType),
resourceType,
instantiatedVmInfo.region
)
}
}

case class GcpCostLookupRequest(vmInfo: InstantiatedVmInfo, replyTo: ActorRef) extends ServiceRegistryMessage {
override def serviceName: String = "GcpCostCatalogService"
}
case class GcpCostLookupResponse(vmInfo: InstantiatedVmInfo, calculatedCost: ErrorOr[BigDecimal])
case class CostCatalogValue(catalogObject: Sku)
case class ExpiringGcpCostCatalog(catalog: Map[CostCatalogKey, CostCatalogValue], fetchTime: Instant)
object ExpiringGcpCostCatalog {
def empty: ExpiringGcpCostCatalog = ExpiringGcpCostCatalog(Map.empty, Instant.MIN)
}

object GcpCostCatalogService {
// Can be gleaned by using googleClient.listServices
private val COMPUTE_ENGINE_SERVICE_NAME = "services/6F81-5844-456A"

// ISO 4217 https://developers.google.com/adsense/management/appendix/currencies
private val DEFAULT_CURRENCY_CODE = "USD"

def getMostRecentPricingInfo(sku: Sku): PricingInfo = {
val mostRecentPricingInfoIndex = sku.getPricingInfoCount - 1
sku.getPricingInfo(mostRecentPricingInfoIndex)
}

// See: https://cloud.google.com/billing/v1/how-tos/catalog-api
def calculateCpuPricePerHour(cpuSku: Sku, coreCount: Int): ErrorOr[BigDecimal] = {
val pricingInfo = getMostRecentPricingInfo(cpuSku)
val usageUnit = pricingInfo.getPricingExpression.getUsageUnit
if (usageUnit == "h") {
// Price per hour of a single core
// NB: Ignoring "TieredRates" here (the idea that stuff gets cheaper the more you use).
// Technically, we should write code that determines which tier(s) to use.
// In practice, from what I've seen, CPU cores and RAM don't have more than a single tier.
val costPerUnit: Money = pricingInfo.getPricingExpression.getTieredRates(0).getUnitPrice
val costPerCorePerHour: BigDecimal =
costPerUnit.getUnits + (costPerUnit.getNanos * 10e-9) // Same as above, but as a big decimal
val result = costPerCorePerHour * coreCount
result.validNel
} else {
s"Expected usage units of CPUs to be 'h'. Got ${usageUnit}".invalidNel
}
}

def calculateRamPricePerHour(ramSku: Sku, ramGbCount: Double): ErrorOr[BigDecimal] = {
val pricingInfo = getMostRecentPricingInfo(ramSku)
val usageUnit = pricingInfo.getPricingExpression.getUsageUnit
if (usageUnit == "GiBy.h") {
val costPerUnit: Money = pricingInfo.getPricingExpression.getTieredRates(0).getUnitPrice
val costPerGbHour: BigDecimal =
costPerUnit.getUnits + (costPerUnit.getNanos * 10e-9) // Same as above, but as a big decimal
val result = costPerGbHour * ramGbCount
result.validNel
} else {
s"Expected usage units of RAM to be 'GiBy.h'. Got ${usageUnit}".invalidNel
}
}
}

/**
Expand All @@ -36,37 +126,40 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service
extends Actor
with LazyLogging {

private val maxCatalogLifetime: Duration =
Duration.of(CostCatalogConfig(serviceConfig).catalogExpirySeconds.longValue, SECONDS)
private val costCatalogConfig = CostCatalogConfig(serviceConfig)

private var googleClient: Option[CloudCatalogClient] = Option.empty
private val maxCatalogLifetime: Duration =
Duration.of(costCatalogConfig.catalogExpirySeconds.longValue, SECONDS)

// Cached catalog. Refreshed lazily when older than maxCatalogLifetime.
private var costCatalog: Option[ExpiringGcpCostCatalog] = Option.empty
private var costCatalog: ExpiringGcpCostCatalog = ExpiringGcpCostCatalog.empty

/**
* Returns the SKU for a given key, if it exists
*/
def getSku(key: CostCatalogKey): Option[CostCatalogValue] = getOrFetchCachedCatalog().get(key)

protected def fetchNewCatalog: Iterable[Sku] = {
if (googleClient.isEmpty) {
// We use option rather than lazy here so that the client isn't created when it is told to shutdown (see receive override)
googleClient = Some(CloudCatalogClient.create)
protected def fetchSkuIterable(googleClient: CloudCatalogClient): Iterable[Sku] =
makeInitialWebRequest(googleClient).iterateAll().asScala

protected def makeCatalog(skus: Iterable[Sku]): ExpiringGcpCostCatalog =
ExpiringGcpCostCatalog(processCostCatalog(skus), Instant.now())

protected def fetchNewCatalog: ExpiringGcpCostCatalog =
Using.resource(CloudCatalogClient.create) { googleClient =>
makeCatalog(makeInitialWebRequest(googleClient).iterateAll().asScala)
}
makeInitialWebRequest(googleClient.get).iterateAll().asScala
}

def getCatalogAge: Duration =
Duration.between(costCatalog.map(c => c.fetchTime).getOrElse(Instant.ofEpochMilli(0)), Instant.now())
private def isCurrentCatalogExpired: Boolean = getCatalogAge.toNanos > maxCatalogLifetime.toNanos
def getCatalogAge: Duration = Duration.between(costCatalog.fetchTime, Instant.now())

private def isCurrentCatalogExpired: Boolean = getCatalogAge.toSeconds > maxCatalogLifetime.toSeconds

private def getOrFetchCachedCatalog(): Map[CostCatalogKey, CostCatalogValue] = {
if (costCatalog.isEmpty || isCurrentCatalogExpired) {
if (isCurrentCatalogExpired) {
logger.info("Fetching a new GCP public cost catalog.")
costCatalog = Some(ExpiringGcpCostCatalog(processCostCatalog(fetchNewCatalog), Instant.now()))
costCatalog = fetchNewCatalog
}
costCatalog.map(expiringCatalog => expiringCatalog.catalog).getOrElse(Map.empty)
costCatalog.catalog
}

/**
Expand All @@ -88,23 +181,63 @@ class GcpCostCatalogService(serviceConfig: Config, globalConfig: Config, service
* Ideally, we don't want to have an entire, unprocessed, cost catalog in memory at once since it's ~20MB.
*/
private def processCostCatalog(skus: Iterable[Sku]): Map[CostCatalogKey, CostCatalogValue] =
// TODO: Account for key collisions (same key can be in multiple regions)
// TODO: reduce memory footprint of returned map (don't store entire SKU object)
skus.foldLeft(Map.empty[CostCatalogKey, CostCatalogValue]) { case (acc, sku) =>
acc + convertSkuToKeyValuePair(sku)
val keys = CostCatalogKey(sku)

// We expect that every cost catalog key is unique, but changes to the SKUs returned by Google may
// break this assumption. Check and log an error if we find collisions.
val collisions = keys.flatMap(acc.get(_).toList).map(_.catalogObject.getDescription)
if (collisions.nonEmpty)
logger.error(
s"Found SKU key collision when adding ${sku.getDescription}, collides with ${collisions.mkString(", ")}"
)

acc ++ keys.map(k => (k, CostCatalogValue(sku)))
}

private def convertSkuToKeyValuePair(sku: Sku): (CostCatalogKey, CostCatalogValue) = CostCatalogKey(
machineType = MachineType.fromSku(sku),
usageType = UsageType.fromSku(sku),
machineCustomization = MachineCustomization.fromSku(sku),
resourceGroup = ResourceGroup.fromSku(sku)
) -> CostCatalogValue(sku)
def lookUpSku(instantiatedVmInfo: InstantiatedVmInfo, resourceType: ResourceType): ErrorOr[Sku] =
CostCatalogKey(instantiatedVmInfo, resourceType).flatMap { key =>
// As of Sept 2024 the cost catalog does not contain entries for custom N1 machines. If we're using N1, attempt
// to fall back to predefined.
lazy val n1PredefinedKey =
(key.machineType, key.machineCustomization) match {
case (N1, Custom) => Option(key.copy(machineCustomization = Predefined))
case _ => None
}
val sku = getSku(key).orElse(n1PredefinedKey.flatMap(getSku)).map(_.catalogObject)
sku match {
case Some(sku) => sku.validNel
case None => s"Failed to look up ${resourceType} SKU for ${instantiatedVmInfo}".invalidNel
}
}

// TODO consider caching this, answers won't change until we reload the SKUs
def calculateVmCostPerHour(instantiatedVmInfo: InstantiatedVmInfo): ErrorOr[BigDecimal] =
for {
cpuSku <- lookUpSku(instantiatedVmInfo, Cpu)
coreCount <- MachineType.extractCoreCountFromMachineTypeString(instantiatedVmInfo.machineType)
cpuPricePerHour <- GcpCostCatalogService.calculateCpuPricePerHour(cpuSku, coreCount)
ramSku <- lookUpSku(instantiatedVmInfo, Ram)
ramMbCount <- MachineType.extractRamMbFromMachineTypeString(instantiatedVmInfo.machineType)
ramGbCount = ramMbCount / 1024d // need sub-integer resolution
ramPricePerHour <- GcpCostCatalogService.calculateRamPricePerHour(ramSku, ramGbCount)
totalCost = cpuPricePerHour + ramPricePerHour
_ = logger.info(
s"Calculated vmCostPerHour of ${totalCost} " +
s"(CPU ${cpuPricePerHour} for ${coreCount} cores [${cpuSku.getDescription}], " +
s"RAM ${ramPricePerHour} for ${ramGbCount} Gb [${ramSku.getDescription}]) " +
s"for ${instantiatedVmInfo}"
)
} yield totalCost

def serviceRegistryActor: ActorRef = serviceRegistry
override def receive: Receive = {
case GcpCostLookupRequest(vmInfo, replyTo) if costCatalogConfig.enabled =>
val calculatedCost = calculateVmCostPerHour(vmInfo)
val response = GcpCostLookupResponse(vmInfo, calculatedCost)
replyTo ! response
case GcpCostLookupRequest(_, _) => // do nothing if we're disabled
case ShutdownCommand =>
googleClient.foreach(client => client.shutdownNow())
context stop self
case other =>
logger.error(
Expand Down
Loading

0 comments on commit eafe5cf

Please sign in to comment.