diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index 7f585eede..225b747d4 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -33,6 +33,15 @@ var gpuCountRequirement string var totalGPUMemoryRequirement string var perGPUMemoryRequirement string +var invalidSourceName string + +func init() { + // Define a invalid source name longer than 253 + for i := 0; i < 32; i++ { + invalidSourceName += "Adapter1" + } +} + type testModel struct{} func (*testModel) GetInferenceParameters() *model.PresetParam { @@ -626,6 +635,18 @@ func TestAdapterSpecValidateCreateorUpdate(t *testing.T) { errContent: "Strength value for Adapter 'Adapter-1' must be between 0 and 1", expectErrs: true, }, + { + name: "Invalid Source Name, longer than 253", + adapterSpec: &AdapterSpec{ + Source: &DataSource{ + Name: invalidSourceName, + Image: "fake.kaito.com/kaito-image:0.0.1", + }, + Strength: &ValidStrength, + }, + errContent: "Name of Adapter must be a valid DNS subdomain value", + expectErrs: true, + }, { name: "Valid Adapter", adapterSpec: &AdapterSpec{ @@ -633,7 +654,6 @@ func TestAdapterSpecValidateCreateorUpdate(t *testing.T) { Name: "Adapter-1", Image: "fake.kaito.com/kaito-image:0.0.1", }, - Strength: &ValidStrength, }, errContent: "", expectErrs: false, @@ -709,6 +729,31 @@ func TestInferenceSpecValidateUpdate(t *testing.T) { errContent: "field cannot be unset/set if it was set/unset", expectErrs: true, }, + { + name: "Template Set", + newInference: &InferenceSpec{ + Template: &v1.PodTemplateSpec{}, + Adapters: []AdapterSpec{ + { + Source: &DataSource{ + Name: "Adapter-1", + Image: "fake.kaito.com/kaito-image:0.0.1", + }, + }, + { + Source: &DataSource{ + Name: "Adapter-1", + Image: "fake.kaito.com/kaito-image:0.0.6", + }, + }, + }, + }, + oldInference: &InferenceSpec{ + Template: nil, + }, + errContent: "field cannot be unset/set if it was set/unset", + expectErrs: true, + }, { name: "Valid Update", newInference: &InferenceSpec{ @@ -1117,8 +1162,8 @@ func TestDataSourceValidateCreate(t *testing.T) { { name: "All fields specified", dataSource: &DataSource{ - URLs: []string{"http://example.com/data"}, - Image: "aimodels.azurecr.io/data-image:latest", + URLs: []string{"http://example.com/data"}, + Image: "aimodels.azurecr.io/data-image:latest", }, wantErr: true, errField: "Exactly one of URLs, Volume, or Image must be specified", @@ -1152,13 +1197,13 @@ func TestDataSourceValidateUpdate(t *testing.T) { { name: "No changes", oldSource: &DataSource{ - URLs: []string{"http://example.com/data1", "http://example.com/data2"}, + URLs: []string{"http://example.com/data1", "http://example.com/data2"}, // Volume: &v1.VolumeSource{}, Image: "data-image:latest", ImagePullSecrets: []string{"secret1", "secret2"}, }, newSource: &DataSource{ - URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter + URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter // Volume: &v1.VolumeSource{}, Image: "data-image:latest", ImagePullSecrets: []string{"secret2", "secret1"}, // Note the different order, should not matter diff --git a/pkg/inference/preset-inferences_test.go b/pkg/inference/preset-inferences_test.go index 2173a166c..390522a7b 100644 --- a/pkg/inference/preset-inferences_test.go +++ b/pkg/inference/preset-inferences_test.go @@ -9,6 +9,7 @@ import ( "strings" "testing" + "github.com/azure/kaito/api/v1alpha1" "github.com/azure/kaito/pkg/utils/test" "github.com/azure/kaito/pkg/model" @@ -19,14 +20,18 @@ import ( v1 "k8s.io/apimachinery/pkg/apis/meta/v1" ) +var ValidStrength string = "0.5" + func TestCreatePresetInference(t *testing.T) { test.RegisterTestModel() testcases := map[string]struct { - nodeCount int - modelName string - callMocks func(c *test.MockClient) - workload string - expectedCmd string + nodeCount int + modelName string + callMocks func(c *test.MockClient) + workload string + expectedCmd string + hasAdapters bool + expectedVolume string }{ "test-model": { @@ -39,6 +44,7 @@ func TestCreatePresetInference(t *testing.T) { // No BaseCommand, TorchRunParams, TorchRunRdzvParams, or ModelRunParams // So expected cmd consists of shell command and inference file expectedCmd: "/bin/sh -c inference_api.py", + hasAdapters: false, }, "test-distributed-model": { @@ -50,6 +56,19 @@ func TestCreatePresetInference(t *testing.T) { }, workload: "StatefulSet", expectedCmd: "/bin/sh -c inference_api.py", + hasAdapters: false, + }, + + "test-model-with-adapters": { + nodeCount: 1, + modelName: "test-model", + callMocks: func(c *test.MockClient) { + c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil) + }, + workload: "Deployment", + expectedCmd: "/bin/sh -c inference_api.py", + hasAdapters: true, + expectedVolume: "adapter-volume", }, } @@ -61,6 +80,18 @@ func TestCreatePresetInference(t *testing.T) { workspace := test.MockWorkspaceWithPreset workspace.Resource.Count = &tc.nodeCount + if tc.hasAdapters { + workspace.Inference.Adapters = []v1alpha1.AdapterSpec{ + { + Source: &v1alpha1.DataSource{ + Name: "Adapter-1", + Image: "fake.kaito.com/kaito-image:0.0.1", + }, + Strength: &ValidStrength, + }, + } + } + useHeadlessSvc := false var inferenceObj *model.PresetParam @@ -113,6 +144,20 @@ func TestCreatePresetInference(t *testing.T) { if !reflect.DeepEqual(params, expectedParams) { t.Errorf("%s parameters are not expected, got %s, expect %s ", k, params, expectedParams) } + + // Check for adapter volume + if tc.hasAdapters { + found := false + for _, volume := range createdObject.(*appsv1.Deployment).Spec.Template.Spec.Volumes { + if volume.Name == tc.expectedVolume { + found = true + break + } + } + if !found { + t.Errorf("%s: expected adapter volume %s not found", k, tc.expectedVolume) + } + } }) } }