Skip to content

Commit

Permalink
fix: Update controller-gen version (#429)
Browse files Browse the repository at this point in the history
**Reason for Change**:
- Bumb controller-gen version to 0.15.0
- Add make generate step in both unit tests and e2e tests pipelines.

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

---------

Signed-off-by: Heba Elayoty <hebaelayoty@gmail.com>
Co-authored-by: ishaansehgal99 <ishaanforthewin@gmail.com>
  • Loading branch information
helayoty and ishaansehgal99 authored May 22, 2024
1 parent b65dd2e commit 0dca449
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 33 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/e2e-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ jobs:
inlineScript: |
az identity create --name gpuIdentity --resource-group ${{ env.CLUSTER_NAME }}
- name: Generate APIs
run: |
make generate
- name: build KAITO image
if: ${{ !inputs.isRelease }}
shell: bash
Expand Down
4 changes: 3 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ jobs:
uses: actions/setup-go@0c52d547c9bc32b1aa3301fd7a9cb496313a4491 # v5.0.0
with:
go-version: ${{ env.GO_VERSION }}

- name: Generate APIs
run: |
make generate
- name: Run unit tests & Generate coverage
run: |
make unit-test
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ CONTROLLER_GEN ?= $(LOCALBIN)/controller-gen
ENVTEST ?= $(LOCALBIN)/setup-envtest

## Tool Versions
CONTROLLER_TOOLS_VERSION ?= v0.12.0
CONTROLLER_TOOLS_VERSION ?= v0.15.0

.PHONY: controller-gen
controller-gen: $(CONTROLLER_GEN) ## Download controller-gen locally if necessary. If wrong version is installed, it will be overwritten.
Expand Down
119 changes: 91 additions & 28 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"gopkg.in/yaml.v2"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"knative.dev/pkg/apis"
"sigs.k8s.io/controller-runtime/pkg/client"
)
Expand All @@ -22,13 +23,13 @@ type Config struct {
}

type TrainingConfig struct {
ModelConfig map[string]interface{} `yaml:"ModelConfig"`
TokenizerParams map[string]interface{} `yaml:"TokenizerParams"`
QuantizationConfig map[string]interface{} `yaml:"QuantizationConfig"`
LoraConfig map[string]interface{} `yaml:"LoraConfig"`
TrainingArguments map[string]interface{} `yaml:"TrainingArguments"`
DatasetConfig map[string]interface{} `yaml:"DatasetConfig"`
DataCollator map[string]interface{} `yaml:"DataCollator"`
ModelConfig map[string]runtime.RawExtension `yaml:"ModelConfig"`
TokenizerParams map[string]runtime.RawExtension `yaml:"TokenizerParams"`
QuantizationConfig map[string]runtime.RawExtension `yaml:"QuantizationConfig"`
LoraConfig map[string]runtime.RawExtension `yaml:"LoraConfig"`
TrainingArguments map[string]runtime.RawExtension `yaml:"TrainingArguments"`
DatasetConfig map[string]runtime.RawExtension `yaml:"DatasetConfig"`
DataCollator map[string]runtime.RawExtension `yaml:"DataCollator"`
}

func validateNilOrBool(value interface{}) error {
Expand All @@ -41,6 +42,58 @@ func validateNilOrBool(value interface{}) error {
return fmt.Errorf("value must be either nil or a boolean, got type %T", value)
}

// UnmarshalYAML custom method
func (t *TrainingConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
var raw map[string]interface{}
if err := unmarshal(&raw); err != nil {
return err
}

// This function converts a map[string]interface{} to a map[string]runtime.RawExtension.
// It does this by setting the raw marshalled data of the unmarshalled YAML to
// be the raw data of the runtime.RawExtension object.
handleRawExtension := func(raw map[string]interface{}, field string) (map[string]runtime.RawExtension, error) {
var target map[string]runtime.RawExtension
if value, found := raw[field]; found {
delete(raw, field)
var ext runtime.RawExtension
data, err := yaml.Marshal(value)
if err != nil {
return nil, err
}
ext.Raw = data
if target == nil {
target = make(map[string]runtime.RawExtension)
}
target[field] = ext
}
return target, nil
}

fields := []struct {
name string
target *map[string]runtime.RawExtension
}{
{"ModelConfig", &t.ModelConfig},
{"TokenizerParams", &t.TokenizerParams},
{"QuantizationConfig", &t.QuantizationConfig},
{"LoraConfig", &t.LoraConfig},
{"TrainingArguments", &t.TrainingArguments},
{"DatasetConfig", &t.DatasetConfig},
{"DataCollator", &t.DataCollator},
}

var err error
for _, field := range fields {
*field.target, err = handleRawExtension(raw, field.name)
if err != nil {
return err
}
}

return nil
}

func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError {
trainingConfigYAML, ok := cm.Data["training_config.yaml"]
if !ok {
Expand All @@ -55,31 +108,41 @@ func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *a
// Validate QuantizationConfig if it exists
quantConfig := config.TrainingConfig.QuantizationConfig
if quantConfig != nil {
// Dynamic field search for quantization settings within ModelConfig
loadIn4bit, _ := utils.SearchMap(quantConfig, "load_in_4bit")
loadIn8bit, _ := utils.SearchMap(quantConfig, "load_in_8bit")
quantConfigRaw, quantConfigExists := quantConfig["QuantizationConfig"]
if quantConfigExists {
// Dynamic field search for quantization settings within ModelConfig
loadIn4bit, _, err := utils.SearchRawExtension(quantConfigRaw, "load_in_4bit")
if err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_4bit")
}
loadIn8bit, _, err := utils.SearchRawExtension(quantConfigRaw, "load_in_8bit")
if err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_8bit")
}

// Validate both loadIn4bit and loadIn8bit
if err := validateNilOrBool(loadIn4bit); err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_4bit")
}
if err := validateNilOrBool(loadIn8bit); err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_8bit")
}
// Validate both loadIn4bit and loadIn8bit
if err := validateNilOrBool(loadIn4bit); err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_4bit")
}
if err := validateNilOrBool(loadIn8bit); err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_8bit")
}

