Skip to content

Commit

Permalink
feat: Simple Configmap Validation Checks - Part 6 (#355)
Browse files Browse the repository at this point in the history
**Reason for Change**:
This PR adds two simple checks to the tuning configmap - makes sure
based on the method specified (LoRa or QLoRa) the correct params are
included. Also checks to make sure all the sections specified in the
configmap are recognized.
  • Loading branch information
ishaansehgal99 authored Apr 19, 2024
1 parent bf4acba commit 6be8a0d
Show file tree
Hide file tree
Showing 5 changed files with 273 additions and 18 deletions.
167 changes: 167 additions & 0 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

package v1alpha1

import (
"context"
"fmt"
"reflect"

"github.com/azure/kaito/pkg/k8sclient"
"github.com/azure/kaito/pkg/utils"
"gopkg.in/yaml.v2"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"knative.dev/pkg/apis"
"sigs.k8s.io/controller-runtime/pkg/client"
)

type Config struct {
TrainingConfig TrainingConfig `yaml:"training_config"`
}

type TrainingConfig struct {
ModelConfig map[string]interface{} `yaml:"ModelConfig"`
TokenizerParams map[string]interface{} `yaml:"TokenizerParams"`
QuantizationConfig map[string]interface{} `yaml:"QuantizationConfig"`
LoraConfig map[string]interface{} `yaml:"LoraConfig"`
TrainingArguments map[string]interface{} `yaml:"TrainingArguments"`
DatasetConfig map[string]interface{} `yaml:"DatasetConfig"`
DataCollator map[string]interface{} `yaml:"DataCollator"`
}

func validateNilOrBool(value interface{}) error {
if value == nil {
return nil // nil is acceptable
}
if _, ok := value.(bool); ok {
return nil // Correct type
}
return fmt.Errorf("value must be either nil or a boolean, got type %T", value)
}

func validateMethodViaConfigMap(cm *corev1.ConfigMap, methodLowerCase string) *apis.FieldError {
trainingConfigYAML, ok := cm.Data["training_config.yaml"]
if !ok {
return apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' does not contain 'training_config.yaml' in namespace '%s'", cm.Name, cm.Namespace), "config")
}

var config Config
if err := yaml.Unmarshal([]byte(trainingConfigYAML), &config); err != nil {
return apis.ErrGeneric(fmt.Sprintf("Failed to parse 'training_config.yaml' in ConfigMap '%s' in namespace '%s': %v", cm.Name, cm.Namespace, err), "config")
}

// Validate QuantizationConfig if it exists
quantConfig := config.TrainingConfig.QuantizationConfig
if quantConfig != nil {
// Dynamic field search for quantization settings within ModelConfig
loadIn4bit, _ := utils.SearchMap(quantConfig, "load_in_4bit")
loadIn8bit, _ := utils.SearchMap(quantConfig, "load_in_8bit")

// Validate both loadIn4bit and loadIn8bit
if err := validateNilOrBool(loadIn4bit); err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_4bit")
}
if err := validateNilOrBool(loadIn8bit); err != nil {
return apis.ErrInvalidValue(err.Error(), "load_in_8bit")
}

loadIn4bitBool, _ := loadIn4bit.(bool)
loadIn8bitBool, _ := loadIn8bit.(bool)

if loadIn4bitBool && loadIn8bitBool {
return apis.ErrGeneric(fmt.Sprintf("Cannot set both 'load_in_4bit' and 'load_in_8bit' to true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
}
if methodLowerCase == string(TuningMethodLora) {
if loadIn4bitBool || loadIn8bitBool {
return apis.ErrGeneric(fmt.Sprintf("For method 'lora', 'load_in_4bit' or 'load_in_8bit' in ConfigMap '%s' must not be true", cm.Name), "QuantizationConfig")
}
} else if methodLowerCase == string(TuningMethodQLora) {
if !loadIn4bitBool && !loadIn8bitBool {
return apis.ErrMissingField(fmt.Sprintf("For method 'qlora', either 'load_in_4bit' or 'load_in_8bit' must be true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
}
}
} else if methodLowerCase == string(TuningMethodQLora) {
return apis.ErrMissingField(fmt.Sprintf("For method 'qlora', either 'load_in_4bit' or 'load_in_8bit' must be true in ConfigMap '%s'", cm.Name), "QuantizationConfig")
}
return nil
}

// getStructInstances dynamically generates instances of all sections in any config struct.
func getStructInstances(s any) map[string]any {
t := reflect.TypeOf(s)
if t.Kind() == reflect.Ptr {
t = t.Elem() // Dereference pointer to get the struct type
}
instances := make(map[string]any, t.NumField())

for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
yamlTag := field.Tag.Get("yaml")
if yamlTag != "" {
// Create a new instance of the type pointed to by the field
instance := reflect.MakeMap(field.Type).Interface()
instances[yamlTag] = instance
}
}

return instances
}

func validateConfigMapSchema(cm *corev1.ConfigMap) *apis.FieldError {
trainingConfigData, ok := cm.Data["training_config.yaml"]
if !ok {
return apis.ErrMissingField("training_config.yaml in ConfigMap")
}

var rawConfig map[string]interface{}
if err := yaml.Unmarshal([]byte(trainingConfigData), &rawConfig); err != nil {
return apis.ErrInvalidValue(err.Error(), "training_config.yaml")
}

// Extract the actual training configuration map
trainingConfigMap, ok := rawConfig["training_config"].(map[interface{}]interface{})
if !ok {
return apis.ErrInvalidValue("Expected 'training_config' key to contain a map", "training_config.yaml")
}

sectionStructs := getStructInstances(TrainingConfig{})
recognizedSections := make([]string, 0, len(sectionStructs))
for section := range sectionStructs {
recognizedSections = append(recognizedSections, section)
}

// Check if valid sections
for section := range trainingConfigMap {
sectionStr := section.(string)
if !utils.Contains(recognizedSections, sectionStr) {
return apis.ErrInvalidValue(fmt.Sprintf("Unrecognized section: %s", section), "training_config.yaml")
}
}
return nil
}

func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, methodLowerCase string, configMapName string) (errs *apis.FieldError) {
var cm corev1.ConfigMap
if k8sclient.Client == nil {
errs = errs.Also(apis.ErrGeneric("Failed to obtain client from context.Context"))
return errs
}
err := k8sclient.Client.Get(ctx, client.ObjectKey{Name: configMapName, Namespace: namespace}, &cm)
if err != nil {
if errors.IsNotFound(err) {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' specified in 'config' not found in namespace '%s'", r.ConfigTemplate, namespace), "config"))
} else {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get ConfigMap '%s' in namespace '%s': %v", r.ConfigTemplate, namespace, err), "config"))
}
} else {
if err := validateConfigMapSchema(&cm); err != nil {
errs = errs.Also(err)
}
if err := validateMethodViaConfigMap(&cm, methodLowerCase); err != nil {
errs = errs.Also(err)
}
}
return errs
}
50 changes: 36 additions & 14 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ import (
"sort"
"strings"

"github.com/azure/kaito/pkg/utils"
"github.com/azure/kaito/pkg/utils/plugin"

admissionregistrationv1 "k8s.io/api/admissionregistration/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand All @@ -21,6 +23,9 @@ import (
const (
N_SERIES_PREFIX = "Standard_N"
D_SERIES_PREFIX = "Standard_D"

DefaultLoraConfigMapTemplate = "lora-params-template"
DefaultQloraConfigMapTemplate = "qlora-params-template"
)

func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType {
Expand All @@ -31,21 +36,18 @@ func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType {
}

func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {

base := apis.GetBaseline(ctx)
if base == nil {
klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
errs = errs.Also(
w.validateCreate().ViaField("spec"),
// TODO: Consider validate resource based on Tuning Spec
w.Resource.validateCreate(*w.Inference).ViaField("resource"),
)
errs = errs.Also(w.validateCreate().ViaField("spec"))
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Inference.validateCreate().ViaField("inference"))
errs = errs.Also(w.Resource.validateCreate(*w.Inference).ViaField("resource"),
w.Inference.validateCreate().ViaField("inference"))
}
if w.Tuning != nil {
errs = errs.Also(w.Tuning.validateCreate().ViaField("tuning"))
// TODO: Add validate resource based on Tuning Spec
errs = errs.Also(w.Tuning.validateCreate(ctx, w.Namespace).ViaField("tuning"))
}
} else {
klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
Expand Down Expand Up @@ -86,7 +88,31 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) {
return errs
}

func (r *TuningSpec) validateCreate() (errs *apis.FieldError) {
func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string) (errs *apis.FieldError) {
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
}
if r.ConfigTemplate == "" {
klog.InfoS("Tuning config not specified. Using default based on method.")
releaseNamespace, err := utils.GetReleaseNamespace()
if err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to determine release namespace: %v", err), "namespace"))
}
defaultConfigMapTemplateName := ""
if methodLowerCase == string(TuningMethodLora) {
defaultConfigMapTemplateName = DefaultLoraConfigMapTemplate
} else if methodLowerCase == string(TuningMethodQLora) {
defaultConfigMapTemplateName = DefaultQloraConfigMapTemplate
}
if err := r.validateConfigMap(ctx, releaseNamespace, methodLowerCase, defaultConfigMapTemplateName); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
} else {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, r.ConfigTemplate); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
}
if r.Input == nil {
errs = errs.Also(apis.ErrMissingField("Input"))
} else {
Expand All @@ -103,10 +129,6 @@ func (r *TuningSpec) validateCreate() (errs *apis.FieldError) {
} else if presetName := string(r.Preset.Name); !isValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
}
return errs
}

