Skip to content

Commit

Permalink
chore: increase adapter test coverage (#477)
Browse files Browse the repository at this point in the history
**Reason for Change**:
<!-- What does this PR improve or fix in Kaito? Why is it needed? -->

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

---------

Signed-off-by: Bangqi Zhu <bangqizhu@microsoft.com>
Co-authored-by: Bangqi Zhu <bangqizhu@microsoft.com>
  • Loading branch information
bangqipropel and Bangqi Zhu authored Jun 20, 2024
1 parent d849274 commit 7b874eb
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 10 deletions.
55 changes: 50 additions & 5 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ var gpuCountRequirement string
var totalGPUMemoryRequirement string
var perGPUMemoryRequirement string

var invalidSourceName string

func init() {
// Define a invalid source name longer than 253
for i := 0; i < 32; i++ {
invalidSourceName += "Adapter1"
}
}

type testModel struct{}

func (*testModel) GetInferenceParameters() *model.PresetParam {
Expand Down Expand Up @@ -626,14 +635,25 @@ func TestAdapterSpecValidateCreateorUpdate(t *testing.T) {
errContent: "Strength value for Adapter 'Adapter-1' must be between 0 and 1",
expectErrs: true,
},
{
name: "Invalid Source Name, longer than 253",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: invalidSourceName,
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
errContent: "Name of Adapter must be a valid DNS subdomain value",
expectErrs: true,
},
{
name: "Valid Adapter",
adapterSpec: &AdapterSpec{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
errContent: "",
expectErrs: false,
Expand Down Expand Up @@ -709,6 +729,31 @@ func TestInferenceSpecValidateUpdate(t *testing.T) {
errContent: "field cannot be unset/set if it was set/unset",
expectErrs: true,
},
{
name: "Template Set",
newInference: &InferenceSpec{
Template: &v1.PodTemplateSpec{},
Adapters: []AdapterSpec{
{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
},
{
Source: &DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.6",
},
},
},
},
oldInference: &InferenceSpec{
Template: nil,
},
errContent: "field cannot be unset/set if it was set/unset",
expectErrs: true,
},
{
name: "Valid Update",
newInference: &InferenceSpec{
Expand Down Expand Up @@ -1117,8 +1162,8 @@ func TestDataSourceValidateCreate(t *testing.T) {
{
name: "All fields specified",
dataSource: &DataSource{
URLs: []string{"http://example.com/data"},
Image: "aimodels.azurecr.io/data-image:latest",
URLs: []string{"http://example.com/data"},
Image: "aimodels.azurecr.io/data-image:latest",
},
wantErr: true,
errField: "Exactly one of URLs, Volume, or Image must be specified",
Expand Down Expand Up @@ -1152,13 +1197,13 @@ func TestDataSourceValidateUpdate(t *testing.T) {
{
name: "No changes",
oldSource: &DataSource{
URLs: []string{"http://example.com/data1", "http://example.com/data2"},
URLs: []string{"http://example.com/data1", "http://example.com/data2"},
// Volume: &v1.VolumeSource{},
Image: "data-image:latest",
ImagePullSecrets: []string{"secret1", "secret2"},
},
newSource: &DataSource{
URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter
URLs: []string{"http://example.com/data2", "http://example.com/data1"}, // Note the different order, should not matter
// Volume: &v1.VolumeSource{},
Image: "data-image:latest",
ImagePullSecrets: []string{"secret2", "secret1"}, // Note the different order, should not matter
Expand Down
55 changes: 50 additions & 5 deletions pkg/inference/preset-inferences_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"strings"
"testing"

"github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/utils/test"

"github.com/azure/kaito/pkg/model"
Expand All @@ -19,14 +20,18 @@ import (
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)

var ValidStrength string = "0.5"

func TestCreatePresetInference(t *testing.T) {
test.RegisterTestModel()
testcases := map[string]struct {
nodeCount int
modelName string
callMocks func(c *test.MockClient)
workload string
expectedCmd string
nodeCount int
modelName string
callMocks func(c *test.MockClient)
workload string
expectedCmd string
hasAdapters bool
expectedVolume string
}{

"test-model": {
Expand All @@ -39,6 +44,7 @@ func TestCreatePresetInference(t *testing.T) {
// No BaseCommand, TorchRunParams, TorchRunRdzvParams, or ModelRunParams
// So expected cmd consists of shell command and inference file
expectedCmd: "/bin/sh -c inference_api.py",
hasAdapters: false,
},

"test-distributed-model": {
Expand All @@ -50,6 +56,19 @@ func TestCreatePresetInference(t *testing.T) {
},
workload: "StatefulSet",
expectedCmd: "/bin/sh -c inference_api.py",
hasAdapters: false,
},

"test-model-with-adapters": {
nodeCount: 1,
modelName: "test-model",
callMocks: func(c *test.MockClient) {
c.On("Create", mock.IsType(context.TODO()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workload: "Deployment",
expectedCmd: "/bin/sh -c inference_api.py",
hasAdapters: true,
expectedVolume: "adapter-volume",
},
}

Expand All @@ -61,6 +80,18 @@ func TestCreatePresetInference(t *testing.T) {
workspace := test.MockWorkspaceWithPreset
workspace.Resource.Count = &tc.nodeCount

if tc.hasAdapters {
workspace.Inference.Adapters = []v1alpha1.AdapterSpec{
{
Source: &v1alpha1.DataSource{
Name: "Adapter-1",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
Strength: &ValidStrength,
},
}
}

useHeadlessSvc := false

var inferenceObj *model.PresetParam
Expand Down Expand Up @@ -113,6 +144,20 @@ func TestCreatePresetInference(t *testing.T) {
if !reflect.DeepEqual(params, expectedParams) {
t.Errorf("%s parameters are not expected, got %s, expect %s ", k, params, expectedParams)
}

// Check for adapter volume
if tc.hasAdapters {
found := false
for _, volume := range createdObject.(*appsv1.Deployment).Spec.Template.Spec.Volumes {
if volume.Name == tc.expectedVolume {
found = true
break
}
}
if !found {
t.Errorf("%s: expected adapter volume %s not found", k, tc.expectedVolume)
}
}
})
}
}
Expand Down

0 comments on commit 7b874eb

Please sign in to comment.