Skip to content

Commit

Permalink
feat: Part 2 - Add validation checks for TuningSpec, DataSource, Data…
Browse files Browse the repository at this point in the history
…Destination (#304)

Open to feedback. Here is validation checks for the new CRD.

Code Coverage: 87.7% - workspace_validation.go


[coverage.txt](https://github.com/Azure/kaito/files/14661460/coverage.txt)
  • Loading branch information
ishaansehgal99 authored Mar 20, 2024
1 parent 2485c58 commit 4ba337d
Show file tree
Hide file tree
Showing 9 changed files with 986 additions and 49 deletions.
14 changes: 7 additions & 7 deletions api/v1alpha1/workspace_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ const (
ModelImageAccessModePrivate ModelImageAccessMode = "private"
)

// ResourceSpec desicribes the resource requirement of running the workload.
// ResourceSpec describes the resource requirement of running the workload.
// If the number of nodes in the cluster that meet the InstanceType and
// LabelSelector requirements is small than the Count, controller
// will provision new nodes before deploying the workload.
Expand Down Expand Up @@ -51,7 +51,7 @@ type PresetMeta struct {
// AccessMode specifies whether the containerized model image is accessible via public registry
// or private registry. This field defaults to "public" if not specified.
// If this field is "private", user needs to provide the private image information in PresetOptions.
// +bebuilder:default:="public"
// +kubebuilder:default:="public"
// +optional
AccessMode ModelImageAccessMode `json:"accessMode,omitempty"`
}
Expand Down Expand Up @@ -106,7 +106,7 @@ type DataSource struct {
// URLs specifies the links to the public data sources. E.g., files in a public github repository.
// +optional
URLs []string `json:"urls,omitempty"`
// The directory in the hsot that contains the data.
// The directory in the host that contains the data.
// +optional
HostPath string `json:"hostPath,omitempty"`
// The name of the image that contains the source data. The assumption is that the source data locates in the
Expand Down Expand Up @@ -150,9 +150,9 @@ type TuningSpec struct {
// +optional
Config string `json:"config,omitempty"`
// Input describes the input used by the tuning method.
Input *DataSource `json:"input,omitempty"`
Input *DataSource `json:"input"`
// Output specified where to store the tuning output.
Output *DataDestination `json:"output,omitempty"`
Output *DataDestination `json:"output"`
}

// WorkspaceStatus defines the observed state of Workspace
Expand Down Expand Up @@ -181,8 +181,8 @@ type Workspace struct {
metav1.ObjectMeta `json:"metadata,omitempty"`

Resource ResourceSpec `json:"resource,omitempty"`
Inference InferenceSpec `json:"inference,omitempty"`
Tuning TuningSpec `json:"tuning,omitempty"`
Inference *InferenceSpec `json:"inference,omitempty"`
Tuning *TuningSpec `json:"tuning,omitempty"`
Status WorkspaceStatus `json:"status,omitempty"`
}

Expand Down
179 changes: 174 additions & 5 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"context"
"fmt"
"reflect"
"sort"
"strings"

"github.com/azure/kaito/pkg/utils/plugin"
Expand Down Expand Up @@ -35,16 +36,184 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
if base == nil {
klog.InfoS("Validate creation", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
errs = errs.Also(
w.Inference.validateCreate().ViaField("inference"),
w.Resource.validateCreate(w.Inference).ViaField("resource"),
w.validateCreate().ViaField("spec"),
// TODO: Consider validate resource based on Tuning Spec
w.Resource.validateCreate(*w.Inference).ViaField("resource"),
)
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Inference.validateCreate().ViaField("inference"))
}
if w.Tuning != nil {
errs = errs.Also(w.Tuning.validateCreate().ViaField("tuning"))
}
} else {
klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
old := base.(*Workspace)
errs = errs.Also(
w.validateUpdate(old).ViaField("spec"),
w.Resource.validateUpdate(&old.Resource).ViaField("resource"),
w.Inference.validateUpdate(&old.Inference).ViaField("inference"),
)
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Inference.validateUpdate(old.Inference).ViaField("inference"))
}
if w.Tuning != nil {
errs = errs.Also(w.Tuning.validateUpdate(old.Tuning).ViaField("tuning"))
}
}
return errs
}

func (w *Workspace) validateCreate() (errs *apis.FieldError) {
if w.Inference == nil && w.Tuning == nil {
errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, not neither", ""))
}
if w.Inference != nil && w.Tuning != nil {
errs = errs.Also(apis.ErrGeneric("Either Inference or Tuning must be specified, but not both", ""))
}
return errs
}

func (w *Workspace) validateUpdate(old *Workspace) (errs *apis.FieldError) {
if (old.Inference == nil && w.Inference != nil) || (old.Inference != nil && w.Inference == nil) {
errs = errs.Also(apis.ErrGeneric("Inference field cannot be toggled once set", "inference"))
}

if (old.Tuning == nil && w.Tuning != nil) || (old.Tuning != nil && w.Tuning == nil) {
errs = errs.Also(apis.ErrGeneric("Tuning field cannot be toggled once set", "tuning"))
}
return errs
}

func (r *TuningSpec) validateCreate() (errs *apis.FieldError) {
if r.Input == nil {
errs = errs.Also(apis.ErrMissingField("Input"))
} else {
errs = errs.Also(r.Input.validateCreate().ViaField("Input"))
}
if r.Output == nil {
errs = errs.Also(apis.ErrMissingField("Output"))
} else {
errs = errs.Also(r.Output.validateCreate().ViaField("Output"))
}
// Currently require a preset to specified, in future we can consider defining a template
if r.Preset == nil {
errs = errs.Also(apis.ErrMissingField("Preset"))
} else if presetName := string(r.Preset.Name); !isValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported tuning preset name %s", presetName), "presetName"))
}
methodLowerCase := strings.ToLower(string(r.Method))
if methodLowerCase != string(TuningMethodLora) && methodLowerCase != string(TuningMethodQLora) {
errs = errs.Also(apis.ErrInvalidValue(r.Method, "Method"))
}
return errs
}

