diff --git a/pkg/common/util/util.go b/pkg/common/util/util.go index 28150c5f90..f635f48f4b 100644 --- a/pkg/common/util/util.go +++ b/pkg/common/util/util.go @@ -50,3 +50,11 @@ func GetReplicaTypes(specs map[commonv1.ReplicaType]*commonv1.ReplicaSpec) []com } return keys } +func GetSchedulerName(replicas map[commonv1.ReplicaType]*commonv1.ReplicaSpec) string { + for _, spec := range replicas { + if len(spec.Template.Spec.SchedulerName) > 0 { + return spec.Template.Spec.SchedulerName + } + } + return "" +} diff --git a/pkg/controller.v1/tensorflow/tfjob_controller.go b/pkg/controller.v1/tensorflow/tfjob_controller.go index a83d3600a6..8b99748535 100644 --- a/pkg/controller.v1/tensorflow/tfjob_controller.go +++ b/pkg/controller.v1/tensorflow/tfjob_controller.go @@ -829,12 +829,13 @@ func (r *TFJobReconciler) createNewPod(tfjob *tfv1.TFJob, rt, index string, spec // 1. if user has specified other scheduler, we report a warning without overriding any fields. // 2. if no SchedulerName is set for pods, then we set the SchedulerName to "volcano". if r.Config.EnableGangScheduling { - if !util.IsGangSchedulerSet(replicas, gangSchedulerName) { + podSchedulerName := util.GetSchedulerName(replicas) + if len(podSchedulerName) == 0 { + podTemplate.Spec.SchedulerName = gangSchedulerName + } else if strings.Compare(podSchedulerName, gangSchedulerName) != 0 { errMsg := "Another scheduler is specified when gang-scheduling is enabled and it will not be overwritten" logger.Warning(errMsg) r.Recorder.Event(tfjob, v1.EventTypeWarning, podTemplateSchedulerNameReason, errMsg) - } else { - podTemplate.Spec.SchedulerName = gangSchedulerName } if podTemplate.Annotations == nil {