Skip to content

Commit

Permalink
feat: RAGEngine update and validation (#747)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Add Update part of RAGEngine controller and validation

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

Signed-off-by: Bangqi Zhu <bangqizhu@microsoft.com>
Co-authored-by: Bangqi Zhu <bangqizhu@microsoft.com>
  • Loading branch information
bangqipropel and Bangqi Zhu authored Dec 11, 2024
1 parent c3be988 commit b099c66
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 1 deletion.
7 changes: 7 additions & 0 deletions api/v1alpha1/ragengine_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ func (w *RAGEngine) Validate(ctx context.Context) (errs *apis.FieldError) {
if base == nil {
klog.InfoS("Validate creation", "ragengine", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
errs = errs.Also(w.validateCreate().ViaField("spec"))
} else {
klog.InfoS("Validate update", "ragengine", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
old := base.(*RAGEngine)
errs = errs.Also(
w.validateCreate().ViaField("spec"),
w.Spec.Compute.validateUpdate(old.Spec.Compute).ViaField("resource"),
)
}
return errs
}
Expand Down
20 changes: 19 additions & 1 deletion pkg/ragengine/controllers/ragengine_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/go-logr/logr"
kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/featuregates"
"github.com/kaito-project/kaito/pkg/ragengine/manifests"
"github.com/kaito-project/kaito/pkg/utils"
"github.com/kaito-project/kaito/pkg/utils/consts"
"github.com/kaito-project/kaito/pkg/utils/machine"
Expand Down Expand Up @@ -152,8 +153,22 @@ func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov

if err = resources.GetResource(ctx, ragEngineObj.Name, ragEngineObj.Namespace, c.Client, deployment); err == nil {
klog.InfoS("An inference workload already exists for ragengine", "ragengine", klog.KObj(ragEngineObj))
return
if deployment.Annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] != revisionStr {

envs := manifests.RAGSetEnv(ragEngineObj)

spec := &deployment.Spec
// Currently, all CRD changes are only passed through environment variables (env)
spec.Template.Spec.Containers[0].Env = envs
deployment.Annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] = revisionStr

if err := c.Update(ctx, deployment); err != nil {
return
}
}
if err = resources.CheckResourceStatus(deployment, c.Client, time.Duration(10)*time.Minute); err != nil {
return
}
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
Expand Down Expand Up @@ -577,6 +592,9 @@ func (c *RAGEngineReconciler) SetupWithManager(mgr ctrl.Manager) error {
}
builder := ctrl.NewControllerManagedBy(mgr).
For(&kaitov1alpha1.RAGEngine{}).
Owns(&appsv1.ControllerRevision{}).
Owns(&appsv1.Deployment{}).
Watches(&v1alpha5.Machine{}, c.watchMachines()).
WithOptions(controller.Options{MaxConcurrentReconciles: 5})
if featuregates.FeatureGates[consts.FeatureFlagKarpenter] {
builder.Watches(&v1beta1.NodeClaim{}, c.watchNodeClaims()) // watches for nodeClaim with labels indicating ragengine name.
Expand Down
128 changes: 128 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,131 @@ func TestUpdateControllerRevision1(t *testing.T) {
})
}
}

func TestApplyRAG(t *testing.T) {
test.RegisterTestModel()
testcases := map[string]struct {
callMocks func(c *test.MockClient)
ragengine v1alpha1.RAGEngine
expectedError error
verifyCalls func(c *test.MockClient)
}{

"Fail because associated workload with ragengine cannot be retrieved": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(errors.New("Failed to get resource"))
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
},
ragengine: *test.MockRAGEngineWithRevision1,
expectedError: errors.New("Failed to get resource"),
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Get", 5)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},

"Create preset inference because inference workload did not exist": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(test.NotFoundError()).Times(4)
c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil).Run(func(args mock.Arguments) {
depObj := &appsv1.Deployment{}
key := client.ObjectKey{Namespace: "kaito", Name: "testRAGEngine"}
c.GetObjectFromMap(depObj, key)
depObj.Status.ReadyReplicas = 1
c.CreateOrUpdateObjectInMap(depObj)
})
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
},
ragengine: *test.MockRAGEngineWithRevision1,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Get", 7)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
expectedError: nil,
},

