Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Simple Configmap Validation Checks - Part 6 #355

Merged
merged 19 commits into from
Apr 19, 2024
167 changes: 167 additions & 0 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package v1alpha1

import (
"context"
"fmt"
"reflect"

"github.com/azure/kaito/pkg/k8sclient"
"github.com/azure/kaito/pkg/utils"
"gopkg.in/yaml.v2"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"knative.dev/pkg/apis"
"sigs.k8s.io/controller-runtime/pkg/client"
)

type Config struct {
TrainingConfig TrainingConfig `yaml:"training_config"`
}

type TrainingConfig struct {
ModelConfig map[string]interface{} `yaml:"ModelConfig"`
TokenizerParams map[string]interface{} `yaml:"TokenizerParams"`
QuantizationConfig map[string]interface{} `yaml:"QuantizationConfig"`
LoraConfig map[string]interface{} `yaml:"LoraConfig"`
TrainingArguments map[string]interface{} `yaml:"TrainingArguments"`
DatasetConfig map[string]interface{} `yaml:"DatasetConfig"`
DataCollator map[string]interface{} `yaml:"DataCollator"`
}

func validateNilOrBool(value interface{}) error {
if value == nil {
return nil // nil is acceptable
}
if _, ok := value.(bool); ok {
return nil // Correct type
}
return fmt.Errorf("value must be either nil or a boolean, got type %T", value)
}

func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError {
trainingConfigYAML, ok := cm.Data["training_config.yaml"]
if !ok {
return apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' does not contain 'training_config.yaml' in namespace '%s'", cm.Name, cm.Namespace), "config")
}

var config Config
if err := yaml.Unmarshal([]byte(trainingConfigYAML), &config); err != nil {
return apis.ErrGeneric(fmt.Sprintf("Failed to parse 'training_config.yaml' in ConfigMap '%s' in namespace '%s': %v", cm.Name, cm.Namespace, err), "config")
}

// Validate QuantizationConfig if it exists
quantConfig := config.TrainingConfig.QuantizationConfig
if quantConfig != nil {
// Dynamic field search for quantization settings within ModelConfig
loadIn4bit, _ := utils.SearchMap(quantConfig, "load_in_4bit")
loadIn8bit, _ := utils.SearchMap(quantConfig, "load_in_8bit")

// Validate both loadIn4bit and loadIn8bit
if err := validateNilOrBool(loadIn4bit); err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_4bit")
}
if err := validateNilOrBool(loadIn8bit); err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_8bit")
}

loadIn4bitBool, _ := loadIn4bit.(bool)
loadIn8bitBool, _ := loadIn8bit.(bool)