func (r *TuningSpec) validateUpdate(old *TuningSpec) (errs *apis.FieldError) {
if r.Input == nil {
errs = errs.Also(apis.ErrMissingField("Input"))
} else {
errs = errs.Also(r.Input.validateUpdate(old.Input, true).ViaField("Input"))
}
if r.Output == nil {
errs = errs.Also(apis.ErrMissingField("Output"))
} else {
errs = errs.Also(r.Output.validateUpdate(old.Output).ViaField("Output"))
}
if !reflect.DeepEqual(old.Preset, r.Preset) {
errs = errs.Also(apis.ErrGeneric("Preset cannot be changed", "Preset"))
}
oldMethod, newMethod := strings.ToLower(string(old.Method)), strings.ToLower(string(r.Method))
if !reflect.DeepEqual(oldMethod, newMethod) {
errs = errs.Also(apis.ErrGeneric("Method cannot be changed", "Method"))
}
// Consider supporting config fields changing
return errs
}

func (r *DataSource) validateCreate() (errs *apis.FieldError) {
sourcesSpecified := 0
if len(r.URLs) > 0 {
sourcesSpecified++
}
if r.HostPath != "" {
sourcesSpecified++
}
if r.Image != "" {
sourcesSpecified++
}

// Ensure exactly one of URLs, HostPath, or Image is specified
if sourcesSpecified != 1 {
errs = errs.Also(apis.ErrGeneric("Exactly one of URLs, HostPath, or Image must be specified", "URLs", "HostPath", "Image"))
}

return errs
}

func (r *DataSource) validateUpdate(old *DataSource, isTuning bool) (errs *apis.FieldError) {
if isTuning && !reflect.DeepEqual(old.Name, r.Name) {
errs = errs.Also(apis.ErrInvalidValue("During tuning Name field cannot be changed once set", "Name"))
}
oldURLs := make([]string, len(old.URLs))
copy(oldURLs, old.URLs)
sort.Strings(oldURLs)

newURLs := make([]string, len(r.URLs))
copy(newURLs, r.URLs)
sort.Strings(newURLs)

if !reflect.DeepEqual(oldURLs, newURLs) {
errs = errs.Also(apis.ErrInvalidValue("URLs field cannot be changed once set", "URLs"))
}
if old.HostPath != r.HostPath {
errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath"))
}
if old.Image != r.Image {
errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image"))
}

oldSecrets := make([]string, len(old.ImagePullSecrets))
copy(oldSecrets, old.ImagePullSecrets)
sort.Strings(oldSecrets)

newSecrets := make([]string, len(r.ImagePullSecrets))
copy(newSecrets, r.ImagePullSecrets)
sort.Strings(newSecrets)

if !reflect.DeepEqual(oldSecrets, newSecrets) {
errs = errs.Also(apis.ErrInvalidValue("ImagePullSecrets field cannot be changed once set", "ImagePullSecrets"))
}
return errs
}

func (r *DataDestination) validateCreate() (errs *apis.FieldError) {
destinationsSpecified := 0
if r.HostPath != "" {
destinationsSpecified++
}
if r.Image != "" {
destinationsSpecified++
}

// If no destination is specified, return an error
if destinationsSpecified == 0 {
errs = errs.Also(apis.ErrMissingField("At least one of HostPath or Image must be specified"))
}
return errs
}

func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.FieldError) {
if old.HostPath != r.HostPath {
errs = errs.Also(apis.ErrInvalidValue("HostPath field cannot be changed once set", "HostPath"))
}
if old.Image != r.Image {
errs = errs.Also(apis.ErrInvalidValue("Image field cannot be changed once set", "Image"))
}

if old.ImagePushSecret != r.ImagePushSecret {
errs = errs.Also(apis.ErrInvalidValue("ImagePushSecret field cannot be changed once set", "ImagePushSecret"))
}
return errs
}
Expand Down Expand Up @@ -131,7 +300,7 @@ func (i *InferenceSpec) validateCreate() (errs *apis.FieldError) {
presetName := string(i.Preset.Name)
// Validate preset name
if !isValidPreset(presetName) {
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported preset name %s", presetName), "presetName"))
errs = errs.Also(apis.ErrInvalidValue(fmt.Sprintf("Unsupported inference preset name %s", presetName), "presetName"))
}
// Validate private preset has private image specified
if plugin.KaitoModelRegister.MustGet(string(i.Preset.Name)).GetInferenceParameters().ImageAccessMode == "private" &&
Expand All @@ -151,7 +320,7 @@ func (i *InferenceSpec) validateUpdate(old *InferenceSpec) (errs *apis.FieldErro
if !reflect.DeepEqual(i.Preset, old.Preset) {
errs = errs.Also(apis.ErrGeneric("field is immutable", "preset"))
}
//inference.template can be changed, but cannot be unset.
// inference.template can be changed, but cannot be set/unset.
if (i.Template != nil && old.Template == nil) || (i.Template == nil && old.Template != nil) {
errs = errs.Also(apis.ErrGeneric("field cannot be unset/set if it was set/unset", "template"))
}
Expand Down
Loading

0 comments on commit 4ba337d

Please sign in to comment.