Skip to content

Commit

Permalink
feat: Update deployment object when adapter changes are detected (#540)
Browse files Browse the repository at this point in the history
**Reason for Change**:
1. Update deployment object when adapter changes are detected
2. Include logic for comparing adapters 
3. Add unit test for deployment update

---------

Signed-off-by: Bangqi Zhu <bangqizhu@microsoft.com>
Co-authored-by: Bangqi Zhu <bangqizhu@microsoft.com>
  • Loading branch information
bangqipropel and Bangqi Zhu authored Jul 30, 2024
1 parent b6f13ed commit c54af47
Show file tree
Hide file tree
Showing 4 changed files with 380 additions and 22 deletions.
82 changes: 79 additions & 3 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"reflect"
"sort"
"strings"
"time"

"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"

"github.com/azure/kaito/pkg/featuregates"
"github.com/azure/kaito/pkg/nodeclaim"
"github.com/azure/kaito/pkg/tuning"
Expand All @@ -37,6 +39,7 @@ import (
apierrors "k8s.io/apimachinery/pkg/api/errors"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
"k8s.io/utils/clock"
Expand Down Expand Up @@ -213,8 +216,7 @@ func (c *WorkspaceReconciler) updateControllerRevision(ctx context.Context, wObj
return nil
}

data := map[string]string{"hash": currentHash}
jsonData, err := json.Marshal(data)
jsonData, err := json.Marshal(wObj)
if err != nil {
return fmt.Errorf("failed to marshal revision data: %w", err)
}
Expand All @@ -228,9 +230,23 @@ func (c *WorkspaceReconciler) updateControllerRevision(ctx context.Context, wObj
})

revisionNum := int64(1)
var latestRevision *appsv1.ControllerRevision
var previousWObj *kaitov1alpha1.Workspace

if len(revisions.Items) > 0 {
revisionNum = revisions.Items[len(revisions.Items)-1].Revision + 1
for i := range revisions.Items {
if revisions.Items[i].Annotations[WorkspaceRevisionAnnotation] == latestHash {
latestRevision = &revisions.Items[i]
break
}
}
if latestRevision != nil {
previousWObj = &kaitov1alpha1.Workspace{}
if err := json.Unmarshal(latestRevision.Data.Raw, previousWObj); err != nil {
return fmt.Errorf("failed to unmarshal previous workspace object: %w", err)
}
}
}

newRevision := &appsv1.ControllerRevision{
Expand All @@ -253,6 +269,37 @@ func (c *WorkspaceReconciler) updateControllerRevision(ctx context.Context, wObj

annotations[WorkspaceRevisionAnnotation] = currentHash
wObj.SetAnnotations(annotations)
deployment := &appsv1.Deployment{}
if wObj.Inference != nil {
if previousWObj == nil || !compareAdapters(previousWObj.Inference.Adapters, wObj.Inference.Adapters) {
if err := c.Get(ctx, types.NamespacedName{
Name: wObj.Name,
Namespace: wObj.Namespace,
}, deployment); err != nil {
if !errors.IsNotFound(err) {
klog.ErrorS(err, "failed to get deployment", "deplyment", wObj.Name)
}
return client.IgnoreNotFound(err)
}
if deployment.Annotations == nil {
deployment.Annotations = make(map[string]string)
}

if hash, exists := deployment.Annotations[WorkspaceRevisionAnnotation]; !exists || (hash != currentHash) {

initContainers, envs := resources.GenerateInitContainers(wObj)
spec := &deployment.Spec
spec.Template.Spec.InitContainers = initContainers
spec.Template.Spec.Containers[0].Env = envs
deployment.Annotations[WorkspaceRevisionAnnotation] = currentHash

if err := c.Update(ctx, deployment); err != nil {
return fmt.Errorf("failed to update deployment: %w", err)
}
}

}
}
if err := c.Update(ctx, wObj); err != nil {
return fmt.Errorf("failed to update Workspace annotations: %w", err)
}
Expand All @@ -278,6 +325,35 @@ func computeHash(w *kaitov1alpha1.Workspace) string {
return hex.EncodeToString(hasher.Sum(nil))
}

func compareAdapters(oldAdapters, newAdapters []kaitov1alpha1.AdapterSpec) bool {
// If both slices are nil or empty, they are equal
if len(oldAdapters) == 0 && len(newAdapters) == 0 {
return true
}

// If only one of the slices is nil or empty, they are not equal
if len(oldAdapters) != len(newAdapters) {
return false
}

oldAdaptersMap := make(map[string]kaitov1alpha1.AdapterSpec)
for _, adapter := range oldAdapters {
key := fmt.Sprintf("%s-%s", adapter.Source.Name, *adapter.Strength)
oldAdaptersMap[key] = adapter
}

for _, adapter := range newAdapters {
key := fmt.Sprintf("%s-%s", adapter.Source.Name, *adapter.Strength)
oldAdapter, found := oldAdaptersMap[key]
if !found || !reflect.DeepEqual(oldAdapter, adapter) {
return false
}
delete(oldAdaptersMap, key)
}

return len(oldAdaptersMap) == 0
}

func (c *WorkspaceReconciler) selectWorkspaceNodes(qualified []*corev1.Node, preferred []string, previous []string, count int) []*corev1.Node {

sort.Slice(qualified, func(i, j int) bool {
Expand Down
113 changes: 110 additions & 3 deletions pkg/controllers/workspace_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package controllers

import (
"context"
"encoding/json"
"errors"
"fmt"
"os"
Expand All @@ -17,6 +18,7 @@ import (
"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
awsv1beta1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1"
"github.com/azure/kaito/api/v1alpha1"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/featuregates"
"github.com/azure/kaito/pkg/machine"
"github.com/azure/kaito/pkg/nodeclaim"
Expand All @@ -27,6 +29,7 @@ import (
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"knative.dev/pkg/apis"
"sigs.k8s.io/controller-runtime/pkg/client"
Expand Down Expand Up @@ -981,13 +984,20 @@ func TestUpdateControllerRevision(t *testing.T) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(errors.New("failed to create ControllerRevision"))
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.Deployment)
*dep = test.MockDeploymentWithAnnotationsAndContainer1
}).
Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceFailToCreateCR,
expectedError: errors.New("failed to create new ControllerRevision: failed to create ControllerRevision"),
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 1)
c.AssertNumberOfCalls(t, "Update", 2)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},
Expand All @@ -996,13 +1006,20 @@ func TestUpdateControllerRevision(t *testing.T) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.Deployment)
*dep = test.MockDeploymentWithAnnotationsAndContainer2
}).
Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceSuccessful,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 1)
c.AssertNumberOfCalls(t, "Update", 2)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},
Expand All @@ -1029,17 +1046,66 @@ func TestUpdateControllerRevision(t *testing.T) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.Deployment)
*dep = test.MockDeploymentWithAnnotationsAndContainer2
}).
Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
c.On("Delete", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceWithDeleteOldCR,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 1)
c.AssertNumberOfCalls(t, "Update", 2)
c.AssertNumberOfCalls(t, "Delete", 1)
},
},
"Deployment updated when adapters change": {
callMocks: func(c *test.MockClient) {
revisions := &appsv1.ControllerRevisionList{}
jsonData, _ := json.Marshal(test.MockWorkspaceWithComputeHash)
revision := &appsv1.ControllerRevision{
ObjectMeta: v1.ObjectMeta{
Name: "revision-1",
Annotations: test.MockWorkspaceWithComputeHash.Annotations,
},
Revision: int64(1),
Data: runtime.RawExtension{Raw: jsonData},
}
revisions.Items = append(revisions.Items, *revision)

relevantMap := c.CreateMapWithType(revisions)

for _, obj := range revisions.Items {
m := obj
objKey := client.ObjectKeyFromObject(&m)
relevantMap[objKey] = &m
}

c.CreateOrUpdateObjectInMap(&test.MockDeploymentUpdated)

c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)

c.On("Update", mock.IsType(context.Background()), mock.IsType(&appsv1.Deployment{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceWithUpdatedDeployment,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 2) // one for workspace and one for deployment
c.AssertNumberOfCalls(t, "Get", 1)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},
}
for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
Expand All @@ -1064,3 +1130,44 @@ func TestUpdateControllerRevision(t *testing.T) {
})
}
}

