Skip to content

Commit

Permalink
feat: Initialize Fine-Tuning Interface and Core Methods - Part 3 (#308)
Browse files Browse the repository at this point in the history
Setup
  • Loading branch information
ishaansehgal99 authored Apr 1, 2024
1 parent ee9101a commit 08dd1f4
Show file tree
Hide file tree
Showing 32 changed files with 357 additions and 94 deletions.
29 changes: 25 additions & 4 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,15 @@ var perGPUMemoryRequirement string

type testModel struct{}

func (*testModel) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
func (*testModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModel) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
Expand All @@ -31,11 +38,22 @@ func (*testModel) GetInferenceParameters() *model.PresetInferenceParam {
func (*testModel) SupportDistributedInference() bool {
return false
}
func (*testModel) SupportTuning() bool {
return true
}

type testModelPrivate struct{}

func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
func (*testModelPrivate) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
PerGPUMemoryRequirement: perGPUMemoryRequirement,
}
}
func (*testModelPrivate) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
ImageAccessMode: "private",
GPUCountRequirement: gpuCountRequirement,
TotalGPUMemoryRequirement: totalGPUMemoryRequirement,
Expand All @@ -45,6 +63,9 @@ func (*testModelPrivate) GetInferenceParameters() *model.PresetInferenceParam {
func (*testModelPrivate) SupportDistributedInference() bool {
return false
}
func (*testModelPrivate) SupportTuning() bool {
return true
}

func RegisterValidationTestModels() {
var test testModel
Expand Down
20 changes: 20 additions & 0 deletions examples/fine-tuning/kaito_workspace_tuning_falcon_7b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-tuning-falcon-7b
spec:
resource:
instanceType: "Standard_NC12s_v3"
labelSelector:
matchLabels:
app: tuning-falcon-7b
tuning:
preset:
name: falcon-7b
method: lora
config: tuning-config-map # ConfigMap containing tuning arguments
input:
name: tuning-data
hostPath: /path/to/your/input/data # dataset on node
output:
hostPath: /path/to/store/output # Tuning Output
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
65 changes: 54 additions & 11 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"strings"
"time"

appsv1 "k8s.io/api/apps/v1"
"k8s.io/utils/clock"
"github.com/azure/kaito/pkg/tuning"
batchv1 "k8s.io/api/batch/v1"

"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
Expand All @@ -21,13 +21,15 @@ import (
"github.com/azure/kaito/pkg/utils/plugin"
"github.com/go-logr/logr"
"github.com/samber/lo"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/clock"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/controller"
Expand Down Expand Up @@ -109,16 +111,22 @@ func (c *WorkspaceReconciler) addOrUpdateWorkspace(ctx context.Context, wObj *ka
return reconcile.Result{}, err
}

if err = c.applyInference(ctx, wObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse,
"workspaceFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj))
return reconcile.Result{}, updateErr
if wObj.Tuning != nil {
if err = c.applyTuning(ctx, wObj); err != nil {
return reconcile.Result{}, err
}
}
if wObj.Inference != nil {
if err = c.applyInference(ctx, wObj); err != nil {
if updateErr := c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionFalse,
"workspaceFailed", err.Error()); updateErr != nil {
klog.ErrorS(updateErr, "failed to update workspace status", "workspace", klog.KObj(wObj))
return reconcile.Result{}, updateErr
}
return reconcile.Result{}, err
}
return reconcile.Result{}, err
}

// TODO apply TrainingSpec
if err = c.updateStatusConditionIfNotMatch(ctx, wObj, kaitov1alpha1.WorkspaceConditionTypeReady, metav1.ConditionTrue,
"workspaceReady", "workspace is ready"); err != nil {
klog.ErrorS(err, "failed to update workspace status", "workspace", klog.KObj(wObj))
Expand Down Expand Up @@ -423,6 +431,41 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al
return nil
}

func (c *WorkspaceReconciler) applyTuning(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
var err error
func() {
if wObj.Tuning.Preset != nil {
presetName := string(wObj.Tuning.Preset.Name)
model := plugin.KaitoModelRegister.MustGet(presetName)

tuningParam := model.GetTuningParameters()
existingObj := &batchv1.Job{}
if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil {
klog.InfoS("A tuning workload already exists for workspace", "workspace", klog.KObj(wObj))
if err = resources.CheckResourceStatus(existingObj, c.Client, tuningParam.ReadinessTimeout); err != nil {
return
}
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = tuning.CreatePresetTuning(ctx, wObj, tuningParam, c.Client)
if err != nil {
return
}
if err = resources.CheckResourceStatus(workloadObj, c.Client, tuningParam.ReadinessTimeout); err != nil {
return
}
}
}
}()

if err != nil {
return err
}

return nil
}

// applyInference applies inference spec.
func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
var err error
Expand Down Expand Up @@ -455,7 +498,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a

if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil {
klog.InfoS("An inference workload already exists for workspace", "workspace", klog.KObj(wObj))
if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.DeploymentTimeout); err != nil {
if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.ReadinessTimeout); err != nil {
return
}
} else if apierrors.IsNotFound(err) {
Expand All @@ -465,7 +508,7 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a
if err != nil {
return
}
if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.DeploymentTimeout); err != nil {
if err = resources.CheckResourceStatus(workloadObj, c.Client, inferenceParam.ReadinessTimeout); err != nil {
return
}
}
Expand Down
10 changes: 5 additions & 5 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ var (
}
)

func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) error {
func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient client.Client, wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) error {
existingService := &corev1.Service{}
err := resources.GetResource(ctx, wObj.Name, wObj.Namespace, kubeClient, existingService)
if err != nil {
Expand All @@ -92,7 +92,7 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl
return nil
}

func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) (string, []corev1.LocalObjectReference) {
func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imageName := string(workspaceObj.Inference.Preset.Name)
imageTag := inferenceObj.Tag
imagePullSecretRefs := []corev1.LocalObjectReference{}
Expand All @@ -110,7 +110,7 @@ func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, in
}

