diff --git a/api/v1alpha1/sku_config.go b/api/v1alpha1/sku_config.go deleted file mode 100644 index ebe926b31..000000000 --- a/api/v1alpha1/sku_config.go +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT license. - -package v1alpha1 - -import ( - "strings" - - "github.com/azure/kaito/pkg/utils/plugin" -) - -type GPUConfig struct { - SKU string - SupportedOS []string - GPUDriver string - GPUCount int - GPUMem int -} - -func isValidPreset(preset string) bool { - return plugin.KaitoModelRegister.Has(preset) -} - -func getSupportedSKUs() string { - skus := make([]string, 0, len(SupportedGPUConfigs)) - for sku := range SupportedGPUConfigs { - skus = append(skus, sku) - } - return strings.Join(skus, ", ") -} - -var SupportedGPUConfigs = map[string]GPUConfig{ - "Standard_NC6": {SKU: "Standard_NC6", GPUCount: 1, GPUMem: 12, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"}, - "Standard_NC12": {SKU: "Standard_NC12", GPUCount: 2, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"}, - "Standard_NC24": {SKU: "Standard_NC24", GPUCount: 4, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"}, - "Standard_NC24r": {SKU: "Standard_NC24r", GPUCount: 4, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia470CudaDriver"}, - "Standard_NV6": {SKU: "Standard_NV6", GPUCount: 1, GPUMem: 8, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV12": {SKU: "Standard_NV12", GPUCount: 2, GPUMem: 16, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV24": {SKU: "Standard_NV24", GPUCount: 4, GPUMem: 32, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV12s_v3": {SKU: "Standard_NV12s_v3", GPUCount: 1, GPUMem: 8, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV24s_v3": {SKU: "Standard_NV24s_v3", GPUCount: 2, GPUMem: 16, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV48s_v3": {SKU: "Standard_NV48s_v3", GPUCount: 4, GPUMem: 32, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - // "Standard_NV24r": {SKU: "Standard_NV24r", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_ND6s": {SKU: "Standard_ND6s", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_ND12s": {SKU: "Standard_ND12s", GPUCount: 2, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_ND24s": {SKU: "Standard_ND24s", GPUCount: 4, GPUMem: 96, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_ND24rs": {SKU: "Standard_ND24rs", GPUCount: 4, GPUMem: 96, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC6s_v2": {SKU: "Standard_NC6s_v2", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC12s_v2": {SKU: "Standard_NC12s_v2", GPUCount: 2, GPUMem: 32, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC24s_v2": {SKU: "Standard_NC24s_v2", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC24rs_v2": {SKU: "Standard_NC24rs_v2", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC6s_v3": {SKU: "Standard_NC6s_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC12s_v3": {SKU: "Standard_NC12s_v3", GPUCount: 2, GPUMem: 32, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC24s_v3": {SKU: "Standard_NC24s_v3", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC24rs_v3": {SKU: "Standard_NC24rs_v3", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - // "Standard_ND40s_v3": {SKU: "Standard_ND40s_v3", GPUCount: x, GPUMem: x, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_ND40rs_v2": {SKU: "Standard_ND40rs_v2", GPUCount: 8, GPUMem: 256, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC4as_T4_v3": {SKU: "Standard_NC4as_T4_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC8as_T4_v3": {SKU: "Standard_NC8as_T4_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC16as_T4_v3": {SKU: "Standard_NC16as_T4_v3", GPUCount: 1, GPUMem: 16, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC64as_T4_v3": {SKU: "Standard_NC64as_T4_v3", GPUCount: 4, GPUMem: 64, SupportedOS: []string{"Mariner", "Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_ND96asr_v4": {SKU: "Standard_ND96asr_v4", GPUCount: 8, GPUMem: 320, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - // "Standard_ND112asr_A100_v4": {SKU: "Standard_ND112asr_A100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - // "Standard_ND120asr_A100_v4": {SKU: "Standard_ND120asr_A100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_ND96amsr_A100_v4": {SKU: "Standard_ND96amsr_A100_v4", GPUCount: 8, GPUMem: 640, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - // "Standard_ND112amsr_A100_v4": {SKU: "Standard_ND112amsr_A100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - // "Standard_ND120amsr_A100_v4": {SKU: "Standard_ND120amsr_A100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC24ads_A100_v4": {SKU: "Standard_NC24ads_A100_v4", GPUCount: 1, GPUMem: 80, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC48ads_A100_v4": {SKU: "Standard_NC48ads_A100_v4", GPUCount: 2, GPUMem: 160, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - "Standard_NC96ads_A100_v4": {SKU: "Standard_NC96ads_A100_v4", GPUCount: 4, GPUMem: 320, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - // "Standard_NCads_A100_v4": {SKU: "Standard_NCads_A100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - /*GPU Mem based on A10-24 Spec - TODO: Need to confirm GPU Mem*/ - // "Standard_NC8ads_A10_v4": {SKU: "Standard_NC8ads_A10_v4", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - // "Standard_NC16ads_A10_v4": {SKU: "Standard_NC16ads_A10_v4", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - // "Standard_NC32ads_A10_v4": {SKU: "Standard_NC32ads_A10_v4", GPUCount: 2, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - /* SKUs with GPU Partition are treated as 1 GPU - https://learn.microsoft.com/en-us/azure/virtual-machines/nvA10v5-series*/ - "Standard_NV6ads_A10_v5": {SKU: "Standard_NV6ads_A10_v5", GPUCount: 1, GPUMem: 4, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV12ads_A10_v5": {SKU: "Standard_NV12ads_A10_v5", GPUCount: 1, GPUMem: 8, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV18ads_A10_v5": {SKU: "Standard_NV18ads_A10_v5", GPUCount: 1, GPUMem: 12, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV36ads_A10_v5": {SKU: "Standard_NV36ads_A10_v5", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV36adms_A10_v5": {SKU: "Standard_NV36adms_A10_v5", GPUCount: 1, GPUMem: 24, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - "Standard_NV72ads_A10_v5": {SKU: "Standard_NV72ads_A10_v5", GPUCount: 2, GPUMem: 48, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia510GridDriver"}, - // "Standard_ND96ams_v4": {SKU: "Standard_ND96ams_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, - // "Standard_ND96ams_A100_v4": {SKU: "Standard_ND96ams_A100_v4", GPUCount: x, GPUMem: x, SupportedOS: []string{"Ubuntu"}, GPUDriver: "Nvidia525CudaDriver"}, -} diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 5e682838b..07e45c9b7 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -6,6 +6,7 @@ package v1alpha1 import ( "context" "fmt" + "os" "reflect" "regexp" "strconv" @@ -168,7 +169,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri // Currently require a preset to specified, in future we can consider defining a template if r.Preset == nil { errs = errs.Also(apis.ErrMissingField("Preset")) - } else if presetName := string(r.Preset.Name); !isValidPreset(presetName) { + } else if presetName := string(r.Preset.Name); !utils.IsValidPreset(presetName) { errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName")) } return errs @@ -295,8 +296,15 @@ func (r *ResourceSpec) validateCreateWithInference(inference *InferenceSpec) (er } instanceType := string(r.InstanceType) - // Check if instancetype exists in our SKUs map - if skuConfig, exists := SupportedGPUConfigs[instanceType]; exists { + skuHandler, err := utils.GetSKUHandler() + if err != nil { + errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get SKU handler: %v", err), "instanceType")) + return errs + } + gpuConfigs := skuHandler.GetGPUConfigs() + + // Check if instancetype exists in our SKUs map for the particular cloud provider + if skuConfig, exists := gpuConfigs[instanceType]; exists { if presetName != "" { model := plugin.KaitoModelRegister.MustGet(presetName) // InferenceSpec has been validated so the name is valid. @@ -350,9 +358,10 @@ func (r *ResourceSpec) validateCreateWithInference(inference *InferenceSpec) (er } } } else { - // Check for other instance types pattern matches - if !strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX) { - errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, getSupportedSKUs()), "instanceType")) + provider := os.Getenv("CLOUD_PROVIDER") + // Check for other instance types pattern matches if cloud provider is Azure + if provider != consts.AzureCloudName || (!strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX)) { + errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, skuHandler.GetSupportedSKUs()), "instanceType")) } } @@ -398,7 +407,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) { if i.Preset != nil { presetName := string(i.Preset.Name) // Validate preset name - if !isValidPreset(presetName) { + if !utils.IsValidPreset(presetName) { errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName")) } // Validate private preset has private image specified diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 35e1a3a30..bf9f8ea3e 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -7,8 +7,6 @@ import ( "context" "fmt" "os" - "reflect" - "sort" "strings" "testing" @@ -234,7 +232,7 @@ func TestResourceSpecValidateCreate(t *testing.T) { { name: "Insufficient total GPU memory", resourceSpec: &ResourceSpec{ - InstanceType: "Standard_NC6", + InstanceType: "Standard_NV6", Count: pointerToInt(1), }, modelGPUCount: "1", @@ -263,7 +261,7 @@ func TestResourceSpecValidateCreate(t *testing.T) { { name: "Insufficient per GPU memory", resourceSpec: &ResourceSpec{ - InstanceType: "Standard_NC6", + InstanceType: "Standard_NV6", Count: pointerToInt(2), }, modelGPUCount: "1", @@ -320,7 +318,7 @@ func TestResourceSpecValidateCreate(t *testing.T) { { name: "Tuning validation with single node", resourceSpec: &ResourceSpec{ - InstanceType: "Standard_NC6", + InstanceType: "Standard_NC6s_v3", Count: pointerToInt(1), }, errContent: "", @@ -330,7 +328,7 @@ func TestResourceSpecValidateCreate(t *testing.T) { { name: "Tuning validation with multinode", resourceSpec: &ResourceSpec{ - InstanceType: "Standard_NC6", + InstanceType: "Standard_NC6s_v3", Count: pointerToInt(2), }, errContent: "Tuning does not currently support multinode configurations", @@ -339,6 +337,8 @@ func TestResourceSpecValidateCreate(t *testing.T) { }, } + os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { if tc.validateTuning { @@ -1438,49 +1438,3 @@ func TestDataDestinationValidateUpdate(t *testing.T) { }) } } - -func TestGetSupportedSKUs(t *testing.T) { - tests := []struct { - name string - gpuConfigs map[string]GPUConfig - expectedResult []string // changed to a slice for deterministic ordering - }{ - { - name: "no SKUs supported", - gpuConfigs: map[string]GPUConfig{}, - expectedResult: []string{""}, - }, - { - name: "one SKU supported", - gpuConfigs: map[string]GPUConfig{ - "Standard_NC6": {SKU: "Standard_NC6"}, - }, - expectedResult: []string{"Standard_NC6"}, - }, - { - name: "multiple SKUs supported", - gpuConfigs: map[string]GPUConfig{ - "Standard_NC6": {SKU: "Standard_NC6"}, - "Standard_NC12": {SKU: "Standard_NC12"}, - }, - expectedResult: []string{"Standard_NC6", "Standard_NC12"}, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - SupportedGPUConfigs = tc.gpuConfigs - - resultSlice := strings.Split(getSupportedSKUs(), ", ") - sort.Strings(resultSlice) - - // Sort the expectedResult for comparison - expectedResultSlice := tc.expectedResult - sort.Strings(expectedResultSlice) - - if !reflect.DeepEqual(resultSlice, expectedResultSlice) { - t.Errorf("getSupportedSKUs() = %v, want %v", resultSlice, expectedResultSlice) - } - }) - } -} diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index ef55fed6a..6acd8aade 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -129,26 +129,6 @@ func (in *EmbeddingSpec) DeepCopy() *EmbeddingSpec { return out } -// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. -func (in *GPUConfig) DeepCopyInto(out *GPUConfig) { - *out = *in - if in.SupportedOS != nil { - in, out := &in.SupportedOS, &out.SupportedOS - *out = make([]string, len(*in)) - copy(*out, *in) - } -} - -// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new GPUConfig. -func (in *GPUConfig) DeepCopy() *GPUConfig { - if in == nil { - return nil - } - out := new(GPUConfig) - in.DeepCopyInto(out) - return out -} - // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *InferenceServiceSpec) DeepCopyInto(out *InferenceServiceSpec) { *out = *in diff --git a/pkg/tuning/preset-tuning.go b/pkg/tuning/preset-tuning.go index 4e3f15a14..ad81115a5 100644 --- a/pkg/tuning/preset-tuning.go +++ b/pkg/tuning/preset-tuning.go @@ -55,7 +55,10 @@ var ( ) func getInstanceGPUCount(sku string) int { - gpuConfig, exists := kaitov1alpha1.SupportedGPUConfigs[sku] + skuHandler, _ := utils.GetSKUHandler() + gpuConfigs := skuHandler.GetGPUConfigs() + + gpuConfig, exists := gpuConfigs[sku] if !exists { return 1 } diff --git a/pkg/tuning/preset-tuning_test.go b/pkg/tuning/preset-tuning_test.go index bf8bd9f95..99344ddcb 100644 --- a/pkg/tuning/preset-tuning_test.go +++ b/pkg/tuning/preset-tuning_test.go @@ -24,13 +24,6 @@ import ( "k8s.io/utils/pointer" ) -// Mocking the SupportedGPUConfigs to be used in test scenarios. -var mockSupportedGPUConfigs = map[string]kaitov1alpha1.GPUConfig{ - "sku1": {GPUCount: 2}, - "sku2": {GPUCount: 4}, - "sku3": {GPUCount: 0}, -} - func normalize(s string) string { return strings.Join(strings.Fields(s), " ") } @@ -54,18 +47,19 @@ func saveEnv(key string) func() { } func TestGetInstanceGPUCount(t *testing.T) { - kaitov1alpha1.SupportedGPUConfigs = mockSupportedGPUConfigs + os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + testcases := map[string]struct { sku string expectedGPUCount int }{ "SKU Exists With Multiple GPUs": { - sku: "sku1", - expectedGPUCount: 2, + sku: "Standard_NC24s_v3", + expectedGPUCount: 4, }, - "SKU Exists With Zero GPUs": { - sku: "sku3", - expectedGPUCount: 0, + "SKU Exists With One GPU": { + sku: "Standard_NC6s_v3", + expectedGPUCount: 1, }, "SKU Does Not Exist": { sku: "sku_unknown", diff --git a/pkg/utils/common-preset.go b/pkg/utils/common-preset.go index c9f5f8dd0..df96276d4 100644 --- a/pkg/utils/common-preset.go +++ b/pkg/utils/common-preset.go @@ -3,6 +3,7 @@ package utils import ( + "github.com/azure/kaito/pkg/utils/plugin" corev1 "k8s.io/api/core/v1" ) @@ -149,3 +150,7 @@ func ConfigAdapterVolume() (corev1.Volume, corev1.VolumeMount) { } return volume, volumeMount } + +func IsValidPreset(preset string) bool { + return plugin.KaitoModelRegister.Has(preset) +}