diff --git a/api/v1alpha1/labels.go b/api/v1alpha1/labels.go index 409c08ef8..7e807c29a 100644 --- a/api/v1alpha1/labels.go +++ b/api/v1alpha1/labels.go @@ -27,4 +27,7 @@ const ( // WorkspaceRevisionAnnotation is the Annotations for revision number WorkspaceRevisionAnnotation = "workspace.kaito.io/revision" + + // RAGEngineRevisionAnnotation is the Annotations for revision number + RAGEngineRevisionAnnotation = "ragengine.kaito.io/revision" ) diff --git a/pkg/ragengine/controllers/ragengine_controller.go b/pkg/ragengine/controllers/ragengine_controller.go index 57d3cf8d4..54adf7642 100644 --- a/pkg/ragengine/controllers/ragengine_controller.go +++ b/pkg/ragengine/controllers/ragengine_controller.go @@ -5,7 +5,12 @@ package controllers import ( "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" "fmt" + "sort" + "strconv" "strings" "time" @@ -19,12 +24,14 @@ import ( "github.com/kaito-project/kaito/pkg/utils" "github.com/kaito-project/kaito/pkg/utils/consts" "github.com/samber/lo" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/errors" apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/sets" "k8s.io/client-go/tools/record" "k8s.io/klog/v2" @@ -39,6 +46,12 @@ import ( "sigs.k8s.io/karpenter/pkg/apis/v1beta1" ) +const ( + RAGEngineHashAnnotation = "ragengine.kaito.io/hash" + RAGEngineNameLabel = "ragengine.kaito.io/name" + revisionHashSuffix = 5 +) + type RAGEngineReconciler struct { client.Client Log logr.Logger @@ -75,6 +88,10 @@ func (c *RAGEngineReconciler) Reconcile(ctx context.Context, req reconcile.Reque return c.deleteRAGEngine(ctx, ragEngineObj) } + if err := c.syncControllerRevision(ctx, ragEngineObj); err != nil { + return reconcile.Result{}, err + } + result, err := c.addRAGEngine(ctx, ragEngineObj) if err != nil { return result, err @@ -114,6 +131,99 @@ func (c *RAGEngineReconciler) deleteRAGEngine(ctx context.Context, ragEngineObj return c.garbageCollectRAGEngine(ctx, ragEngineObj) } +func (c *RAGEngineReconciler) syncControllerRevision(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error { + currentHash := computeHash(ragEngineObj) + + annotations := ragEngineObj.GetAnnotations() + if annotations == nil { + annotations = make(map[string]string) + } // nil checking. + + revisionNum := int64(1) + + revisions := &appsv1.ControllerRevisionList{} + if err := c.List(ctx, revisions, client.InNamespace(ragEngineObj.Namespace), client.MatchingLabels{RAGEngineNameLabel: ragEngineObj.Name}); err != nil { + return fmt.Errorf("failed to list revisions: %w", err) + } + sort.Slice(revisions.Items, func(i, j int) bool { + return revisions.Items[i].Revision < revisions.Items[j].Revision + }) + + var latestRevision *appsv1.ControllerRevision + + jsonData, err := json.Marshal(ragEngineObj.Spec) + if err != nil { + return fmt.Errorf("failed to marshal selected fields: %w", err) + } + + if len(revisions.Items) > 0 { + latestRevision = &revisions.Items[len(revisions.Items)-1] + + revisionNum = latestRevision.Revision + 1 + } + newRevision := &appsv1.ControllerRevision{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("%s-%s", ragEngineObj.Name, currentHash[:revisionHashSuffix]), + Namespace: ragEngineObj.Namespace, + Annotations: map[string]string{ + RAGEngineHashAnnotation: currentHash, + }, + Labels: map[string]string{ + RAGEngineNameLabel: ragEngineObj.Name, + }, + OwnerReferences: []metav1.OwnerReference{ + *metav1.NewControllerRef(ragEngineObj, kaitov1alpha1.GroupVersion.WithKind("RAGEngine")), + }, + }, + Revision: revisionNum, + Data: runtime.RawExtension{Raw: jsonData}, + } + + annotations[RAGEngineHashAnnotation] = currentHash + ragEngineObj.SetAnnotations(annotations) + controllerRevision := &appsv1.ControllerRevision{} + if err := c.Get(ctx, types.NamespacedName{ + Name: newRevision.Name, + Namespace: newRevision.Namespace, + }, controllerRevision); err != nil { + if errors.IsNotFound(err) { + + if err := c.Create(ctx, newRevision); err != nil { + return fmt.Errorf("failed to create new ControllerRevision: %w", err) + } else { + annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] = strconv.FormatInt(revisionNum, 10) + } + + if len(revisions.Items) > consts.MaxRevisionHistoryLimit { + if err := c.Delete(ctx, &revisions.Items[0]); err != nil { + return fmt.Errorf("failed to delete old revision: %w", err) + } + } + } else { + return fmt.Errorf("failed to get controller revision: %w", err) + } + } else { + if controllerRevision.Annotations[RAGEngineHashAnnotation] != newRevision.Annotations[RAGEngineHashAnnotation] { + return fmt.Errorf("revision name conflicts, the hash values are different") + } + annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] = strconv.FormatInt(controllerRevision.Revision, 10) + } + annotations[RAGEngineHashAnnotation] = currentHash + ragEngineObj.SetAnnotations(annotations) + + if err := c.Update(ctx, ragEngineObj); err != nil { + return fmt.Errorf("failed to update RAGEngine annotations: %w", err) + } + return nil +} + +func computeHash(ragEngineObj *kaitov1alpha1.RAGEngine) string { + hasher := sha256.New() + encoder := json.NewEncoder(hasher) + encoder.Encode(ragEngineObj.Spec) + return hex.EncodeToString(hasher.Sum(nil)) +} + // applyRAGEngineResource applies RAGEngine resource spec. func (c *RAGEngineReconciler) applyRAGEngineResource(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error { diff --git a/pkg/ragengine/controllers/ragengine_controller_test.go b/pkg/ragengine/controllers/ragengine_controller_test.go index 4c34c994a..6c69b1d5e 100644 --- a/pkg/ragengine/controllers/ragengine_controller_test.go +++ b/pkg/ragengine/controllers/ragengine_controller_test.go @@ -5,7 +5,9 @@ package controllers import ( "context" + "encoding/json" "errors" + "fmt" "os" "testing" "time" @@ -19,8 +21,11 @@ import ( "github.com/kaito-project/kaito/pkg/utils/test" "github.com/stretchr/testify/mock" "gotest.tools/assert" + appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "knative.dev/pkg/apis" "sigs.k8s.io/controller-runtime/pkg/client" @@ -362,3 +367,144 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) { }) } } + +func TestUpdateControllerRevision1(t *testing.T) { + testcases := map[string]struct { + callMocks func(c *test.MockClient) + ragengine v1alpha1.RAGEngine + expectedError error + verifyCalls func(c *test.MockClient) + }{ + + "No new revision needed": { + callMocks: func(c *test.MockClient) { + c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything, mock.Anything).Return(nil) + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevision{}), mock.Anything). + Run(func(args mock.Arguments) { + dep := args.Get(2).(*appsv1.ControllerRevision) + *dep = appsv1.ControllerRevision{ + ObjectMeta: v1.ObjectMeta{ + Annotations: map[string]string{ + RAGEngineHashAnnotation: "7985249e078eb041e38c10c3637032b2d352616c609be8542a779460d3ff1d67", + }, + }, + } + }). + Return(nil) + c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything). + Return(nil) + }, + ragengine: test.MockRAGEngineWithComputeHash, + expectedError: nil, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 1) + c.AssertNumberOfCalls(t, "Create", 0) + c.AssertNumberOfCalls(t, "Get", 1) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 1) + }, + }, + + "Fail to create ControllerRevision": { + callMocks: func(c *test.MockClient) { + c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything, mock.Anything).Return(nil) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(errors.New("failed to create ControllerRevision")) + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevision{}), mock.Anything). + Return(apierrors.NewNotFound(appsv1.Resource("ControllerRevision"), test.MockRAGEngineFailToCreateCR.Name)) + c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything). + Return(nil) + }, + ragengine: test.MockRAGEngineFailToCreateCR, + expectedError: errors.New("failed to create new ControllerRevision: failed to create ControllerRevision"), + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 1) + c.AssertNumberOfCalls(t, "Create", 1) + c.AssertNumberOfCalls(t, "Get", 1) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 0) + }, + }, + "Successfully create new ControllerRevision": { + callMocks: func(c *test.MockClient) { + c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything, mock.Anything).Return(nil) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil) + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevision{}), mock.Anything). + Return(apierrors.NewNotFound(appsv1.Resource("ControllerRevision"), test.MockRAGEngineFailToCreateCR.Name)) + c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything). + Return(nil) + }, + ragengine: test.MockRAGEngineSuccessful, + expectedError: nil, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 1) + c.AssertNumberOfCalls(t, "Create", 1) + c.AssertNumberOfCalls(t, "Get", 1) + c.AssertNumberOfCalls(t, "Delete", 0) + c.AssertNumberOfCalls(t, "Update", 1) + }, + }, + + "Successfully delete old ControllerRevision": { + callMocks: func(c *test.MockClient) { + revisions := &appsv1.ControllerRevisionList{} + jsonData, _ := json.Marshal(test.MockRAGEngineWithUpdatedDeployment) + + for i := 0; i <= consts.MaxRevisionHistoryLimit; i++ { + revision := &appsv1.ControllerRevision{ + ObjectMeta: v1.ObjectMeta{ + Name: fmt.Sprintf("revision-%d", i), + }, + Revision: int64(i), + Data: runtime.RawExtension{Raw: jsonData}, + } + revisions.Items = append(revisions.Items, *revision) + } + relevantMap := c.CreateMapWithType(revisions) + + for _, obj := range revisions.Items { + m := obj + objKey := client.ObjectKeyFromObject(&m) + relevantMap[objKey] = &m + } + c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything, mock.Anything).Return(nil) + c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil) + c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevision{}), mock.Anything). + Return(apierrors.NewNotFound(appsv1.Resource("ControllerRevision"), test.MockRAGEngineFailToCreateCR.Name)) + c.On("Delete", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil) + c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything). + Return(nil) + }, + ragengine: test.MockRAGEngineWithDeleteOldCR, + expectedError: nil, + verifyCalls: func(c *test.MockClient) { + c.AssertNumberOfCalls(t, "List", 1) + c.AssertNumberOfCalls(t, "Create", 1) + c.AssertNumberOfCalls(t, "Get", 1) + c.AssertNumberOfCalls(t, "Delete", 1) + c.AssertNumberOfCalls(t, "Update", 1) + }, + }, + } + for k, tc := range testcases { + t.Run(k, func(t *testing.T) { + mockClient := test.NewClient() + tc.callMocks(mockClient) + + reconciler := &RAGEngineReconciler{ + Client: mockClient, + Scheme: test.NewTestScheme(), + } + ctx := context.Background() + + err := reconciler.syncControllerRevision(ctx, &tc.ragengine) + if tc.expectedError == nil { + assert.Check(t, err == nil, "Not expected to return error") + } else { + assert.Equal(t, tc.expectedError.Error(), err.Error()) + } + if tc.verifyCalls != nil { + tc.verifyCalls(mockClient) + } + }) + } +} diff --git a/pkg/utils/test/testUtils.go b/pkg/utils/test/testUtils.go index 89fdc96bf..d51c3958e 100644 --- a/pkg/utils/test/testUtils.go +++ b/pkg/utils/test/testUtils.go @@ -148,6 +148,35 @@ var ( } ) +var ( + MockRAGEngineWithDeleteOldCR = v1alpha1.RAGEngine{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine", + Namespace: "kaito", + Annotations: map[string]string{ + "workspace.kaito.io/hash": "14485768c1b67a529a71e3c87d9f2e6c1ed747534dea07e268e93475a5e21e", + "workspace.kaito.io/revision": "1", + }, + }, + Spec: &v1alpha1.RAGEngineSpec{ + Compute: &v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "ragengine.kaito.io/name": "testRAGEngine", + }, + }, + }, + Embedding: &v1alpha1.EmbeddingSpec{ + Local: &v1alpha1.LocalEmbeddingSpec{ + ModelID: "BAAI/bge-small-en-v1.5", + }, + }, + }, + } +) + var ( MockWorkspaceFailToCreateCR = v1alpha1.Workspace{ ObjectMeta: metav1.ObjectMeta{ @@ -176,6 +205,28 @@ var ( } ) +var ( + MockRAGEngineFailToCreateCR = v1alpha1.RAGEngine{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine-failedtocreateCR", + Namespace: "kaito", + Annotations: map[string]string{ + "ragengine.kaito.io/revision": "1", + }, + }, + Spec: &v1alpha1.RAGEngineSpec{ + Compute: &v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "ragengine.kaito.io/name": "testRAGEngine", + }, + }, + }}, + } +) + var ( MockWorkspaceSuccessful = v1alpha1.Workspace{ ObjectMeta: metav1.ObjectMeta{ @@ -204,6 +255,28 @@ var ( } ) +var ( + MockRAGEngineSuccessful = v1alpha1.RAGEngine{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine-successful", + Namespace: "kaito", + Annotations: map[string]string{ + "ragengine.kaito.io/revision": "0", + }, + }, + Spec: &v1alpha1.RAGEngineSpec{ + Compute: &v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "ragengine.kaito.io/name": "testRAGEngine", + }, + }, + }}, + } +) + var ( MockWorkspaceWithComputeHash = v1alpha1.Workspace{ ObjectMeta: metav1.ObjectMeta{ @@ -233,6 +306,29 @@ var ( } ) +var ( + MockRAGEngineWithComputeHash = v1alpha1.RAGEngine{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine", + Namespace: "kaito", + Annotations: map[string]string{ + "ragengine.kaito.io/hash": "7985249e078eb041e38c10c3637032b2d352616c609be8542a779460d3ff1d67", + "ragengine.kaito.io/revision": "1", + }, + }, + Spec: &v1alpha1.RAGEngineSpec{ + Compute: &v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "ragengine.kaito.io/name": "testRAGEngine", + }, + }, + }}, + } +) + var ( MockWorkspaceUpdateCR = v1alpha1.Workspace{ ObjectMeta: metav1.ObjectMeta{ @@ -300,6 +396,29 @@ var ( } ) +var ( + MockRAGEngineWithUpdatedDeployment = v1alpha1.RAGEngine{ + ObjectMeta: metav1.ObjectMeta{ + Name: "testRAGEngine", + Namespace: "kaito", + Annotations: map[string]string{ + "ragengine.kaito.io/hash": "7985249e078eb041e38c10c3637032b2d352616c609be8542a779460d3ff1d67", + "ragengine.kaito.io/revision": "1", + }, + }, + Spec: &v1alpha1.RAGEngineSpec{ + Compute: &v1alpha1.ResourceSpec{ + Count: &gpuNodeCount, + InstanceType: "Standard_NC12s_v3", + LabelSelector: &metav1.LabelSelector{ + MatchLabels: map[string]string{ + "ragengine.kaito.io/name": "testRAGEngine", + }, + }, + }}, + } +) + var ( numRep = int32(1) MockDeploymentUpdated = appsv1.Deployment{ diff --git a/pkg/workspace/controllers/workspace_controller.go b/pkg/workspace/controllers/workspace_controller.go index 1ad5bc93a..bc65a3d63 100644 --- a/pkg/workspace/controllers/workspace_controller.go +++ b/pkg/workspace/controllers/workspace_controller.go @@ -53,10 +53,9 @@ import ( ) const ( - nodePluginInstallTimeout = 60 * time.Second - WorkspaceHashAnnotation = "workspace.kaito.io/hash" - WorkspaceNameLabel = "workspace.kaito.io/name" - revisionHashSuffix = 5 + WorkspaceHashAnnotation = "workspace.kaito.io/hash" + WorkspaceNameLabel = "workspace.kaito.io/name" + revisionHashSuffix = 5 ) type WorkspaceReconciler struct {