Expand Down Expand Up @@ -247,7 +269,7 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field
}
}
} else {
// Check for other instancetypes pattern matches
// Check for other instance types pattern matches
if !strings.HasPrefix(instanceType, N_SERIES_PREFIX) && !strings.HasPrefix(instanceType, D_SERIES_PREFIX) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported instance type %s. Supported SKUs: %s", instanceType, getSupportedSKUs()), "instanceType"))
}
Expand Down
70 changes: 68 additions & 2 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,26 @@
package v1alpha1

import (
"context"
"os"
"reflect"
"sort"
"strings"
"testing"

"github.com/azure/kaito/pkg/model"
"github.com/azure/kaito/pkg/k8sclient"
"github.com/azure/kaito/pkg/utils"
"github.com/azure/kaito/pkg/utils/plugin"
"k8s.io/apimachinery/pkg/runtime"

"github.com/azure/kaito/pkg/model"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"sigs.k8s.io/controller-runtime/pkg/client/fake"
)

const DefaultReleaseNamespace = "kaito-workspace"

var gpuCountRequirement string
var totalGPUMemoryRequirement string
var perGPUMemoryRequirement string
Expand Down Expand Up @@ -84,6 +93,50 @@ func pointerToInt(i int) *int {
return &i
}

func defaultConfigMapManifest() *v1.ConfigMap {
return &v1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: DefaultLoraConfigMapTemplate,
Namespace: DefaultReleaseNamespace, // Replace this with the appropriate namespace variable if dynamic
},
Data: map[string]string{
"training_config.yaml": `training_config:
ModelConfig:
torch_dtype: "bfloat16"
local_files_only: true
device_map: "auto"
TokenizerParams:
padding: true
truncation: true
QuantizationConfig:
load_in_4bit: false
LoraConfig:
r: 16
lora_alpha: 32
target_modules: "query_key_value"
lora_dropout: 0.05
bias: "none"
TrainingArguments:
output_dir: "."
num_train_epochs: 4
auto_find_batch_size: true
ddp_find_unused_parameters: false
save_strategy: "epoch"
DatasetConfig:
shuffle_dataset: true
train_test_split: 1
DataCollator:
mlm: true`,
},
}
}

