Skip to content

Commit

Permalink
feat: add inference config api (#791)
Browse files Browse the repository at this point in the history
**Reason for Change**:

- API change: add config to `workspace.inference.config`
- generate a default config if no user config specified by copying from
release-namespace

**Requirements**

- [x] added unit tests and e2e tests (if applicable).

---------

Signed-off-by: jerryzhuang <zhuangqhc@gmail.com>
Co-authored-by: Fei Guo <guofei@microsoft.com>
  • Loading branch information
zhuangqh and Fei-Guo authored Dec 26, 2024
1 parent 9dd17a3 commit d9dc364
Show file tree
Hide file tree
Showing 22 changed files with 419 additions and 182 deletions.
4 changes: 2 additions & 2 deletions api/v1alpha1/params_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ func getStructInstances(s any) map[string]any {
return instances
}

func validateConfigMapSchema(cm *corev1.ConfigMap) *apis.FieldError {
func validateTuningConfigMapSchema(cm *corev1.ConfigMap) *apis.FieldError {
trainingConfigData, ok := cm.Data["training_config.yaml"]
if !ok {
return apis.ErrMissingField("training_config.yaml in ConfigMap")
Expand Down Expand Up @@ -263,7 +263,7 @@ func (r *TuningSpec) validateConfigMap(ctx context.Context, namespace string, me
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get ConfigMap '%s' in namespace '%s': %v", r.Config, namespace, err), "config"))
}
} else {
if err := validateConfigMapSchema(&cm); err != nil {
if err := validateTuningConfigMapSchema(&cm); err != nil {
errs = errs.Also(err)
}
if err := validateMethodViaConfigMap(&cm, methodLowerCase); err != nil {
Expand Down
4 changes: 4 additions & 0 deletions api/v1alpha1/workspace_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ type InferenceSpec struct {
// +kubebuilder:validation:Schemaless
// +optional
Template *v1.PodTemplateSpec `json:"template,omitempty"`
// Config specifies the name of a custom ConfigMap that contains inference arguments.
// If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
// +optional
Config string `json:"config,omitempty"`
// Adapters are integrated into the base model for inference.
// Users can specify multiple adapters for the model and the respective weight of using each of them.
// +optional
Expand Down
44 changes: 39 additions & 5 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@ import (
"strconv"
"strings"

"github.com/kaito-project/kaito/pkg/k8sclient"
"github.com/kaito-project/kaito/pkg/utils/consts"
"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/kaito-project/kaito/pkg/utils"
"github.com/kaito-project/kaito/pkg/utils/plugin"

admissionregistrationv1 "k8s.io/api/admissionregistration/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/util/validation"
Expand All @@ -29,9 +33,10 @@ const (
N_SERIES_PREFIX = "Standard_N"
D_SERIES_PREFIX = "Standard_D"

DefaultLoraConfigMapTemplate = "lora-params-template"
DefaultQloraConfigMapTemplate = "qlora-params-template"
MaxAdaptersNumber = 10
DefaultLoraConfigMapTemplate = "lora-params-template"
DefaultQloraConfigMapTemplate = "qlora-params-template"
DefaultInferenceConfigTemplate = "inference-params-template"
MaxAdaptersNumber = 10
)

func (w *Workspace) SupportedVerbs() []admissionregistrationv1.OperationType {
Expand All @@ -53,7 +58,7 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Resource.validateCreateWithInference(w.Inference).ViaField("resource"),
w.Inference.validateCreate().ViaField("inference"))
w.Inference.validateCreate(ctx, w.Namespace).ViaField("inference"))
}
if w.Tuning != nil {
// TODO: Add validate resource based on Tuning Spec
Expand Down Expand Up @@ -390,7 +395,7 @@ func (r *ResourceSpec) validateUpdate(old *ResourceSpec) (errs *apis.FieldError)
return errs
}

func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
func (i *InferenceSpec) validateCreate(ctx context.Context, namespace string) (errs *apis.FieldError) {
// Check if both Preset and Template are not set
if i.Preset == nil && i.Template == nil {
errs = errs.Also(apis.ErrMissingField("Preset or Template must be specified"))
Expand Down Expand Up @@ -428,6 +433,35 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
errs = errs.Also(validateDuplicateName(i.Adapters, nameMap))
}

if i.Config != "" {
errs = errs.Also(i.validateConfigMap(ctx, namespace))
}

return errs
}

func (i *InferenceSpec) validateConfigMap(ctx context.Context, namespace string) (errs *apis.FieldError) {
var cm corev1.ConfigMap
if k8sclient.Client == nil {
errs = errs.Also(apis.ErrGeneric("Failed to obtain client from context.Context"))
return errs
}
err := k8sclient.Client.Get(ctx, client.ObjectKey{Name: i.Config, Namespace: namespace}, &cm)
if err != nil {
if errors.IsNotFound(err) {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("ConfigMap '%s' specified in 'config' not found in namespace '%s'", i.Config, namespace), "config"))
} else {
errs = errs.Also(apis.ErrGeneric(fmt.Sprintf("Failed to get ConfigMap '%s' in namespace '%s': %v", i.Config, namespace, err), "config"))
}
return errs
}

// basic check here, it's hard to validate the content of the configmap in controller
_, ok := cm.Data["inference_config.yaml"]
if !ok {
return apis.ErrMissingField("inference_config.yaml in ConfigMap")
}

return errs
}

Expand Down
69 changes: 68 additions & 1 deletion api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,31 @@ func TestResourceSpecValidateUpdate(t *testing.T) {

func TestInferenceSpecValidateCreate(t *testing.T) {
RegisterValidationTestModels()
ctx := context.Background()

// Create fake client with default ConfigMap
scheme := runtime.NewScheme()
_ = v1.AddToScheme(scheme)
client := fake.NewClientBuilder().WithScheme(scheme).WithRuntimeObjects(
&v1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: "valid-config",
},
Data: map[string]string{
"inference_config.yaml": "a: b",
},
},
&v1.ConfigMap{
ObjectMeta: metav1.ObjectMeta{
Name: "missing-key-config",
},
Data: map[string]string{
"other_key": "some value",
},
},
).Build()
k8sclient.SetGlobalClient(client)

tests := []struct {
name string
inferenceSpec *InferenceSpec
Expand Down Expand Up @@ -638,6 +663,48 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
errContent: "Duplicate adapter source name found:",
expectErrs: false,
},
{
name: "Config specified but ConfigMap not found",
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: ModelImageAccessModePublic,
},
},
Config: "nonexistent-config",
},
errContent: "ConfigMap 'nonexistent-config' specified in 'config' not found in namespace",
expectErrs: true,
},
{
name: "Config specified with valid ConfigMap",
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: ModelImageAccessModePublic,
},
},
Config: "valid-config",
},
errContent: "",
expectErrs: false,
},
{
name: "ConfigMap missing required inference_config.yaml",
inferenceSpec: &InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
AccessMode: ModelImageAccessModePublic,
},
},
Config: "missing-key-config",
},
errContent: "missing field(s): inference_config.yaml in ConfigMap",
expectErrs: true,
},
}

