Skip to content

Commit

Permalink
feat: Update controllerRevision and deployment lifecycle (#559)
Browse files Browse the repository at this point in the history
**Reason for Change**:
1. Change "workspace.kaito.io/revision" to revision number, and make the
hash result as "workspace.kaito.io/hash"
2. Sync contropllerrevision and workspace at the beginning of the
Reconcile
3. Update the tests for the changes

**Requirements**

- [ ] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

Signed-off-by: Bangqi Zhu <bangqizhu@microsoft.com>
Co-authored-by: Bangqi Zhu <bangqizhu@microsoft.com>
  • Loading branch information
bangqipropel and Bangqi Zhu authored Aug 23, 2024
1 parent 8e8581e commit 1afb924
Show file tree
Hide file tree
Showing 8 changed files with 287 additions and 245 deletions.
3 changes: 3 additions & 0 deletions api/v1alpha1/workspace_labels.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@ const (

// LabelWorkspaceName is the label for workspace namespace.
LabelWorkspaceNamespace = KAITOPrefix + "workspacenamespace"

// WorkspaceRevisionAnnotation is the Annotations for revision number
WorkspaceRevisionAnnotation = "workspace.kaito.io/revision"
)
226 changes: 97 additions & 129 deletions pkg/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
"encoding/hex"
"encoding/json"
"fmt"
"reflect"
"sort"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -52,10 +52,10 @@ import (
)

const (
gpuSkuPrefix = "Standard_N"
nodePluginInstallTimeout = 60 * time.Second
WorkspaceRevisionAnnotation = "workspace.kaito.io/revision"
WorkspaceNameLabel = "workspace.kaito.io/name"
gpuSkuPrefix = "Standard_N"
nodePluginInstallTimeout = 60 * time.Second
WorkspaceHashAnnotation = "workspace.kaito.io/hash"
WorkspaceNameLabel = "workspace.kaito.io/name"
)

type WorkspaceReconciler struct {
Expand Down Expand Up @@ -85,6 +85,10 @@ func (c *WorkspaceReconciler) Reconcile(ctx context.Context, req reconcile.Reque

klog.InfoS("Reconciling", "workspace", req.NamespacedName)

if err := c.syncControllerRevision(ctx, workspaceObj); err != nil {
return reconcile.Result{}, err
}

if err := c.ensureFinalizer(ctx, workspaceObj); err != nil {
return reconcile.Result{}, err
}
Expand All @@ -105,11 +109,6 @@ func (c *WorkspaceReconciler) Reconcile(ctx context.Context, req reconcile.Reque
return result, err
}

if err := c.updateControllerRevision(ctx, workspaceObj); err != nil {
klog.ErrorS(err, "failed to update ControllerRevision", "workspace", klog.KObj(workspaceObj))
return reconcile.Result{}, nil
}

return result, nil
}

Expand Down Expand Up @@ -213,20 +212,14 @@ func (c *WorkspaceReconciler) deleteWorkspace(ctx context.Context, wObj *kaitov1

return c.garbageCollectWorkspace(ctx, wObj)
}

func (c *WorkspaceReconciler) updateControllerRevision(ctx context.Context, wObj *kaitov1alpha1.Workspace) error { // TODO: Move non-updateControllerRevision related logic to separate functions
func (c *WorkspaceReconciler) syncControllerRevision(ctx context.Context, wObj *kaitov1alpha1.Workspace) error {
currentHash := computeHash(wObj)
annotations := wObj.GetAnnotations()
if annotations == nil {
annotations = make(map[string]string)
} // nil checking.

latestHash, exists := annotations[WorkspaceRevisionAnnotation]
if exists && latestHash == currentHash {
return nil
}

jsonData, err := json.Marshal(wObj)
if err != nil {
return fmt.Errorf("failed to marshal revision data: %w", err)
}
revisionNum := int64(1)

revisions := &appsv1.ControllerRevisionList{}
if err := c.List(ctx, revisions, client.InNamespace(wObj.Namespace), client.MatchingLabels{WorkspaceNameLabel: wObj.Name}); err != nil {
Expand All @@ -236,36 +229,28 @@ func (c *WorkspaceReconciler) updateControllerRevision(ctx context.Context, wObj
return revisions.Items[i].Revision < revisions.Items[j].Revision
})

revisionNum := int64(1)
var latestRevision *appsv1.ControllerRevision
var previousWObj *kaitov1alpha1.Workspace

if len(revisions.Items) > 0 {
revisionNum = revisions.Items[len(revisions.Items)-1].Revision + 1
for i := range revisions.Items {
if revisions.Items[i].Annotations[WorkspaceRevisionAnnotation] == latestHash {
latestRevision = &revisions.Items[i]
break
}
}
if latestRevision != nil {
previousWObj = &kaitov1alpha1.Workspace{}
if err := json.Unmarshal(latestRevision.Data.Raw, previousWObj); err != nil {
return fmt.Errorf("failed to unmarshal previous workspace object: %w", err)
}
}
jsonData, err := marshalSelectedFields(wObj)
if err != nil {
return fmt.Errorf("failed to marshal revision data: %w", err)
}

if len(revisions.Items) > 0 {
latestRevision = &revisions.Items[len(revisions.Items)-1]

revisionNum = latestRevision.Revision + 1
}
newRevision := &appsv1.ControllerRevision{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf("%s-%s", wObj.Name, currentHash[:8]),
Name: fmt.Sprintf("%s-%s", wObj.Name, currentHash[:5]),
Namespace: wObj.Namespace,
Annotations: map[string]string{
WorkspaceHashAnnotation: currentHash,
},
Labels: map[string]string{
WorkspaceNameLabel: wObj.Name,
},
Annotations: map[string]string{
WorkspaceRevisionAnnotation: currentHash,
},
OwnerReferences: []metav1.OwnerReference{
*metav1.NewControllerRef(wObj, kaitov1alpha1.GroupVersion.WithKind("Workspace")),
},
Expand All @@ -274,75 +259,57 @@ func (c *WorkspaceReconciler) updateControllerRevision(ctx context.Context, wObj
Data: runtime.RawExtension{Raw: jsonData},
}

if annotations == nil {
annotations = make(map[string]string)
} // nil checking.

annotations[WorkspaceRevisionAnnotation] = currentHash
annotations[WorkspaceHashAnnotation] = currentHash
wObj.SetAnnotations(annotations)
deployment := &appsv1.Deployment{}
if wObj.Inference != nil {
if previousWObj == nil || !compareAdapters(previousWObj.Inference.Adapters, wObj.Inference.Adapters) {
if err := c.Get(ctx, types.NamespacedName{
Name: wObj.Name,
Namespace: wObj.Namespace,
}, deployment); err != nil {
if !errors.IsNotFound(err) {
klog.ErrorS(err, "failed to get deployment", "deplyment", wObj.Name)
}
return client.IgnoreNotFound(err)
}
if deployment.Annotations == nil {
deployment.Annotations = make(map[string]string)
controllerRevision := &appsv1.ControllerRevision{}
if err := c.Get(ctx, types.NamespacedName{
Name: newRevision.Name,
Namespace: newRevision.Namespace,
}, controllerRevision); err != nil {
if errors.IsNotFound(err) {

if err := c.Create(ctx, newRevision); err != nil {
return fmt.Errorf("failed to create new ControllerRevision: %w", err)
} else {
annotations[kaitov1alpha1.WorkspaceRevisionAnnotation] = strconv.FormatInt(revisionNum, 10)
}

if hash, exists := deployment.Annotations[WorkspaceRevisionAnnotation]; !exists || (hash != currentHash) {

var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount
shmVolume, shmVolumeMount := utils.ConfigSHMVolume(*wObj.Resource.Count)
if shmVolume.Name != "" {
volumes = append(volumes, shmVolume)
}
if shmVolumeMount.Name != "" {
volumeMounts = append(volumeMounts, shmVolumeMount)
}

if len(wObj.Inference.Adapters) > 0 {
adapterVolume, adapterVolumeMount := utils.ConfigAdapterVolume()
volumes = append(volumes, adapterVolume)
volumeMounts = append(volumeMounts, adapterVolumeMount)
}

initContainers, envs := resources.GenerateInitContainers(wObj, volumeMounts)
spec := &deployment.Spec
spec.Template.Spec.InitContainers = initContainers
spec.Template.Spec.Containers[0].Env = envs
spec.Template.Spec.Containers[0].VolumeMounts = volumeMounts
deployment.Annotations[WorkspaceRevisionAnnotation] = currentHash
spec.Template.Spec.Volumes = volumes

if err := c.Update(ctx, deployment); err != nil {
return fmt.Errorf("failed to update deployment: %w", err)
if len(revisions.Items) > consts.MaxRevisionHistoryLimit {
if err := c.Delete(ctx, &revisions.Items[0]); err != nil {
return fmt.Errorf("failed to delete old revision: %w", err)
}
}

} else {
return fmt.Errorf("failed to get controller revision: %w", err)
}
} else {
if controllerRevision.Annotations[WorkspaceHashAnnotation] != newRevision.Annotations[WorkspaceHashAnnotation] {
return fmt.Errorf("revision name conflicts, the hash values are different")
}
annotations[kaitov1alpha1.WorkspaceRevisionAnnotation] = strconv.FormatInt(controllerRevision.Revision, 10)
}
annotations[WorkspaceHashAnnotation] = currentHash
wObj.SetAnnotations(annotations)

if err := c.Update(ctx, wObj); err != nil {
return fmt.Errorf("failed to update Workspace annotations: %w", err)
}
return nil
}

if err := c.Create(ctx, newRevision); err != nil {
return fmt.Errorf("failed to create new ControllerRevision: %w", err)
func marshalSelectedFields(wObj *kaitov1alpha1.Workspace) ([]byte, error) {
partialMap := map[string]interface{}{
"resource": wObj.Resource,
"inference": wObj.Inference,
"tuning": wObj.Tuning,
}

if len(revisions.Items) > consts.MaxRevisionHistoryLimit {
if err := c.Delete(ctx, &revisions.Items[0]); err != nil {
return fmt.Errorf("failed to delete old revision: %w", err)
}
jsonData, err := json.Marshal(partialMap)
if err != nil {
return nil, fmt.Errorf("failed to marshal selected fields: %w", err)
}
return nil

return jsonData, nil
}

func computeHash(w *kaitov1alpha1.Workspace) string {
Expand All @@ -354,35 +321,6 @@ func computeHash(w *kaitov1alpha1.Workspace) string {
return hex.EncodeToString(hasher.Sum(nil))
}

func compareAdapters(oldAdapters, newAdapters []kaitov1alpha1.AdapterSpec) bool {
// If both slices are nil or empty, they are equal
if len(oldAdapters) == 0 && len(newAdapters) == 0 {
return true
}

// If only one of the slices is nil or empty, they are not equal
if len(oldAdapters) != len(newAdapters) {
return false
}

oldAdaptersMap := make(map[string]kaitov1alpha1.AdapterSpec)
for _, adapter := range oldAdapters {
key := fmt.Sprintf("%s-%s", adapter.Source.Name, *adapter.Strength)
oldAdaptersMap[key] = adapter
}

for _, adapter := range newAdapters {
key := fmt.Sprintf("%s-%s", adapter.Source.Name, *adapter.Strength)
oldAdapter, found := oldAdaptersMap[key]
if !found || !reflect.DeepEqual(oldAdapter, adapter) {
return false
}
delete(oldAdaptersMap, key)
}

return len(oldAdaptersMap) == 0
}

func (c *WorkspaceReconciler) selectWorkspaceNodes(qualified []*corev1.Node, preferred []string, previous []string, count int) []*corev1.Node {

sort.Slice(qualified, func(i, j int) bool {
Expand Down Expand Up @@ -822,25 +760,55 @@ func (c *WorkspaceReconciler) applyInference(ctx context.Context, wObj *kaitov1a

inferenceParam := model.GetInferenceParameters()

// TODO: we only do create if it does not exist for preset model. Need to document it.

var existingObj client.Object
if model.SupportDistributedInference() {
existingObj = &appsv1.StatefulSet{}
} else {
existingObj = &appsv1.Deployment{}

}

revisionStr := wObj.Annotations[kaitov1alpha1.WorkspaceRevisionAnnotation]
if err = resources.GetResource(ctx, wObj.Name, wObj.Namespace, c.Client, existingObj); err == nil {
klog.InfoS("An inference workload already exists for workspace", "workspace", klog.KObj(wObj))
if !model.SupportDistributedInference() {
deployment := existingObj.(*appsv1.Deployment)
if deployment.Annotations[kaitov1alpha1.WorkspaceRevisionAnnotation] != revisionStr {
var volumes []corev1.Volume
var volumeMounts []corev1.VolumeMount
shmVolume, shmVolumeMount := utils.ConfigSHMVolume(*wObj.Resource.Count)
if shmVolume.Name != "" {
volumes = append(volumes, shmVolume)
}
if shmVolumeMount.Name != "" {
volumeMounts = append(volumeMounts, shmVolumeMount)
}

if len(wObj.Inference.Adapters) > 0 {
adapterVolume, adapterVolumeMount := utils.ConfigAdapterVolume()
volumes = append(volumes, adapterVolume)
volumeMounts = append(volumeMounts, adapterVolumeMount)
}
initContainers, envs := resources.GenerateInitContainers(wObj, volumeMounts)
spec := &deployment.Spec

spec.Template.Spec.InitContainers = initContainers
spec.Template.Spec.Containers[0].Env = envs
spec.Template.Spec.Containers[0].VolumeMounts = volumeMounts
deployment.Annotations[kaitov1alpha1.WorkspaceRevisionAnnotation] = revisionStr
spec.Template.Spec.Volumes = volumes

if err := c.Update(ctx, deployment); err != nil {
return
}
}
}
if err = resources.CheckResourceStatus(existingObj, c.Client, inferenceParam.ReadinessTimeout); err != nil {
return
}
} else if apierrors.IsNotFound(err) {
var workloadObj client.Object
// Need to create a new workload
workloadObj, err = inference.CreatePresetInference(ctx, wObj, inferenceParam, model.SupportDistributedInference(), c.Client)
workloadObj, err = inference.CreatePresetInference(ctx, wObj, revisionStr, inferenceParam, model.SupportDistributedInference(), c.Client)
if err != nil {
return
}
Expand Down
Loading

0 comments on commit 1afb924

Please sign in to comment.