Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Automate the adapters manifests #463

Merged
merged 16 commits into from
Jun 12, 2024
77 changes: 76 additions & 1 deletion api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
"reflect"
"regexp"
"sort"
"strconv"
"strings"

"github.com/azure/kaito/pkg/utils"
Expand All @@ -27,6 +28,7 @@

DefaultLoraConfigMapTemplate = "lora-params-template"
DefaultQloraConfigMapTemplate = "qlora-params-template"
MaxAdaptersNumber = 10
)

func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType {
Expand Down Expand Up @@ -58,7 +60,6 @@
w.Resource.validateUpdate(&old.Resource).ViaField("resource"),
)
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference"))
}
if w.Tuning != nil {
Expand Down Expand Up @@ -89,6 +90,44 @@
return errs
}

func ValidateDNSSubdomain(name string) bool {
var dnsSubDomainRegexp = regexp.MustCompile(`^(?i:[a-z0-9]([-a-z0-9]*[a-z0-9])?)$`)
if len(name) < 1 || len(name) > 253 {
return false

Check warning on line 96 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L96

Added line #L96 was not covered by tests
}
return dnsSubDomainRegexp.MatchString(name)
}

func (r *AdapterSpec) validateCreateorUpdate() (errs *apis.FieldError) {
if r.Source == nil {
errs = errs.Also(apis.ErrMissingField("Source"))
} else {
errs = errs.Also(r.Source.validateCreate().ViaField("Adapters"))

if r.Source.Name == "" {
errs = errs.Also(apis.ErrMissingField("Name of Adapter field must be specified"))
} else if !ValidateDNSSubdomain(r.Source.Name) {
errs = errs.Also(apis.ErrMissingField("Name of Adapter must be a valid DNS subdomain value"))

Check warning on line 110 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L110

Added line #L110 was not covered by tests
}
if r.Source.Image == "" {
errs = errs.Also(apis.ErrMissingField("Image of Adapter field must be specified"))

Check warning on line 113 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L113

Added line #L113 was not covered by tests
}
if r.Strength == nil {
var defaultStrength = "1.0"
r.Strength = &defaultStrength

Check warning on line 117 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L116-L117

Added lines #L116 - L117 were not covered by tests
}
strength, err := strconv.ParseFloat(*r.Strength, 64)
if err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Invalid strength value for Adapter '%s': %v", r.Source.Name, err), "adapter"))
}
if strength < 0 || strength > 1.0 {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Strength value for Adapter '%s' must be between 0 and 1", r.Source.Name), "adapter"))
}

}
return errs
}

func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string) (errs *apis.FieldError) {
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
Expand Down Expand Up @@ -346,6 +385,16 @@
}
// Note: we don't enforce private access mode to have image secrets, in case anonymous pulling is enabled
}
if len(i.Adapters) > MaxAdaptersNumber {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Number of Adapters exceeds the maximum limit, maximum of %s allowed", strconv.Itoa(MaxAdaptersNumber))))
}

// check if adapter names are duplicate
if len(i.Adapters) > 0 {
nameMap := make(map[string]string)
errs = errs.Also(validateDuplicateName(i.Adapters, nameMap))
}

return errs
}

Expand All @@ -357,6 +406,32 @@
if (i.Template != nil && old.Template == nil) || (i.Template == nil && old.Template != nil) {
errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "template"))
}
nameMap := make(map[string]string)
for _, adapter := range old.Adapters {
nameMap[adapter.Source.Name] = adapter.Source.Image

Check warning on line 411 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L411

Added line #L411 was not covered by tests
}

// check if adapter names are duplicate
for _, adapter := range i.Adapters {
errs = errs.Also(adapter.validateCreateorUpdate())

Check warning on line 416 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L416

Added line #L416 was not covered by tests
}

// check if adapter names are duplicate
if len(i.Adapters) > 0 {
errs = errs.Also(validateDuplicateName(i.Adapters, nameMap))

Check warning on line 421 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L421

Added line #L421 was not covered by tests
}
return errs
}

func validateDuplicateName(adapters []AdapterSpec, nameMap map[string]string) (errs *apis.FieldError) {
for _, adapter := range adapters {
if previousAdapterImage, ok := nameMap[adapter.Source.Name]; ok {
if previousAdapterImage != adapter.Source.Image {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Duplicate adapter source name found: %s", adapter.Source.Name)))
}
} else {
nameMap[adapter.Source.Name] = adapter.Source.Image
}
}
return errs
}
142 changes: 141 additions & 1 deletion api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package v1alpha1

