Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [SKU modularization] remove sku_config from v1alpha1 and implement skuHandler interface #602

Merged
merged 2 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 0 additions & 85 deletions api/v1alpha1/sku_config.go

This file was deleted.

23 changes: 16 additions & 7 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import (
"context"
"fmt"
"os"
"reflect"
"regexp"
"strconv"
Expand Down Expand Up @@ -168,7 +169,7 @@
// Currently require a preset to specified, in future we can consider defining a template
if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if presetName := string(r.Preset.Name); !isValidPreset(presetName) {
} else if presetName := string(r.Preset.Name); !utils.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
return errs
Expand Down Expand Up @@ -295,8 +296,15 @@
}
instanceType := string(r.InstanceType)

// Check if instancetype exists in our SKUs map
if skuConfig, exists := SupportedGPUConfigs[instanceType]; exists {
skuHandler, err := utils.GetSKUHandler()
if err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get SKU handler: %v", err), "instanceType"))
return errs

Check warning on line 302 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L301-L302

Added lines #L301 - L302 were not covered by tests
}
gpuConfigs := skuHandler.GetGPUConfigs()

// Check if instancetype exists in our SKUs map for the particular cloud provider
if skuConfig, exists := gpuConfigs[instanceType]; exists {
if presetName != "" {
model := plugin.KaitoModelRegister.MustGet(presetName) // InferenceSpec has been validated so the name is valid.

Expand Down Expand Up @@ -350,9 +358,10 @@
}
}
} else {
// Check for other instance types pattern matches
if !strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, getSupportedSKUs()), "instanceType"))
provider := os.Getenv("CLOUD_PROVIDER")
// Check for other instance types pattern matches if cloud provider is Azure
if provider != consts.AzureCloudName || (!strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX)) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, skuHandler.GetSupportedSKUs()), "instanceType"))
}
}

Expand Down Expand Up @@ -398,7 +407,7 @@
if i.Preset != nil {
presetName := string(i.Preset.Name)
// Validate preset name
if !isValidPreset(presetName) {
if !utils.IsValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
Expand Down
58 changes: 6 additions & 52 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import (
"context"
"fmt"
"os"
"reflect"
"sort"
"strings"
"testing"

Expand Down Expand Up @@ -234,7 +232,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
{
name: "Insufficient total GPU memory",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
InstanceType: "Standard_NV6",
Count: pointerToInt(1),
},
modelGPUCount: "1",
Expand Down Expand Up @@ -263,7 +261,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
{
name: "Insufficient per GPU memory",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
InstanceType: "Standard_NV6",
Count: pointerToInt(2),
},
modelGPUCount: "1",
Expand Down Expand Up @@ -320,7 +318,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
{
name: "Tuning validation with single node",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
InstanceType: "Standard_NC6s_v3",
Count: pointerToInt(1),
},
errContent: "",
Expand All @@ -330,7 +328,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
{
name: "Tuning validation with multinode",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
InstanceType: "Standard_NC6s_v3",
Count: pointerToInt(2),
},
errContent: "Tuning does not currently support multinode configurations",
Expand All @@ -339,6 +337,8 @@ func TestResourceSpecValidateCreate(t *testing.T) {
},
}

os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if tc.validateTuning {
Expand Down Expand Up @@ -1438,49 +1438,3 @@ func TestDataDestinationValidateUpdate(t *testing.T) {
})
}
}

func TestGetSupportedSKUs(t *testing.T) {
tests := []struct {
name string
gpuConfigs map[string]GPUConfig
expectedResult []string // changed to a slice for deterministic ordering
}{
{
name: "no SKUs supported",
gpuConfigs: map[string]GPUConfig{},
expectedResult: []string{""},
},
{
name: "one SKU supported",
gpuConfigs: map[string]GPUConfig{
"Standard_NC6": {SKU: "Standard_NC6"},
},
expectedResult: []string{"Standard_NC6"},
},
{
name: "multiple SKUs supported",
gpuConfigs: map[string]GPUConfig{
"Standard_NC6": {SKU: "Standard_NC6"},
"Standard_NC12": {SKU: "Standard_NC12"},
},
expectedResult: []string{"Standard_NC6", "Standard_NC12"},
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
SupportedGPUConfigs = tc.gpuConfigs

resultSlice := strings.Split(getSupportedSKUs(), ", ")
sort.Strings(resultSlice)

// Sort the expectedResult for comparison
expectedResultSlice := tc.expectedResult
sort.Strings(expectedResultSlice)

if !reflect.DeepEqual(resultSlice, expectedResultSlice) {
t.Errorf("getSupportedSKUs() = %v, want %v", resultSlice, expectedResultSlice)
}
})
}
}
20 changes: 0 additions & 20 deletions api/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pkg/tuning/preset-tuning.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ var (
)

func getInstanceGPUCount(sku string) int {
gpuConfig, exists := kaitov1alpha1.SupportedGPUConfigs[sku]
skuHandler, _ := utils.GetSKUHandler()
gpuConfigs := skuHandler.GetGPUConfigs()

gpuConfig, exists := gpuConfigs[sku]
if !exists {
return 1
}
Expand Down
20 changes: 7 additions & 13 deletions pkg/tuning/preset-tuning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,6 @@ import (
"k8s.io/utils/pointer"
)

// Mocking the SupportedGPUConfigs to be used in test scenarios.
var mockSupportedGPUConfigs = map[string]kaitov1alpha1.GPUConfig{
"sku1": {GPUCount: 2},
"sku2": {GPUCount: 4},
"sku3": {GPUCount: 0},
}

func normalize(s string) string {
return strings.Join(strings.Fields(s), " ")
}
Expand All @@ -54,18 +47,19 @@ func saveEnv(key string) func() {
}

func TestGetInstanceGPUCount(t *testing.T) {
kaitov1alpha1.SupportedGPUConfigs = mockSupportedGPUConfigs
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)

testcases := map[string]struct {
sku string
expectedGPUCount int
}{
"SKU Exists With Multiple GPUs": {
sku: "sku1",
expectedGPUCount: 2,
sku: "Standard_NC24s_v3",
expectedGPUCount: 4,
},
"SKU Exists With Zero GPUs": {
sku: "sku3",
expectedGPUCount: 0,
"SKU Exists With One GPU": {
sku: "Standard_NC6s_v3",
expectedGPUCount: 1,
},
"SKU Does Not Exist": {
sku: "sku_unknown",
Expand Down
5 changes: 5 additions & 0 deletions pkg/utils/common-preset.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package utils

import (
"github.com/azure/kaito/pkg/utils/plugin"
corev1 "k8s.io/api/core/v1"
)

Expand Down Expand Up @@ -149,3 +150,7 @@
}
return volume, volumeMount
}

func IsValidPreset(preset string) bool {
return plugin.KaitoModelRegister.Has(preset)

Check warning on line 155 in pkg/utils/common-preset.go

View check run for this annotation

Codecov / codecov/patch

pkg/utils/common-preset.go#L154-L155

Added lines #L154 - L155 were not covered by tests
}
Loading