loadIn4bitBool, _ := loadIn4bit.(bool)
loadIn8bitBool, _ := loadIn8bit.(bool)
loadIn4bitBool, _ := loadIn4bit.(bool)
loadIn8bitBool, _ := loadIn8bit.(bool)

if loadIn4bitBool && loadIn8bitBool {
return apis.ErrGeneric(fmt.Sprintf("Cannot set both 'load_in_4bit' and 'load_in_8bit' to true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
}
if methodLowerCase == string(TuningMethodLora) {
if loadIn4bitBool || loadIn8bitBool {
return apis.ErrGeneric(fmt.Sprintf("For method 'lora', 'load_in_4bit' or 'load_in_8bit' in ConfigMap '%s' must not be true", cm.Name), "QuantizationConfig")
// Validation Logic
if loadIn4bitBool && loadIn8bitBool {
return apis.ErrGeneric(fmt.Sprintf("Cannot set both 'load_in_4bit' and 'load_in_8bit' to true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
}
} else if methodLowerCase == string(TuningMethodQLora) {
if !loadIn4bitBool && !loadIn8bitBool {
return apis.ErrMissingField(fmt.Sprintf("For method 'qlora', either 'load_in_4bit' or 'load_in_8bit' must be true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
if methodLowerCase == string(TuningMethodLora) {
if loadIn4bitBool || loadIn8bitBool {
return apis.ErrGeneric(fmt.Sprintf("For method 'lora', 'load_in_4bit' or 'load_in_8bit' in ConfigMap '%s' must not be true", cm.Name), "QuantizationConfig")
}
} else if methodLowerCase == string(TuningMethodQLora) {
if !loadIn4bitBool && !loadIn8bitBool {
return apis.ErrMissingField(fmt.Sprintf("For method 'qlora', either 'load_in_4bit' or 'load_in_8bit' must be true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
}
}
}
} else if methodLowerCase == string(TuningMethodQLora) {
Expand Down
60 changes: 59 additions & 1 deletion api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,53 @@ func defaultConfigMapManifest() *v1.ConfigMap {
}
}

func qloraConfigMapManifest() *v1.ConfigMap {
return &v1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: DefaultQloraConfigMapTemplate,
Namespace: DefaultReleaseNamespace,
},
Data: map[string]string{
"training_config.yaml": `training_config:
ModelConfig:
torch_dtype: "bfloat16"
local_files_only: true
device_map: "auto"
TokenizerParams:
padding: true
truncation: true
QuantizationConfig:
load_in_4bit: true
bnb_4bit_quant_type: "nf4"
bnb_4bit_compute_dtype: "bfloat16"
bnb_4bit_use_double_quant: true
LoraConfig:
r: 16
lora_alpha: 32
target_modules: "query_key_value"
lora_dropout: 0.05
bias: "none"
TrainingArguments:
output_dir: "."
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
save_strategy: "epoch"
DatasetConfig:
shuffle_dataset: true
train_test_split: 1
DataCollator:
mlm: true`,
},
}
}

func TestResourceSpecValidateCreate(t *testing.T) {
RegisterValidationTestModels()
tests := []struct {
Expand Down Expand Up @@ -700,7 +747,7 @@ func TestTuningSpecValidateCreate(t *testing.T) {
// Create fake client with default ConfigMap
scheme := runtime.NewScheme()
_ = v1.AddToScheme(scheme)
client := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(defaultConfigMapManifest()).Build()
client := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(defaultConfigMapManifest(), qloraConfigMapManifest()).Build()
k8sclient.SetGlobalClient(client)
// Include client in ctx
ctx := context.Background()
Expand All @@ -722,6 +769,17 @@ func TestTuningSpecValidateCreate(t *testing.T) {
wantErr: false,
errFields: nil,
},
{
name: "Verify QLoRA Config",
tuningSpec: &TuningSpec{
Input: &DataSource{Name: "valid-input", Volume: &v1.VolumeSource{}},
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}},
Method: TuningMethodQLora,
},
wantErr: false,
errFields: nil,
},
{
name: "Missing Input",
tuningSpec: &TuningSpec{
Expand Down
83 changes: 81 additions & 2 deletions api/v1alpha1/zz_generated.deepcopy.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 0dca449

Please sign in to comment.