Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: RAGEngine update and validation #747

Merged
merged 2 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions api/v1alpha1/ragengine_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ func (w *RAGEngine) Validate(ctx context.Context) (errs *apis.FieldError) {
if base == nil {
klog.InfoS("Validate creation", "ragengine", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
errs = errs.Also(w.validateCreate().ViaField("spec"))
} else {
klog.InfoS("Validate update", "ragengine", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
old := base.(*RAGEngine)
errs = errs.Also(
w.validateCreate().ViaField("spec"),
w.Spec.Compute.validateUpdate(old.Spec.Compute).ViaField("resource"),
)
}
return errs
}
Expand Down
20 changes: 19 additions & 1 deletion pkg/ragengine/controllers/ragengine_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/go-logr/logr"
kaitov1alpha1 "github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/featuregates"
"github.com/kaito-project/kaito/pkg/ragengine/manifests"
"github.com/kaito-project/kaito/pkg/utils"
"github.com/kaito-project/kaito/pkg/utils/consts"
"github.com/kaito-project/kaito/pkg/utils/machine"
Expand Down Expand Up @@ -152,8 +153,22 @@ func (c *RAGEngineReconciler) applyRAG(ctx context.Context, ragEngineObj *kaitov

if err = resources.GetResource(ctx, ragEngineObj.Name, ragEngineObj.Namespace, c.Client, deployment); err == nil {
klog.InfoS("An inference workload already exists for ragengine", "ragengine", klog.KObj(ragEngineObj))
return
if deployment.Annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] != revisionStr {

envs := manifests.RAGSetEnv(ragEngineObj)

spec := &deployment.Spec
// Currently, all CRD changes are only passed through environment variables (env)
spec.Template.Spec.Containers[0].Env = envs
deployment.Annotations[kaitov1alpha1.RAGEngineRevisionAnnotation] = revisionStr

if err := c.Update(ctx, deployment); err != nil {
return
}
}
if err = resources.CheckResourceStatus(deployment, c.Client, time.Duration(10)*time.Minute); err != nil {
return
}
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
Expand Down Expand Up @@ -577,6 +592,9 @@ func (c *RAGEngineReconciler) SetupWithManager(mgr ctrl.Manager) error {
}
builder := ctrl.NewControllerManagedBy(mgr).
For(&kaitov1alpha1.RAGEngine{}).
Owns(&appsv1.ControllerRevision{}).
Owns(&appsv1.Deployment{}).
Watches(&v1alpha5.Machine{}, c.watchMachines()).
WithOptions(controller.Options{MaxConcurrentReconciles: 5})
if featuregates.FeatureGates[consts.FeatureFlagKarpenter] {
builder.Watches(&v1beta1.NodeClaim{}, c.watchNodeClaims()) // watches for nodeClaim with labels indicating ragengine name.
Expand Down
128 changes: 128 additions & 0 deletions pkg/ragengine/controllers/ragengine_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,131 @@ func TestUpdateControllerRevision1(t *testing.T) {
})
}
}

func TestApplyRAG(t *testing.T) {
test.RegisterTestModel()
testcases := map[string]struct {
callMocks func(c *test.MockClient)
ragengine v1alpha1.RAGEngine
expectedError error
verifyCalls func(c *test.MockClient)
}{

"Fail because associated workload with ragengine cannot be retrieved": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(errors.New("Failed to get resource"))
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
},
ragengine: *test.MockRAGEngineWithRevision1,
expectedError: errors.New("Failed to get resource"),
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Get", 5)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},

"Create preset inference because inference workload did not exist": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(test.NotFoundError()).Times(4)
c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil).Run(func(args mock.Arguments) {
depObj := &appsv1.Deployment{}
key := client.ObjectKey{Namespace: "kaito", Name: "testRAGEngine"}
c.GetObjectFromMap(depObj, key)
depObj.Status.ReadyReplicas = 1
c.CreateOrUpdateObjectInMap(depObj)
})
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
},
ragengine: *test.MockRAGEngineWithRevision1,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Get", 7)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
expectedError: nil,
},

"Apply inference from existing workload": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.Deployment)
*dep = test.MockRAGDeploymentUpdated
}).
Return(nil)

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
},
ragengine: *test.MockRAGEngineWithRevision1,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Get", 3)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 0)
},
},

