Skip to content

Commit

Permalink
Handle invalid cluster recommendation for Dataproc
Browse files Browse the repository at this point in the history
Signed-off-by: Partho Sarthi <psarthi@nvidia.com>
  • Loading branch information
parthosa committed Feb 7, 2025
1 parent 2b9473b commit 9ec73e2
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 12 deletions.
35 changes: 32 additions & 3 deletions core/src/main/scala/com/nvidia/spark/rapids/tool/Platform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,7 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
}

val dynamicAllocSettings = Platform.getDynamicAllocationSettings(sparkProperties)
recommendedNodeInstanceInfo = Some(recommendedNodeInstance)
recommendedClusterInfo = Some(RecommendedClusterInfo(
val recommendedCluster = RecommendedClusterInfo(
vendor = vendor,
coresPerExecutor = clusterConfig.coresPerExec,
numWorkerNodes = numWorkerNodes,
Expand All @@ -498,7 +497,15 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
dynamicAllocationMinExecutors = dynamicAllocSettings.min,
dynamicAllocationInitialExecutors = dynamicAllocSettings.initial,
workerNodeType = Some(recommendedNodeInstance.name)
))
)

validateRecommendedCluster(recommendedCluster) match {
case Right(validCluster) =>
recommendedNodeInstanceInfo = Some(recommendedNodeInstance)
recommendedClusterInfo = Some(validCluster)
case Left(reason) =>
logWarning(s"Failed to generate a cluster recommendation. Reason: $reason")
}

case None =>
logWarning("Failed to generate a cluster recommendation. " +
Expand All @@ -513,6 +520,17 @@ abstract class Platform(var gpuDevice: Option[GpuDevice],
"Cluster properties are missing and event log does not contain cluster information.")
}
}

/**
* Validates the recommended cluster configuration. This can be overridden by
* subclasses to provide platform-specific validation.
* @param recommendedClusterInfo Recommended cluster configuration
* @return Either a failure message or the valid recommended cluster configuration
*/
protected def validateRecommendedCluster(
recommendedClusterInfo: RecommendedClusterInfo): Either[String, RecommendedClusterInfo] = {
Right(recommendedClusterInfo)
}
}

abstract class DatabricksPlatform(gpuDevice: Option[GpuDevice],
Expand Down Expand Up @@ -592,10 +610,21 @@ class DataprocPlatform(gpuDevice: Option[GpuDevice],
override val defaultGpuDevice: GpuDevice = T4Gpu
override def isPlatformCSP: Boolean = true
override def maxGpusSupported: Int = 4
private val minWorkerNodes = 2

override def getInstanceByResourcesMap: Map[(Int, Int), InstanceInfo] = {
PlatformInstanceTypes.DATAPROC_BY_GPUS_CORES
}

override def validateRecommendedCluster(
recommendedClusterInfo: RecommendedClusterInfo): Either[String, RecommendedClusterInfo] = {
if (recommendedClusterInfo.numWorkerNodes < minWorkerNodes) {
Left(s"Requested number of worker nodes (${recommendedClusterInfo.numWorkerNodes}) " +
s"is less than the minimum required ($minWorkerNodes) by the platform.")
} else {
Right(recommendedClusterInfo)
}
}
}

class DataprocServerlessPlatform(gpuDevice: Option[GpuDevice],
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,12 @@ class QualificationSuite extends BaseTestSuite {
expectedClusterInfoMap.foreach { case (eventlogPath, expectedClusterInfo) =>
test(s"test cluster information JSON - $eventlogPath") {
val logFile = s"$logDir/cluster_information/$eventlogPath"
runQualificationAndTestClusterInfo(logFile, PlatformNames.DEFAULT, expectedClusterInfo)
val actualClusterInfo =
runQualificationAndGetClusterSummary(logFile, PlatformNames.DEFAULT)
.flatMap(_.clusterInfo)
assert(actualClusterInfo == expectedClusterInfo,
s"Actual cluster info does not match the expected cluster info. " +
s"Expected: $expectedClusterInfo, Actual: $actualClusterInfo")
}
}

Expand Down Expand Up @@ -1688,16 +1693,22 @@ class QualificationSuite extends BaseTestSuite {

expectedPlatformClusterInfoMap.foreach { case (platform, expectedClusterInfo) =>
test(s"test cluster information JSON for platform - $platform ") {
val logFile = s"$logDir/cluster_information/platform/$platform"
runQualificationAndTestClusterInfo(logFile, platform, Some(expectedClusterInfo))
val logFile = s"$logDir/cluster_information/platform/valid/$platform"
val actualClusterInfo =
runQualificationAndGetClusterSummary(logFile, platform)
.flatMap(_.clusterInfo)
assert(actualClusterInfo.contains(expectedClusterInfo),
s"Actual cluster info does not match the expected cluster info. " +
s"Expected: $expectedClusterInfo, Actual: $actualClusterInfo")
}
}

/**
* Runs the qualification tool and verifies cluster information against expected values.
* Runs the qualification tool and returns the cluster summary.
*/
private def runQualificationAndTestClusterInfo(eventlogPath: String, platform: String,
expectedClusterInfo: Option[ExistingClusterInfo]): Unit = {
private def runQualificationAndGetClusterSummary(
eventlogPath: String, platform: String): Option[ClusterSummary] = {
var clusterSummary: Option[ClusterSummary] = None
TrampolineUtil.withTempDir { outPath =>
val baseArgs = Array("--output-directory", outPath.getAbsolutePath, "--platform", platform)
val appArgs = new QualificationArgs(baseArgs :+ eventlogPath)
Expand All @@ -1714,10 +1725,9 @@ class QualificationSuite extends BaseTestSuite {
// Read output JSON and create a set of (event log, cluster info)
val outputResultFile = s"$outPath/${QualOutputWriter.LOGFILE_NAME}/" +
s"${QualOutputWriter.LOGFILE_NAME}_cluster_information.json"
val actualClusterInfo = readJson(outputResultFile).headOption.flatMap(_.clusterInfo)
assert(actualClusterInfo == expectedClusterInfo,
"Actual cluster info does not match the expected cluster info.")
clusterSummary = readJson(outputResultFile).headOption
}
clusterSummary
}

test("test cluster information generation is disabled") {
Expand All @@ -1740,6 +1750,18 @@ class QualificationSuite extends BaseTestSuite {
}
}

// TODO: This should be extended for validating the recommended cluster information
// for other platforms.
test(s"test invalid recommended cluster information JSON for platform - dataproc") {
val logFile = s"$logDir/cluster_information/platform/invalid/dataproc.zstd"
val actualRecommendedClusterInfo =
runQualificationAndGetClusterSummary(logFile, PlatformNames.DATAPROC)
.flatMap(_.recommendedClusterInfo)
assert(actualRecommendedClusterInfo.isEmpty,
"Recommended cluster info is expected to be empty. " +
s"Actual: $actualRecommendedClusterInfo")
}

test("test status report generation for wildcard event log") {
val logFiles = Array(
s"$logDir/cluster_information/eventlog_3node*") // correct wildcard event log with 3 matches
Expand Down

0 comments on commit 9ec73e2

Please sign in to comment.