for _, tc := range tests {
Expand All @@ -654,7 +721,7 @@ func TestInferenceSpecValidateCreate(t *testing.T) {
}
}()
}
errs := tc.inferenceSpec.validateCreate()
errs := tc.inferenceSpec.validateCreate(ctx, "")
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs)
Expand Down
5 changes: 5 additions & 0 deletions charts/kaito/workspace/crds/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ spec:
type: string
type: object
type: array
config:
description: |-
Config specifies the name of a custom ConfigMap that contains inference arguments.
If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
type: string
preset:
description: Preset describes the base model that will be deployed
with preset configurations.
Expand Down
5 changes: 5 additions & 0 deletions config/crd/bases/kaito.sh_workspaces.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ spec:
type: string
type: object
type: array
config:
description: |-
Config specifies the name of a custom ConfigMap that contains inference arguments.
If specified, the ConfigMap must be in the same namespace as the Workspace custom resource.
type: string
preset:
description: Preset describes the base model that will be deployed
with preset configurations.
Expand Down
9 changes: 8 additions & 1 deletion pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
package model

import (
"path"
"time"

"github.com/kaito-project/kaito/pkg/utils"
corev1 "k8s.io/api/core/v1"
)

type Model interface {
Expand All @@ -21,6 +23,8 @@ type RuntimeName string
const (
RuntimeNameHuggingfaceTransformers RuntimeName = "transformers"
RuntimeNameVLLM RuntimeName = "vllm"

ConfigfileNameVLLM = "inference_config.yaml"
)

// PresetParam defines the preset inference parameters for a model.
Expand Down Expand Up @@ -133,7 +137,7 @@ func (v *VLLMParam) DeepCopy() VLLMParam {

// builds the container command:
// eg. torchrun <TORCH_PARAMS> <OPTIONAL_RDZV_PARAMS> baseCommand <MODEL_PARAMS>
func (p *PresetParam) GetInferenceCommand(runtime RuntimeName, skuNumGPUs string) []string {
func (p *PresetParam) GetInferenceCommand(runtime RuntimeName, skuNumGPUs string, configVolume *corev1.VolumeMount) []string {
switch runtime {
case RuntimeNameHuggingfaceTransformers:
torchCommand := utils.BuildCmdStr(p.Transformers.BaseCommand, p.Transformers.TorchRunParams, p.Transformers.TorchRunRdzvParams)
Expand All @@ -146,6 +150,9 @@ func (p *PresetParam) GetInferenceCommand(runtime RuntimeName, skuNumGPUs string
if !p.DisableTensorParallelism {
p.VLLM.ModelRunParams["tensor-parallel-size"] = skuNumGPUs
}
if configVolume != nil {
p.VLLM.ModelRunParams["kaito-config-file"] = path.Join(configVolume.MountPath, ConfigfileNameVLLM)
}
modelCommand := utils.BuildCmdStr(p.VLLM.BaseCommand, p.VLLM.ModelRunParams)
return utils.ShellCmd(modelCommand)
default:
Expand Down
70 changes: 70 additions & 0 deletions pkg/utils/resources/resources.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"fmt"
"time"

"github.com/kaito-project/kaito/pkg/utils"
appsv1 "k8s.io/api/apps/v1"
batchv1 "k8s.io/api/batch/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
"k8s.io/client-go/util/retry"
"k8s.io/klog/v2"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -102,3 +104,71 @@ func CheckResourceStatus(obj client.Object, kubeClient client.Client, timeoutDur
}
}
}

// EnsureConfigOrCopyFromDefault handles two scenarios:
// 1. User provided config:
// - Check if it exists in the target namespace
// - If not found, return error as this is user-specified
//
// 2. No user config specified:
// - Use the default config template
// - Check if it exists in the target namespace
// - If not, copy from release namespace to target namespace
func EnsureConfigOrCopyFromDefault(ctx context.Context, kubeClient client.Client,
userProvided, systemDefault client.ObjectKey,
) (*corev1.ConfigMap, error) {

// If user specified a config, use that
if userProvided.Name != "" {
userCM := &corev1.ConfigMap{}
err := GetResource(ctx, userProvided.Name, userProvided.Namespace, kubeClient, userCM)
if err != nil {
if errors.IsNotFound(err) {
return nil, fmt.Errorf("user specified ConfigMap %s not found in namespace %s",
userProvided.Name, userProvided.Namespace)
}
return nil, err
}

return userCM, nil
}

// Check if default configmap already exists in target namespace
existingCM := &corev1.ConfigMap{}
err := GetResource(ctx, systemDefault.Name, userProvided.Namespace, kubeClient, existingCM)
if err != nil {
if !errors.IsNotFound(err) {
return nil, err
}
} else {
klog.Infof("Default ConfigMap already exists in target namespace: %s, no action taken.", userProvided.Namespace)
return existingCM, nil
}

// Copy default template from release namespace if not found
if systemDefault.Namespace == "" {
releaseNamespace, err := utils.GetReleaseNamespace()
if err != nil {
return nil, fmt.Errorf("failed to get release namespace: %v", err)
}
systemDefault.Namespace = releaseNamespace
}

templateCM := &corev1.ConfigMap{}
err = GetResource(ctx, systemDefault.Name, systemDefault.Namespace, kubeClient, templateCM)
if err != nil {
return nil, fmt.Errorf("failed to get default ConfigMap from template namespace: %v", err)
}

templateCM.Namespace = userProvided.Namespace
templateCM.ResourceVersion = "" // Clear metadata not needed for creation
templateCM.UID = "" // Clear UID

err = CreateResource(ctx, templateCM, kubeClient)
if err != nil {
return nil, fmt.Errorf("failed to create default ConfigMap in target namespace %s: %v",
userProvided.Namespace, err)
}

return templateCM, nil
}
Loading

0 comments on commit d9dc364

Please sign in to comment.