From 82451cbe9519e70fd26ba51de32a4234fcf8a6d8 Mon Sep 17 00:00:00 2001 From: jerryzhuang Date: Mon, 30 Dec 2024 13:42:33 +1100 Subject: [PATCH] fix: unstable testing order causing flaky test (#799) fix flaky tests. Signed-off-by: jerryzhuang --- api/v1alpha1/ragengine_validation_test.go | 3 +- api/v1alpha1/workspace_validation_test.go | 8 ++--- ..._workspace_qwen_2.5_coder_7b-instruct.yaml | 4 +-- pkg/ragengine/controllers/preset-rag_test.go | 4 +-- .../controllers/ragengine_controller_test.go | 19 +++++++---- pkg/utils/nodeclaim/nodeclaim_test.go | 32 ++++++++++++------- pkg/utils/resources/resources_test.go | 28 ++++++---------- pkg/utils/test/testUtils.go | 20 ------------ .../controllers/workspace_controller_test.go | 25 +++++++++++---- .../inference/preset-inferences_test.go | 4 +-- pkg/workspace/tuning/preset-tuning_test.go | 14 +++----- 11 files changed, 76 insertions(+), 85 deletions(-) diff --git a/api/v1alpha1/ragengine_validation_test.go b/api/v1alpha1/ragengine_validation_test.go index 37e101c64..611e5b722 100644 --- a/api/v1alpha1/ragengine_validation_test.go +++ b/api/v1alpha1/ragengine_validation_test.go @@ -4,7 +4,6 @@ package v1alpha1 import ( - "os" "strings" "testing" @@ -97,7 +96,7 @@ func TestRAGEngineValidateCreate(t *testing.T) { wantErr: false, }, } - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { err := tt.ragEngine.validateCreate() diff --git a/api/v1alpha1/workspace_validation_test.go b/api/v1alpha1/workspace_validation_test.go index a541ca816..8d73a544e 100644 --- a/api/v1alpha1/workspace_validation_test.go +++ b/api/v1alpha1/workspace_validation_test.go @@ -6,7 +6,6 @@ package v1alpha1 import ( "context" "fmt" - "os" "strings" "testing" @@ -365,7 +364,7 @@ func TestResourceSpecValidateCreate(t *testing.T) { }, } - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { @@ -1012,7 +1011,7 @@ func TestWorkspaceValidateName(t *testing.T) { }, } RegisterValidationTestModels() - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) tests := []struct { name string workspaceName string @@ -1132,8 +1131,7 @@ func TestWorkspaceValidateUpdate(t *testing.T) { func TestTuningSpecValidateCreate(t *testing.T) { RegisterValidationTestModels() // Set ReleaseNamespace Env - os.Setenv(consts.DefaultReleaseNamespaceEnvVar, DefaultReleaseNamespace) - defer os.Unsetenv(consts.DefaultReleaseNamespaceEnvVar) + t.Setenv(consts.DefaultReleaseNamespaceEnvVar, DefaultReleaseNamespace) // Create fake client with default ConfigMap scheme := runtime.NewScheme() diff --git a/examples/inference/kaito_workspace_qwen_2.5_coder_7b-instruct.yaml b/examples/inference/kaito_workspace_qwen_2.5_coder_7b-instruct.yaml index 15586cdae..c01f759a0 100644 --- a/examples/inference/kaito_workspace_qwen_2.5_coder_7b-instruct.yaml +++ b/examples/inference/kaito_workspace_qwen_2.5_coder_7b-instruct.yaml @@ -1,12 +1,12 @@ apiVersion: kaito.sh/v1alpha1 kind: Workspace metadata: - name: workspace-qwen-2.5-coder-7b-instruct + name: workspace-qwen-2-5-coder-7b-instruct resource: instanceType: "Standard_NC24ads_A100_v4" labelSelector: matchLabels: - apps: qwen-2.5-coder + apps: qwen-2-5-coder inference: preset: name: qwen2.5-coder-7b-instruct diff --git a/pkg/ragengine/controllers/preset-rag_test.go b/pkg/ragengine/controllers/preset-rag_test.go index eeee4a500..3d35c2501 100644 --- a/pkg/ragengine/controllers/preset-rag_test.go +++ b/pkg/ragengine/controllers/preset-rag_test.go @@ -4,7 +4,6 @@ package controllers import ( "context" - "os" "strings" "testing" @@ -37,7 +36,8 @@ func TestCreatePresetRAG(t *testing.T) { for k, tc := range testcases { t.Run(k, func(t *testing.T) { - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + mockClient := test.NewClient() tc.callMocks(mockClient) diff --git a/pkg/ragengine/controllers/ragengine_controller_test.go b/pkg/ragengine/controllers/ragengine_controller_test.go index 135056c67..a9d1f1aff 100644 --- a/pkg/ragengine/controllers/ragengine_controller_test.go +++ b/pkg/ragengine/controllers/ragengine_controller_test.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "os" "testing" "time" @@ -270,6 +269,7 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) { test.RegisterTestModel() testcases := map[string]struct { callMocks func(c *test.MockClient) + cloudProvider string objectConditions apis.Conditions ragengine v1alpha1.RAGEngine karpenterFeatureGates bool @@ -281,6 +281,7 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) { c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha5.Machine{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil) }, + cloudProvider: consts.AzureCloudName, objectConditions: apis.Conditions{ { Type: apis.ConditionReady, @@ -297,8 +298,8 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) { c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil) - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) }, + cloudProvider: consts.AzureCloudName, objectConditions: apis.Conditions{ { Type: apis.ConditionReady, @@ -316,8 +317,8 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) { c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil) - os.Setenv("CLOUD_PROVIDER", "aws") }, + cloudProvider: consts.AWSCloudName, objectConditions: apis.Conditions{ { Type: apis.ConditionReady, @@ -348,6 +349,11 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) { } } + if tc.cloudProvider != "" { + t.Setenv("CLOUD_PROVIDER", tc.cloudProvider) + + } + tc.callMocks(mockClient) reconciler := &RAGEngineReconciler{ @@ -550,7 +556,6 @@ func TestApplyRAG(t *testing.T) { c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil) - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) }, ragengine: *test.MockRAGEngineWithRevision1, verifyCalls: func(c *test.MockClient) { @@ -568,7 +573,7 @@ func TestApplyRAG(t *testing.T) { c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything). Run(func(args mock.Arguments) { dep := args.Get(2).(*appsv1.Deployment) - *dep = test.MockRAGDeploymentUpdated + *dep = *test.MockRAGDeploymentUpdated.DeepCopy() }). Return(nil) @@ -592,7 +597,7 @@ func TestApplyRAG(t *testing.T) { c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything). Run(func(args mock.Arguments) { dep := args.Get(2).(*appsv1.Deployment) - *dep = test.MockRAGDeploymentUpdated + *dep = *test.MockRAGDeploymentUpdated.DeepCopy() }). Return(nil) @@ -615,6 +620,8 @@ func TestApplyRAG(t *testing.T) { for k, tc := range testcases { t.Run(k, func(t *testing.T) { + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + mockClient := test.NewClient() tc.callMocks(mockClient) diff --git a/pkg/utils/nodeclaim/nodeclaim_test.go b/pkg/utils/nodeclaim/nodeclaim_test.go index bf387f866..4f3443da7 100644 --- a/pkg/utils/nodeclaim/nodeclaim_test.go +++ b/pkg/utils/nodeclaim/nodeclaim_test.go @@ -5,7 +5,6 @@ package nodeclaim import ( "context" "errors" - "os" "testing" azurev1alpha2 "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2" @@ -77,7 +76,7 @@ func TestCreateNodeClaim(t *testing.T) { mockNodeClaim := &test.MockNodeClaim mockNodeClaim.Status.Conditions = tc.nodeClaimConditions - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) err := CreateNodeClaim(context.Background(), mockNodeClaim, mockClient) if tc.expectedError == nil { assert.Check(t, err == nil, "Not expected to return error") @@ -178,7 +177,8 @@ func TestWaitForPendingNodeClaims(t *testing.T) { func TestGenerateNodeClaimManifest(t *testing.T) { t.Run("Should generate a nodeClaim object from the given workspace when cloud provider set to azure", func(t *testing.T) { mockWorkspace := test.MockWorkspaceWithPreset - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + nodeClaim := GenerateNodeClaimManifest(context.Background(), "0", mockWorkspace) assert.Check(t, nodeClaim != nil, "NodeClaim must not be nil") @@ -196,7 +196,8 @@ func TestGenerateNodeClaimManifest(t *testing.T) { t.Run("Should generate a nodeClaim object from the given workspace when cloud provider set to aws", func(t *testing.T) { mockWorkspace := test.MockWorkspaceWithPreset - os.Setenv("CLOUD_PROVIDER", "aws") + t.Setenv("CLOUD_PROVIDER", consts.AWSCloudName) + nodeClaim := GenerateNodeClaimManifest(context.Background(), "0", mockWorkspace) assert.Check(t, nodeClaim != nil, "NodeClaim must not be nil") @@ -241,7 +242,8 @@ func TestGenerateAKSNodeClassManifest(t *testing.T) { func TestGenerateEC2NodeClassManifest(t *testing.T) { t.Run("Should generate a valid EC2NodeClass object with correct name and annotations", func(t *testing.T) { - os.Setenv("CLUSTER_NAME", "test-cluster") + t.Setenv("CLUSTER_NAME", "test-cluster") + nodeClass := GenerateEC2NodeClassManifest(context.Background()) assert.Check(t, nodeClass != nil, "EC2NodeClass must not be nil") @@ -252,7 +254,8 @@ func TestGenerateEC2NodeClassManifest(t *testing.T) { }) t.Run("Should generate a valid EC2NodeClass object with correct subnet and security group selectors", func(t *testing.T) { - os.Setenv("CLUSTER_NAME", "test-cluster") + t.Setenv("CLUSTER_NAME", "test-cluster") + nodeClass := GenerateEC2NodeClassManifest(context.Background()) assert.Check(t, nodeClass != nil, "EC2NodeClass must not be nil") @@ -261,7 +264,8 @@ func TestGenerateEC2NodeClassManifest(t *testing.T) { }) t.Run("Should handle missing CLUSTER_NAME environment variable", func(t *testing.T) { - os.Unsetenv("CLUSTER_NAME") + t.Setenv("CLUSTER_NAME", "") + nodeClass := GenerateEC2NodeClassManifest(context.Background()) assert.Check(t, nodeClass != nil, "EC2NodeClass must not be nil") @@ -273,7 +277,8 @@ func TestGenerateEC2NodeClassManifest(t *testing.T) { func TestCreateKarpenterNodeClass(t *testing.T) { t.Run("Should create AKSNodeClass when cloud provider is Azure", func(t *testing.T) { - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + mockClient := test.NewClient() mockClient.On("Create", mock.IsType(context.Background()), mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil) @@ -283,8 +288,9 @@ func TestCreateKarpenterNodeClass(t *testing.T) { }) t.Run("Should create EC2NodeClass when cloud provider is AWS", func(t *testing.T) { - os.Setenv("CLOUD_PROVIDER", consts.AWSCloudName) - os.Setenv("CLUSTER_NAME", "test-cluster") + t.Setenv("CLOUD_PROVIDER", consts.AWSCloudName) + t.Setenv("CLUSTER_NAME", "test-cluster") + mockClient := test.NewClient() mockClient.On("Create", mock.IsType(context.Background()), mock.IsType(&awsv1beta1.EC2NodeClass{}), mock.Anything).Return(nil) @@ -294,7 +300,8 @@ func TestCreateKarpenterNodeClass(t *testing.T) { }) t.Run("Should return error when cloud provider is unsupported", func(t *testing.T) { - os.Setenv("CLOUD_PROVIDER", "unsupported") + t.Setenv("CLOUD_PROVIDER", "unsupported") + mockClient := test.NewClient() err := CreateKarpenterNodeClass(context.Background(), mockClient) @@ -302,7 +309,8 @@ func TestCreateKarpenterNodeClass(t *testing.T) { }) t.Run("Should return error when Create call fails", func(t *testing.T) { - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + mockClient := test.NewClient() mockClient.On("Create", mock.IsType(context.Background()), mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(errors.New("create failed")) diff --git a/pkg/utils/resources/resources_test.go b/pkg/utils/resources/resources_test.go index 5de68a894..67e73bd74 100644 --- a/pkg/utils/resources/resources_test.go +++ b/pkg/utils/resources/resources_test.go @@ -5,7 +5,6 @@ package resources import ( "context" "errors" - "os" "testing" "time" @@ -275,15 +274,13 @@ func TestEnsureInferenceConfigMap(t *testing.T) { } testcases := map[string]struct { - setupEnv func() - callMocks func(c *test.MockClient) - userProvided client.ObjectKey - expectedError string + callMocks func(c *test.MockClient) + releaseNamespace string + userProvided client.ObjectKey + expectedError string }{ "Config already exists in workspace namespace": { - setupEnv: func() { - os.Setenv(consts.DefaultReleaseNamespaceEnvVar, "release-namespace") - }, + releaseNamespace: "release-namespace", callMocks: func(c *test.MockClient) { c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.ConfigMap{}), mock.Anything).Return(nil) }, @@ -303,9 +300,7 @@ func TestEnsureInferenceConfigMap(t *testing.T) { expectedError: "failed to get release namespace: failed to determine release namespace from file /var/run/secrets/kubernetes.io/serviceaccount/namespace and env var RELEASE_NAMESPACE", }, "Config doesn't exist in namespace": { - setupEnv: func() { - os.Setenv(consts.DefaultReleaseNamespaceEnvVar, "release-namespace") - }, + releaseNamespace: "release-namespace", callMocks: func(c *test.MockClient) { c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.ConfigMap{}), mock.Anything).Return(apierrors.NewNotFound(schema.GroupResource{}, "inference-config-template")) }, @@ -316,9 +311,7 @@ func TestEnsureInferenceConfigMap(t *testing.T) { expectedError: "user specified ConfigMap inference-config-template not found in namespace workspace-namespace", }, "Generate default config": { - setupEnv: func() { - os.Setenv(consts.DefaultReleaseNamespaceEnvVar, "release-namespace") - }, + releaseNamespace: "release-namespace", callMocks: func(c *test.MockClient) { c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.ConfigMap{}), mock.Anything). Return(apierrors.NewNotFound(schema.GroupResource{}, "inference-params-template")).Times(4) @@ -342,11 +335,8 @@ func TestEnsureInferenceConfigMap(t *testing.T) { for name, tc := range testcases { t.Run(name, func(t *testing.T) { - cleanupEnv := test.SaveEnv(consts.DefaultReleaseNamespaceEnvVar) - defer cleanupEnv() - - if tc.setupEnv != nil { - tc.setupEnv() + if tc.releaseNamespace != "" { + t.Setenv(consts.DefaultReleaseNamespaceEnvVar, tc.releaseNamespace) } mockClient := test.NewClient() diff --git a/pkg/utils/test/testUtils.go b/pkg/utils/test/testUtils.go index 980acca91..5ad728345 100644 --- a/pkg/utils/test/testUtils.go +++ b/pkg/utils/test/testUtils.go @@ -4,8 +4,6 @@ package test import ( - "os" - "github.com/aws/karpenter-core/pkg/apis/v1alpha5" "github.com/kaito-project/kaito/api/v1alpha1" "github.com/kaito-project/kaito/pkg/model" @@ -950,21 +948,3 @@ func NotFoundError() error { func IsAlreadyExistsError() error { return &apierrors.StatusError{ErrStatus: metav1.Status{Reason: metav1.StatusReasonAlreadyExists}} } - -// Saves state of current env, and returns function to restore to saved state -func SaveEnv(key string) func() { - envVal, envExists := os.LookupEnv(key) - return func() { - if envExists { - err := os.Setenv(key, envVal) - if err != nil { - return - } - } else { - err := os.Unsetenv(key) - if err != nil { - return - } - } - } -} diff --git a/pkg/workspace/controllers/workspace_controller_test.go b/pkg/workspace/controllers/workspace_controller_test.go index 9a8c93165..be4494e02 100644 --- a/pkg/workspace/controllers/workspace_controller_test.go +++ b/pkg/workspace/controllers/workspace_controller_test.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "os" "reflect" "sort" "testing" @@ -329,6 +328,7 @@ func TestCreateAndValidateMachineNode(t *testing.T) { workspace v1alpha1.Workspace karpenterFeatureGates bool expectedError error + cloudProvider string }{ "Node is not created because machine creation fails": { callMocks: func(c *test.MockClient) { @@ -369,8 +369,8 @@ func TestCreateAndValidateMachineNode(t *testing.T) { c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil) - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) }, + cloudProvider: consts.AzureCloudName, objectConditions: apis.Conditions{ { Type: apis.ConditionReady, @@ -388,8 +388,8 @@ func TestCreateAndValidateMachineNode(t *testing.T) { c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil) - os.Setenv("CLOUD_PROVIDER", "aws") }, + cloudProvider: consts.AWSCloudName, objectConditions: apis.Conditions{ { Type: apis.ConditionReady, @@ -408,8 +408,8 @@ func TestCreateAndValidateMachineNode(t *testing.T) { c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil) c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil) - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) }, + cloudProvider: consts.AzureCloudName, objectConditions: apis.Conditions{ { Type: v1beta1.Launched, @@ -441,6 +441,11 @@ func TestCreateAndValidateMachineNode(t *testing.T) { } } + if tc.cloudProvider != "" { + t.Setenv("CLOUD_PROVIDER", tc.cloudProvider) + + } + tc.callMocks(mockClient) reconciler := &WorkspaceReconciler{ @@ -465,6 +470,7 @@ func TestCreateAndValidateNodeClaimNode(t *testing.T) { test.RegisterTestModel() testcases := map[string]struct { callMocks func(c *test.MockClient) + cloudProvider string karpenterFeatureGates bool nodeClaimConditions apis.Conditions workspace v1alpha1.Workspace @@ -478,8 +484,8 @@ func TestCreateAndValidateNodeClaimNode(t *testing.T) { c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil) c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.Workspace{}), mock.Anything).Return(nil) - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) }, + cloudProvider: consts.AzureCloudName, karpenterFeatureGates: true, nodeClaimConditions: apis.Conditions{ { @@ -499,6 +505,7 @@ func TestCreateAndValidateNodeClaimNode(t *testing.T) { c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil) c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil) }, + cloudProvider: consts.AzureCloudName, karpenterFeatureGates: true, nodeClaimConditions: apis.Conditions{ { @@ -522,6 +529,11 @@ func TestCreateAndValidateNodeClaimNode(t *testing.T) { mockClient.CreateOrUpdateObjectInMap(mockNodeClaim) } + if tc.cloudProvider != "" { + t.Setenv("CLOUD_PROVIDER", tc.cloudProvider) + + } + tc.callMocks(mockClient) featuregates.FeatureGates[consts.FeatureFlagKarpenter] = tc.karpenterFeatureGates @@ -693,7 +705,8 @@ func TestApplyInferenceWithPreset(t *testing.T) { } ctx := context.Background() - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + err := reconciler.applyInference(ctx, &tc.workspace) if tc.expectedError == nil { assert.Check(t, err == nil, fmt.Sprintf("Not expected to return error: %v", err)) diff --git a/pkg/workspace/inference/preset-inferences_test.go b/pkg/workspace/inference/preset-inferences_test.go index 308277dea..9b44f6713 100644 --- a/pkg/workspace/inference/preset-inferences_test.go +++ b/pkg/workspace/inference/preset-inferences_test.go @@ -5,7 +5,6 @@ package inference import ( "context" - "os" "reflect" "strings" "testing" @@ -126,7 +125,8 @@ func TestCreatePresetInference(t *testing.T) { for k, tc := range testcases { t.Run(k, func(t *testing.T) { - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + mockClient := test.NewClient() tc.callMocks(mockClient) diff --git a/pkg/workspace/tuning/preset-tuning_test.go b/pkg/workspace/tuning/preset-tuning_test.go index 5216ce399..8d7be2657 100644 --- a/pkg/workspace/tuning/preset-tuning_test.go +++ b/pkg/workspace/tuning/preset-tuning_test.go @@ -5,7 +5,6 @@ package tuning import ( "context" - "os" "strings" "testing" @@ -25,7 +24,7 @@ func normalize(s string) string { } func TestGetInstanceGPUCount(t *testing.T) { - os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) testcases := map[string]struct { sku string @@ -54,12 +53,6 @@ func TestGetInstanceGPUCount(t *testing.T) { } func TestGetTuningImageInfo(t *testing.T) { - // Setting up test environment - originalRegistryName := os.Getenv("PRESET_REGISTRY_NAME") - defer func() { - os.Setenv("PRESET_REGISTRY_NAME", originalRegistryName) // Reset after tests - }() - testcases := map[string]struct { registryName string wObj *kaitov1alpha1.Workspace @@ -102,7 +95,8 @@ func TestGetTuningImageInfo(t *testing.T) { for name, tc := range testcases { t.Run(name, func(t *testing.T) { - os.Setenv("PRESET_REGISTRY_NAME", tc.registryName) + t.Setenv("PRESET_REGISTRY_NAME", tc.registryName) + result, _ := GetTuningImageInfo(context.Background(), tc.wObj, tc.presetObj) assert.Equal(t, tc.expected, result) }) @@ -347,6 +341,8 @@ func TestPrepareTuningParameters(t *testing.T) { for name, tc := range testcases { t.Run(name, func(t *testing.T) { + t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName) + commands, resources := prepareTuningParameters(ctx, tc.workspaceObj, tc.modelCommand, tc.tuningObj, "2") assert.Equal(t, tc.expectedCommands, commands) assert.Equal(t, tc.expectedRequirements.Requests, resources.Requests)