diff --git a/pkg/utils/accelerators/tpu.go b/pkg/utils/accelerators/tpu.go index aa7b3056..de4e8f54 100644 --- a/pkg/utils/accelerators/tpu.go +++ b/pkg/utils/accelerators/tpu.go @@ -157,7 +157,7 @@ func addTPUVariablesSubGroup(pod *corev1.Pod) error { }, corev1.EnvVar{ Name: TpuName, - Value: fmt.Sprint(pod.Name), + Value: fmt.Sprint(leaderName), }, ) return nil @@ -218,7 +218,7 @@ func AddTPUVariables(pod *corev1.Pod, size int) error { }, corev1.EnvVar{ Name: TpuName, - Value: fmt.Sprint(pod.Name), + Value: fmt.Sprint(leaderName), }, ) return nil diff --git a/pkg/utils/accelerators/tpu_test.go b/pkg/utils/accelerators/tpu_test.go index 5f286ffd..c18d5010 100644 --- a/pkg/utils/accelerators/tpu_test.go +++ b/pkg/utils/accelerators/tpu_test.go @@ -74,7 +74,7 @@ func TestAddTPUVariables(t *testing.T) { hasWorkerIndexLabelKey: true, expectedTpuWorkerHostNames: "test-sample-1.default,test-sample-1-1.default,test-sample-1-2.default,test-sample-1-3.default,test-sample-1-4.default", expectedTpuWorkerId: "3", - expectedTpuName: "test-sample-1-3", + expectedTpuName: "test-sample-1", }, } @@ -133,7 +133,7 @@ func TestAddTPUVariablesSubGroup(t *testing.T) { }, expectedTpuWorkerId: "3", expectedTpuWorkerHostNames: "test-sample-1.default,test-sample-1-1.default,test-sample-1-2.default,test-sample-1-3.default,test-sample-1-4.default", - expectedTpuName: "test-sample-1-3", + expectedTpuName: "test-sample-1", }, { name: "Leader requests TPU resources, worker with subgroup index > 0", @@ -154,7 +154,7 @@ func TestAddTPUVariablesSubGroup(t *testing.T) { }, expectedTpuWorkerId: "3", expectedTpuWorkerHostNames: "test-sample-1-4.default,test-sample-1-5.default,test-sample-1-6.default,test-sample-1-7.default", - expectedTpuName: "test-sample-1-7", + expectedTpuName: "test-sample-1", }, { name: "Leader does not request TPU resources, worker with subgroup index > 0", @@ -174,7 +174,7 @@ func TestAddTPUVariablesSubGroup(t *testing.T) { }, expectedTpuWorkerId: "0", expectedTpuWorkerHostNames: "test-sample-1-5.default,test-sample-1-6.default,test-sample-1-7.default,test-sample-1-8.default", - expectedTpuName: "test-sample-1-5", + expectedTpuName: "test-sample-1", }, } for _, tc := range tests {