"Update deployment with new configuration": {
callMocks: func(c *test.MockClient) {
// Mocking existing Deployment object
c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.Deployment)
*dep = test.MockRAGDeploymentUpdated
}).
Return(nil)

c.On("Update", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
},
ragengine: *test.MockRAGEngineWithPreset,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Get", 3)
c.AssertNumberOfCalls(t, "Delete", 0)
c.AssertNumberOfCalls(t, "Update", 1)
},
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := test.NewClient()
tc.callMocks(mockClient)

reconciler := &RAGEngineReconciler{
Client: mockClient,
Scheme: test.NewTestScheme(),
}
ctx := context.Background()

err := reconciler.applyRAG(ctx, &tc.ragengine)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
if tc.verifyCalls != nil {
tc.verifyCalls(mockClient)
}
})
}
}
106 changes: 106 additions & 0 deletions pkg/utils/test/testUtils.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,58 @@ var (
},
},
}
MockRAGEngineWithRevision1 = &v1alpha1.RAGEngine{
ObjectMeta: metav1.ObjectMeta{
Name: "testRAGEngine",
Namespace: "kaito",
Annotations: map[string]string{v1alpha1.RAGEngineRevisionAnnotation: "1"},
},
Spec: &v1alpha1.RAGEngineSpec{
Compute: &v1alpha1.ResourceSpec{
Count: &gpuNodeCount,
InstanceType: "Standard_NC12s_v3",
LabelSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"ragengine.kaito.io/name": "testRAGEngine",
},
},
},
Embedding: &v1alpha1.EmbeddingSpec{
Local: &v1alpha1.LocalEmbeddingSpec{
ModelID: "BAAI/bge-small-en-v1.5",
},
},
InferenceService: &v1alpha1.InferenceServiceSpec{
URL: "http://localhost:5000/chat",
},
},
}
MockRAGEngineWithRevision2 = &v1alpha1.RAGEngine{
ObjectMeta: metav1.ObjectMeta{
Name: "testRAGEngine",
Namespace: "kaito",
Annotations: map[string]string{v1alpha1.RAGEngineRevisionAnnotation: "2"},
},
Spec: &v1alpha1.RAGEngineSpec{
Compute: &v1alpha1.ResourceSpec{
Count: &gpuNodeCount,
InstanceType: "Standard_NC12s_v3",
LabelSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"ragengine.kaito.io/name": "testRAGEngine",
},
},
},
Embedding: &v1alpha1.EmbeddingSpec{
Local: &v1alpha1.LocalEmbeddingSpec{
ModelID: "BAAI/bge-small-en-v1.5",
},
},
InferenceService: &v1alpha1.InferenceServiceSpec{
URL: "http://localhost:5000/chat",
},
},
}
)

var MockRAGEngineWithPresetHash = "14485768c1b67a529a71e3c87d9f2e6c1ed747534dea07e268e93475a5e21e"
Expand Down Expand Up @@ -620,6 +672,60 @@ var (
},
}
)
var MockRAGDeploymentUpdated = appsv1.Deployment{
ObjectMeta: metav1.ObjectMeta{
Name: "testRAGEngine",
Namespace: "kaito",
Annotations: map[string]string{v1alpha1.RAGEngineRevisionAnnotation: "1"},
},
Spec: appsv1.DeploymentSpec{
Replicas: &numRep,
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"app": "test-app",
},
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "test-container",
Image: "nginx:latest",
Ports: []corev1.ContainerPort{
{
ContainerPort: 80,
Protocol: corev1.ProtocolTCP,
},
},
Env: []corev1.EnvVar{
{
Name: "ENV_VAR_NAME",
Value: "ENV_VAR_VALUE",
},
},
VolumeMounts: []corev1.VolumeMount{
{
Name: "volume-name",
MountPath: "/mount/path",
},
},
},
},
Volumes: []corev1.Volume{
{
Name: "volume-name",
VolumeSource: corev1.VolumeSource{
EmptyDir: &corev1.EmptyDirVolumeSource{},
},
},
},
},
},
},
Status: appsv1.DeploymentStatus{
ReadyReplicas: 1,
},
}

var (
MockWorkspaceWithInferenceTemplate = &v1alpha1.Workspace{
Expand Down
Loading