diff --git a/.github/workflows/e2e-workflow.yml b/.github/workflows/e2e-workflow.yml index aafb0de36..79793c989 100644 --- a/.github/workflows/e2e-workflow.yml +++ b/.github/workflows/e2e-workflow.yml @@ -115,6 +115,10 @@ jobs: inlineScript: | az identity create --name gpuIdentity --resource-group ${{ env.CLUSTER_NAME }} + - name: Generate APIs + run: | + make generate + - name: build KAITO image if: ${{ !inputs.isRelease }} shell: bash diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 478110f97..a47383007 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -39,7 +39,9 @@ jobs: uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5.0.0 with: go-version: ${{ env.GO_VERSION }} - + - name: Generate APIs + run: | + make generate - name: Run unit tests & Generate coverage run: | make unit-test diff --git a/Makefile b/Makefile index ee3033938..410219d8e 100644 --- a/Makefile +++ b/Makefile @@ -229,7 +229,7 @@ CONTROLLER_GEN ?= $(LOCALBIN)/controller-gen ENVTEST ?= $(LOCALBIN)/setup-envtest ## Tool Versions -CONTROLLER_TOOLS_VERSION ?= v0.12.0 +CONTROLLER_TOOLS_VERSION ?= v0.15.0 .PHONY: controller-gen controller-gen: $(CONTROLLER_GEN) ## Download controller-gen locally if necessary. If wrong version is installed, it will be overwritten. diff --git a/api/v1alpha1/params_validation.go b/api/v1alpha1/params_validation.go index cd6debaac..4543d5db7 100644 --- a/api/v1alpha1/params_validation.go +++ b/api/v1alpha1/params_validation.go @@ -13,6 +13,7 @@ import ( "gopkg.in/yaml.v2" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" + "k8s.io/apimachinery/pkg/runtime" "knative.dev/pkg/apis" "sigs.k8s.io/controller-runtime/pkg/client" ) @@ -22,13 +23,13 @@ type Config struct { } type TrainingConfig struct { - ModelConfig map[string]interface{} `yaml:"ModelConfig"` - TokenizerParams map[string]interface{} `yaml:"TokenizerParams"` - QuantizationConfig map[string]interface{} `yaml:"QuantizationConfig"` - LoraConfig map[string]interface{} `yaml:"LoraConfig"` - TrainingArguments map[string]interface{} `yaml:"TrainingArguments"` - DatasetConfig map[string]interface{} `yaml:"DatasetConfig"` - DataCollator map[string]interface{} `yaml:"DataCollator"` + ModelConfig map[string]runtime.RawExtension `yaml:"ModelConfig"` + TokenizerParams map[string]runtime.RawExtension `yaml:"TokenizerParams"` + QuantizationConfig map[string]runtime.RawExtension `yaml:"QuantizationConfig"` + LoraConfig map[string]runtime.RawExtension `yaml:"LoraConfig"` + TrainingArguments map[string]runtime.RawExtension `yaml:"TrainingArguments"` + DatasetConfig map[string]runtime.RawExtension `yaml:"DatasetConfig"` + DataCollator map[string]runtime.RawExtension `yaml:"DataCollator"` } func validateNilOrBool(value interface{}) error { @@ -41,6 +42,58 @@ func validateNilOrBool(value interface{}) error { return fmt.Errorf("value must be either nil or a boolean, got type %T", value) } +// UnmarshalYAML custom method +func (t *TrainingConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + var raw map[string]interface{} + if err := unmarshal(&raw); err != nil { + return err + } + + // This function converts a map[string]interface{} to a map[string]runtime.RawExtension. + // It does this by setting the raw marshalled data of the unmarshalled YAML to + // be the raw data of the runtime.RawExtension object. + handleRawExtension := func(raw map[string]interface{}, field string) (map[string]runtime.RawExtension, error) { + var target map[string]runtime.RawExtension + if value, found := raw[field]; found { + delete(raw, field) + var ext runtime.RawExtension + data, err := yaml.Marshal(value) + if err != nil { + return nil, err + } + ext.Raw = data + if target == nil { + target = make(map[string]runtime.RawExtension) + } + target[field] = ext + } + return target, nil + } + + fields := []struct { + name string + target *map[string]runtime.RawExtension + }{ + {"ModelConfig", &t.ModelConfig}, + {"TokenizerParams", &t.TokenizerParams}, + {"QuantizationConfig", &t.QuantizationConfig}, + {"LoraConfig", &t.LoraConfig}, + {"TrainingArguments", &t.TrainingArguments}, + {"DatasetConfig", &t.DatasetConfig}, + {"DataCollator", &t.DataCollator}, + } + + var err error + for _, field := range fields { + *field.target, err = handleRawExtension(raw, field.name) + if err != nil { + return err + } + } + + return nil +} + func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError { trainingConfigYAML, ok := cm.Data["training_config.yaml"] if !ok { @@ -55,31 +108,41 @@ func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *a // Validate QuantizationConfig if it exists quantConfig := config.TrainingConfig.QuantizationConfig if quantConfig != nil { - // Dynamic field search for quantization settings within ModelConfig - loadIn4bit, _ := utils.SearchMap(quantConfig, "load_in_4bit") - loadIn8bit, _ := utils.SearchMap(quantConfig, "load_in_8bit") + quantConfigRaw, quantConfigExists := quantConfig["QuantizationConfig"] + if quantConfigExists { + // Dynamic field search for quantization settings within ModelConfig + loadIn4bit, _, err := utils.SearchRawExtension(quantConfigRaw, "load_in_4bit") + if err != nil { + return apis.ErrInvalidValue(err.Error(), "load_in_4bit") + } + loadIn8bit, _, err := utils.SearchRawExtension(quantConfigRaw, "load_in_8bit") + if err != nil { + return apis.ErrInvalidValue(err.Error(), "load_in_8bit") + } - // Validate both loadIn4bit and loadIn8bit - if err := validateNilOrBool(loadIn4bit); err != nil { - return apis.ErrInvalidValue(err.Error(), "load_in_4bit") - } - if err := validateNilOrBool(loadIn8bit); err != nil { - return apis.ErrInvalidValue(err.Error(), "load_in_8bit") - } + // Validate both loadIn4bit and loadIn8bit + if err := validateNilOrBool(loadIn4bit); err != nil { + return apis.ErrInvalidValue(err.Error(), "load_in_4bit") + } + if err := validateNilOrBool(loadIn8bit); err != nil { + return apis.ErrInvalidValue(err.Error(), "load_in_8bit") + } - loadIn4bitBool, _ := loadIn4bit.(bool) - loadIn8bitBool, _ := loadIn8bit.(bool) + loadIn4bitBool, _ := loadIn4bit.(bool) + loadIn8bitBool, _ := loadIn8bit.(bool) - if loadIn4bitBool && loadIn8bitBool { - return apis.ErrGeneric(fmt.Sprintf("Cannot set both 'load_in_4bit' and 'load_in_8bit' to true in ConfigMap '%s'", cm.Name), "QuantizationConfig") - } - if methodLowerCase == string(TuningMethodLora) { - if loadIn4bitBool || loadIn8bitBool { - return apis.ErrGeneric(fmt.Sprintf("For method 'lora', 'load_in_4bit' or 'load_in_8bit' in ConfigMap '%s' must not be true", cm.Name), "QuantizationConfig") + // Validation Logic + if loadIn4bitBool && loadIn8bitBool { + return apis.ErrGeneric(fmt.Sprintf("Cannot set both 'load_in_4bit' and 'load_in_8bit' to true in ConfigMap '%s'", cm.Name), "QuantizationConfig") } - } else if methodLowerCase == string(TuningMethodQLora) { - if !loadIn4bitBool && !loadIn8bitBool { - return apis.ErrMissingField(fmt.Sprintf("For method 'qlora', either 'load_in_4bit' or 'load_in_8bit' must be true in ConfigMap '%s'", cm.Name), "QuantizationConfig") + if methodLowerCase == string(TuningMethodLora) { + if loadIn4bitBool || loadIn8bitBool { + return apis.ErrGeneric(fmt.Sprintf("For method 'lora', 'load_in_4bit' or 'load_in_8bit' in ConfigMap '%s' must not be true", cm.Name), "QuantizationConfig") + } + } else if methodLowerCase == string(TuningMethodQLora) { + if !loadIn4bitBool && !loadIn8bitBool { + return apis.ErrMissingField(fmt.Sprintf("For method 'qlora', either 'load_in_4bit' or 'load_in_8bit' must be true in ConfigMap '%s'", cm.Name), "QuantizationConfig") + } } } } else if methodLowerCase == string(TuningMethodQLora) { diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 65d667009..d84cb9906 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -137,6 +137,53 @@ func defaultConfigMapManifest() *v1.ConfigMap { } } +func qloraConfigMapManifest() *v1.ConfigMap { + return &v1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: DefaultQloraConfigMapTemplate, + Namespace: DefaultReleaseNamespace, + }, + Data: map[string]string{ + "training_config.yaml": `training_config: + ModelConfig: + torch_dtype: "bfloat16" + local_files_only: true + device_map: "auto" + + TokenizerParams: + padding: true + truncation: true + + QuantizationConfig: + load_in_4bit: true + bnb_4bit_quant_type: "nf4" + bnb_4bit_compute_dtype: "bfloat16" + bnb_4bit_use_double_quant: true + + LoraConfig: + r: 16 + lora_alpha: 32 + target_modules: "query_key_value" + lora_dropout: 0.05 + bias: "none" + + TrainingArguments: + output_dir: "." + num_train_epochs: 4 + auto_find_batch_size: true + ddp_find_unused_parameters: false + save_strategy: "epoch" + + DatasetConfig: + shuffle_dataset: true + train_test_split: 1 + + DataCollator: + mlm: true`, + }, + } +} + func TestResourceSpecValidateCreate(t *testing.T) { RegisterValidationTestModels() tests := []struct { @@ -700,7 +747,7 @@ func TestTuningSpecValidateCreate(t *testing.T) { // Create fake client with default ConfigMap scheme := runtime.NewScheme() _ = v1.AddToScheme(scheme) - client := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(defaultConfigMapManifest()).Build() + client := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(defaultConfigMapManifest(), qloraConfigMapManifest()).Build() k8sclient.SetGlobalClient(client) // Include client in ctx ctx := context.Background() @@ -722,6 +769,17 @@ func TestTuningSpecValidateCreate(t *testing.T) { wantErr: false, errFields: nil, }, + { + name: "Verify QLoRA Config", + tuningSpec: &TuningSpec{ + Input: &DataSource{Name: "valid-input", Volume: &v1.VolumeSource{}}, + Output: &DataDestination{Volume: &v1.VolumeSource{}}, + Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}}, + Method: TuningMethodQLora, + }, + wantErr: false, + errFields: nil, + }, { name: "Missing Input", tuningSpec: &TuningSpec{ diff --git a/api/v1alpha1/zz_generated.deepcopy.go b/api/v1alpha1/zz_generated.deepcopy.go index 112c94b1b..8d0bf3782 100644 --- a/api/v1alpha1/zz_generated.deepcopy.go +++ b/api/v1alpha1/zz_generated.deepcopy.go @@ -1,5 +1,4 @@ //go:build !ignore_autogenerated -// +build !ignore_autogenerated // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. @@ -11,7 +10,7 @@ package v1alpha1 import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/apis/meta/v1" - runtime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime" ) // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. @@ -39,6 +38,22 @@ func (in *AdapterSpec) DeepCopy() *AdapterSpec { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *Config) DeepCopyInto(out *Config) { + *out = *in + in.TrainingConfig.DeepCopyInto(&out.TrainingConfig) +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new Config. +func (in *Config) DeepCopy() *Config { + if in == nil { + return nil + } + out := new(Config) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *DataDestination) DeepCopyInto(out *DataDestination) { *out = *in @@ -223,6 +238,70 @@ func (in *ResourceSpec) DeepCopy() *ResourceSpec { return out } +// DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. +func (in *TrainingConfig) DeepCopyInto(out *TrainingConfig) { + *out = *in + if in.ModelConfig != nil { + in, out := &in.ModelConfig, &out.ModelConfig + *out = make(map[string]runtime.RawExtension, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } + if in.TokenizerParams != nil { + in, out := &in.TokenizerParams, &out.TokenizerParams + *out = make(map[string]runtime.RawExtension, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } + if in.QuantizationConfig != nil { + in, out := &in.QuantizationConfig, &out.QuantizationConfig + *out = make(map[string]runtime.RawExtension, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } + if in.LoraConfig != nil { + in, out := &in.LoraConfig, &out.LoraConfig + *out = make(map[string]runtime.RawExtension, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } + if in.TrainingArguments != nil { + in, out := &in.TrainingArguments, &out.TrainingArguments + *out = make(map[string]runtime.RawExtension, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } + if in.DatasetConfig != nil { + in, out := &in.DatasetConfig, &out.DatasetConfig + *out = make(map[string]runtime.RawExtension, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } + if in.DataCollator != nil { + in, out := &in.DataCollator, &out.DataCollator + *out = make(map[string]runtime.RawExtension, len(*in)) + for key, val := range *in { + (*out)[key] = *val.DeepCopy() + } + } +} + +// DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new TrainingConfig. +func (in *TrainingConfig) DeepCopy() *TrainingConfig { + if in == nil { + return nil + } + out := new(TrainingConfig) + in.DeepCopyInto(out) + return out +} + // DeepCopyInto is an autogenerated deepcopy function, copying the receiver, writing into out. in must be non-nil. func (in *TuningSpec) DeepCopyInto(out *TuningSpec) { *out = *in diff --git a/pkg/utils/common.go b/pkg/utils/common.go index f2bf7a368..8c021b5fa 100644 --- a/pkg/utils/common.go +++ b/pkg/utils/common.go @@ -5,7 +5,9 @@ package utils import ( "fmt" + "gopkg.in/yaml.v2" "io/ioutil" + "k8s.io/apimachinery/pkg/runtime" "os" "github.com/azure/kaito/pkg/utils/consts" @@ -28,6 +30,21 @@ func SearchMap(m map[string]interface{}, key string) (value interface{}, exists return nil, false } +// SearchRawExtension performs a search for a key within a runtime.RawExtension. +func SearchRawExtension(raw runtime.RawExtension, key string) (interface{}, bool, error) { + var data map[string]interface{} + if err := yaml.Unmarshal(raw.Raw, &data); err != nil { + return nil, false, fmt.Errorf("failed to unmarshal runtime.RawExtension: %w", err) + } + + result, found := data[key] + if !found { + return nil, false, nil + } + + return result, true, nil +} + func MergeConfigMaps(baseMap, overrideMap map[string]string) map[string]string { merged := make(map[string]string) for k, v := range baseMap {