Skip to content

Commit

Permalink
feat: RAG engine controller revision (#682)
Browse files Browse the repository at this point in the history
**Reason for Change**:
RAG engine controller revision

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

Signed-off-by: Bangqi Zhu <bangqizhu@microsoft.com>
Co-authored-by: Bangqi Zhu <bangqizhu@microsoft.com>
  • Loading branch information
bangqipropel and Bangqi Zhu authored Nov 8, 2024
1 parent 79494a2 commit 2ecfdf1
Show file tree
Hide file tree
Showing 5 changed files with 381 additions and 4 deletions.
3 changes: 3 additions & 0 deletions api/v1alpha1/labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,7 @@ const (

// WorkspaceRevisionAnnotation is the Annotations for revision number
WorkspaceRevisionAnnotation = "workspace.kaito.io/revision"

// RAGEngineRevisionAnnotation is the Annotations for revision number
RAGEngineRevisionAnnotation = "ragengine.kaito.io/revision"
)
110 changes: 110 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@ package controllers

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"sort"
"strconv"
"strings"
"time"

Expand All @@ -19,12 +24,14 @@ import (
"github.com/kaito-project/kaito/pkg/utils"
"github.com/kaito-project/kaito/pkg/utils/consts"
"github.com/samber/lo"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/errors"
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/sets"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
Expand All @@ -39,6 +46,12 @@ import (
"sigs.k8s.io/karpenter/pkg/apis/v1beta1"
)

const (
RAGEngineHashAnnotation = "ragengine.kaito.io/hash"
RAGEngineNameLabel = "ragengine.kaito.io/name"
revisionHashSuffix = 5
)

type RAGEngineReconciler struct {
client.Client
Log logr.Logger
Expand Down Expand Up @@ -75,6 +88,10 @@ func (c *RAGEngineReconciler) Reconcile(ctx context.Context, req reconcile.Reque
return c.deleteRAGEngine(ctx, ragEngineObj)
}

if err := c.syncControllerRevision(ctx, ragEngineObj); err != nil {
return reconcile.Result{}, err
}

result, err := c.addRAGEngine(ctx, ragEngineObj)
if err != nil {
return result, err
Expand Down Expand Up @@ -114,6 +131,99 @@ func (c *RAGEngineReconciler) deleteRAGEngine(ctx context.Context, ragEngineObj
return c.garbageCollectRAGEngine(ctx, ragEngineObj)
}

func (c *RAGEngineReconciler) syncControllerRevision(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error {
currentHash := computeHash(ragEngineObj)

annotations := ragEngineObj.GetAnnotations()
if annotations == nil {
annotations = make(map[string]string)
} // nil checking.

revisionNum := int64(1)

revisions := &appsv1.ControllerRevisionList{}
if err := c.List(ctx, revisions, client.InNamespace(ragEngineObj.Namespace), client.MatchingLabels{RAGEngineNameLabel: ragEngineObj.Name}); err != nil {
return fmt.Errorf("failed to list revisions: %w", err)
}
sort.Slice(revisions.Items, func(i, j int) bool {
return revisions.Items[i].Revision < revisions.Items[j].Revision
})

var latestRevision *appsv1.ControllerRevision

jsonData, err := json.Marshal(ragEngineObj.Spec)
if err != nil {
return fmt.Errorf("failed to marshal selected fields: %w", err)
}

if len(revisions.Items) > 0 {
latestRevision = &revisions.Items[len(revisions.Items)-1]

revisionNum = latestRevision.Revision + 1
}
newRevision := &appsv1.ControllerRevision{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf("%s-%s", ragEngineObj.Name, currentHash[:revisionHashSuffix]),
Namespace: ragEngineObj.Namespace,
Annotations: map[string]string{
RAGEngineHashAnnotation: currentHash,
},
Labels: map[string]string{
RAGEngineNameLabel: ragEngineObj.Name,
},
OwnerReferences: []metav1.OwnerReference{
*metav1.NewControllerRef(ragEngineObj, kaitov1alpha1.GroupVersion.WithKind("RAGEngine")),
},
},
Revision: revisionNum,
Data: runtime.RawExtension{Raw: jsonData},
}

annotations[RAGEngineHashAnnotation] = currentHash
ragEngineObj.SetAnnotations(annotations)
controllerRevision := &appsv1.ControllerRevision{}
if err := c.Get(ctx, types.NamespacedName{
Name: newRevision.Name,
Namespace: newRevision.Namespace,
}, controllerRevision); err != nil {
if errors.IsNotFound(err) {

if err := c.Create(ctx, newRevision); err != nil {
return fmt.Errorf("failed to create new ControllerRevision: %w", err)
} else {
annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] = strconv.FormatInt(revisionNum, 10)
}

if len(revisions.Items) > consts.MaxRevisionHistoryLimit {
if err := c.Delete(ctx, &revisions.Items[0]); err != nil {
return fmt.Errorf("failed to delete old revision: %w", err)
}
}
} else {
return fmt.Errorf("failed to get controller revision: %w", err)
}
} else {
if controllerRevision.Annotations[RAGEngineHashAnnotation] != newRevision.Annotations[RAGEngineHashAnnotation] {
return fmt.Errorf("revision name conflicts, the hash values are different")
}
annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] = strconv.FormatInt(controllerRevision.Revision, 10)
}
annotations[RAGEngineHashAnnotation] = currentHash
ragEngineObj.SetAnnotations(annotations)

if err := c.Update(ctx, ragEngineObj); err != nil {
return fmt.Errorf("failed to update RAGEngine annotations: %w", err)
}
return nil
}

func computeHash(ragEngineObj *kaitov1alpha1.RAGEngine) string {
hasher := sha256.New()
encoder := json.NewEncoder(hasher)
encoder.Encode(ragEngineObj.Spec)
return hex.EncodeToString(hasher.Sum(nil))
}

// applyRAGEngineResource applies RAGEngine resource spec.
func (c *RAGEngineReconciler) applyRAGEngineResource(ctx context.Context, ragEngineObj *kaitov1alpha1.RAGEngine) error {

Expand Down
146 changes: 146 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ package controllers

import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
"testing"
"time"
Expand All @@ -19,8 +21,11 @@ import (
"github.com/kaito-project/kaito/pkg/utils/test"
"github.com/stretchr/testify/mock"
"gotest.tools/assert"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"knative.dev/pkg/apis"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -362,3 +367,144 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) {
})
}
}

