Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Factoring out reusable presets logic - Part 4 #332

Merged
merged 4 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 29 additions & 63 deletions pkg/inference/preset-inferences.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package inference
import (
"context"
"fmt"
"github.com/azure/kaito/pkg/utils"
"os"
"strconv"

Expand All @@ -19,10 +20,9 @@ import (
)

const (
ProbePath = "/healthz"
Port5000 = int32(5000)
InferenceFile = "inference_api.py"
DefaultVolumeMountPath = "/dev/shm"
ProbePath = "/healthz"
Port5000 = int32(5000)
InferenceFile = "inference_api.py"
)

var (
Expand Down Expand Up @@ -92,21 +92,21 @@ func updateTorchParamsForDistributedInference(ctx context.Context, kubeClient cl
return nil
}

func GetImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imageName := string(workspaceObj.Inference.Preset.Name)
imageTag := inferenceObj.Tag
func GetInferenceImageInfo(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, presetObj *model.PresetParam) (string, []corev1.LocalObjectReference) {
imagePullSecretRefs := []corev1.LocalObjectReference{}
if inferenceObj.ImageAccessMode == "private" {
imageName = string(workspaceObj.Inference.Preset.PresetOptions.Image)
if presetObj.ImageAccessMode == "private" {
imageName := workspaceObj.Inference.Preset.PresetOptions.Image
for _, secretName := range workspaceObj.Inference.Preset.PresetOptions.ImagePullSecrets {
imagePullSecretRefs = append(imagePullSecretRefs, corev1.LocalObjectReference{Name: secretName})
}
return imageName, imagePullSecretRefs
} else {
imageName := string(workspaceObj.Inference.Preset.Name)
imageTag := presetObj.Tag
registryName := os.Getenv("PRESET_REGISTRY_NAME")
imageName = fmt.Sprintf("%s/kaito-%s:%s", registryName, imageName, imageTag)
return imageName, imagePullSecretRefs
}

registryName := os.Getenv("PRESET_REGISTRY_NAME")
imageName = registryName + fmt.Sprintf("/kaito-%s:%s", imageName, imageTag)
return imageName, imagePullSecretRefs
}

func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace,
Expand All @@ -118,17 +118,25 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
}
}

volume, volumeMount := configVolume(workspaceObj, inferenceObj)
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount
volume, volumeMount := utils.ConfigSHMVolume(workspaceObj)
if volume.Name != "" {
volumes = append(volumes, volume)
}
if volumeMount.Name != "" {
volumeMounts = append(volumeMounts, volumeMount)
}
commands, resourceReq := prepareInferenceParameters(ctx, inferenceObj)
image, imagePullSecrets := GetImageInfo(ctx, workspaceObj, inferenceObj)
image, imagePullSecrets := GetInferenceImageInfo(ctx, workspaceObj, inferenceObj)

var depObj client.Object
if supportDistributedInference {
depObj = resources.GenerateStatefulSetManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount)
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts)
} else {
depObj = resources.GenerateDeploymentManifest(ctx, workspaceObj, image, imagePullSecrets, *workspaceObj.Resource.Count, commands,
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volume, volumeMount)
containerPorts, livenessProbe, readinessProbe, resourceReq, tolerations, volumes, volumeMounts)
}
err := resources.CreateResource(ctx, depObj, kubeClient)
if client.IgnoreAlreadyExists(err) != nil {
Expand All @@ -142,10 +150,10 @@ func CreatePresetInference(ctx context.Context, workspaceObj *kaitov1alpha1.Work
// and sets the GPU resources required for inference.
// Returns the command and resource configuration.
func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetParam) ([]string, corev1.ResourceRequirements) {
torchCommand := buildCommandStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams)
torchCommand = buildCommandStr(torchCommand, inferenceObj.TorchRunRdzvParams)
modelCommand := buildCommandStr(InferenceFile, inferenceObj.ModelRunParams)
commands := shellCommand(torchCommand + " " + modelCommand)
torchCommand := utils.BuildCmdStr(inferenceObj.BaseCommand, inferenceObj.TorchRunParams)
torchCommand = utils.BuildCmdStr(torchCommand, inferenceObj.TorchRunRdzvParams)
modelCommand := utils.BuildCmdStr(InferenceFile, inferenceObj.ModelRunParams)
commands := utils.ShellCmd(torchCommand + " " + modelCommand)

resourceRequirements := corev1.ResourceRequirements{
Requests: corev1.ResourceList{
Expand All @@ -158,45 +166,3 @@ func prepareInferenceParameters(ctx context.Context, inferenceObj *model.PresetP

return commands, resourceRequirements
}

func configVolume(wObj *kaitov1alpha1.Workspace, inferenceObj *model.PresetParam) ([]corev1.Volume, []corev1.VolumeMount) {
volume := []corev1.Volume{}
volumeMount := []corev1.VolumeMount{}

// Signifies multinode inference requirement
if *wObj.Resource.Count > 1 {
// Append share memory volume to any existing volumes
volume = append(volume, corev1.Volume{
Name: "dshm",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: "Memory",
},
},
})

volumeMount = append(volumeMount, corev1.VolumeMount{
Name: volume[0].Name,
MountPath: DefaultVolumeMountPath,
})
}

return volume, volumeMount
}

func shellCommand(command string) []string {
return []string{
"/bin/sh",
"-c",
command,
}
}

func buildCommandStr(baseCommand string, torchRunParams map[string]string) string {
updatedBaseCommand := baseCommand
for key, value := range torchRunParams {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
}

return updatedBaseCommand
}
72 changes: 72 additions & 0 deletions pkg/utils/common-preset.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package utils

import (
"fmt"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
corev1 "k8s.io/api/core/v1"
)

const (
DefaultVolumeMountPath = "/dev/shm"
)

func ConfigSHMVolume(wObj *kaitov1alpha1.Workspace) (corev1.Volume, corev1.VolumeMount) {
volume := corev1.Volume{}
volumeMount := corev1.VolumeMount{}

// Signifies multinode inference requirement
if *wObj.Resource.Count > 1 {
// Append share memory volume to any existing volumes
volume = corev1.Volume{
Name: "dshm",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{
Medium: "Memory",
},
},
}

volumeMount = corev1.VolumeMount{
Name: volume.Name,
MountPath: DefaultVolumeMountPath,
}
}

return volume, volumeMount
}

func ConfigDataVolume() ([]corev1.Volume, []corev1.VolumeMount) {
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount
volumes = append(volumes, corev1.Volume{
Name: "data-volume",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
})

volumeMounts = append(volumeMounts, corev1.VolumeMount{
Name: "data-volume",
MountPath: "/data",
})
return volumes, volumeMounts
}

func ShellCmd(command string) []string {
return []string{
"/bin/sh",
"-c",
command,
}
}

func BuildCmdStr(baseCommand string, torchRunParams map[string]string) string {
updatedBaseCommand := baseCommand
for key, value := range torchRunParams {
updatedBaseCommand = fmt.Sprintf("%s --%s=%s", updatedBaseCommand, key, value)
}

return updatedBaseCommand
}
Loading