Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: increase adapter test coverage #477

Merged
merged 3 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ var gpuCountRequirement string
var totalGPUMemoryRequirement string
var perGPUMemoryRequirement string

var invalidSourceName string

func init() {
// Define a invalid source name longer than 253
for i := 0; i < 32; i++ {
invalidSourceName += "Adapter1"
}
}

type testModel struct{}

func (*testModel) GetInferenceParameters() *model.PresetParam {
Expand Down Expand Up @@ -626,14 +635,25 @@ func TestAdapterSpecValidateCreateorUpdate(t *testing.T) {
errContent: "Strength value for Adapter 'Adapter-1' must be between 0 and 1",
expectErrs: true,
},
{
name: "Invalid Source Name, longer than 253",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: invalidSourceName,
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
errContent: "Name of Adapter must be a valid DNS subdomain value",
expectErrs: true,
},
{
name: "Valid Adapter",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
errContent: "",
expectErrs: false,
Expand Down Expand Up @@ -709,6 +729,31 @@ func TestInferenceSpecValidateUpdate(t *testing.T) {
errContent: "field cannot be unset/set if it was set/unset",
expectErrs: true,
},
{
name: "Template Set",
newInference: &InferenceSpec{
Template: &v1.PodTemplateSpec{},
Adapters: []AdapterSpec{
{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
},
{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.6",
},
},
},
},
oldInference: &InferenceSpec{
Template: nil,
},
errContent: "field cannot be unset/set if it was set/unset",
expectErrs: true,
},
{
name: "Valid Update",
newInference: &InferenceSpec{
Expand Down Expand Up @@ -1117,8 +1162,8 @@ func TestDataSourceValidateCreate(t *testing.T) {
{
name: "All fields specified",
dataSource: &DataSource{
URLs: []string{"http://example.com/data"},
Image: "aimodels.azurecr.io/data-image:latest",
URLs: []string{"http://example.com/data"},
Image: "aimodels.azurecr.io/data-image:latest",
},
wantErr: true,
errField: "Exactly one of URLs, Volume, or Image must be specified",
Expand Down Expand Up @@ -1152,13 +1197,13 @@ func TestDataSourceValidateUpdate(t *testing.T) {
{
name: "No changes",
oldSource: &DataSource{
URLs: []string{"http://example.com/data1", "http://example.com/data2"},
URLs: []string{"http://example.com/data1", "http://example.com/data2"},
// Volume: &v1.VolumeSource{},
Image: "data-image:latest",
ImagePullSecrets: []string{"secret1", "secret2"},
},
newSource: &DataSource{
URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter
URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter
// Volume: &v1.VolumeSource{},
Image: "data-image:latest",
ImagePullSecrets: []string{"secret2", "secret1"}, // Note the different order, should not matter
Expand Down
55 changes: 50 additions & 5 deletions pkg/inference/preset-inferences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"
"testing"

"github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/utils/test"

"github.com/azure/kaito/pkg/model"
Expand All @@ -19,14 +20,18 @@ import (
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

var ValidStrength string = "0.5"

func TestCreatePresetInference(t *testing.T) {
test.RegisterTestModel()
testcases := map[string]struct {
nodeCount int
modelName string
callMocks func(c *test.MockClient)
workload string
expectedCmd string
nodeCount int
modelName string
callMocks func(c *test.MockClient)
workload string
expectedCmd string
hasAdapters bool
expectedVolume string
}{

"test-model": {
Expand All @@ -39,6 +44,7 @@ func TestCreatePresetInference(t *testing.T) {
// No BaseCommand, TorchRunParams, TorchRunRdzvParams, or ModelRunParams
// So expected cmd consists of shell command and inference file
expectedCmd: "/bin/sh -c inference_api.py",
hasAdapters: false,
},

"test-distributed-model": {
Expand All @@ -50,6 +56,19 @@ func TestCreatePresetInference(t *testing.T) {
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c inference_api.py",
hasAdapters: false,
},

"test-model-with-adapters": {
nodeCount: 1,
modelName: "test-model",
callMocks: func(c *test.MockClient) {
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workload: "Deployment",
expectedCmd: "/bin/sh -c inference_api.py",
hasAdapters: true,
expectedVolume: "adapter-volume",
},
}

Expand All @@ -61,6 +80,18 @@ func TestCreatePresetInference(t *testing.T) {
workspace := test.MockWorkspaceWithPreset
workspace.Resource.Count = &tc.nodeCount

if tc.hasAdapters {
workspace.Inference.Adapters = []v1alpha1.AdapterSpec{
{
Source: &v1alpha1.DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
}
}

useHeadlessSvc := false

var inferenceObj *model.PresetParam
Expand Down Expand Up @@ -113,6 +144,20 @@ func TestCreatePresetInference(t *testing.T) {
if !reflect.DeepEqual(params, expectedParams) {
t.Errorf("%s parameters are not expected, got %s, expect %s ", k, params, expectedParams)
}

// Check for adapter volume
if tc.hasAdapters {
found := false
for _, volume := range createdObject.(*appsv1.Deployment).Spec.Template.Spec.Volumes {
if volume.Name == tc.expectedVolume {
found = true
break
}
}
if !found {
t.Errorf("%s: expected adapter volume %s not found", k, tc.expectedVolume)
}
}
})
}
}
Expand Down
Loading