diff --git a/api/v1alpha1/ragengine_validation.go b/api/v1alpha1/ragengine_validation.go index 9ffff50cf..d045dc53a 100644 --- a/api/v1alpha1/ragengine_validation.go +++ b/api/v1alpha1/ragengine_validation.go @@ -31,6 +31,13 @@ func (w *RAGEngine) Validate(ctx context.Context) (errs *apis.FieldError) { if base == nil { klog.InfoS("Validate creation", "ragengine", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) errs = errs.Also(w.validateCreate().ViaField("spec")) + } else { + klog.InfoS("Validate update", "ragengine", fmt.Sprintf("%s/%s", w.Namespace, w.Name)) + old := base.(*RAGEngine) + errs = errs.Also( + w.validateCreate().ViaField("spec"), + w.Spec.Compute.validateUpdate(old.Spec.Compute).ViaField("resource"), + ) } return errs } diff --git a/pkg/ragengine/controllers/ragengine_controller.go b/pkg/ragengine/controllers/ragengine_controller.go index a3cb31f8e..349b5d1c1 100644 --- a/pkg/ragengine/controllers/ragengine_controller.go +++ b/pkg/ragengine/controllers/ragengine_controller.go @@ -18,6 +18,7 @@ import ( "github.com/go-logr/logr" kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1" "github.com/kaito-project/kaito/pkg/featuregates" + "github.com/kaito-project/kaito/pkg/ragengine/manifests" "github.com/kaito-project/kaito/pkg/utils" "github.com/kaito-project/kaito/pkg/utils/consts" "github.com/kaito-project/kaito/pkg/utils/machine" @@ -152,8 +153,22 @@ func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov if err = resources.GetResource(ctx, ragEngineObj.Name, ragEngineObj.Namespace, c.Client, deployment); err == nil { klog.InfoS("An inference workload already exists for ragengine", "ragengine", klog.KObj(ragEngineObj)) - return + if deployment.Annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] != revisionStr { + envs := manifests.RAGSetEnv(ragEngineObj) + + spec := &deployment.Spec + // Currently, all CRD changes are only passed through environment variables (env) + spec.Template.Spec.Containers[0].Env = envs + deployment.Annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] = revisionStr + + if err := c.Update(ctx, deployment); err != nil { + return + } + } + if err = resources.CheckResourceStatus(deployment, c.Client, time.Duration(10)*time.Minute); err != nil { + return + } } else if apierrors.IsNotFound(err) { var workloadObj client.Object // Need to create a new workload @@ -577,6 +592,9 @@ func (c *RAGEngineReconciler) SetupWithManager(mgr ctrl.Manager) error { } builder := ctrl.NewControllerManagedBy(mgr). For(&kaitov1alpha1.RAGEngine{}). + Owns(&appsv1.ControllerRevision{}). + Owns(&appsv1.Deployment{}). + Watches(&v1alpha5.Machine{}, c.watchMachines()). WithOptions(controller.Options{MaxConcurrentReconciles: 5}) if featuregates.FeatureGates[consts.FeatureFlagKarpenter] { builder.Watches(&v1beta1.NodeClaim{}, c.watchNodeClaims()) // watches for nodeClaim with labels indicating ragengine name. diff --git a/pkg/ragengine/controllers/ragengine_controller_test.go b/pkg/ragengine/controllers/ragengine_controller_test.go index 6c69b1d5e..407c4665d 100644 --- a/pkg/ragengine/controllers/ragengine_controller_test.go +++ b/pkg/ragengine/controllers/ragengine_controller_test.go @@ -508,3 +508,131 @@ func TestUpdateControllerRevision1(t *testing.T) { }) } } + +func TestApplyRAG(t *testing.T) { + test.RegisterTestModel() + testcases := map[string]struct { + callMocks func(c *test.MockClient) + ragengine v1alpha1.RAGEngine + expectedError error + verifyCalls func(c *test.MockClient) + }{ + + "Fail because associated workload with ragengine cannot be retrieved": { + callMocks: func(c *test.MockClient) { + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(errors.New("Failed to get resource")) + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) + c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) + }, + ragengine: *test.MockRAGEngineWithRevision1, + expectedError: errors.New("Failed to get resource"), + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 0) + c.AssertNumberOfCalls(t, "Get", 5) + c.AssertNumberOfCalls(t, "Delete", 0) + }, + }, + + "Create preset inference because inference workload did not exist": { + callMocks: func(c *test.MockClient) { + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(test.NotFoundError()).Times(4) + c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil).Run(func(args mock.Arguments) { + depObj := &appsv1.Deployment{} + key := client.ObjectKey{Namespace: "kaito", Name: "testRAGEngine"} + c.GetObjectFromMap(depObj, key) + depObj.Status.ReadyReplicas = 1 + c.CreateOrUpdateObjectInMap(depObj) + }) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil) + + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil) + + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) + c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) + os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + }, + ragengine: *test.MockRAGEngineWithRevision1, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 1) + c.AssertNumberOfCalls(t, "Get", 7) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 0) + }, + expectedError: nil, + }, + + "Apply inference from existing workload": { + callMocks: func(c *test.MockClient) { + c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything). + Run(func(args mock.Arguments) { + dep := args.Get(2).(*appsv1.Deployment) + *dep = test.MockRAGDeploymentUpdated + }). + Return(nil) + + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) + c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) + }, + ragengine: *test.MockRAGEngineWithRevision1, + expectedError: nil, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 0) + c.AssertNumberOfCalls(t, "Get", 3) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 0) + }, + }, + + "Update deployment with new configuration": { + callMocks: func(c *test.MockClient) { + // Mocking existing Deployment object + c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything). + Run(func(args mock.Arguments) { + dep := args.Get(2).(*appsv1.Deployment) + *dep = test.MockRAGDeploymentUpdated + }). + Return(nil) + + c.On("Update", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil) + + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) + c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) + }, + ragengine: *test.MockRAGEngineWithPreset, + expectedError: nil, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 0) + c.AssertNumberOfCalls(t, "Create", 0) + c.AssertNumberOfCalls(t, "Get", 3) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 1) + }, + }, + } + + for k, tc := range testcases { + t.Run(k, func(t *testing.T) { + mockClient := test.NewClient() + tc.callMocks(mockClient) + + reconciler := &RAGEngineReconciler{ + Client: mockClient, + Scheme: test.NewTestScheme(), + } + ctx := context.Background() + + err := reconciler.applyRAG(ctx, &tc.ragengine) + if tc.expectedError == nil { + assert.Check(t, err == nil, "Not expected to return error") + } else { + assert.Equal(t, tc.expectedError.Error(), err.Error()) + } + if tc.verifyCalls != nil { + tc.verifyCalls(mockClient) + } + }) + } +} diff --git a/pkg/utils/test/testUtils.go b/pkg/utils/test/testUtils.go index 72d5db880..5ad728345 100644 --- a/pkg/utils/test/testUtils.go +++ b/pkg/utils/test/testUtils.go @@ -225,6 +225,58 @@ var ( }, }, } + MockRAGEngineWithRevision1 = &v1alpha1.RAGEngine{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine", + Namespace: "kaito", + Annotations: map[string]string{v1alpha1.RAGEngineRevisionAnnotation: "1"}, + }, + Spec: &v1alpha1.RAGEngineSpec{ + Compute: &v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "ragengine.kaito.io/name": "testRAGEngine", + }, + }, + }, + Embedding: &v1alpha1.EmbeddingSpec{ + Local: &v1alpha1.LocalEmbeddingSpec{ + ModelID: "BAAI/bge-small-en-v1.5", + }, + }, + InferenceService: &v1alpha1.InferenceServiceSpec{ + URL: "http://localhost:5000/chat", + }, + }, + } + MockRAGEngineWithRevision2 = &v1alpha1.RAGEngine{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine", + Namespace: "kaito", + Annotations: map[string]string{v1alpha1.RAGEngineRevisionAnnotation: "2"}, + }, + Spec: &v1alpha1.RAGEngineSpec{ + Compute: &v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "ragengine.kaito.io/name": "testRAGEngine", + }, + }, + }, + Embedding: &v1alpha1.EmbeddingSpec{ + Local: &v1alpha1.LocalEmbeddingSpec{ + ModelID: "BAAI/bge-small-en-v1.5", + }, + }, + InferenceService: &v1alpha1.InferenceServiceSpec{ + URL: "http://localhost:5000/chat", + }, + }, + } ) var MockRAGEngineWithPresetHash = "14485768c1b67a529a71e3c87d9f2e6c1ed747534dea07e268e93475a5e21e" @@ -620,6 +672,60 @@ var ( }, } ) +var MockRAGDeploymentUpdated = appsv1.Deployment{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine", + Namespace: "kaito", + Annotations: map[string]string{v1alpha1.RAGEngineRevisionAnnotation: "1"}, + }, + Spec: appsv1.DeploymentSpec{ + Replicas: &numRep, + Template: corev1.PodTemplateSpec{ + ObjectMeta: metav1.ObjectMeta{ + Labels: map[string]string{ + "app": "test-app", + }, + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{ + { + Name: "test-container", + Image: "nginx:latest", + Ports: []corev1.ContainerPort{ + { + ContainerPort: 80, + Protocol: corev1.ProtocolTCP, + }, + }, + Env: []corev1.EnvVar{ + { + Name: "ENV_VAR_NAME", + Value: "ENV_VAR_VALUE", + }, + }, + VolumeMounts: []corev1.VolumeMount{ + { + Name: "volume-name", + MountPath: "/mount/path", + }, + }, + }, + }, + Volumes: []corev1.Volume{ + { + Name: "volume-name", + VolumeSource: corev1.VolumeSource{ + EmptyDir: &corev1.EmptyDirVolumeSource{}, + }, + }, + }, + }, + }, + }, + Status: appsv1.DeploymentStatus{ + ReadyReplicas: 1, + }, +} var ( MockWorkspaceWithInferenceTemplate = &v1alpha1.Workspace{