diff --git a/pkg/controller.v1/mpi/mpijob_controller_test.go b/pkg/controller.v1/mpi/mpijob_controller_test.go index bf867deb6d..4efa2d3710 100644 --- a/pkg/controller.v1/mpi/mpijob_controller_test.go +++ b/pkg/controller.v1/mpi/mpijob_controller_test.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "strings" - "time" common "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" . "github.com/onsi/ginkgo/v2" @@ -31,6 +30,7 @@ import ( ctrl "sigs.k8s.io/controller-runtime" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/util/testutil" ) const ( @@ -129,12 +129,6 @@ func newMPIJobWithLauncher(name string, replicas *int32, pusPerReplica int64, re } var _ = Describe("MPIJob controller", func() { - // Define utility constants for object names and testing timeouts/durations and intervals. - const ( - timeout = 10 * time.Second - interval = 1000 * time.Millisecond - ) - Context("Test launcher is GPU launcher", func() { It("Should pass GPU Launcher verification", func() { By("By creating MPIJobs with various resource configuration") @@ -194,7 +188,7 @@ var _ = Describe("MPIJob controller", func() { } launcherCreated.Status.Phase = corev1.PodSucceeded return testK8sClient.Status().Update(ctx, launcherCreated) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) created := &kubeflowv1.MPIJob{} launcherStatus := &common.ReplicaStatus{ @@ -208,7 +202,7 @@ var _ = Describe("MPIJob controller", func() { return false } return ReplicaStatusMatch(created.Status.ReplicaStatuses, kubeflowv1.MPIJobReplicaTypeLauncher, launcherStatus) - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) }) }) @@ -236,7 +230,7 @@ var _ = Describe("MPIJob controller", func() { } launcherCreated.Status.Phase = corev1.PodFailed return testK8sClient.Status().Update(ctx, launcherCreated) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) launcherStatus := &common.ReplicaStatus{ Active: 0, @@ -250,7 +244,7 @@ var _ = Describe("MPIJob controller", func() { return false } return ReplicaStatusMatch(created.Status.ReplicaStatuses, kubeflowv1.MPIJobReplicaTypeLauncher, launcherStatus) - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) }) }) @@ -280,7 +274,7 @@ var _ = Describe("MPIJob controller", func() { } launcherCreated.Status.Phase = corev1.PodSucceeded return testK8sClient.Status().Update(ctx, launcherCreated) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) created := &kubeflowv1.MPIJob{} launcherStatus := &common.ReplicaStatus{ @@ -294,7 +288,7 @@ var _ = Describe("MPIJob controller", func() { return false } return ReplicaStatusMatch(created.Status.ReplicaStatuses, kubeflowv1.MPIJobReplicaTypeWorker, launcherStatus) - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) }) }) @@ -324,7 +318,7 @@ var _ = Describe("MPIJob controller", func() { } launcherCreated.Status.Phase = corev1.PodRunning return testK8sClient.Status().Update(ctx, launcherCreated) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) for i := 0; i < int(replicas); i++ { name := fmt.Sprintf("%s-%d", mpiJob.Name+workerSuffix, i) @@ -340,7 +334,7 @@ var _ = Describe("MPIJob controller", func() { } workerCreated.Status.Phase = corev1.PodPending return testK8sClient.Status().Update(ctx, workerCreated) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) } key := types.NamespacedName{ @@ -366,7 +360,7 @@ var _ = Describe("MPIJob controller", func() { return ReplicaStatusMatch(created.Status.ReplicaStatuses, kubeflowv1.MPIJobReplicaTypeLauncher, launcherStatus) && ReplicaStatusMatch(created.Status.ReplicaStatuses, kubeflowv1.MPIJobReplicaTypeWorker, workerStatus) - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) }) }) @@ -396,7 +390,7 @@ var _ = Describe("MPIJob controller", func() { } launcherCreated.Status.Phase = corev1.PodRunning return testK8sClient.Status().Update(ctx, launcherCreated) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) for i := 0; i < int(replicas); i++ { name := fmt.Sprintf("%s-%d", mpiJob.Name+workerSuffix, i) @@ -412,7 +406,7 @@ var _ = Describe("MPIJob controller", func() { } workerCreated.Status.Phase = corev1.PodRunning return testK8sClient.Status().Update(ctx, workerCreated) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) } key := types.NamespacedName{ @@ -438,7 +432,7 @@ var _ = Describe("MPIJob controller", func() { return ReplicaStatusMatch(created.Status.ReplicaStatuses, kubeflowv1.MPIJobReplicaTypeLauncher, launcherStatus) && ReplicaStatusMatch(created.Status.ReplicaStatuses, kubeflowv1.MPIJobReplicaTypeWorker, workerStatus) - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) }) }) @@ -470,7 +464,7 @@ var _ = Describe("MPIJob controller", func() { } workerCreated.Status.Phase = corev1.PodRunning return testK8sClient.Status().Update(ctx, workerCreated) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) } launcherKey := types.NamespacedName{ @@ -481,7 +475,7 @@ var _ = Describe("MPIJob controller", func() { Eventually(func() bool { err := testK8sClient.Get(ctx, launcherKey, launcher) return err != nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) key := types.NamespacedName{ Namespace: metav1.NamespaceDefault, @@ -506,7 +500,7 @@ var _ = Describe("MPIJob controller", func() { return ReplicaStatusMatch(created.Status.ReplicaStatuses, kubeflowv1.MPIJobReplicaTypeLauncher, launcherStatus) && ReplicaStatusMatch(created.Status.ReplicaStatuses, kubeflowv1.MPIJobReplicaTypeWorker, workerStatus) - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) }) }) @@ -552,7 +546,7 @@ var _ = Describe("MPIJob controller", func() { Eventually(func() error { _, err := reconciler.Reconcile(ctx, req) return err - }, timeout, interval).Should(MatchError(expectedErr)) + }, testutil.Timeout, testutil.Interval).Should(MatchError(expectedErr)) }) }) @@ -585,7 +579,7 @@ var _ = Describe("MPIJob controller", func() { Eventually(func() error { _, err := reconciler.Reconcile(ctx, req) return err - }, timeout, interval).Should(MatchError(expectedErr)) + }, testutil.Timeout, testutil.Interval).Should(MatchError(expectedErr)) }) }) @@ -615,7 +609,7 @@ var _ = Describe("MPIJob controller", func() { Eventually(func() error { _, err := reconciler.Reconcile(ctx, req) return err - }, timeout, interval).Should(MatchError(expectedErr)) + }, testutil.Timeout, testutil.Interval).Should(MatchError(expectedErr)) }) }) @@ -645,7 +639,7 @@ var _ = Describe("MPIJob controller", func() { Eventually(func() error { _, err := reconciler.Reconcile(ctx, req) return err - }, timeout, interval).Should(MatchError(expectedErr)) + }, testutil.Timeout, testutil.Interval).Should(MatchError(expectedErr)) }) }) @@ -675,7 +669,7 @@ var _ = Describe("MPIJob controller", func() { Eventually(func() error { _, err := reconciler.Reconcile(ctx, req) return err - }, timeout, interval).Should(MatchError(expectedErr)) + }, testutil.Timeout, testutil.Interval).Should(MatchError(expectedErr)) }) }) @@ -705,7 +699,7 @@ var _ = Describe("MPIJob controller", func() { Eventually(func() error { _, err := reconciler.Reconcile(ctx, req) return err - }, timeout, interval).Should(MatchError(expectedErr)) + }, testutil.Timeout, testutil.Interval).Should(MatchError(expectedErr)) }) }) diff --git a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_test.go b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_test.go index 3cf121b2ac..77b65406a7 100644 --- a/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_test.go +++ b/pkg/controller.v1/paddlepaddle/paddlepaddle_controller_test.go @@ -17,7 +17,6 @@ package paddle import ( "context" "fmt" - "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -28,13 +27,12 @@ import ( commonv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/util/testutil" ) var _ = Describe("PaddleJob controller", func() { // Define utility constants for object names and testing timeouts/durations and intervals. const ( - timeout = time.Second * 10 - interval = time.Millisecond * 250 expectedPort = int32(8080) ) @@ -99,20 +97,20 @@ var _ = Describe("PaddleJob controller", func() { Eventually(func() bool { err := testK8sClient.Get(ctx, key, created) return err == nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) masterKey := types.NamespacedName{Name: fmt.Sprintf("%s-master-0", name), Namespace: namespace} masterPod := &corev1.Pod{} Eventually(func() bool { err := testK8sClient.Get(ctx, masterKey, masterPod) return err == nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) masterSvc := &corev1.Service{} Eventually(func() bool { err := testK8sClient.Get(ctx, masterKey, masterSvc) return err == nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) // Check the pod port. Expect(masterPod.Spec.Containers[0].Ports).To(ContainElement(corev1.ContainerPort{ @@ -156,7 +154,7 @@ var _ = Describe("PaddleJob controller", func() { } return created.Status.ReplicaStatuses != nil && created.Status. ReplicaStatuses[kubeflowv1.PaddleJobReplicaTypeMaster].Succeeded == 1 - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) // Check if the job is succeeded. cond := getCondition(created.Status, commonv1.JobSucceeded) Expect(cond.Status).To(Equal(corev1.ConditionTrue)) diff --git a/pkg/controller.v1/pytorch/pytorchjob_controller_test.go b/pkg/controller.v1/pytorch/pytorchjob_controller_test.go index 99f2b2107c..39ab652c52 100644 --- a/pkg/controller.v1/pytorch/pytorchjob_controller_test.go +++ b/pkg/controller.v1/pytorch/pytorchjob_controller_test.go @@ -17,7 +17,6 @@ package pytorch import ( "context" "fmt" - "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -28,13 +27,12 @@ import ( commonv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" + "github.com/kubeflow/training-operator/pkg/util/testutil" ) var _ = Describe("PyTorchJob controller", func() { - // Define utility constants for object names and testing timeouts/durations and intervals. + // Define utility constants for object names. const ( - timeout = time.Second * 10 - interval = time.Millisecond * 250 expectedPort = int32(8080) ) @@ -99,20 +97,20 @@ var _ = Describe("PyTorchJob controller", func() { Eventually(func() bool { err := testK8sClient.Get(ctx, key, created) return err == nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) masterKey := types.NamespacedName{Name: fmt.Sprintf("%s-master-0", name), Namespace: namespace} masterPod := &corev1.Pod{} Eventually(func() bool { err := testK8sClient.Get(ctx, masterKey, masterPod) return err == nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) masterSvc := &corev1.Service{} Eventually(func() bool { err := testK8sClient.Get(ctx, masterKey, masterSvc) return err == nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) // Check the pod port. Expect(masterPod.Spec.Containers[0].Ports).To(ContainElement(corev1.ContainerPort{ @@ -159,7 +157,7 @@ var _ = Describe("PyTorchJob controller", func() { } return created.Status.ReplicaStatuses != nil && created.Status. ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeMaster].Succeeded == 1 - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) // Check if the job is succeeded. cond := getCondition(created.Status, commonv1.JobSucceeded) Expect(cond.Status).To(Equal(corev1.ConditionTrue)) @@ -222,20 +220,20 @@ var _ = Describe("PyTorchJob controller", func() { Eventually(func() bool { err := testK8sClient.Get(ctx, key, created) return err == nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) workerKey := types.NamespacedName{Name: fmt.Sprintf("%s-worker-0", name), Namespace: namespace} pod := &corev1.Pod{} Eventually(func() bool { err := testK8sClient.Get(ctx, workerKey, pod) return err == nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) svc := &corev1.Service{} Eventually(func() bool { err := testK8sClient.Get(ctx, workerKey, svc) return err == nil - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) // Check pod port. Expect(pod.Spec.Containers[0].Ports).To(ContainElement(corev1.ContainerPort{ @@ -287,7 +285,7 @@ var _ = Describe("PyTorchJob controller", func() { } return created.Status.ReplicaStatuses != nil && created.Status. ReplicaStatuses[kubeflowv1.PyTorchJobReplicaTypeWorker].Succeeded == 1 - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) // Check if the job is succeeded. cond := getCondition(created.Status, commonv1.JobSucceeded) Expect(cond.Status).To(Equal(corev1.ConditionTrue)) diff --git a/pkg/controller.v1/tensorflow/job_test.go b/pkg/controller.v1/tensorflow/job_test.go index 869b116fcd..f377c325dc 100644 --- a/pkg/controller.v1/tensorflow/job_test.go +++ b/pkg/controller.v1/tensorflow/job_test.go @@ -37,15 +37,10 @@ import ( kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" tftestutil "github.com/kubeflow/training-operator/pkg/controller.v1/tensorflow/testutil" + "github.com/kubeflow/training-operator/pkg/util/testutil" ) var _ = Describe("TFJob controller", func() { - // Define utility constants for object names and testing timeouts/durations and intervals. - const ( - timeout = 10 * time.Second - interval = 1000 * time.Millisecond - ) - Context("Test Add TFJob", func() { It("should get the exact TFJob", func() { By("submitting an TFJob") @@ -75,7 +70,7 @@ var _ = Describe("TFJob controller", func() { Eventually(func() error { job := &kubeflowv1.TFJob{} return reconciler.Get(ctx, key, job) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) Expect(testK8sClient.Delete(ctx, tfJob)).Should(Succeed()) Expect(testK8sClient.Delete(ctx, decoyJob)).Should(Succeed()) @@ -139,7 +134,7 @@ var _ = Describe("TFJob controller", func() { } return nil - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) }) }) @@ -562,7 +557,7 @@ var _ = Describe("TFJob controller", func() { var updatedTFJob kubeflowv1.TFJob Eventually(func() error { return reconciler.Get(ctx, client.ObjectKeyFromObject(tc.tfJob), &updatedTFJob) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) initializeReplicaStatuses(&updatedTFJob.Status, kubeflowv1.TFJobReplicaTypeWorker) @@ -586,7 +581,7 @@ var _ = Describe("TFJob controller", func() { var getTFJob kubeflowv1.TFJob Expect(reconciler.Get(ctx, client.ObjectKeyFromObject(tc.tfJob), &getTFJob)).Should(Succeed()) return getTFJob.Status.ReplicaStatuses[kubeflowv1.TFJobReplicaTypeWorker] - }, timeout, interval).ShouldNot(BeNil()) + }, testutil.Timeout, testutil.Interval).ShouldNot(BeNil()) ttl := updatedTFJob.Spec.RunPolicy.TTLSecondsAfterFinished if ttl != nil { @@ -607,7 +602,7 @@ var _ = Describe("TFJob controller", func() { return err } return fmt.Errorf("job %s still remains", name) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) } }) }) @@ -653,11 +648,11 @@ var _ = Describe("Test for controller.v1/common", func() { Eventually(func() bool { gotErr := testK8sClient.Get(ctx, client.ObjectKeyFromObject(tc.tfJob), &kubeflowv1.TFJob{}) return errors.IsNotFound(gotErr) - }).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) } else { Eventually(func() error { return testK8sClient.Get(ctx, client.ObjectKeyFromObject(tc.tfJob), &kubeflowv1.TFJob{}) - }).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) } }, Entry("TFJob shouldn't be removed since TTL is nil", &cleanUpCases{ @@ -771,7 +766,7 @@ var _ = Describe("Test for controller.v1/common", func() { svc := &corev1.Service{} Expect(testK8sClient.Get(ctx, client.ObjectKeyFromObject(wantSvc), svc)).Should(Succeed()) return svc - }).Should(BeComparableTo(wantSvc, + }, testutil.Timeout, testutil.Interval).Should(BeComparableTo(wantSvc, cmpopts.IgnoreFields(metav1.ObjectMeta{}, "UID", "ResourceVersion", "Generation", "CreationTimestamp", "ManagedFields"))) } }, diff --git a/pkg/controller.v1/tensorflow/pod_test.go b/pkg/controller.v1/tensorflow/pod_test.go index 3c3a830217..860ffc8226 100644 --- a/pkg/controller.v1/tensorflow/pod_test.go +++ b/pkg/controller.v1/tensorflow/pod_test.go @@ -18,10 +18,8 @@ import ( "context" "fmt" "os" - "time" commonv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - "github.com/kubeflow/training-operator/pkg/core" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" @@ -33,14 +31,11 @@ import ( kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" tftestutil "github.com/kubeflow/training-operator/pkg/controller.v1/tensorflow/testutil" + "github.com/kubeflow/training-operator/pkg/core" + "github.com/kubeflow/training-operator/pkg/util/testutil" ) var _ = Describe("TFJob controller", func() { - const ( - timeout = 10 * time.Second - interval = 1000 * time.Millisecond - ) - Context("Test ClusterSpec", func() { It("should generate desired cluster spec", func() { type tc struct { @@ -264,7 +259,7 @@ var _ = Describe("TFJob controller", func() { return fmt.Errorf("pod status is not Failed") } return nil - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) @@ -276,7 +271,7 @@ var _ = Describe("TFJob controller", func() { return noPod.GetDeletionTimestamp() != nil } return errors.IsNotFound(err) - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) }) }) @@ -326,7 +321,7 @@ var _ = Describe("TFJob controller", func() { return fmt.Errorf("expecting %d Pods while got %d", 3, len(podList.Items)) } return nil - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) @@ -341,7 +336,7 @@ var _ = Describe("TFJob controller", func() { return false } return errors.IsNotFound(err) - }, timeout, interval).Should(BeTrue()) + }, testutil.Timeout, testutil.Interval).Should(BeTrue()) }) }) @@ -388,7 +383,7 @@ var _ = Describe("TFJob controller", func() { return fmt.Errorf("before reconciling, expecting %d Pods while got %d", 1, len(podList.Items)) } return nil - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) _ = reconciler.ReconcileJobs(tfJob, tfJob.Spec.TFReplicaSpecs, tfJob.Status, &tfJob.Spec.RunPolicy) @@ -412,7 +407,7 @@ var _ = Describe("TFJob controller", func() { return fmt.Errorf("after reconciling, expecting %d Pods while got %d", 3, len(podList.Items)) } return nil - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) }) }) @@ -538,7 +533,7 @@ var _ = Describe("TFJob controller", func() { len(podList.Items), tt.tfJob.GetName(), totalExpectedPodCount) } return nil - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) got, err := reconciler.IsWorker0Completed(tt.tfJob, tt.replicas) diff --git a/pkg/controller.v1/tensorflow/status_test.go b/pkg/controller.v1/tensorflow/status_test.go index 92affc52ae..d2cf9aaa80 100644 --- a/pkg/controller.v1/tensorflow/status_test.go +++ b/pkg/controller.v1/tensorflow/status_test.go @@ -17,10 +17,8 @@ package tensorflow import ( "context" "fmt" - "time" commonv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - "github.com/kubeflow/training-operator/pkg/util" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" @@ -32,15 +30,11 @@ import ( kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" tftestutil "github.com/kubeflow/training-operator/pkg/controller.v1/tensorflow/testutil" + "github.com/kubeflow/training-operator/pkg/util" + "github.com/kubeflow/training-operator/pkg/util/testutil" ) var _ = Describe("TFJob controller", func() { - // Define utility constants for object names and testing timeouts/durations and intervals. - const ( - timeout = 10 * time.Second - interval = 1000 * time.Millisecond - ) - Context("Test Failed", func() { It("should update TFJob with failed status", func() { By("creating a TFJob with replicaStatues initialized") @@ -442,7 +436,7 @@ var _ = Describe("TFJob controller", func() { len(podList.Items), c.tfJob.GetName(), totalExpectedPodCount) } return nil - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) _ = reconciler.ReconcileJobs(c.tfJob, c.tfJob.Spec.TFReplicaSpecs, c.tfJob.Status, &c.tfJob.Spec.RunPolicy) @@ -469,12 +463,6 @@ func setStatusForTest(tfJob *kubeflowv1.TFJob, rtype commonv1.ReplicaType, faile } basicLabels := reconciler.GenLabels(tfJob.GetName()) - - const ( - timeout = 10 * time.Second - interval = 1000 * time.Millisecond - ) - ctx := context.Background() Expect(rtype).Should(BeElementOf([]kubeflowv1.ReplicaType{ @@ -519,7 +507,7 @@ func setStatusForTest(tfJob *kubeflowv1.TFJob, rtype commonv1.ReplicaType, faile } return client.Status().Update(ctx, po) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) updateJobReplicaStatuses(&tfJob.Status, rtype, po) @@ -558,7 +546,7 @@ func setStatusForTest(tfJob *kubeflowv1.TFJob, rtype commonv1.ReplicaType, faile } return client.Status().Update(ctx, po) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) updateJobReplicaStatuses(&tfJob.Status, rtype, po) index++ @@ -582,7 +570,7 @@ func setStatusForTest(tfJob *kubeflowv1.TFJob, rtype commonv1.ReplicaType, faile po.Status.Phase = corev1.PodRunning return client.Status().Update(ctx, po) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) updateJobReplicaStatuses(&tfJob.Status, rtype, po) index++ diff --git a/pkg/controller.v1/tensorflow/suite_test.go b/pkg/controller.v1/tensorflow/suite_test.go index f9cac98f6f..dde7b29c11 100644 --- a/pkg/controller.v1/tensorflow/suite_test.go +++ b/pkg/controller.v1/tensorflow/suite_test.go @@ -23,6 +23,7 @@ import ( kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/kubeflow/training-operator/pkg/controller.v1/common" + "github.com/kubeflow/training-operator/pkg/util/testutil" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -55,10 +56,6 @@ func TestAPIs(t *testing.T) { } var _ = BeforeSuite(func() { - const ( - timeout = 10 * time.Second - interval = 1000 * time.Millisecond - ) logf.SetLogger(zap.New(zap.WriteTo(GinkgoWriter), zap.UseDevMode(true))) testCtx, testCancel = context.WithCancel(context.TODO()) @@ -108,7 +105,7 @@ var _ = BeforeSuite(func() { return fmt.Errorf("cannot get at lease one namespace, got %d", len(nsList.Items)) } return nil - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) }) var _ = AfterSuite(func() { diff --git a/pkg/controller.v1/tensorflow/testutil/pod.go b/pkg/controller.v1/tensorflow/testutil/pod.go index f172b0c44b..99ca3b9817 100644 --- a/pkg/controller.v1/tensorflow/testutil/pod.go +++ b/pkg/controller.v1/tensorflow/testutil/pod.go @@ -18,7 +18,6 @@ import ( "context" "fmt" "strings" - "time" commonv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" . "github.com/onsi/gomega" @@ -26,6 +25,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" + + "github.com/kubeflow/training-operator/pkg/util/testutil" ) const ( @@ -74,8 +75,6 @@ func NewPodList(count int32, status corev1.PodPhase, job metav1.Object, typ comm func SetPodsStatuses(client client.Client, job metav1.Object, typ commonv1.ReplicaType, pendingPods, activePods, succeededPods, failedPods int32, restartCounts []int32, refs []metav1.OwnerReference, basicLabels map[string]string) { - timeout := 10 * time.Second - interval := 1000 * time.Millisecond var index int32 taskMap := map[corev1.PodPhase]int32{ corev1.PodFailed: failedPods, @@ -105,7 +104,7 @@ func SetPodsStatuses(client client.Client, job metav1.Object, typ commonv1.Repli po.Status.ContainerStatuses = []corev1.ContainerStatus{{RestartCount: restartCounts[i]}} } return client.Status().Update(ctx, po) - }, timeout, interval).Should(BeNil()) + }, testutil.Timeout, testutil.Interval).Should(BeNil()) } index += desiredCount } diff --git a/pkg/util/testutil/constants.go b/pkg/util/testutil/constants.go new file mode 100644 index 0000000000..f935731fcf --- /dev/null +++ b/pkg/util/testutil/constants.go @@ -0,0 +1,8 @@ +package testutil + +import "time" + +const ( + Timeout = 30 * time.Second + Interval = 250 * time.Millisecond +)