Skip to content

Commit

Permalink
fix: unstable testing order causing flaky test (#799)
Browse files Browse the repository at this point in the history
fix flaky tests.

Signed-off-by: jerryzhuang <zhuangqhc@gmail.com>
  • Loading branch information
zhuangqh authored Dec 30, 2024
1 parent 2baa9a4 commit 82451cb
Show file tree
Hide file tree
Showing 11 changed files with 76 additions and 85 deletions.
3 changes: 1 addition & 2 deletions api/v1alpha1/ragengine_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package v1alpha1

import (
"os"
"strings"
"testing"

Expand Down Expand Up @@ -97,7 +96,7 @@ func TestRAGEngineValidateCreate(t *testing.T) {
wantErr: false,
},
}
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.ragEngine.validateCreate()
Expand Down
8 changes: 3 additions & 5 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package v1alpha1
import (
"context"
"fmt"
"os"
"strings"
"testing"

Expand Down Expand Up @@ -365,7 +364,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
},
}

os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
Expand Down Expand Up @@ -1012,7 +1011,7 @@ func TestWorkspaceValidateName(t *testing.T) {
},
}
RegisterValidationTestModels()
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
tests := []struct {
name string
workspaceName string
Expand Down Expand Up @@ -1132,8 +1131,7 @@ func TestWorkspaceValidateUpdate(t *testing.T) {
func TestTuningSpecValidateCreate(t *testing.T) {
RegisterValidationTestModels()
// Set ReleaseNamespace Env
os.Setenv(consts.DefaultReleaseNamespaceEnvVar, DefaultReleaseNamespace)
defer os.Unsetenv(consts.DefaultReleaseNamespaceEnvVar)
t.Setenv(consts.DefaultReleaseNamespaceEnvVar, DefaultReleaseNamespace)

// Create fake client with default ConfigMap
scheme := runtime.NewScheme()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-qwen-2.5-coder-7b-instruct
name: workspace-qwen-2-5-coder-7b-instruct
resource:
instanceType: "Standard_NC24ads_A100_v4"
labelSelector:
matchLabels:
apps: qwen-2.5-coder
apps: qwen-2-5-coder
inference:
preset:
name: qwen2.5-coder-7b-instruct
4 changes: 2 additions & 2 deletions pkg/ragengine/controllers/preset-rag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package controllers

import (
"context"
"os"
"strings"
"testing"

Expand Down Expand Up @@ -37,7 +36,8 @@ func TestCreatePresetRAG(t *testing.T) {

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)

mockClient := test.NewClient()
tc.callMocks(mockClient)

Expand Down
19 changes: 13 additions & 6 deletions pkg/ragengine/controllers/ragengine_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"encoding/json"
"errors"
"fmt"
"os"
"testing"
"time"

Expand Down Expand Up @@ -270,6 +269,7 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) {
test.RegisterTestModel()
testcases := map[string]struct {
callMocks func(c *test.MockClient)
cloudProvider string
objectConditions apis.Conditions
ragengine v1alpha1.RAGEngine
karpenterFeatureGates bool
Expand All @@ -281,6 +281,7 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha5.Machine{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil)
},
cloudProvider: consts.AzureCloudName,
objectConditions: apis.Conditions{
{
Type: apis.ConditionReady,
Expand All @@ -297,8 +298,8 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) {
c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil)
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
},
cloudProvider: consts.AzureCloudName,
objectConditions: apis.Conditions{
{
Type: apis.ConditionReady,
Expand All @@ -316,8 +317,8 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) {
c.On("Create", mock.IsType(context.Background()), mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1beta1.NodeClaim{}), mock.Anything).Return(nil)
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil)
os.Setenv("CLOUD_PROVIDER", "aws")
},
cloudProvider: consts.AWSCloudName,
objectConditions: apis.Conditions{
{
Type: apis.ConditionReady,
Expand Down Expand Up @@ -348,6 +349,11 @@ func TestCreateAndValidateMachineNodeforRAGEngine(t *testing.T) {
}
}

if tc.cloudProvider != "" {
t.Setenv("CLOUD_PROVIDER", tc.cloudProvider)

}

tc.callMocks(mockClient)

reconciler := &RAGEngineReconciler{
Expand Down Expand Up @@ -550,7 +556,6 @@ func TestApplyRAG(t *testing.T) {

c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
c.StatusMock.On("Update", mock.IsType(context.Background()), mock.IsType(&v1alpha1.RAGEngine{}), mock.Anything).Return(nil)
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
},
ragengine: *test.MockRAGEngineWithRevision1,
verifyCalls: func(c *test.MockClient) {
Expand All @@ -568,7 +573,7 @@ func TestApplyRAG(t *testing.T) {
c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.Deployment)
*dep = test.MockRAGDeploymentUpdated
*dep = *test.MockRAGDeploymentUpdated.DeepCopy()
}).
Return(nil)

Expand All @@ -592,7 +597,7 @@ func TestApplyRAG(t *testing.T) {
c.On("Get", mock.Anything, mock.Anything, mock.IsType(&appsv1.Deployment{}), mock.Anything).
Run(func(args mock.Arguments) {
dep := args.Get(2).(*appsv1.Deployment)
*dep = test.MockRAGDeploymentUpdated
*dep = *test.MockRAGDeploymentUpdated.DeepCopy()
}).
Return(nil)

Expand All @@ -615,6 +620,8 @@ func TestApplyRAG(t *testing.T) {

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)

mockClient := test.NewClient()
tc.callMocks(mockClient)

Expand Down
32 changes: 20 additions & 12 deletions pkg/utils/nodeclaim/nodeclaim_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package nodeclaim
import (
"context"
"errors"
"os"
"testing"

azurev1alpha2 "github.com/Azure/karpenter-provider-azure/pkg/apis/v1alpha2"
Expand Down Expand Up @@ -77,7 +76,7 @@ func TestCreateNodeClaim(t *testing.T) {

mockNodeClaim := &test.MockNodeClaim
mockNodeClaim.Status.Conditions = tc.nodeClaimConditions
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
err := CreateNodeClaim(context.Background(), mockNodeClaim, mockClient)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
Expand Down Expand Up @@ -178,7 +177,8 @@ func TestWaitForPendingNodeClaims(t *testing.T) {
func TestGenerateNodeClaimManifest(t *testing.T) {
t.Run("Should generate a nodeClaim object from the given workspace when cloud provider set to azure", func(t *testing.T) {
mockWorkspace := test.MockWorkspaceWithPreset
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)

nodeClaim := GenerateNodeClaimManifest(context.Background(), "0", mockWorkspace)

assert.Check(t, nodeClaim != nil, "NodeClaim must not be nil")
Expand All @@ -196,7 +196,8 @@ func TestGenerateNodeClaimManifest(t *testing.T) {

t.Run("Should generate a nodeClaim object from the given workspace when cloud provider set to aws", func(t *testing.T) {
mockWorkspace := test.MockWorkspaceWithPreset
os.Setenv("CLOUD_PROVIDER", "aws")
t.Setenv("CLOUD_PROVIDER", consts.AWSCloudName)

nodeClaim := GenerateNodeClaimManifest(context.Background(), "0", mockWorkspace)

assert.Check(t, nodeClaim != nil, "NodeClaim must not be nil")
Expand Down Expand Up @@ -241,7 +242,8 @@ func TestGenerateAKSNodeClassManifest(t *testing.T) {

func TestGenerateEC2NodeClassManifest(t *testing.T) {
t.Run("Should generate a valid EC2NodeClass object with correct name and annotations", func(t *testing.T) {
os.Setenv("CLUSTER_NAME", "test-cluster")
t.Setenv("CLUSTER_NAME", "test-cluster")

nodeClass := GenerateEC2NodeClassManifest(context.Background())

assert.Check(t, nodeClass != nil, "EC2NodeClass must not be nil")
Expand All @@ -252,7 +254,8 @@ func TestGenerateEC2NodeClassManifest(t *testing.T) {
})

t.Run("Should generate a valid EC2NodeClass object with correct subnet and security group selectors", func(t *testing.T) {
os.Setenv("CLUSTER_NAME", "test-cluster")
t.Setenv("CLUSTER_NAME", "test-cluster")

nodeClass := GenerateEC2NodeClassManifest(context.Background())

assert.Check(t, nodeClass != nil, "EC2NodeClass must not be nil")
Expand All @@ -261,7 +264,8 @@ func TestGenerateEC2NodeClassManifest(t *testing.T) {
})

t.Run("Should handle missing CLUSTER_NAME environment variable", func(t *testing.T) {
os.Unsetenv("CLUSTER_NAME")
t.Setenv("CLUSTER_NAME", "")

nodeClass := GenerateEC2NodeClassManifest(context.Background())

assert.Check(t, nodeClass != nil, "EC2NodeClass must not be nil")
Expand All @@ -273,7 +277,8 @@ func TestGenerateEC2NodeClassManifest(t *testing.T) {

func TestCreateKarpenterNodeClass(t *testing.T) {
t.Run("Should create AKSNodeClass when cloud provider is Azure", func(t *testing.T) {
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)

mockClient := test.NewClient()
mockClient.On("Create", mock.IsType(context.Background()), mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(nil)

Expand All @@ -283,8 +288,9 @@ func TestCreateKarpenterNodeClass(t *testing.T) {
})

t.Run("Should create EC2NodeClass when cloud provider is AWS", func(t *testing.T) {
os.Setenv("CLOUD_PROVIDER", consts.AWSCloudName)
os.Setenv("CLUSTER_NAME", "test-cluster")
t.Setenv("CLOUD_PROVIDER", consts.AWSCloudName)
t.Setenv("CLUSTER_NAME", "test-cluster")

mockClient := test.NewClient()
mockClient.On("Create", mock.IsType(context.Background()), mock.IsType(&awsv1beta1.EC2NodeClass{}), mock.Anything).Return(nil)

Expand All @@ -294,15 +300,17 @@ func TestCreateKarpenterNodeClass(t *testing.T) {
})

t.Run("Should return error when cloud provider is unsupported", func(t *testing.T) {
os.Setenv("CLOUD_PROVIDER", "unsupported")
t.Setenv("CLOUD_PROVIDER", "unsupported")

mockClient := test.NewClient()

err := CreateKarpenterNodeClass(context.Background(), mockClient)
assert.Error(t, err, "unsupported cloud provider unsupported")
})

t.Run("Should return error when Create call fails", func(t *testing.T) {
os.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)
t.Setenv("CLOUD_PROVIDER", consts.AzureCloudName)

mockClient := test.NewClient()
mockClient.On("Create", mock.IsType(context.Background()), mock.IsType(&azurev1alpha2.AKSNodeClass{}), mock.Anything).Return(errors.New("create failed"))

Expand Down
28 changes: 9 additions & 19 deletions pkg/utils/resources/resources_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package resources
import (
"context"
"errors"
"os"
"testing"
"time"

Expand Down Expand Up @@ -275,15 +274,13 @@ func TestEnsureInferenceConfigMap(t *testing.T) {
}

testcases := map[string]struct {
setupEnv func()
callMocks func(c *test.MockClient)
userProvided client.ObjectKey
expectedError string
callMocks func(c *test.MockClient)
releaseNamespace string
userProvided client.ObjectKey
expectedError string
}{
"Config already exists in workspace namespace": {
setupEnv: func() {
os.Setenv(consts.DefaultReleaseNamespaceEnvVar, "release-namespace")
},
releaseNamespace: "release-namespace",
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.ConfigMap{}), mock.Anything).Return(nil)
},
Expand All @@ -303,9 +300,7 @@ func TestEnsureInferenceConfigMap(t *testing.T) {
expectedError: "failed to get release namespace: failed to determine release namespace from file /var/run/secrets/kubernetes.io/serviceaccount/namespace and env var RELEASE_NAMESPACE",
},
"Config doesn't exist in namespace": {
setupEnv: func() {
os.Setenv(consts.DefaultReleaseNamespaceEnvVar, "release-namespace")
},
releaseNamespace: "release-namespace",
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.ConfigMap{}), mock.Anything).Return(apierrors.NewNotFound(schema.GroupResource{}, "inference-config-template"))
},
Expand All @@ -316,9 +311,7 @@ func TestEnsureInferenceConfigMap(t *testing.T) {
expectedError: "user specified ConfigMap inference-config-template not found in namespace workspace-namespace",
},
"Generate default config": {
setupEnv: func() {
os.Setenv(consts.DefaultReleaseNamespaceEnvVar, "release-namespace")
},
releaseNamespace: "release-namespace",
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.ConfigMap{}), mock.Anything).
Return(apierrors.NewNotFound(schema.GroupResource{}, "inference-params-template")).Times(4)
Expand All @@ -342,11 +335,8 @@ func TestEnsureInferenceConfigMap(t *testing.T) {

for name, tc := range testcases {
t.Run(name, func(t *testing.T) {
cleanupEnv := test.SaveEnv(consts.DefaultReleaseNamespaceEnvVar)
defer cleanupEnv()

if tc.setupEnv != nil {
tc.setupEnv()
if tc.releaseNamespace != "" {
t.Setenv(consts.DefaultReleaseNamespaceEnvVar, tc.releaseNamespace)
}

mockClient := test.NewClient()
Expand Down
20 changes: 0 additions & 20 deletions pkg/utils/test/testUtils.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
package test

import (
"os"

"github.com/aws/karpenter-core/pkg/apis/v1alpha5"
"github.com/kaito-project/kaito/api/v1alpha1"
"github.com/kaito-project/kaito/pkg/model"
Expand Down Expand Up @@ -950,21 +948,3 @@ func NotFoundError() error {
func IsAlreadyExistsError() error {
return &apierrors.StatusError{ErrStatus: metav1.Status{Reason: metav1.StatusReasonAlreadyExists}}
}

// Saves state of current env, and returns function to restore to saved state
func SaveEnv(key string) func() {
envVal, envExists := os.LookupEnv(key)
return func() {
if envExists {
err := os.Setenv(key, envVal)
if err != nil {
return
}
} else {
err := os.Unsetenv(key)
if err != nil {
return
}
}
}
}
Loading

0 comments on commit 82451cb

Please sign in to comment.