Skip to content

Commit

Permalink
feat: Tuning Resource Validation Check (#484)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Tuning Resource Validation Check

---------

Signed-off-by: Ishaan Sehgal <ishaanforthewin@gmail.com>
  • Loading branch information
ishaansehgal99 authored Jun 27, 2024
1 parent 174f6f5 commit f70dcc5
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 38 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ create-acr: ## Create test ACR
create-aks-cluster: ## Create test AKS cluster (with msi, oidc, and workload identity enabled)
az aks create --name $(AZURE_CLUSTER_NAME) --resource-group $(AZURE_RESOURCE_GROUP) --location $(AZURE_LOCATION) \
--attach-acr $(AZURE_ACR_NAME) --kubernetes-version $(AKS_K8S_VERSION) --node-count 1 --generate-ssh-keys \
--enable-managed-identity --enable-workload-identity --enable-oidc-issuer -o none
--enable-managed-identity --enable-workload-identity --enable-oidc-issuer --node-vm-size Standard_D2s_v3 -o none

.PHONY: create-aks-cluster-with-kaito
create-aks-cluster-with-kaito: ## Create test AKS cluster (with msi, oidc and kaito enabled)
Expand Down
16 changes: 12 additions & 4 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
errs = errs.Also(w.validateCreate().ViaField("spec"))
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Resource.validateCreate(*w.Inference).ViaField("resource"),
errs = errs.Also(w.Resource.validateCreateWithInference(w.Inference).ViaField("resource"),
w.Inference.validateCreate().ViaField("inference"))
}
if w.Tuning != nil {
// TODO: Add validate resource based on Tuning Spec
errs = errs.Also(w.Tuning.validateCreate(ctx, w.Namespace).ViaField("tuning"))
errs = errs.Also(w.Resource.validateCreateWithTuning(w.Tuning).ViaField("resource"),
w.Tuning.validateCreate(ctx, w.Namespace).ViaField("tuning"))
}
} else {
klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
Expand Down Expand Up @@ -299,7 +300,14 @@ func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.Field
return errs
}

func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) {
func (r *ResourceSpec) validateCreateWithTuning(tuning *TuningSpec) (errs *apis.FieldError) {
if *r.Count > 1 {
errs = errs.Also(apis.ErrInvalidValue("Tuning does not currently support multinode configurations. Please set the node count to 1. Future support with DeepSpeed will allow this.", "count"))
}
return errs
}

func (r *ResourceSpec) validateCreateWithInference(inference *InferenceSpec) (errs *apis.FieldError) {
var presetName string
if inference.Preset != nil {
presetName = strings.ToLower(string(inference.Preset.Name))
Expand All @@ -308,7 +316,7 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field

// Check if instancetype exists in our SKUs map
if skuConfig, exists := SupportedGPUConfigs[instanceType]; exists {
if inference.Preset != nil {
if presetName != "" {
model := plugin.KaitoModelRegister.MustGet(presetName) // InferenceSpec has been validated so the name is valid.
// Validate GPU count for given SKU
machineCount := *r.Count
Expand Down
111 changes: 78 additions & 33 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset bool
errContent string // Content expect error to include, if any
expectErrs bool
validateTuning bool // To indicate if we are testing tuning validation
}{
{
name: "Valid resource",
Expand All @@ -214,6 +215,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset: true,
errContent: "",
expectErrs: false,
validateTuning: false,
},
{
name: "Insufficient total GPU memory",
Expand All @@ -227,6 +229,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset: true,
errContent: "Insufficient total GPU memory",
expectErrs: true,
validateTuning: false,
},

{
Expand All @@ -241,6 +244,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset: true,
errContent: "Insufficient number of GPUs",
expectErrs: true,
validateTuning: false,
},
{
name: "Insufficient per GPU memory",
Expand All @@ -254,6 +258,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset: true,
errContent: "Insufficient per GPU memory",
expectErrs: true,
validateTuning: false,
},

{
Expand All @@ -262,27 +267,30 @@ func TestResourceSpecValidateCreate(t *testing.T) {
InstanceType: "Standard_invalid_sku",
Count: pointerToInt(1),
},
errContent: "Unsupported instance",
expectErrs: true,
errContent: "Unsupported instance",
expectErrs: true,
validateTuning: false,
},
{
name: "Only Template set",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NV12s_v3",
Count: pointerToInt(1),
},
preset: false,
errContent: "",
expectErrs: false,
preset: false,
errContent: "",
expectErrs: false,
validateTuning: false,
},
{
name: "N-Prefix SKU",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_Nsku",
Count: pointerToInt(1),
},
errContent: "",
expectErrs: false,
errContent: "",
expectErrs: false,
validateTuning: false,
},

{
Expand All @@ -291,44 +299,81 @@ func TestResourceSpecValidateCreate(t *testing.T) {
InstanceType: "Standard_Dsku",
Count: pointerToInt(1),
},
errContent: "",
expectErrs: false,
errContent: "",
expectErrs: false,
validateTuning: false,
},
{
name: "Tuning validation with single node",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
Count: pointerToInt(1),
},
errContent: "",
expectErrs: false,
validateTuning: true,
},
{
name: "Tuning validation with multinode",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
Count: pointerToInt(2),
},
errContent: "Tuning does not currently support multinode configurations",
expectErrs: true,
validateTuning: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var spec InferenceSpec
if tc.validateTuning {
tuningSpec := &TuningSpec{}
errs := tc.resourceSpec.validateCreateWithTuning(tuningSpec)
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateCreateWithTuning() errors = %v, expectErrs %v", errs, tc.expectErrs)
}

if tc.preset {
spec = InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
},
},
if hasErrs && tc.errContent != "" {
errMsg := errs.Error()
if !strings.Contains(errMsg, tc.errContent) {
t.Errorf("validateCreateWithTuning() error message = %v, expected to contain = %v", errMsg, tc.errContent)
}
}
} else {
spec = InferenceSpec{
Template: &v1.PodTemplateSpec{}, // Assuming a non-nil TemplateSpec implies it's set
var spec InferenceSpec

if tc.preset {
spec = InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
},
},
}
} else {
spec = InferenceSpec{
Template: &v1.PodTemplateSpec{}, // Assuming a non-nil TemplateSpec implies it's set
}
}
}

gpuCountRequirement = tc.modelGPUCount
totalGPUMemoryRequirement = tc.modelTotalGPUMemory
perGPUMemoryRequirement = tc.modelPerGPUMemory
gpuCountRequirement = tc.modelGPUCount
totalGPUMemoryRequirement = tc.modelTotalGPUMemory
perGPUMemoryRequirement = tc.modelPerGPUMemory

errs := tc.resourceSpec.validateCreate(spec)
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs)
}
errs := tc.resourceSpec.validateCreateWithInference(&spec)
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs)
}

// If there is an error and errContent is not empty, check that the error contains the expected content.
if hasErrs && tc.errContent != "" {
errMsg := errs.Error()
if !strings.Contains(errMsg, tc.errContent) {
t.Errorf("validateCreate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
// If there is an error and errContent is not empty, check that the error contains the expected content.
if hasErrs && tc.errContent != "" {
errMsg := errs.Error()
if !strings.Contains(errMsg, tc.errContent) {
t.Errorf("validateCreate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
}
}
}
})
Expand Down

0 comments on commit f70dcc5

Please sign in to comment.