func TestResourceSpecValidateCreate(t *testing.T) {
RegisterValidationTestModels()
tests := []struct {
Expand Down Expand Up @@ -640,6 +693,18 @@ func TestWorkspaceValidateUpdate(t *testing.T) {

func TestTuningSpecValidateCreate(t *testing.T) {
RegisterValidationTestModels()
// Set ReleaseNamespace Env
os.Setenv(utils.DefaultReleaseNamespaceEnvVar, DefaultReleaseNamespace)
defer os.Unsetenv(utils.DefaultReleaseNamespaceEnvVar)

// Create fake client with default ConfigMap
scheme := runtime.NewScheme()
_ = v1.AddToScheme(scheme)
client := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(defaultConfigMapManifest()).Build()
k8sclient.SetGlobalClient(client)
// Include client in ctx
ctx := context.Background()

tests := []struct {
name string
tuningSpec *TuningSpec
Expand Down Expand Up @@ -713,7 +778,7 @@ func TestTuningSpecValidateCreate(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
errs := tt.tuningSpec.validateCreate()
errs := tt.tuningSpec.validateCreate(ctx, "WORKSPACE_NAMESPACE")
hasErrs := errs != nil

if hasErrs != tt.wantErr {
Expand Down Expand Up @@ -993,6 +1058,7 @@ func TestDataDestinationValidateCreate(t *testing.T) {
dataDestination: &DataDestination{
Volume: &v1.VolumeSource{},
Image: "data-image:latest",
ImagePushSecret: "imagePushSecret",
},
wantErr: false,
},
Expand Down
2 changes: 1 addition & 1 deletion charts/kaito/workspace/templates/lora-params.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
apiVersion: v1
kind: ConfigMap
metadata:
name: lora-params
name: lora-params-template
namespace: {{ .Release.Namespace }}
data:
training_config.yaml: |
Expand Down
Loading

0 comments on commit 6be8a0d

Please sign in to comment.