Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Refactor Naming Conventions, Update Dependencies, Enhance Examples, and Add Volume Validation Check #470

Merged
merged 19 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,18 @@ package v1alpha1
import (
"context"
"fmt"
"path/filepath"
"reflect"
"strings"

"github.com/azure/kaito/pkg/k8sclient"
"github.com/azure/kaito/pkg/utils"
"gopkg.in/yaml.v2"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/runtime"
"knative.dev/pkg/apis"
"path/filepath"
"reflect"
"sigs.k8s.io/controller-runtime/pkg/client"
"strings"
)

type Config struct {
Expand Down Expand Up @@ -257,9 +258,9 @@ func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, me
err := k8sclient.Client.Get(ctx, client.ObjectKey{Name: configMapName, Namespace: namespace}, &cm)
if err != nil {
if errors.IsNotFound(err) {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' specified in 'config' not found in namespace '%s'", r.ConfigTemplate, namespace), "config"))
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' specified in 'config' not found in namespace '%s'", r.Config, namespace), "config"))
} else {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get ConfigMap '%s' in namespace '%s': %v", r.ConfigTemplate, namespace, err), "config"))
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get ConfigMap '%s' in namespace '%s': %v", r.Config, namespace, err), "config"))
}
} else {
if err := validateConfigMapSchema(&cm); err != nil {
Expand Down
9 changes: 4 additions & 5 deletions api/v1alpha1/workspace_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,11 @@ type TuningSpec struct {
// Method specifies the Parameter-Efficient Fine-Tuning(PEFT) method, such as lora, qlora, used for the tuning.
// +optional
Method TuningMethod `json:"method,omitempty"`
// ConfigTemplate specifies the name of the configmap that contains the basic tuning arguments.
// A separate configmap will be generated based on the ConfigTemplate and the preset model name, and used by
// the tuning Job. If specified, the congfigmap needs to be in the same namespace of the workspace custom resource.
// If not specified, a default ConfigTemplate is used based on the specified tuning method.
// Config specifies the name of a custom ConfigMap that contains tuning arguments.
// If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
// If not specified, a default Config is used based on the specified tuning method.
// +optional
ConfigTemplate string `json:"configTemplate,omitempty"`
Config string `json:"config,omitempty"`
// Input describes the input used by the tuning method.
Input *DataSource `json:"input"`
// Output specified where to store the tuning output.
Expand Down
15 changes: 12 additions & 3 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
}
if r.ConfigTemplate == "" {
if r.Config == "" {
klog.InfoS("Tuning config not specified. Using default based on method.")
releaseNamespace, err := utils.GetReleaseNamespace()
if err != nil {
Expand All @@ -149,7 +149,7 @@ func (r *TuningSpec) validateCreate(ctx context.Context, workspaceNamespace stri
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
} else {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, r.ConfigTemplate); err != nil {
if err := r.validateConfigMap(ctx, workspaceNamespace, methodLowerCase, r.Config); err != nil {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to evaluate validateConfigMap: %v", err), "Config"))
}
}
Expand Down Expand Up @@ -200,6 +200,7 @@ func (r *DataSource) validateCreate() (errs *apis.FieldError) {
sourcesSpecified++
}
if r.Volume != nil {
errs = errs.Also(apis.ErrInvalidValue("Volume support is not implemented yet", "Volume"))
sourcesSpecified++
}
// Regex checks for a / and a colon followed by a tag
Expand All @@ -223,6 +224,9 @@ func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.
if isTuning && !reflect.DeepEqual(old.Name, r.Name) {
errs = errs.Also(apis.ErrInvalidValue("During tuning Name field cannot be changed once set", "Name"))
}
if r.Volume != nil {
errs = errs.Also(apis.ErrInvalidValue("Volume support is not implemented yet", "Volume"))
}
oldURLs := make([]string, len(old.URLs))
copy(oldURLs, old.URLs)
sort.Strings(oldURLs)
Expand Down Expand Up @@ -255,7 +259,9 @@ func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.

func (r *DataDestination) validateCreate() (errs *apis.FieldError) {
destinationsSpecified := 0
// TODO: Implement Volumes
if r.Volume != nil {
errs = errs.Also(apis.ErrInvalidValue("Volume support is not implemented yet", "Volume"))
destinationsSpecified++
}
if r.Image != "" {
Expand All @@ -279,7 +285,10 @@ func (r *DataDestination) validateCreate() (errs *apis.FieldError) {
}

func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) {
// TODO: Check if the Volume is changed.
// TODO: Implement Volumes
if r.Volume != nil {
errs = errs.Also(apis.ErrInvalidValue("Volume support is not implemented yet", "Volume"))
}
if old.Image != r.Image {
errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image"))
}
Expand Down
9 changes: 4 additions & 5 deletions charts/kaito/workspace/crds/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,11 @@ spec:
type: object
tuning:
properties:
configTemplate:
config:
description: |-
ConfigTemplate specifies the name of the configmap that contains the basic tuning arguments.
A separate configmap will be generated based on the ConfigTemplate and the preset model name, and used by
the tuning Job. If specified, the congfigmap needs to be in the same namespace of the workspace custom resource.
If not specified, a default ConfigTemplate is used based on the specified tuning method.
Config specifies the name of a custom ConfigMap that contains tuning arguments.
If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
If not specified, a default Config is used based on the specified tuning method.
type: string
input:
description: Input describes the input used by the tuning method.
Expand Down
4 changes: 2 additions & 2 deletions cmd/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ import (
_ "github.com/azure/kaito/presets/models/llama2"
_ "github.com/azure/kaito/presets/models/llama2chat"
_ "github.com/azure/kaito/presets/models/mistral"
_ "github.com/azure/kaito/presets/models/phi-2"
_ "github.com/azure/kaito/presets/models/phi-3"
_ "github.com/azure/kaito/presets/models/phi2"
_ "github.com/azure/kaito/presets/models/phi3"
)
9 changes: 4 additions & 5 deletions config/crd/bases/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,11 @@ spec:
type: object
tuning:
properties:
configTemplate:
config:
description: |-
ConfigTemplate specifies the name of the configmap that contains the basic tuning arguments.
A separate configmap will be generated based on the ConfigTemplate and the preset model name, and used by
the tuning Job. If specified, the congfigmap needs to be in the same namespace of the workspace custom resource.
If not specified, a default ConfigTemplate is used based on the specified tuning method.
Config specifies the name of a custom ConfigMap that contains tuning arguments.
If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
If not specified, a default Config is used based on the specified tuning method.
type: string
input:
description: Input describes the input used by the tuning method.
Expand Down
20 changes: 0 additions & 20 deletions examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml

This file was deleted.

19 changes: 19 additions & 0 deletions examples/fine-tuning/kaito_workspace_tuning_phi_3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-tuning-phi-3
resource:
instanceType: "Standard_NC6s_v3"
labelSelector:
matchLabels:
app: tuning-phi-3
tuning:
preset:
name: phi-3-mini-128k-instruct
method: qlora
input:
urls:
- "https://huggingface.co/datasets/philschmid/dolly-15k-oai-style/resolve/main/data/train-00000-of-00001-54e3756291ca09c6.parquet?download=true"
output:
image: "ACR_REPO_HERE.azurecr.io/ADAPTER_HERE:0.0.1" # Tuning Output ACR Path
imagePushSecret: ACR_REGISTRY_SECRET_HERE
13 changes: 13 additions & 0 deletions examples/inference/kaito_workspace_phi_3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-phi-3-mini
resource:
instanceType: "Standard_NC6s_v3"
labelSelector:
matchLabels:
apps: phi-3
inference:
preset:
name: phi-3-mini-4k-instruct
# Note: This configuration also works with the phi-3-mini-128k-instruct preset
17 changes: 17 additions & 0 deletions examples/inference/kaito_workspace_phi_3_with_adapters.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-phi-3-mini-adapter
resource:
instanceType: "Standard_NC6s_v3"
labelSelector:
matchLabels:
apps: phi-3-adapter
inference:
preset:
name: phi-3-mini-128k-instruct
adapters:
- source:
name: "phi-3-adapter"
image: "ACR_REPO_HERE.azurecr.io/ADAPTER_HERE:0.0.1"
strength: "1.0"
9 changes: 5 additions & 4 deletions pkg/tuning/preset-tuning.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package tuning
import (
"context"
"fmt"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/utils/pointer"
"knative.dev/pkg/apis"
"os"
"path/filepath"
"strings"

"k8s.io/apimachinery/pkg/runtime"
"k8s.io/utils/pointer"
"knative.dev/pkg/apis"

"k8s.io/apimachinery/pkg/api/resource"

kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
Expand Down Expand Up @@ -96,7 +97,7 @@ func GetDataSrcImageInfo(ctx context.Context, wObj *kaitov1alpha1.Workspace) (st
// - If not, check the release namespace and copy it to the target namespace if found.
func EnsureTuningConfigMap(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
kubeClient client.Client) (*corev1.ConfigMap, error) {
tuningConfigMapName := workspaceObj.Tuning.ConfigTemplate
tuningConfigMapName := workspaceObj.Tuning.Config
if tuningConfigMapName == "" {
if workspaceObj.Tuning.Method == kaitov1alpha1.TuningMethodLora {
tuningConfigMapName = kaitov1alpha1.DefaultLoraConfigMapTemplate
Expand Down
6 changes: 3 additions & 3 deletions pkg/tuning/preset-tuning_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func TestEnsureTuningConfigMap(t *testing.T) {
},
workspaceObj: &kaitov1alpha1.Workspace{
Tuning: &kaitov1alpha1.TuningSpec{
ConfigTemplate: "config-template",
Config: "config-template",
},
},
expectedError: "",
Expand All @@ -207,7 +207,7 @@ func TestEnsureTuningConfigMap(t *testing.T) {
},
workspaceObj: &kaitov1alpha1.Workspace{
Tuning: &kaitov1alpha1.TuningSpec{
ConfigTemplate: "config-template",
Config: "config-template",
},
},
expectedError: "failed to get release namespace: failed to determine release namespace from file /var/run/secrets/kubernetes.io/serviceaccount/namespace and env var RELEASE_NAMESPACE",
Expand All @@ -221,7 +221,7 @@ func TestEnsureTuningConfigMap(t *testing.T) {
},
workspaceObj: &kaitov1alpha1.Workspace{
Tuning: &kaitov1alpha1.TuningSpec{
ConfigTemplate: "config-template",
Config: "config-template",
},
},
expectedError: "failed to get ConfigMap from template namespace: \"config-template\" not found",
Expand Down
2 changes: 1 addition & 1 deletion presets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The current supported model families with preset configurations are listed below
| [llama2](./models/llama2) | v0.0.1+|
| [llama2chat](./models/llama2chat) | v0.0.1+|
| [mistral](./models/mistral) | v0.2.0+|
| [phi2](./models/phi-2) | v0.2.0+|
| [phi2](./models/phi2) | v0.2.0+|

## Validation
Each preset model has its own hardware requirements in terms of GPU count and GPU memory defined in the respective `model.go` file. Kaito controller performs a validation check of whether the specified SKU and node count are sufficient to run the model or not. In case the provided SKU is not in the known list, the controller bypasses the validation check which means users need to ensure the model can run with the provided SKU.
Expand Down
3 changes: 2 additions & 1 deletion presets/inference/text-generation/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ accelerate==0.30.1
fastapi>=0.111.0,<0.112.0 # Allow patch updates
pydantic>=2.7.1,<2.8 # Allow patch updates
uvicorn[standard]>=0.29.0,<0.30.0 # Allow patch updates
peft
peft==0.11.1
numpy==1.22.4

# Utility libraries
bitsandbytes==0.42.0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from unittest.mock import patch

import numpy
import pytest
from fastapi.testclient import TestClient
from transformers import AutoTokenizer
Expand All @@ -13,8 +14,8 @@
sys.path.append(parent_dir)

@pytest.fixture(params=[
{"pipeline": "text-generation", "model_path": "stanford-crfm/alias-gpt2-small-x21"},
{"pipeline": "conversational", "model_path": "stanford-crfm/alias-gpt2-small-x21"},
{"pipeline": "text-generation", "model_path": "stanford-crfm/alias-gpt2-small-x21", "device": "cpu"},
{"pipeline": "conversational", "model_path": "stanford-crfm/alias-gpt2-small-x21", "device": "cpu"},
])
def configured_app(request):
original_argv = sys.argv.copy()
Expand All @@ -23,6 +24,7 @@ def configured_app(request):
'program_name',
'--pipeline', request.param['pipeline'],
'--pretrained_model_name_or_path', request.param['model_path'],
'--device_map', request.param['device'],
'--allow_remote_files', 'True'
]
sys.argv = test_args
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
## Supported Models
|Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference|
|----|:----:|:----:| :----: |:----: |
|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/inference/kaito_workspace_phi-2.yaml)|Deployment| false|
|phi-2 |[microsoft](https://huggingface.co/microsoft/phi-2)|[link](../../../examples/inference/kaito_workspace_phi_3.yaml)|Deployment| false|


## Image Source
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package phi_2
package phi2

import (
"time"
Expand Down
Loading
Loading