func TestUpdateControllerRevision1(t *testing.T) {
testcases := map[string]struct {
callMocks func(c *test.MockClient)
ragengine v1alpha1.RAGEngine
expectedError error
verifyCalls func(c *test.MockClient)
}{

"No new revision needed": {
callMocks: func(c *test.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything, mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.ControllerRevision)
*dep = appsv1.ControllerRevision{
ObjectMeta: v1.ObjectMeta{
Annotations: map[string]string{
RAGEngineHashAnnotation: "7985249e078eb041e38c10c3637032b2d352616c609be8542a779460d3ff1d67",
},
},
}
}).
Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).
Return(nil)
},
ragengine: test.MockRAGEngineWithComputeHash,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Get", 1)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 1)
},
},

"Fail to create ControllerRevision": {
callMocks: func(c *test.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything, mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(errors.New("failed to create ControllerRevision"))
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).
Return(apierrors.NewNotFound(appsv1.Resource("ControllerRevision"), test.MockRAGEngineFailToCreateCR.Name))
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).
Return(nil)
},
ragengine: test.MockRAGEngineFailToCreateCR,
expectedError: errors.New("failed to create new ControllerRevision: failed to create ControllerRevision"),
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Get", 1)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
},
"Successfully create new ControllerRevision": {
callMocks: func(c *test.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything, mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).
Return(apierrors.NewNotFound(appsv1.Resource("ControllerRevision"), test.MockRAGEngineFailToCreateCR.Name))
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).
Return(nil)
},
ragengine: test.MockRAGEngineSuccessful,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Get", 1)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 1)
},
},

"Successfully delete old ControllerRevision": {
callMocks: func(c *test.MockClient) {
revisions := &appsv1.ControllerRevisionList{}
jsonData, _ := json.Marshal(test.MockRAGEngineWithUpdatedDeployment)

for i := 0; i <= consts.MaxRevisionHistoryLimit; i++ {
revision := &appsv1.ControllerRevision{
ObjectMeta: v1.ObjectMeta{
Name: fmt.Sprintf("revision-%d", i),
},
Revision: int64(i),
Data: runtime.RawExtension{Raw: jsonData},
}
revisions.Items = append(revisions.Items, *revision)
}
relevantMap := c.CreateMapWithType(revisions)

for _, obj := range revisions.Items {
m := obj
objKey := client.ObjectKeyFromObject(&m)
relevantMap[objKey] = &m
}
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything, mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).
Return(apierrors.NewNotFound(appsv1.Resource("ControllerRevision"), test.MockRAGEngineFailToCreateCR.Name))
c.On("Delete", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).
Return(nil)
},
ragengine: test.MockRAGEngineWithDeleteOldCR,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Get", 1)
c.AssertNumberOfCalls(t, "Delete", 1)
c.AssertNumberOfCalls(t, "Update", 1)
},
},
}
for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := test.NewClient()
tc.callMocks(mockClient)

reconciler := &RAGEngineReconciler{
Client: mockClient,
Scheme: test.NewTestScheme(),
}
ctx := context.Background()

err := reconciler.syncControllerRevision(ctx, &tc.ragengine)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
if tc.verifyCalls != nil {
tc.verifyCalls(mockClient)
}
})
}
}
Loading

0 comments on commit 2ecfdf1

Please sign in to comment.