"Apply inference from existing workload": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.Deployment)
*dep = test.MockRAGDeploymentUpdated
}).
Return(nil)

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
},
ragengine: *test.MockRAGEngineWithRevision1,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Get", 3)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
},

"Update deployment with new configuration": {
callMocks: func(c *test.MockClient) {
// Mocking existing Deployment object
c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.Deployment)
*dep = test.MockRAGDeploymentUpdated
}).
Return(nil)

c.On("Update", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
},
ragengine: *test.MockRAGEngineWithPreset,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Get", 3)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 1)
},
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := test.NewClient()
tc.callMocks(mockClient)

reconciler := &RAGEngineReconciler{
Client: mockClient,
Scheme: test.NewTestScheme(),
}
ctx := context.Background()

err := reconciler.applyRAG(ctx, &tc.ragengine)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
if tc.verifyCalls != nil {
tc.verifyCalls(mockClient)
}
})
}
}
106 changes: 106 additions & 0 deletions pkg/utils/test/testUtils.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,58 @@ var (
},
},
}
MockRAGEngineWithRevision1 = &v1alpha1.RAGEngine{
ObjectMeta: metav1.ObjectMeta{
Name: "testRAGEngine",
Namespace: "kaito",
Annotations: map[string]string{v1alpha1.RAGEngineRevisionAnnotation: "1"},
},
Spec: &v1alpha1.RAGEngineSpec{
Compute: &v1alpha1.ResourceSpec{
Count: &gpuNodeCount,
InstanceType: "Standard_NC12s_v3",
LabelSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"ragengine.kaito.io/name": "testRAGEngine",
},
},
},
Embedding: &v1alpha1.EmbeddingSpec{
Local: &v1alpha1.LocalEmbeddingSpec{
ModelID: "BAAI/bge-small-en-v1.5",
},
},
InferenceService: &v1alpha1.InferenceServiceSpec{
URL: "http://localhost:5000/chat",
},
},
}
MockRAGEngineWithRevision2 = &v1alpha1.RAGEngine{
ObjectMeta: metav1.ObjectMeta{
Name: "testRAGEngine",
Namespace: "kaito",
Annotations: map[string]string{v1alpha1.RAGEngineRevisionAnnotation: "2"},
},
Spec: &v1alpha1.RAGEngineSpec{
Compute: &v1alpha1.ResourceSpec{
Count: &gpuNodeCount,
InstanceType: "Standard_NC12s_v3",
LabelSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"ragengine.kaito.io/name": "testRAGEngine",
},
},
},
Embedding: &v1alpha1.EmbeddingSpec{
Local: &v1alpha1.LocalEmbeddingSpec{
ModelID: "BAAI/bge-small-en-v1.5",
},
},
InferenceService: &v1alpha1.InferenceServiceSpec{
URL: "http://localhost:5000/chat",
},
},
}
)

var MockRAGEngineWithPresetHash = "14485768c1b67a529a71e3c87d9f2e6c1ed747534dea07e268e93475a5e21e"
Expand Down Expand Up @@ -620,6 +672,60 @@ var (
},
}
)
var MockRAGDeploymentUpdated = appsv1.Deployment{
ObjectMeta: metav1.ObjectMeta{
Name: "testRAGEngine",
Namespace: "kaito",
Annotations: map[string]string{v1alpha1.RAGEngineRevisionAnnotation: "1"},
},
Spec: appsv1.DeploymentSpec{
Replicas: &numRep,
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"app": "test-app",
},
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "test-container",
Image: "nginx:latest",
Ports: []corev1.ContainerPort{
{
ContainerPort: 80,
Protocol: corev1.ProtocolTCP,
},
},
Env: []corev1.EnvVar{
{
Name: "ENV_VAR_NAME",
Value: "ENV_VAR_VALUE",
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: "volume-name",
MountPath: "/mount/path",
},
},
},
},
Volumes: []corev1.Volume{
{
Name: "volume-name",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
},
},
},
},
},
Status: appsv1.DeploymentStatus{
ReadyReplicas: 1,
},
}

var (
MockWorkspaceWithInferenceTemplate = &v1alpha1.Workspace{
Expand Down

0 comments on commit b099c66

Please sign in to comment.