func TestCompareAdapters(t *testing.T) {
testcases := map[string]struct {
oldAdapters []kaitov1alpha1.AdapterSpec
newAdapters []kaitov1alpha1.AdapterSpec
expectedResult bool
}{
"Both slices are empty": {
oldAdapters: []kaitov1alpha1.AdapterSpec{},
newAdapters: []kaitov1alpha1.AdapterSpec{},
expectedResult: true,
},
"One slice is empty": {
oldAdapters: []kaitov1alpha1.AdapterSpec{},
newAdapters: test.Adapters1,
expectedResult: false,
},
"Different lengths": {
oldAdapters: test.Adapters1,
newAdapters: test.Adapters2,
expectedResult: false,
},
"Different contents": {
oldAdapters: test.Adapters2,
newAdapters: test.Adapters4,
expectedResult: false,
},
"Same length and contents": {
oldAdapters: test.Adapters2,
newAdapters: test.Adapters3,
expectedResult: true,
},
}

for name, tc := range testcases {
t.Run(name, func(t *testing.T) {
result := compareAdapters(tc.oldAdapters, tc.newAdapters)
assert.Equal(t, tc.expectedResult, result, "Expected result did not match actual result")
})
}
}
41 changes: 25 additions & 16 deletions pkg/resources/manifests.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,22 +275,7 @@ func GenerateDeploymentManifest(ctx context.Context, workspaceObj *kaitov1alpha1
initContainers := []corev1.Container{}
envs := []corev1.EnvVar{}
if len(workspaceObj.Inference.Adapters) > 0 {
for _, adapter := range workspaceObj.Inference.Adapters {
// TODO: accept Volumes and url link to pull images
initContainer := corev1.Container{
Name: adapter.Source.Name,
Image: adapter.Source.Image,
Command: []string{"/bin/sh", "-c", fmt.Sprintf("mkdir -p /mnt/adapter/%s && cp -r /data/* /mnt/adapter/%s", adapter.Source.Name, adapter.Source.Name)},
VolumeMounts: volumeMount,
ImagePullPolicy: corev1.PullAlways,
}
initContainers = append(initContainers, initContainer)
env := corev1.EnvVar{
Name: adapter.Source.Name,
Value: *adapter.Strength,
}
envs = append(envs, env)
}
initContainers, envs = GenerateInitContainers(workspaceObj)
}

return &appsv1.Deployment{
Expand Down Expand Up @@ -349,6 +334,30 @@ func GenerateDeploymentManifest(ctx context.Context, workspaceObj *kaitov1alpha1
}
}

func GenerateInitContainers(wObj *kaitov1alpha1.Workspace) ([]corev1.Container, []corev1.EnvVar) {
initContainers := []corev1.Container{}
envs := []corev1.EnvVar{}
if len(wObj.Inference.Adapters) > 0 {
for _, adapter := range wObj.Inference.Adapters {
initContainer := corev1.Container{
Name: adapter.Source.Name,
Image: adapter.Source.Image,
Command: []string{"/bin/sh", "-c", fmt.Sprintf("mkdir -p /mnt/adapter/%s && cp -r /data/* /mnt/adapter/%s", adapter.Source.Name, adapter.Source.Name)},
VolumeMounts: []corev1.VolumeMount{},
ImagePullPolicy: corev1.PullAlways,
}
initContainers = append(initContainers, initContainer)
env := corev1.EnvVar{
Name: adapter.Source.Name,
Value: *adapter.Strength,
}
envs = append(envs, env)
}

}
return initContainers, envs
}

func GenerateDeploymentManifestWithPodTemplate(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace, tolerations []corev1.Toleration) *appsv1.Deployment {
nodeRequirements := make([]corev1.NodeSelectorRequirement, 0, len(workspaceObj.Resource.LabelSelector.MatchLabels))
for key, value := range workspaceObj.Resource.LabelSelector.MatchLabels {
Expand Down
Loading

0 comments on commit c54af47

Please sign in to comment.