Skip to content

Commit

Permalink
fix: Update memory requirement checks using resource.quantity (#539)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Fix units being used
  • Loading branch information
ishaansehgal99 authored Jul 29, 2024
1 parent 5413920 commit b6f13ed
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 11 deletions.
50 changes: 40 additions & 10 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package v1alpha1
import (
"context"
"fmt"
"github.com/azure/kaito/pkg/utils/consts"
"reflect"
"regexp"
"sort"
Expand Down Expand Up @@ -318,25 +319,54 @@ func (r *ResourceSpec) validateCreateWithInference(inference *InferenceSpec) (er
if skuConfig, exists := SupportedGPUConfigs[instanceType]; exists {
if presetName != "" {
model := plugin.KaitoModelRegister.MustGet(presetName) // InferenceSpec has been validated so the name is valid.
// Validate GPU count for given SKU

machineCount := *r.Count
totalNumGPUs := machineCount * skuConfig.GPUCount
totalGPUMem := machineCount * skuConfig.GPUMem * skuConfig.GPUCount
machineTotalNumGPUs := resource.NewQuantity(int64(machineCount*skuConfig.GPUCount), resource.DecimalSI)
machinePerGPUMemory := resource.NewQuantity(int64(skuConfig.GPUMem/skuConfig.GPUCount)*consts.GiBToBytes, resource.BinarySI) // Ensure it's per GPU
machineTotalGPUMem := resource.NewQuantity(int64(machineCount*skuConfig.GPUMem)*consts.GiBToBytes, resource.BinarySI) // Total GPU memory

modelGPUCount := resource.MustParse(model.GetInferenceParameters().GPUCountRequirement)
modelPerGPUMemory := resource.MustParse(model.GetInferenceParameters().PerGPUMemoryRequirement)
modelTotalGPUMemory := resource.MustParse(model.GetInferenceParameters().TotalGPUMemoryRequirement)

// Separate the checks for specific error messages
if int64(totalNumGPUs) < modelGPUCount.Value() {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient number of GPUs: Instance type %s provides %d, but preset %s requires at least %d", instanceType, totalNumGPUs, presetName, modelGPUCount.Value()), "instanceType"))
if machineTotalNumGPUs.Cmp(modelGPUCount) < 0 {
errs = errs.Also(apis.ErrInvalidValue(
fmt.Sprintf(
"Insufficient number of GPUs: Instance type %s provides %s, but preset %s requires at least %d",
instanceType,
machineTotalNumGPUs.String(),
presetName,
modelGPUCount.Value(),
),
"instanceType",
))
}
skuPerGPUMemory := skuConfig.GPUMem / skuConfig.GPUCount
if int64(skuPerGPUMemory) < modelPerGPUMemory.ScaledValue(resource.Giga) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient per GPU memory: Instance type %s provides %d per GPU, but preset %s requires at least %d per GPU", instanceType, skuPerGPUMemory, presetName, modelPerGPUMemory.ScaledValue(resource.Giga)), "instanceType"))

if machinePerGPUMemory.Cmp(modelPerGPUMemory) < 0 {
errs = errs.Also(apis.ErrInvalidValue(
fmt.Sprintf(
"Insufficient per GPU memory: Instance type %s provides %s per GPU, but preset %s requires at least %s per GPU",
instanceType,
machinePerGPUMemory.String(),
presetName,
modelPerGPUMemory.String(),
),
"instanceType",
))
}
if int64(totalGPUMem) < modelTotalGPUMemory.ScaledValue(resource.Giga) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Insufficient total GPU memory: Instance type %s has a total of %d, but preset %s requires at least %d", instanceType, totalGPUMem, presetName, modelTotalGPUMemory.ScaledValue(resource.Giga)), "instanceType"))

if machineTotalGPUMem.Cmp(modelTotalGPUMemory) < 0 {
errs = errs.Also(apis.ErrInvalidValue(
fmt.Sprintf(
"Insufficient total GPU memory: Instance type %s has a total of %s, but preset %s requires at least %s",
instanceType,
machineTotalGPUMem.String(),
presetName,
modelTotalGPUMemory.String(),
),
"instanceType",
))
}
}
} else {
Expand Down
16 changes: 15 additions & 1 deletion api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
validateTuning bool // To indicate if we are testing tuning validation
}{
{
name: "Valid resource",
name: "Valid Resource",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_ND96asr_v4",
Count: pointerToInt(1),
Expand All @@ -217,6 +217,20 @@ func TestResourceSpecValidateCreate(t *testing.T) {
expectErrs: false,
validateTuning: false,
},
{
name: "Valid Resource - SKU Capacity == Model Requirement",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC12s_v3",
Count: pointerToInt(1),
},
modelGPUCount: "1",
modelPerGPUMemory: "16Gi",
modelTotalGPUMemory: "16Gi",
preset: true,
errContent: "",
expectErrs: false,
validateTuning: false,
},
{
name: "Insufficient total GPU memory",
resourceSpec: &ResourceSpec{
Expand Down
1 change: 1 addition & 0 deletions pkg/utils/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ const (
GPUString = "gpu"
SKUString = "sku"
MaxRevisionHistoryLimit = 10
GiBToBytes = 1024 * 1024 * 1024 // Conversion factor from GiB to bytes
)

0 comments on commit b6f13ed

Please sign in to comment.