Skip to content

Commit

Permalink
feat: Update karpenter nodeclass (#570)
Browse files Browse the repository at this point in the history
**Reason for Change**:
- Enable workspace controller to manage karpenter providers objects to
be able to create the NodeClass.
- Remove the node affinity condition that applies on gpu-provisioner and
will not allow Karpenter.
- add scheme registration for Karpenter CRDs
- cleanup the NodeClass code.
- Add unit tests for NodeClass.
- Update the suite parameter to `node_provisioner`

**Requirements**

- [x] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

**Notes for Reviewers**:

---------

Signed-off-by: Heba Elayoty <hebaelayoty@gmail.com>
  • Loading branch information
helayoty authored Aug 20, 2024
1 parent 3db011e commit 9d42673
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 64 deletions.
28 changes: 14 additions & 14 deletions .github/workflows/e2e-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ on:
git_sha:
type: string
required: true
nodeprovisioner:
node_provisioner:
type: string
required: true
tag:
Expand Down Expand Up @@ -39,7 +39,7 @@ on:
jobs:
e2e-tests:
runs-on: ubuntu-latest
name: e2e-tests-${{ inputs.nodeprovisioner }}
name: e2e-tests-${{ inputs.node_provisioner }}
permissions:
contents: read
id-token: write # This is required for requesting the JWT
Expand Down Expand Up @@ -68,8 +68,8 @@ jobs:
fi
echo "VERSION=${rand}" >> $GITHUB_ENV
echo "CLUSTER_NAME=${{ inputs.nodeprovisioner }}${rand}" >> $GITHUB_ENV
echo "REGISTRY=${{ inputs.nodeprovisioner }}${rand}.azurecr.io" >> $GITHUB_ENV
echo "CLUSTER_NAME=${{ inputs.node_provisioner }}${rand}" >> $GITHUB_ENV
echo "REGISTRY=${{ inputs.node_provisioner }}${rand}.azurecr.io" >> $GITHUB_ENV
echo "RUN_LLAMA_13B=false" >> $GITHUB_ENV
- name: Set Registry
Expand Down Expand Up @@ -114,7 +114,7 @@ jobs:
uses: azure/CLI@v2.0.0
with:
inlineScript: |
az identity create --name ${{ inputs.nodeprovisioner }}Identity --resource-group ${{ env.CLUSTER_NAME }}
az identity create --name ${{ inputs.node_provisioner }}Identity --resource-group ${{ env.CLUSTER_NAME }}
- name: Generate APIs
run: |
Expand Down Expand Up @@ -146,7 +146,7 @@ jobs:
- name: create cluster
shell: bash
run: |
if [ "${{ inputs.nodeprovisioner }}" == "gpuprovisioner" ]; then
if [ "${{ inputs.node_provisioner }}" == "gpuprovisioner" ]; then
make create-aks-cluster
else
make create-aks-cluster-for-karpenter
Expand All @@ -165,18 +165,18 @@ jobs:
tenant-id: ${{ secrets.E2E_TENANT_ID }}
subscription-id: ${{ secrets.E2E_SUBSCRIPTION_ID }}

- name: Create Identities and Permissions for ${{ inputs.nodeprovisioner }}
- name: Create Identities and Permissions for ${{ inputs.node_provisioner }}
shell: bash
run: |
make generate-identities
env:
AZURE_RESOURCE_GROUP: ${{ env.CLUSTER_NAME }}
AZURE_CLUSTER_NAME: ${{ env.CLUSTER_NAME }}
TEST_SUITE: ${{ inputs.nodeprovisioner }}
TEST_SUITE: ${{ inputs.node_provisioner }}
AZURE_SUBSCRIPTION_ID: ${{ secrets.E2E_SUBSCRIPTION_ID }}

- name: Install gpu-provisioner helm chart
if: ${{ inputs.nodeprovisioner == 'gpuprovisioner' }}
if: ${{ inputs.node_provisioner == 'gpuprovisioner' }}
shell: bash
run: |
make gpu-provisioner-helm
Expand All @@ -188,7 +188,7 @@ jobs:
GPU_PROVISIONER_VERSION: ${{ vars.GPU_PROVISIONER_VERSION }}

- name: Install karpenter Azure provider helm chart
if: ${{ inputs.nodeprovisioner == 'azkarpenter' }}
if: ${{ inputs.node_provisioner == 'azkarpenter' }}
shell: bash
run: |
make azure-karpenter-helm
Expand Down Expand Up @@ -225,7 +225,7 @@ jobs:
AZURE_CLUSTER_NAME: ${{ env.CLUSTER_NAME }}
REGISTRY: ${{ env.REGISTRY }}
VERSION: ${{ env.VERSION }}
TEST_SUITE: ${{ inputs.nodeprovisioner }}
TEST_SUITE: ${{ inputs.node_provisioner }}

# Retrieve E2E ACR credentials and create Kubernetes secret
- name: Set up E2E ACR Credentials and Secret
Expand Down Expand Up @@ -255,9 +255,9 @@ jobs:
--docker-username=${{ secrets.E2E_ACR_AMRT_USERNAME }} \
--docker-password=${{ secrets.E2E_ACR_AMRT_PASSWORD }}
- name: Log ${{ inputs.nodeprovisioner }}
- name: Log ${{ inputs.node_provisioner }}
run: |
if [ "${{ inputs.nodeprovisioner }}" == "gpuprovisioner" ]; then
if [ "${{ inputs.node_provisioner }}" == "gpuprovisioner" ]; then
kubectl logs -n "${{ env.GPU_PROVISIONER_NAMESPACE }}" -l app.kubernetes.io/name=gpu-provisioner -c controller
else
kubectl logs -n "${{ env.KARPENTER_NAMESPACE }}" -l app.kubernetes.io/name=karpenter -c controller
Expand All @@ -276,7 +276,7 @@ jobs:
REGISTRY: ${{ env.REGISTRY }}
AI_MODELS_REGISTRY: ${{ secrets.E2E_ACR_AMRT_USERNAME }}.azurecr.io
AI_MODELS_REGISTRY_SECRET: ${{ secrets.E2E_AMRT_SECRET_NAME }}
TEST_SUITE: ${{ inputs.nodeprovisioner }}
TEST_SUITE: ${{ inputs.node_provisioner }}
E2E_ACR_REGISTRY: ${{ env.CLUSTER_NAME }}.azurecr.io
E2E_ACR_REGISTRY_SECRET: ${{ env.CLUSTER_NAME }}-acr-secret

Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/kaito-e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
suite: [ gpuprovisioner ]
node-provisioner: [ gpuprovisioner ]
permissions:
contents: read
id-token: write
Expand All @@ -29,7 +29,7 @@ jobs:
with:
git_sha: ${{ github.event.pull_request.head.sha }}
k8s_version: ${{ vars.AKS_K8S_VERSION }}
nodeprovisioner: ${{ matrix.suite }}
node_provisioner: ${{ matrix.node-provisioner }}
secrets:
E2E_CLIENT_ID: ${{ secrets.AZURE_CLIENT_ID }}
E2E_TENANT_ID: ${{ secrets.AZURE_TENANT_ID }}
Expand Down
3 changes: 3 additions & 0 deletions charts/kaito/workspace/templates/clusterrole.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ rules:
- apiGroups: ["karpenter.sh"]
resources: ["machines", "machines/status", "nodeclaims", "nodeclaims/status"]
verbs: ["get","list","watch","create", "delete", "update", "patch"]
- apiGroups: [ "karpenter.azure.com" ]
resources: [ "aksnodeclasses"]
verbs: [ "get","list","watch","create", "delete", "update", "patch" ]
- apiGroups: ["admissionregistration.k8s.io"]
resources: ["validatingwebhookconfigurations"]
verbs: ["get","list","watch"]
Expand Down
4 changes: 0 additions & 4 deletions charts/kaito/workspace/templates/nvidia-device-plugin-ds.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,6 @@ spec:
operator: NotIn
values:
- virtual-kubelet
- key: karpenter.sh/provisioner-name
operator: Exists
- key: kaito.sh/machine-type
operator: Exists
tolerations:
# Allow this pod to be rescheduled while the node is in "critical add-ons only" mode.
# This, along with the annotation above marks this pod as a critical add-on.
Expand Down
55 changes: 23 additions & 32 deletions cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import (
"syscall"
"time"

azurev1alpha2 "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2"
awsv1beta1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1"
"github.com/azure/kaito/pkg/featuregates"
"github.com/azure/kaito/pkg/k8sclient"
"github.com/azure/kaito/pkg/nodeclaim"
"github.com/azure/kaito/pkg/utils/consts"
"sigs.k8s.io/controller-runtime/pkg/client"
metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server"
"sigs.k8s.io/karpenter/pkg/apis/v1beta1"

Expand All @@ -24,7 +25,6 @@ import (
"github.com/azure/kaito/pkg/webhooks"
"k8s.io/klog/v2"
"knative.dev/pkg/injection/sharedmain"
"knative.dev/pkg/signals"
"knative.dev/pkg/webhook"

// Import all Kubernetes client auth plugins (e.g. Azure, GCP, OIDC, etc.)
Expand Down Expand Up @@ -59,10 +59,12 @@ var (

func init() {
utilruntime.Must(clientgoscheme.AddToScheme(scheme))

utilruntime.Must(kaitov1alpha1.AddToScheme(scheme))
utilruntime.Must(v1alpha5.SchemeBuilder.AddToScheme(scheme))
utilruntime.Must(v1beta1.SchemeBuilder.AddToScheme(scheme))
utilruntime.Must(azurev1alpha2.SchemeBuilder.AddToScheme(scheme))
utilruntime.Must(awsv1beta1.SchemeBuilder.AddToScheme(scheme))

//+kubebuilder:scaffold:scheme
klog.InitFlags(nil)
}
Expand All @@ -89,6 +91,8 @@ func main() {

ctrl.SetLogger(zap.New(zap.UseFlagOptions(&opts)))

ctx := withShutdownSignal(context.Background())

mgr, err := ctrl.NewManager(ctrl.GetConfigOrDie(), ctrl.Options{
Scheme: scheme,
Metrics: metricsserver.Options{
Expand Down Expand Up @@ -117,8 +121,12 @@ func main() {
k8sclient.SetGlobalClient(mgr.GetClient())
kClient := k8sclient.GetGlobalClient()

workspaceReconciler := controllers.NewWorkspaceReconciler(k8sclient.GetGlobalClient(),
mgr.GetScheme(), log.Log.WithName("controllers").WithName("Workspace"), mgr.GetEventRecorderFor("KAITO-Workspace-controller"))
workspaceReconciler := controllers.NewWorkspaceReconciler(
kClient,
mgr.GetScheme(),
log.Log.WithName("controllers").WithName("Workspace"),
mgr.GetEventRecorderFor("KAITO-Workspace-controller"),
)

if err = workspaceReconciler.SetupWithManager(mgr); err != nil {
klog.ErrorS(err, "unable to create controller", "controller", "Workspace")
Expand All @@ -142,7 +150,7 @@ func main() {
klog.ErrorS(err, "unable to parse the webhook port number")
exitWithErrorFunc()
}
ctx := webhook.WithOptions(signals.NewContext(), webhook.Options{
ctx := webhook.WithOptions(ctx, webhook.Options{
ServiceName: os.Getenv(WebhookServiceName),
Port: p,
SecretName: "workspace-webhook-cert",
Expand All @@ -153,9 +161,16 @@ func main() {

// wait 2 seconds to allow reconciling webhookconfiguration and service endpoint.
time.Sleep(2 * time.Second)
}

if err = featuregates.ParseAndValidateFeatureGates(featureGates); err != nil {
klog.ErrorS(err, "unable to set `feature-gates` flag")
if err := featuregates.ParseAndValidateFeatureGates(featureGates); err != nil {
klog.ErrorS(err, "unable to set `feature-gates` flag")
exitWithErrorFunc()
}

if featuregates.FeatureGates[consts.FeatureFlagKarpenter] {
err = nodeclaim.CheckNodeClass(ctx, kClient)
if err != nil {
exitWithErrorFunc()
}
}
Expand All @@ -165,20 +180,6 @@ func main() {
klog.ErrorS(err, "problem running manager")
exitWithErrorFunc()
}
ctx := withShutdownSignal(context.Background())

// check if Karpenter NodeClass is available. If not, the controller will create it automatically.
if featuregates.FeatureGates[consts.FeatureFlagKarpenter] {
cloud := GetCloudProviderName()
if !nodeclaim.IsNodeClassAvailable(ctx, cloud, kClient) {
klog.Infof("NodeClass is not available, creating NodeClass")
if err := nodeclaim.CreateKarpenterNodeClass(ctx, kClient); err != nil {
if client.IgnoreAlreadyExists(err) != nil {
exitWithErrorFunc()
}
}
}
}
}

// withShutdownSignal returns a copy of the parent context that will close if
Expand All @@ -196,13 +197,3 @@ func withShutdownSignal(ctx context.Context) context.Context {
}()
return nctx
}

// GetCloudProviderName returns the cloud provider name from the environment variable.
// If the environment variable is not set, the controller will exit with an error.
func GetCloudProviderName() string {
cloudProvider := os.Getenv("CLOUD_PROVIDER")
if cloudProvider == "" {
exitWithErrorFunc()
}
return cloudProvider
}
5 changes: 5 additions & 0 deletions pkg/controllers/workspace_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ func TestCreateAndValidateMachineNode(t *testing.T) {
},
"An Azure nodeClaim is successfully created": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
Expand All @@ -385,6 +386,7 @@ func TestCreateAndValidateMachineNode(t *testing.T) {
},
"An AWS nodeClaim is successfully created": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&awsv1beta1.EC2NodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&awsv1beta1.EC2NodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
Expand All @@ -403,6 +405,7 @@ func TestCreateAndValidateMachineNode(t *testing.T) {
},
"Node is not created because nodeClaim creation fails": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
Expand Down Expand Up @@ -472,6 +475,7 @@ func TestCreateAndValidateNodeClaimNode(t *testing.T) {
}{
"Node is not created because nodeClaim creation fails": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
Expand All @@ -492,6 +496,7 @@ func TestCreateAndValidateNodeClaimNode(t *testing.T) {
},
"A nodeClaim is successfully created": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil)
c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
Expand Down
35 changes: 27 additions & 8 deletions pkg/nodeclaim/nodeclaim.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
"time"

azurev1alpha2 "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2"
awsv1beta1 "github.com/aws/karpenter-provider-aws/pkg/apis/v1beta1"
kaitov1alpha1 "github.com/azure/kaito/api/v1alpha1"
"github.com/azure/kaito/pkg/featuregates"
"github.com/azure/kaito/pkg/utils/consts"
"github.com/samber/lo"
v1 "k8s.io/api/core/v1"
Expand Down Expand Up @@ -190,13 +190,12 @@ func CreateNodeClaim(ctx context.Context, nodeClaimObj *v1beta1.NodeClaim, kubeC
return retry.OnError(retry.DefaultBackoff, func(err error) bool {
return err.Error() != ErrorInstanceTypesUnavailable
}, func() error {
if featuregates.FeatureGates[consts.FeatureFlagKarpenter] {
err := CreateKarpenterNodeClass(ctx, kubeClient)
if err != nil {
return err
}
err := CheckNodeClass(ctx, kubeClient)
if err != nil {
return err
}
err := kubeClient.Create(ctx, nodeClaimObj, &client.CreateOptions{})

err = kubeClient.Create(ctx, nodeClaimObj, &client.CreateOptions{})
if err != nil {
return err
}
Expand All @@ -221,13 +220,16 @@ func CreateNodeClaim(ctx context.Context, nodeClaimObj *v1beta1.NodeClaim, kubeC
// CreateKarpenterNodeClass creates a nodeClass object for Karpenter.
func CreateKarpenterNodeClass(ctx context.Context, kubeClient client.Client) error {
cloudName := os.Getenv("CLOUD_PROVIDER")
klog.InfoS("CreateKarpenterNodeClass", "cloudName", cloudName)

if cloudName == consts.AzureCloudName {
nodeClassObj := GenerateAKSNodeClassManifest(ctx)
return kubeClient.Create(ctx, nodeClassObj, &client.CreateOptions{})
} else { //aws
} else if cloudName == consts.AWSCloudName {
nodeClassObj := GenerateEC2NodeClassManifest(ctx)
return kubeClient.Create(ctx, nodeClassObj, &client.CreateOptions{})
} else {
return errors.New("unsupported cloud provider " + cloudName)
}
}

Expand Down Expand Up @@ -334,3 +336,20 @@ func IsNodeClassAvailable(ctx context.Context, cloudName string, kubeClient clie
klog.Error("unsupported cloud provider ", cloudName)
return false
}

// CheckNodeClass checks if Karpenter NodeClass is available. If not, the controller will create it automatically.
// This is only applicable when Karpenter feature flag is enabled.
func CheckNodeClass(ctx context.Context, kClient client.Client) error {
cloudProvider := os.Getenv("CLOUD_PROVIDER")
if cloudProvider == "" {
return errors.New("CLOUD_PROVIDER environment variable cannot be empty")
}
if !IsNodeClassAvailable(ctx, cloudProvider, kClient) {
klog.Infof("NodeClass is not available, creating NodeClass")
if err := CreateKarpenterNodeClass(ctx, kClient); err != nil && client.IgnoreAlreadyExists(err) != nil {
klog.ErrorS(err, "unable to create NodeClass")
return errors.New("error while creating NodeClass")
}
}
return nil
}
Loading

0 comments on commit 9d42673

Please sign in to comment.