diff --git a/api/v1alpha1/workspace_validation.go b/api/v1alpha1/workspace_validation.go index 9c09586ec..d6808beaa 100644 --- a/api/v1alpha1/workspace_validation.go +++ b/api/v1alpha1/workspace_validation.go @@ -9,6 +9,7 @@ import ( "reflect" "regexp" "sort" + "strconv" "strings" "github.com/azure/kaito/pkg/utils" @@ -27,6 +28,7 @@ const ( DefaultLoraConfigMapTemplate = "lora-params-template" DefaultQloraConfigMapTemplate = "qlora-params-template" + MaxAdaptersNumber = 10 ) func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType { @@ -58,7 +60,6 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) { w.Resource.validateUpdate(&old.Resource).ViaField("resource"), ) if w.Inference != nil { - // TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference")) } if w.Tuning != nil { @@ -89,6 +90,44 @@ func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) { return errs } +func ValidateDNSSubdomain(name string) bool { + var dnsSubDomainRegexp = regexp.MustCompile(`^(?i:[a-z0-9]([-a-z0-9]*[a-z0-9])?)$`) + if len(name) < 1 || len(name) > 253 { + return false + } + return dnsSubDomainRegexp.MatchString(name) +} + +func (r *AdapterSpec) validateCreateorUpdate() (errs *apis.FieldError) { + if r.Source == nil { + errs = errs.Also(apis.ErrMissingField("Source")) + } else { + errs = errs.Also(r.Source.validateCreate().ViaField("Adapters")) + + if r.Source.Name == "" { + errs = errs.Also(apis.ErrMissingField("Name of Adapter field must be specified")) + } else if !ValidateDNSSubdomain(r.Source.Name) { + errs = errs.Also(apis.ErrMissingField("Name of Adapter must be a valid DNS subdomain value")) + } + if r.Source.Image == "" { + errs = errs.Also(apis.ErrMissingField("Image of Adapter field must be specified")) + } + if r.Strength == nil { + var defaultStrength = "1.0" + r.Strength = &defaultStrength + } + strength, err := strconv.ParseFloat(*r.Strength, 64) + if err != nil { + errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Invalid strength value for Adapter '%s': %v", r.Source.Name, err), "adapter")) + } + if strength < 0 || strength > 1.0 { + errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Strength value for Adapter '%s' must be between 0 and 1", r.Source.Name), "adapter")) + } + + } + return errs +} + func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace string) (errs *apis.FieldError) { methodLowerCase := strings.ToLower(string(r.Method)) if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) { @@ -346,6 +385,16 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) { } // Note: we don't enforce private access mode to have image secrets, in case anonymous pulling is enabled } + if len(i.Adapters) > MaxAdaptersNumber { + errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Number of Adapters exceeds the maximum limit, maximum of %s allowed", strconv.Itoa(MaxAdaptersNumber)))) + } + + // check if adapter names are duplicate + if len(i.Adapters) > 0 { + nameMap := make(map[string]bool) + errs = errs.Also(validateDuplicateName(i.Adapters, nameMap)) + } + return errs } @@ -358,5 +407,27 @@ func (i *InferenceSpec) validateUpdate(old *InferenceSpec) (errs *apis.FieldErro errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "template")) } + // check if adapter names are duplicate + for _, adapter := range i.Adapters { + errs = errs.Also(adapter.validateCreateorUpdate()) + } + + // check if adapter names are duplicate + + if len(i.Adapters) > 0 { + nameMap := make(map[string]bool) + errs = errs.Also(validateDuplicateName(i.Adapters, nameMap)) + } + return errs +} + +func validateDuplicateName(adapters []AdapterSpec, nameMap map[string]bool) (errs *apis.FieldError) { + for _, adapter := range adapters { + if _, ok := nameMap[adapter.Source.Name]; ok { + errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Duplicate adapter source name found: %s", adapter.Source.Name))) + } else { + nameMap[adapter.Source.Name] = true + } + } return errs } diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 29775ded1..1a1bfdd4a 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -5,6 +5,7 @@ package v1alpha1 import ( "context" + "fmt" "os" "reflect" "sort" @@ -24,6 +25,10 @@ import ( const DefaultReleaseNamespace = "kaito-workspace" +var ValidStrength string = "0.5" +var InvalidStrength1 string = "invalid" +var InvalidStrength2 string = "1.5" + var gpuCountRequirement string var totalGPUMemoryRequirement string var perGPUMemoryRequirement string @@ -474,6 +479,56 @@ func TestInferenceSpecValidateCreate(t *testing.T) { errContent: "This preset only supports private AccessMode, AccessMode must be private to continue", expectErrs: true, }, + { + name: "Adapeters more than 10", + inferenceSpec: func() *InferenceSpec { + spec := &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("test-validation"), + AccessMode: ModelImageAccessModePublic, + }, + }, + } + for i := 1; i <= 11; i++ { + spec.Adapters = append(spec.Adapters, AdapterSpec{ + Source: &DataSource{ + Name: fmt.Sprintf("Adapter-%d", i), + Image: fmt.Sprintf("fake.kaito.com/kaito-image:0.0.%d", i), + }, + Strength: &ValidStrength, + }) + } + return spec + }(), + errContent: "Number of Adapters exceeds the maximum limit, maximum of 10 allowed", + expectErrs: true, + }, + { + name: "Adapeters names are duplicated", + inferenceSpec: func() *InferenceSpec { + spec := &InferenceSpec{ + Preset: &PresetSpec{ + PresetMeta: PresetMeta{ + Name: ModelName("test-validation"), + AccessMode: ModelImageAccessModePublic, + }, + }, + } + for i := 1; i <= 2; i++ { + spec.Adapters = append(spec.Adapters, AdapterSpec{ + Source: &DataSource{ + Name: "Adapter", + Image: fmt.Sprintf("fake.kaito.com/kaito-image:0.0.%d", i), + }, + Strength: &ValidStrength, + }) + } + return spec + }(), + errContent: "", + expectErrs: true, + }, { name: "Valid Preset", inferenceSpec: &InferenceSpec{ @@ -484,7 +539,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) { }, }, }, - errContent: "", + errContent: "Duplicate adapter source name found:", expectErrs: false, }, } @@ -520,6 +575,91 @@ func TestInferenceSpecValidateCreate(t *testing.T) { } } +func TestAdapterSpecValidateCreateorUpdate(t *testing.T) { + RegisterValidationTestModels() + tests := []struct { + name string + adapterSpec *AdapterSpec + errContent string // Content expected error to include, if any + expectErrs bool + }{ + { + name: "Missing Source", + adapterSpec: &AdapterSpec{ + Strength: &ValidStrength, + }, + errContent: "Source", + expectErrs: true, + }, + { + name: "Missing Source Name", + adapterSpec: &AdapterSpec{ + Source: &DataSource{ + Image: "fake.kaito.com/kaito-image:0.0.1", + }, + Strength: &ValidStrength, + }, + errContent: "Name of Adapter field must be specified", + expectErrs: true, + }, + { + name: "Invalid Strength, not a number", + adapterSpec: &AdapterSpec{ + Source: &DataSource{ + Name: "Adapter-1", + Image: "fake.kaito.com/kaito-image:0.0.1", + }, + Strength: &InvalidStrength1, + }, + errContent: "Invalid strength value for Adapter 'Adapter-1'", + expectErrs: true, + }, + { + name: "Invalid Strength, larger than 1", + adapterSpec: &AdapterSpec{ + Source: &DataSource{ + Name: "Adapter-1", + Image: "fake.kaito.com/kaito-image:0.0.1", + }, + Strength: &InvalidStrength2, + }, + errContent: "Strength value for Adapter 'Adapter-1' must be between 0 and 1", + expectErrs: true, + }, + { + name: "Valid Adapter", + adapterSpec: &AdapterSpec{ + Source: &DataSource{ + Name: "Adapter-1", + Image: "fake.kaito.com/kaito-image:0.0.1", + }, + Strength: &ValidStrength, + }, + errContent: "", + expectErrs: false, + }, + } + + // Run the tests + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + errs := tc.adapterSpec.validateCreateorUpdate() + hasErrs := errs != nil + if hasErrs != tc.expectErrs { + t.Errorf("validateUpdate() errors = %v, expectErrs %v", errs, tc.expectErrs) + } + + // If there is an error and errContent is not empty, check that the error contains the expected content. + if hasErrs && tc.errContent != "" { + errMsg := errs.Error() + if !strings.Contains(errMsg, tc.errContent) { + t.Errorf("validateUpdate() error message = %v, expected to contain = %v", errMsg, tc.errContent) + } + } + }) + } +} + func TestInferenceSpecValidateUpdate(t *testing.T) { tests := []struct { name string diff --git a/examples/inference/kaito_workspace_falcon_7b.yaml b/examples/inference/kaito_workspace_falcon_7b.yaml index 4eaf1590d..afb813757 100644 --- a/examples/inference/kaito_workspace_falcon_7b.yaml +++ b/examples/inference/kaito_workspace_falcon_7b.yaml @@ -10,3 +10,4 @@ resource: inference: preset: name: "falcon-7b" + \ No newline at end of file diff --git a/examples/inference/kaito_workspace_falcon_7b_with_adapters.yaml b/examples/inference/kaito_workspace_falcon_7b_with_adapters.yaml new file mode 100644 index 000000000..e2ce58dec --- /dev/null +++ b/examples/inference/kaito_workspace_falcon_7b_with_adapters.yaml @@ -0,0 +1,18 @@ +apiVersion: kaito.sh/v1alpha1 +kind: Workspace +metadata: + name: workspace-falcon-7b +resource: + instanceType: "Standard_NC12s_v3" + labelSelector: + matchLabels: + apps: falcon-7b +inference: + preset: + name: "falcon-7b" + adapters: + - source: + name: "falcon-7b-adapter" + image: "" + strength: "0.2" + \ No newline at end of file diff --git a/pkg/inference/preset-inferences.go b/pkg/inference/preset-inferences.go index d681ddaf4..8ffa22e79 100644 --- a/pkg/inference/preset-inferences.go +++ b/pkg/inference/preset-inferences.go @@ -5,10 +5,11 @@ package inference import ( "context" "fmt" - "github.com/azure/kaito/pkg/utils" "os" "strconv" + "github.com/azure/kaito/pkg/utils" + kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1" "github.com/azure/kaito/pkg/model" "github.com/azure/kaito/pkg/resources" @@ -56,13 +57,14 @@ var ( tolerations = []corev1.Toleration{ { Effect: corev1.TaintEffectNoSchedule, - Operator: corev1.TolerationOpEqual, - Key: resources.GPUString, + Operator: corev1.TolerationOpExists, + Key: resources.CapacityNvidiaGPU, }, { - Effect: corev1.TaintEffectNoSchedule, - Value: resources.GPUString, - Key: "sku", + Effect: corev1.TaintEffectNoSchedule, + Value: resources.GPUString, + Key: "sku", + Operator: corev1.TolerationOpEqual, }, } ) @@ -127,6 +129,13 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work if shmVolumeMount.Name != "" { volumeMounts = append(volumeMounts, shmVolumeMount) } + + if len(workspaceObj.Inference.Adapters) > 0 { + adapterVolume, adapterVolumeMount := utils.ConfigAdapterVolume() + volumes = append(volumes, adapterVolume) + volumeMounts = append(volumeMounts, adapterVolumeMount) + } + commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj) image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceObj) diff --git a/pkg/resources/manifests.go b/pkg/resources/manifests.go index 87cf65740..a2d6a8f01 100644 --- a/pkg/resources/manifests.go +++ b/pkg/resources/manifests.go @@ -270,6 +270,26 @@ func GenerateDeploymentManifest(ctx context.Context, workspaceObj *kaitov1alpha1 labelselector := &v1.LabelSelector{ MatchLabels: selector, } + initContainers := []corev1.Container{} + envs := []corev1.EnvVar{} + if len(workspaceObj.Inference.Adapters) != 0 { + for _, adapter := range workspaceObj.Inference.Adapters { + // TODO: accept Volumes and url link to pull images + initContainer := corev1.Container{ + Name: adapter.Source.Name, + Image: adapter.Source.Image, + Command: []string{"/bin/sh", "-c", fmt.Sprintf("mkdir -p /mnt/adapter/%s && cp -r /data/* /mnt/adapter/%s", adapter.Source.Name, adapter.Source.Name)}, + VolumeMounts: volumeMount, + ImagePullPolicy: corev1.PullAlways, + } + initContainers = append(initContainers, initContainer) + env := corev1.EnvVar{ + Name: adapter.Source.Name, + Value: *adapter.Strength, + } + envs = append(envs, env) + } + } return &appsv1.Deployment{ ObjectMeta: v1.ObjectMeta{ @@ -305,6 +325,7 @@ func GenerateDeploymentManifest(ctx context.Context, workspaceObj *kaitov1alpha1 }, }, }, + InitContainers: initContainers, Containers: []corev1.Container{ { Name: workspaceObj.Name, @@ -315,6 +336,7 @@ func GenerateDeploymentManifest(ctx context.Context, workspaceObj *kaitov1alpha1 ReadinessProbe: readinessProbe, Ports: containerPorts, VolumeMounts: volumeMount, + Env: envs, }, }, Tolerations: tolerations, diff --git a/pkg/utils/common-preset.go b/pkg/utils/common-preset.go index e70bc8b29..11358f4e0 100644 --- a/pkg/utils/common-preset.go +++ b/pkg/utils/common-preset.go @@ -10,6 +10,7 @@ const ( DefaultVolumeMountPath = "/dev/shm" DefaultConfigMapMountPath = "/mnt/config" DefaultDataVolumePath = "/mnt/data" + DefaultAdapterVolumePath = "/mnt/adapter" ) func ConfigResultsVolume(outputPath string) (corev1.Volume, corev1.VolumeMount) { @@ -121,3 +122,23 @@ func ConfigDataVolume(hostPath *string) (corev1.Volume, corev1.VolumeMount) { } return volume, volumeMount } + +func ConfigAdapterVolume() (corev1.Volume, corev1.VolumeMount) { + var volume corev1.Volume + var volumeMount corev1.VolumeMount + + volumeSource := corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + } + + volume = corev1.Volume{ + Name: "adapter-volume", + VolumeSource: volumeSource, + } + + volumeMount = corev1.VolumeMount{ + Name: "adapter-volume", + MountPath: DefaultAdapterVolumePath, + } + return volume, volumeMount +} diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index eea6970de..6324373dd 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -8,14 +8,16 @@ import psutil import torch import transformers +import subprocess import uvicorn from fastapi import Body, FastAPI, HTTPException from fastapi.responses import Response -from pydantic import BaseModel, Extra, Field +from pydantic import BaseModel, Extra, Field, validator from transformers import (AutoModelForCausalLM, AutoTokenizer, GenerationConfig, HfArgumentParser) +from peft import PeftModel - +ADAPTERS_DIR = '/mnt/adapter' @dataclass class ModelConfig: """ @@ -23,6 +25,7 @@ class ModelConfig: """ pipeline: str = field(metadata={"help": "The model pipeline for the pre-trained model"}) pretrained_model_name_or_path: Optional[str] = field(default="/workspace/tfs/weights", metadata={"help": "Path to the pretrained model or model identifier from huggingface.co/models"}) + combination_type: Optional[str]=field(default="svd", metadata={"help": "The combination type of multi adapters"}) state_dict: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "State dictionary for the model"}) cache_dir: Optional[str] = field(default=None, metadata={"help": "Cache directory for the model"}) from_tf: bool = field(default=False, metadata={"help": "Load model from a TensorFlow checkpoint"}) @@ -59,7 +62,7 @@ def process_additional_args(self, addt_args: List[str]): # Update the ModelConfig instance with the additional args self.__dict__.update(addt_args_dict) - def __post_init__(self): + def __post_init__(self): # validate parameters """ Post-initialization to validate some ModelConfig values """ @@ -70,6 +73,7 @@ def __post_init__(self): supported_pipelines = {"conversational", "text-generation"} if self.pipeline not in supported_pipelines: raise ValueError(f"Unsupported pipeline: {self.pipeline}") + parser = HfArgumentParser(ModelConfig) args, additional_args = parser.parse_args_into_dataclasses( @@ -81,10 +85,45 @@ def __post_init__(self): model_args = asdict(args) model_args["local_files_only"] = not model_args.pop('allow_remote_files') model_pipeline = model_args.pop('pipeline') +combination_type = model_args.pop('combination_type') app = FastAPI() tokenizer = AutoTokenizer.from_pretrained(**model_args) -model = AutoModelForCausalLM.from_pretrained(**model_args) +base_model = AutoModelForCausalLM.from_pretrained(**model_args) + +def list_files(directory): + try: + result = subprocess.run(['ls', directory], capture_output=True, text=True) + if result.returncode == 0: + return result.stdout.strip().split('\n') + else: + return [f"Command execution failed with return code: {result.returncode}"] + except Exception as e: + return [f"An error occurred: {str(e)}"] +if not os.path.exists(ADAPTERS_DIR): + model = base_model +else: + output = os.listdir(ADAPTERS_DIR) + filtered_output = [s for s in output if s.strip()] + adapters_list = [f"{ADAPTERS_DIR}/{file}" for file in filtered_output] + filtered_adapters_list = [path for path in adapters_list if os.path.exists(os.path.join(path, "adapter_config.json"))] + + adapter_names, weights= [], [] + for adapter_path in filtered_adapters_list: + adapter_name = os.path.basename(adapter_path) + adapter_names.append(adapter_name) + weights.append(float(os.getenv(adapter_name))) + model = PeftModel.from_pretrained(base_model, filtered_adapters_list[0], adapter_name=adapter_names[0]) + for i in range(1, len(filtered_adapters_list)): + model.load_adapter(filtered_adapters_list[i], adapter_names[i]) + + model.add_weighted_adapter( + adapters = adapter_names, + weights = weights, + adapter_name="combined_adapter", + combination_type=combination_type, + ) +print("Model:",model) pipeline_kwargs = { "trust_remote_code": args.trust_remote_code, diff --git a/presets/inference/text-generation/requirements.txt b/presets/inference/text-generation/requirements.txt index 873e54879..feed2129e 100644 --- a/presets/inference/text-generation/requirements.txt +++ b/presets/inference/text-generation/requirements.txt @@ -18,3 +18,4 @@ psutil # For UTs pytest httpx +peft \ No newline at end of file diff --git a/presets/test/manifests/falcon-7b-with-adapter/falcon-7b.yaml b/presets/test/manifests/falcon-7b-with-adapter/falcon-7b.yaml new file mode 100644 index 000000000..3f2212ab6 --- /dev/null +++ b/presets/test/manifests/falcon-7b-with-adapter/falcon-7b.yaml @@ -0,0 +1,51 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: falcon-7b +spec: + replicas: 1 + selector: + matchLabels: + app: falcon + template: + metadata: + labels: + app: falcon + spec: + volumes: + - name: adapter-volume + emptyDir: {} + initContainers: + - name: falcon-7b-adapter + image: + imagePullPolicy: Always + command: ["/bin/sh", "-c", "mkdir -p /mnt/adapter/falcon-7b-adapter && cp -r /data/* /mnt/adapter/falcon-7b-adapter"] + volumeMounts: + - name: adapter-volume + mountPath: /mnt/adapter + containers: + - name: falcon-container + image: + command: + - /bin/sh + - -c + - accelerate launch --num_processes 1 --num_machines 1 --machine_rank 0 --gpu_ids all inference_api.py --pipeline text-generation --torch_dtype bfloat16 + resources: + requests: + nvidia.com/gpu: 2 + limits: + nvidia.com/gpu: 2 # Requesting 2 GPUs + volumeMounts: + - name: adapter-volume + mountPath: /mnt/adapter + env: + - name: falcon-7b-adapter + value: "0.2" + tolerations: + - effect: NoSchedule + value: gpu + key: sku + operator: Equal + - effect: NoSchedule + key: nvidia.com/gpu + operator: Exists