if loadIn4bitBool && loadIn8bitBool {
return apis.ErrGeneric(fmt.Sprintf("Cannot set both 'load_in_4bit' and 'load_in_8bit' to true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
}
if methodLowerCase == string(TuningMethodLora) {
if loadIn4bitBool || loadIn8bitBool {
return apis.ErrGeneric(fmt.Sprintf("For method 'lora', 'load_in_4bit' or 'load_in_8bit' in ConfigMap '%s' must not be true", cm.Name), "QuantizationConfig")
}
} else if methodLowerCase == string(TuningMethodQLora) {
if !loadIn4bitBool && !loadIn8bitBool {
return apis.ErrMissingField(fmt.Sprintf("For method 'qlora', either 'load_in_4bit' or 'load_in_8bit' must be true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
}
}
} else if methodLowerCase == string(TuningMethodQLora) {
return apis.ErrMissingField(fmt.Sprintf("For method 'qlora', either 'load_in_4bit' or 'load_in_8bit' must be true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
}
return nil
}

// getStructInstances dynamically generates instances of all sections in any config struct.
func getStructInstances(s any) map[string]any {
t := reflect.TypeOf(s)
if t.Kind() == reflect.Ptr {
t = t.Elem() // Dereference pointer to get the struct type
}
instances := make(map[string]any, t.NumField())

for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
yamlTag := field.Tag.Get("yaml")
if yamlTag != "" {
// Create a new instance of the type pointed to by the field
instance := reflect.MakeMap(field.Type).Interface()
instances[yamlTag] = instance
}
}

return instances
}

func validateConfigMapSchema(cm *corev1.ConfigMap) *apis.FieldError {
trainingConfigData, ok := cm.Data["training_config.yaml"]
if !ok {
return apis.ErrMissingField("training_config.yaml in ConfigMap")
}

var rawConfig map[string]interface{}
if err := yaml.Unmarshal([]byte(trainingConfigData), &rawConfig); err != nil {
return apis.ErrInvalidValue(err.Error(), "training_config.yaml")
}

// Extract the actual training configuration map
trainingConfigMap, ok := rawConfig["training_config"].(map[interface{}]interface{})
if !ok {
return apis.ErrInvalidValue("Expected 'training_config' key to contain a map", "training_config.yaml")
}

sectionStructs := getStructInstances(TrainingConfig{})
recognizedSections := make([]string, 0, len(sectionStructs))
for section := range sectionStructs {
recognizedSections = append(recognizedSections, section)
}

// Check if valid sections
for section := range trainingConfigMap {
sectionStr := section.(string)
if !utils.Contains(recognizedSections, sectionStr) {
return apis.ErrInvalidValue(fmt.Sprintf("Unrecognized section: %s", section), "training_config.yaml")
}
}
return nil
}

func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, methodLowerCase string) (errs *apis.FieldError) {
var cm corev1.ConfigMap
if k8sclient.Client == nil {
errs = errs.Also(apis.ErrGeneric("Failed to obtain client from context.Context"))
return errs
}
err := k8sclient.Client.Get(ctx, client.ObjectKey{Name: r.ConfigTemplate, Namespace: namespace}, &cm)
if err != nil {
if errors.IsNotFound(err) {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' specified in 'config' not found in namespace '%s'", r.ConfigTemplate, namespace), "config"))
} else {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get ConfigMap '%s' in namespace '%s': %v", r.ConfigTemplate, namespace, err), "config"))
}
} else {
if err := validateConfigMapSchema(&cm); err != nil {
errs = errs.Also(err)
}
if err := validateMethodViaConfigMap(&cm, methodLowerCase); err != nil {
errs = errs.Also(err)
}
}
return errs
}
46 changes: 32 additions & 14 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"sort"
"strings"

"github.com/azure/kaito/pkg/utils"
"github.com/azure/kaito/pkg/utils/plugin"

admissionregistrationv1 "k8s.io/api/admissionregistration/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -21,6 +23,9 @@ import (
const (
N_SERIES_PREFIX = "Standard_N"
D_SERIES_PREFIX = "Standard_D"

DefaultLoraConfigMap = "lora-params"
DefaultQloraConfigMap = "qlora-params"
)

func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType {
Expand All @@ -31,21 +36,18 @@ func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType {
}

func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {

base := apis.GetBaseline(ctx)
if base == nil {
klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
errs = errs.Also(
w.validateCreate().ViaField("spec"),
// TODO: Consider validate resource based on Tuning Spec
w.Resource.validateCreate(*w.Inference).ViaField("resource"),
)
errs = errs.Also(w.validateCreate().ViaField("spec"))
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Inference.validateCreate().ViaField("inference"))
errs = errs.Also(w.Resource.validateCreate(*w.Inference).ViaField("resource"),
w.Inference.validateCreate().ViaField("inference"))
}
if w.Tuning != nil {
errs = errs.Also(w.Tuning.validateCreate().ViaField("tuning"))
// TODO: Add validate resource based on Tuning Spec
errs = errs.Also(w.Tuning.validateCreate(ctx, w.Namespace).ViaField("tuning"))
}
} else {
klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
Expand Down Expand Up @@ -86,7 +88,27 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) {
return errs
}

func (r *TuningSpec) validateCreate() (errs *apis.FieldError) {
func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string) (errs *apis.FieldError) {
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
}
if r.ConfigTemplate == "" {
klog.InfoS("Tuning config not specified. Using default.")
} else if r.ConfigTemplate == DefaultLoraConfigMap || r.ConfigTemplate == DefaultQloraConfigMap {
klog.InfoS("Template config specified")
releaseNamespace, err := utils.GetReleaseNamespace()
if err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to determine release namespace: %v", err), "namespace"))
}
if err := r.validateConfigMap(ctx, releaseNamespace, methodLowerCase); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
} else {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
}
if r.Input == nil {
errs = errs.Also(apis.ErrMissingField("Input"))
} else {
Expand All @@ -103,10 +125,6 @@ func (r *TuningSpec) validateCreate() (errs *apis.FieldError) {
} else if presetName := string(r.Preset.Name); !isValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
}
return errs
}