func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
inferenceObj *model.PresetInferenceParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) {
inferenceObj *model.PresetParam, supportDistributedInference bool, kubeClient client.Client) (client.Object, error) {
if inferenceObj.TorchRunParams != nil && supportDistributedInference {
if err := updateTorchParamsForDistributedInference(ctx, kubeClient, workspaceObj, inferenceObj); err != nil {
klog.ErrorS(err, "failed to update torch params", "workspace", workspaceObj)
Expand Down Expand Up @@ -141,7 +141,7 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
// torchrun <TORCH_PARAMS> <OPTIONAL_RDZV_PARAMS> baseCommand <MODEL_PARAMS>
// and sets the GPU resources required for inference.
// Returns the command and resource configuration.
func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetInferenceParam) ([]string, corev1.ResourceRequirements) {
func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam) ([]string, corev1.ResourceRequirements) {
torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams)
torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams)
modelCommand := buildCommandStr(InferenceFile, inferenceObj.ModelRunParams)
Expand All @@ -159,7 +159,7 @@ func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetI
return commands, resourceRequirements
}

func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetInferenceParam) ([]corev1.Volume, []corev1.VolumeMount) {
func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) ([]corev1.Volume, []corev1.VolumeMount) {
volume := []corev1.Volume{}
volumeMount := []corev1.VolumeMount{}

Expand Down
2 changes: 1 addition & 1 deletion pkg/inference/preset-inferences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestCreatePresetInference(t *testing.T) {

useHeadlessSvc := false

var inferenceObj *model.PresetInferenceParam
var inferenceObj *model.PresetParam
model := plugin.KaitoModelRegister.MustGet(tc.modelName)
inferenceObj = model.GetInferenceParameters()

Expand Down
24 changes: 13 additions & 11 deletions pkg/model/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,29 @@ import (
)

type Model interface {
GetInferenceParameters() *PresetInferenceParam
GetInferenceParameters() *PresetParam
GetTuningParameters() *PresetParam
SupportDistributedInference() bool //If true, the model workload will be a StatefulSet, using the torch elastic runtime framework.
SupportTuning() bool
}

// PresetInferenceParam defines the preset inference parameters for a model.
type PresetInferenceParam struct {
// PresetParam defines the preset inference parameters for a model.
type PresetParam struct {
ModelFamilyName string // The name of the model family.
ImageAccessMode string // Defines where the Image is Public or Private.
DiskStorageRequirement string // Disk storage requirements for the model.
GPUCountRequirement string // Number of GPUs required for the inference.
TotalGPUMemoryRequirement string // Total GPU memory required for the inference.
GPUCountRequirement string // Number of GPUs required for the Preset.
TotalGPUMemoryRequirement string // Total GPU memory required for the Preset.
PerGPUMemoryRequirement string // GPU memory required per GPU.
TorchRunParams map[string]string // Parameters for configuring the torchrun command.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed inference using torchrun (elastic).
ModelRunParams map[string]string // Parameters for running the model inference.
// DeploymentTimeout defines the maximum duration for pulling the Preset image.
TorchRunRdzvParams map[string]string // Optional rendezvous parameters for distributed training/inference using torchrun (elastic).
// BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
BaseCommand string
ModelRunParams map[string]string // Parameters for running the model training/inference.
// ReadinessTimeout defines the maximum duration for creating the workload.
// This timeout accommodates the size of the image, ensuring pull completion
// even under slower network conditions or unforeseen delays.
DeploymentTimeout time.Duration
// BaseCommand is the initial command (e.g., 'torchrun', 'accelerate launch') used in the command line.
BaseCommand string
ReadinessTimeout time.Duration
// WorldSize defines the number of processes required for distributed inference.
WorldSize int
Tag string // The model image tag
Expand Down
21 changes: 21 additions & 0 deletions pkg/tuning/preset-tuning-types.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package tuning

import corev1 "k8s.io/api/core/v1"

const (
DefaultNumProcesses = "1"
DefaultNumMachines = "1"
DefaultMachineRank = "0"
DefaultGPUIds = "all"
)

var (
DefaultAccelerateParams = map[string]string{
"num_processes": DefaultNumProcesses,
"num_machines": DefaultNumMachines,
"machine_rank": DefaultMachineRank,
"gpu_ids": DefaultGPUIds,
}

DefaultImagePullSecrets = []corev1.LocalObjectReference{}
)
14 changes: 14 additions & 0 deletions pkg/tuning/preset-tuning.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package tuning

import (
"context"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/model"
"sigs.k8s.io/controller-runtime/pkg/client"
)

func CreatePresetTuning(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
tuningObj *model.PresetParam, kubeClient client.Client) (client.Object, error) {
// TODO
return nil, nil
}
30 changes: 24 additions & 6 deletions pkg/utils/testModel.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,45 @@ import (

type testModel struct{}

func (*testModel) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
func (*testModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
DeploymentTimeout: time.Duration(30) * time.Minute,
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testModel) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testModel) SupportDistributedInference() bool {
return false
}
func (*testModel) SupportTuning() bool {
return true
}

type testDistributedModel struct{}

func (*testDistributedModel) GetInferenceParameters() *model.PresetInferenceParam {
return &model.PresetInferenceParam{
func (*testDistributedModel) GetInferenceParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
DeploymentTimeout: time.Duration(30) * time.Minute,
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testDistributedModel) GetTuningParameters() *model.PresetParam {
return &model.PresetParam{
GPUCountRequirement: "1",
ReadinessTimeout: time.Duration(30) * time.Minute,
}
}
func (*testDistributedModel) SupportDistributedInference() bool {
return true
}
func (*testDistributedModel) SupportTuning() bool {
return true
}

func RegisterTestModel() {
var test testModel
Expand Down
8 changes: 4 additions & 4 deletions presets/models/falcon/README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
## Supported Models
|Model name| Model source | Sample workspace|Kubernetes Workload|Distributed inference|
|----|:----:|:----:| :----: |:----: |
|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false|
|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/kaito_workspace_falcon_7b.yaml)|Deployment| false|
|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false|
|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/kaito_workspace_falcon_40b.yaml)|Deployment| false|
|falcon-7b-instruct |[tiiuae](https://huggingface.co/tiiuae/falcon-7b-instruct)|[link](../../../examples/inference/kaito_workspace_falcon_7b-instruct.yaml)|Deployment| false|
|falcon-7b |[tiiuae](https://huggingface.co/tiiuae/falcon-7b) |[link](../../../examples/inference/kaito_workspace_falcon_7b.yaml)|Deployment| false|
|falcon-40b-instruct|[tiiuae](https://huggingface.co/tiiuae/falcon-40b-instruct) |[link](../../../examples/inference/kaito_workspace_falcon_40b-instruct.yaml)|Deployment| false|
|falcon-40b |[tiiuae](https://huggingface.co/tiiuae/falcon-40b)|[link](../../../examples/inference/kaito_workspace_falcon_40b.yaml)|Deployment| false|

## Image Source
- **Public**: Kaito maintainers manage the lifecycle of the inference service images that contain model weights. The images are available in Microsoft Container Registry (MCR).
Expand Down
Loading

0 comments on commit 08dd1f4

Please sign in to comment.