Skip to content

Commit

Permalink
feat: Add controllerrevision for workspaceController (#524)
Browse files Browse the repository at this point in the history
**Reason for Change**:
Add contropllerrevision for workspace controller to mark the revision
Add unit test for contropllerrevision

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

---------

Signed-off-by: Bangqi Zhu <bangqizhu@microsoft.com>
Co-authored-by: Bangqi Zhu <bangqizhu@microsoft.com>
  • Loading branch information
bangqipropel and Bangqi Zhu authored Jul 22, 2024
1 parent 4d5cd2d commit 2ccc93d
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 13 deletions.
3 changes: 3 additions & 0 deletions charts/kaito/workspace/templates/clusterrole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ rules:
- apiGroups: [ "apps" ]
resources: ["deployments" ]
verbs: ["get","list","watch","create", "delete","update", "patch"]
- apiGroups: [ "apps" ]
resources: ["controllerrevisions" ]
verbs: [ "get","list","watch","create", "delete","update", "patch"]
- apiGroups: [ "apps" ]
resources: [ "statefulsets" ]
verbs: [ "get","list","watch","create", "delete","update", "patch" ]
Expand Down
11 changes: 5 additions & 6 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"knative.dev/pkg/injection/sharedmain"
"knative.dev/pkg/signals"
"knative.dev/pkg/webhook"

// Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.)
// to ensure that exec-entrypoint and run can make use of them.
_ "k8s.io/client-go/plugin/pkg/client/auth"
Expand Down Expand Up @@ -116,12 +117,10 @@ func main() {
k8sclient.SetGlobalClient(mgr.GetClient())
kClient := k8sclient.GetGlobalClient()

if err = (&controllers.WorkspaceReconciler{
Client: kClient,
Log: log.Log.WithName("controllers").WithName("Workspace"),
Scheme: mgr.GetScheme(),
Recorder: mgr.GetEventRecorderFor("KAITO-Workspace-controller"),
}).SetupWithManager(mgr); err != nil {
workspaceReconciler := controllers.NewWorkspaceReconciler(k8sclient.GetGlobalClient(),
mgr.GetScheme(), log.Log.WithName("controllers").WithName("Workspace"), mgr.GetEventRecorderFor("KAITO-Workspace-controller"))

if err = workspaceReconciler.SetupWithManager(mgr); err != nil {
klog.ErrorS(err, "unable to create controller", "controller", "Workspace")
exitWithErrorFunc()
}
Expand Down
105 changes: 102 additions & 3 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ package controllers

import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"sigs.k8s.io/controller-runtime/pkg/controller/controllerutil"
"sort"
Expand Down Expand Up @@ -46,8 +49,10 @@ import (
)

const (
gpuSkuPrefix = "Standard_N"
nodePluginInstallTimeout = 60 * time.Second
gpuSkuPrefix = "Standard_N"
nodePluginInstallTimeout = 60 * time.Second
WorkspaceRevisionAnnotation = "workspace.kaito.io/revision"
WorkspaceNameLabel = "workspace.kaito.io/name"
)

type WorkspaceReconciler struct {
Expand All @@ -57,6 +62,15 @@ type WorkspaceReconciler struct {
Recorder record.EventRecorder
}

func NewWorkspaceReconciler(client client.Client, scheme *runtime.Scheme, log logr.Logger, Recorder record.EventRecorder) *WorkspaceReconciler {
return &WorkspaceReconciler{
Client: client,
Scheme: scheme,
Log: log,
Recorder: Recorder,
}
}

func (c *WorkspaceReconciler) Reconcile(ctx context.Context, req reconcile.Request) (reconcile.Result, error) {
workspaceObj := &kaitov1alpha1.Workspace{}
if err := c.Client.Get(ctx, req.NamespacedName, workspaceObj); err != nil {
Expand All @@ -83,7 +97,17 @@ func (c *WorkspaceReconciler) Reconcile(ctx context.Context, req reconcile.Reque
}
}

return c.addOrUpdateWorkspace(ctx, workspaceObj)
result, err := c.addOrUpdateWorkspace(ctx, workspaceObj)
if err != nil {
return result, err
}

if err := c.updateControllerRevision(ctx, workspaceObj); err != nil {
klog.ErrorS(err, "failed to update ControllerRevision", "workspace", klog.KObj(workspaceObj))
return reconcile.Result{}, nil
}

return result, nil
}

func (c *WorkspaceReconciler) ensureFinalizer(ctx context.Context, workspaceObj *kaitov1alpha1.Workspace) error {
Expand Down Expand Up @@ -180,6 +204,80 @@ func (c *WorkspaceReconciler) deleteWorkspace(ctx context.Context, wObj *kaitov1
return c.garbageCollectWorkspace(ctx, wObj)
}

func (c *WorkspaceReconciler) updateControllerRevision(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
currentHash := computeHash(wObj)
annotations := wObj.GetAnnotations()

latestHash, exists := annotations[WorkspaceRevisionAnnotation]
if exists && latestHash == currentHash {
return nil
}

data := map[string]string{"hash": currentHash}
jsonData, err := json.Marshal(data)
if err != nil {
return fmt.Errorf("failed to marshal revision data: %w", err)
}

revisions := &appsv1.ControllerRevisionList{}
if err := c.List(ctx, revisions, client.InNamespace(wObj.Namespace), client.MatchingLabels{WorkspaceNameLabel: wObj.Name}); err != nil {
return fmt.Errorf("failed to list revisions: %w", err)
}
sort.Slice(revisions.Items, func(i, j int) bool {
return revisions.Items[i].Revision < revisions.Items[j].Revision
})

revisionNum := int64(1)

if len(revisions.Items) > 0 {
revisionNum = revisions.Items[len(revisions.Items)-1].Revision + 1
}

newRevision := &appsv1.ControllerRevision{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf("%s-%s", wObj.Name, currentHash[:8]),
Namespace: wObj.Namespace,
Labels: map[string]string{
WorkspaceNameLabel: wObj.Name,
},
Annotations: map[string]string{
WorkspaceRevisionAnnotation: currentHash,
},
OwnerReferences: []metav1.OwnerReference{
*metav1.NewControllerRef(wObj, kaitov1alpha1.GroupVersion.WithKind("Workspace")),
},
},
Revision: revisionNum,
Data: runtime.RawExtension{Raw: jsonData},
}

annotations[WorkspaceRevisionAnnotation] = currentHash
wObj.SetAnnotations(annotations)
if err := c.Update(ctx, wObj); err != nil {
return fmt.Errorf("failed to update Workspace annotations: %w", err)
}

if err := c.Create(ctx, newRevision); err != nil {
return fmt.Errorf("failed to create new ControllerRevision: %w", err)
}

if len(revisions.Items) > consts.MaxRevisionHistoryLimit {
if err := c.Delete(ctx, &revisions.Items[0]); err != nil {
return fmt.Errorf("failed to delete old revision: %w", err)
}
}
return nil
}

func computeHash(w *kaitov1alpha1.Workspace) string {
hasher := sha256.New()
encoder := json.NewEncoder(hasher)
encoder.Encode(w.Resource)
encoder.Encode(w.Inference)
encoder.Encode(w.Tuning)
return hex.EncodeToString(hasher.Sum(nil))
}

func (c *WorkspaceReconciler) selectWorkspaceNodes(qualified []*corev1.Node, preferred []string, previous []string, count int) []*corev1.Node {

sort.Slice(qualified, func(i, j int) bool {
Expand Down Expand Up @@ -679,6 +777,7 @@ func (c *WorkspaceReconciler) SetupWithManager(mgr ctrl.Manager) error {

builder := ctrl.NewControllerManagedBy(mgr).
For(&kaitov1alpha1.Workspace{}).
Owns(&appsv1.ControllerRevision{}).
Owns(&appsv1.Deployment{}).
Owns(&appsv1.StatefulSet{}).
Owns(&batchv1.Job{}).
Expand Down
110 changes: 110 additions & 0 deletions pkg/controllers/workspace_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package controllers
import (
"context"
"errors"
"fmt"
"os"
"reflect"
"sort"
Expand Down Expand Up @@ -954,3 +955,112 @@ func TestApplyWorkspaceResource(t *testing.T) {
})
}
}

func TestUpdateControllerRevision(t *testing.T) {
testcases := map[string]struct {
callMocks func(c *test.MockClient)
workspace v1alpha1.Workspace
expectedError error
verifyCalls func(c *test.MockClient)
}{
"No new revision needed": {
callMocks: func(c *test.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.Anything, mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(errors.New("should not be called"))
},
workspace: test.MockWorkspaceWithComputeHash,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 0)
c.AssertNumberOfCalls(t, "Create", 0)
c.AssertNumberOfCalls(t, "Update", 0)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},
"Fail to create ControllerRevision": {
callMocks: func(c *test.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(errors.New("failed to create ControllerRevision"))
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceFailToCreateCR,
expectedError: errors.New("failed to create new ControllerRevision: failed to create ControllerRevision"),
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 1)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},
"Successfully create new ControllerRevision": {
callMocks: func(c *test.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceSuccessful,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 1)
c.AssertNumberOfCalls(t, "Delete", 0)
},
},
"Successfully delete old ControllerRevision": {
callMocks: func(c *test.MockClient) {
revisions := &appsv1.ControllerRevisionList{}
for i := 0; i <= consts.MaxRevisionHistoryLimit; i++ {
revision := &appsv1.ControllerRevision{
ObjectMeta: v1.ObjectMeta{
Name: fmt.Sprintf("revision-%d", i),
},
Revision: int64(i),
}
revisions.Items = append(revisions.Items, *revision)
}
relevantMap := c.CreateMapWithType(revisions)

for _, obj := range revisions.Items {
m := obj
objKey := client.ObjectKeyFromObject(&m)
relevantMap[objKey] = &m
}

c.On("List", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevisionList{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil)
c.On("Delete", mock.IsType(context.Background()), mock.IsType(&appsv1.ControllerRevision{}), mock.Anything).Return(nil)
},
workspace: test.MockWorkspaceWithDeleteOldCR,
expectedError: nil,
verifyCalls: func(c *test.MockClient) {
c.AssertNumberOfCalls(t, "List", 1)
c.AssertNumberOfCalls(t, "Create", 1)
c.AssertNumberOfCalls(t, "Update", 1)
c.AssertNumberOfCalls(t, "Delete", 1)
},
},
}
for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := test.NewClient()
tc.callMocks(mockClient)

reconciler := &WorkspaceReconciler{
Client: mockClient,
Scheme: test.NewTestScheme(),
}
ctx := context.Background()

err := reconciler.updateControllerRevision(ctx, &tc.workspace)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
if tc.verifyCalls != nil {
tc.verifyCalls(mockClient)
}
})
}
}
1 change: 1 addition & 0 deletions pkg/utils/consts/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ const (
AWSCloudName = "aws"
GPUString = "gpu"
SKUString = "sku"
MaxRevisionHistoryLimit = 10
)
10 changes: 9 additions & 1 deletion pkg/utils/test/mockClient.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
"github.com/stretchr/testify/mock"
appsv1 "k8s.io/api/apps/v1"
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/meta"
"k8s.io/apimachinery/pkg/runtime"
Expand Down Expand Up @@ -83,7 +84,6 @@ func (m *MockClient) Get(ctx context.Context, key types.NamespacedName, obj k8sC
}

func (m *MockClient) List(ctx context.Context, list k8sClient.ObjectList, opts ...k8sClient.ListOption) error {

v := reflect.ValueOf(list).Elem()
newList := m.getObjectListFromMap(list)
v.Set(reflect.ValueOf(newList).Elem())
Expand Down Expand Up @@ -121,6 +121,14 @@ func (m *MockClient) getObjectListFromMap(list k8sClient.ObjectList) k8sClient.O
}
}
return nodeClaimList
case *appsv1.ControllerRevisionList:
controllerRevisionList := &appsv1.ControllerRevisionList{}
for _, obj := range relevantMap {
if m, ok := obj.(*appsv1.ControllerRevision); ok {
controllerRevisionList.Items = append(controllerRevisionList.Items, *m)
}
}
return controllerRevisionList
}
//add additional object lists as needed
return nil
Expand Down
Loading

0 comments on commit 2ccc93d

Please sign in to comment.