Expand Down Expand Up @@ -247,7 +265,7 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field
}
}
} else {
// Check for other instancetypes pattern matches
// Check for other instance types pattern matches
if !strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, getSupportedSKUs()), "instanceType"))
}
Expand Down
71 changes: 69 additions & 2 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@
package v1alpha1

import (
"context"
"os"
"reflect"
"sort"
"strings"
"testing"

"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/k8sclient"
"github.com/azure/kaito/pkg/utils"
"github.com/azure/kaito/pkg/utils/plugin"
"k8s.io/apimachinery/pkg/runtime"

"github.com/azure/kaito/pkg/model"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

const DefaultReleaseNamespace = "kaito-workspace"

var gpuCountRequirement string
var totalGPUMemoryRequirement string
var perGPUMemoryRequirement string
Expand Down Expand Up @@ -84,6 +93,50 @@ func pointerToInt(i int) *int {
return &i
}

func defaultConfigMapManifest() *v1.ConfigMap {
return &v1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: DefaultLoraConfigMap,
Namespace: DefaultReleaseNamespace, // Replace this with the appropriate namespace variable if dynamic
},
Data: map[string]string{
"training_config.yaml": `training_config:
ModelConfig:
torch_dtype: "bfloat16"
local_files_only: true
device_map: "auto"

TokenizerParams:
padding: true
truncation: true

QuantizationConfig:
load_in_4bit: false

LoraConfig:
r: 16
lora_alpha: 32
target_modules: "query_key_value"
lora_dropout: 0.05
bias: "none"

TrainingArguments:
output_dir: "."
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
save_strategy: "epoch"

DatasetConfig:
shuffle_dataset: true
train_test_split: 1

DataCollator:
mlm: true`,
},
}
}

func TestResourceSpecValidateCreate(t *testing.T) {
RegisterValidationTestModels()
tests := []struct {
Expand Down Expand Up @@ -640,6 +693,18 @@ func TestWorkspaceValidateUpdate(t *testing.T) {

func TestTuningSpecValidateCreate(t *testing.T) {
RegisterValidationTestModels()
// Set ReleaseNamespace Env
os.Setenv(utils.DefaultReleaseNamespaceEnvVar, DefaultReleaseNamespace)
defer os.Unsetenv(utils.DefaultReleaseNamespaceEnvVar)

// Create fake client with default ConfigMap
scheme := runtime.NewScheme()
_ = v1.AddToScheme(scheme)
client := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(defaultConfigMapManifest()).Build()
k8sclient.SetGlobalClient(client)
// Include client in ctx
ctx := context.Background()

tests := []struct {
name string
tuningSpec *TuningSpec
Expand All @@ -653,6 +718,7 @@ func TestTuningSpecValidateCreate(t *testing.T) {
Output: &DataDestination{Volume: &v1.VolumeSource{}},
Preset: &PresetSpec{PresetMeta: PresetMeta{Name: ModelName("test-validation")}},
Method: TuningMethodLora,
ConfigTemplate: DefaultLoraConfigMap,
},
wantErr: false,
errFields: nil,
Expand Down Expand Up @@ -713,7 +779,7 @@ func TestTuningSpecValidateCreate(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errs := tt.tuningSpec.validateCreate()
errs := tt.tuningSpec.validateCreate(ctx, "WORKSPACE_NAMESPACE")
hasErrs := errs != nil

if hasErrs != tt.wantErr {
Expand Down Expand Up @@ -993,6 +1059,7 @@ func TestDataDestinationValidateCreate(t *testing.T) {
dataDestination: &DataDestination{
Volume: &v1.VolumeSource{},
Image: "data-image:latest",
ImagePushSecret: "imagePushSecret",
},
wantErr: false,
},
Expand Down
Loading