import (
"context"
"fmt"
"os"
"reflect"
"sort"
Expand All @@ -24,6 +25,10 @@ import (

const DefaultReleaseNamespace = "kaito-workspace"

var ValidStrength string = "0.5"
var InvalidStrength1 string = "invalid"
var InvalidStrength2 string = "1.5"

var gpuCountRequirement string
var totalGPUMemoryRequirement string
var perGPUMemoryRequirement string
Expand Down Expand Up @@ -474,6 +479,56 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
errContent: "This preset only supports private AccessMode, AccessMode must be private to continue",
expectErrs: true,
},
{
name: "Adapeters more than 10",
inferenceSpec: func() *InferenceSpec {
spec := &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: ModelImageAccessModePublic,
},
},
}
for i := 1; i <= 11; i++ {
spec.Adapters = append(spec.Adapters, AdapterSpec{
Source: &DataSource{
Name: fmt.Sprintf("Adapter-%d", i),
Image: fmt.Sprintf("fake.kaito.com/kaito-image:0.0.%d", i),
},
Strength: &ValidStrength,
})
}
return spec
}(),
errContent: "Number of Adapters exceeds the maximum limit, maximum of 10 allowed",
expectErrs: true,
},
{
name: "Adapeters names are duplicated",
inferenceSpec: func() *InferenceSpec {
spec := &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: ModelImageAccessModePublic,
},
},
}
for i := 1; i <= 2; i++ {
spec.Adapters = append(spec.Adapters, AdapterSpec{
Source: &DataSource{
Name: "Adapter",
Image: fmt.Sprintf("fake.kaito.com/kaito-image:0.0.%d", i),
},
Strength: &ValidStrength,
})
}
return spec
}(),
errContent: "",
expectErrs: true,
},
{
name: "Valid Preset",
inferenceSpec: &InferenceSpec{
Expand All @@ -484,7 +539,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
},
},
},
errContent: "",
errContent: "Duplicate adapter source name found:",
expectErrs: false,
},
}
Expand Down Expand Up @@ -520,6 +575,91 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
}
}

func TestAdapterSpecValidateCreateorUpdate(t *testing.T) {
RegisterValidationTestModels()
tests := []struct {
name string
adapterSpec *AdapterSpec
errContent string // Content expected error to include, if any
expectErrs bool
}{
{
name: "Missing Source",
adapterSpec: &AdapterSpec{
Strength: &ValidStrength,
},
errContent: "Source",
expectErrs: true,
},
{
name: "Missing Source Name",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
errContent: "Name of Adapter field must be specified",
expectErrs: true,
},
{
name: "Invalid Strength, not a number",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &InvalidStrength1,
},
errContent: "Invalid strength value for Adapter 'Adapter-1'",
expectErrs: true,
},
{
name: "Invalid Strength, larger than 1",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &InvalidStrength2,
},
errContent: "Strength value for Adapter 'Adapter-1' must be between 0 and 1",
expectErrs: true,
},
{
name: "Valid Adapter",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
errContent: "",
expectErrs: false,
},
}

// Run the tests
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
errs := tc.adapterSpec.validateCreateorUpdate()
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tc.expectErrs)
}

// If there is an error and errContent is not empty, check that the error contains the expected content.
if hasErrs && tc.errContent != "" {
errMsg := errs.Error()
if !strings.Contains(errMsg, tc.errContent) {
t.Errorf("validateUpdate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
}
}
})
}
}

func TestInferenceSpecValidateUpdate(t *testing.T) {
tests := []struct {
name string
Expand Down
1 change: 1 addition & 0 deletions examples/inference/kaito_workspace_falcon_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ resource:
inference:
preset:
name: "falcon-7b"

19 changes: 19 additions & 0 deletions examples/inference/kaito_workspace_falcon_7b_with_adapters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-falcon-7b
resource:
instanceType: "Standard_NC12s_v3"
labelSelector:
matchLabels:
apps: falcon-7b
inference:
preset:
name: "falcon-7b"
adapters:
- source:
name: "falcon-7b-mental"
urls:
- "https://huggingface.co/dfurman/Falcon-7B-Chat-v0.1/blob/main/finetune_falcon7b_oasst1_with_bnb_peft.ipynb"
strength: "0.2"

16 changes: 9 additions & 7 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ package inference
import (
"context"
"fmt"
"github.com/azure/kaito/pkg/utils"
"os"
"strconv"

"github.com/azure/kaito/pkg/utils"

kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/resources"
Expand Down Expand Up @@ -56,13 +57,14 @@ var (
tolerations = []corev1.Toleration{
{
Effect: corev1.TaintEffectNoSchedule,
Operator: corev1.TolerationOpEqual,
Key: resources.GPUString,
Operator: corev1.TolerationOpExists,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why change from corev1.TolerationOpEqual to corev1.TolerationOpExists

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The origin one was different from the yaml file, and it will make the adapter cannot find the node

Key: resources.CapacityNvidiaGPU,
Copy link
Collaborator

@ishaansehgal99 ishaansehgal99 Jun 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similiar question on why change to resources.CapacityNvidiaGPU here. these tolerations should ideally not need to be changed...

},
{
Effect: corev1.TaintEffectNoSchedule,
Value: resources.GPUString,
Key: "sku",
Effect: corev1.TaintEffectNoSchedule,
Value: resources.GPUString,
Key: "sku",
Operator: corev1.TolerationOpEqual,
},
}
)
Expand Down Expand Up @@ -120,7 +122,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work

var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount
shmVolume, shmVolumeMount := utils.ConfigSHMVolume(*workspaceObj.Resource.Count)
shmVolume, shmVolumeMount := utils.ConfigSHMVolume(*workspaceObj.Resource.Count + len(workspaceObj.Inference.Adapters))
if shmVolume.Name != "" {
volumes = append(volumes, shmVolume)
}
Expand Down
Loading
Loading