diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index db62aeb4e7..e8252090df 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -348,6 +348,15 @@ func ApplyFlytePodConfiguration(ctx context.Context, tCtx pluginsCore.TaskExecut IncludeConsoleURL: hasExternalLinkType(taskTemplate), } + // iterate over the initContainers first + for index := range podSpec.InitContainers { + var resourceMode = ResourceCustomizationModeEnsureExistingResourcesInRange + + if err := AddFlyteCustomizationsToContainer(ctx, templateParameters, resourceMode, &podSpec.InitContainers[index]); err != nil { + return nil, nil, err + } + } + resourceRequests := make([]v1.ResourceRequirements, 0, len(podSpec.Containers)) var primaryContainer *v1.Container for index, container := range podSpec.Containers { diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go b/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go index 5d89e2f0ec..f9d49b2448 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/container_test.go @@ -49,6 +49,13 @@ func dummyContainerTaskTemplate(command []string, args []string) *core.TaskTempl func dummyContainerTaskTemplateWithPodSpec(command []string, args []string) *core.TaskTemplate { podSpec := v1.PodSpec{ + InitContainers: []v1.Container{ + v1.Container{ + Name: "test-image", + Command: command, + Args: args, + }, + }, Containers: []v1.Container{ v1.Container{ Name: "test-image", @@ -174,24 +181,28 @@ func TestContainerTaskExecutor_BuildResource(t *testing.T) { taskTemplate *core.TaskTemplate taskMetadata pluginsCore.TaskExecutionMetadata expectServiceAccount string + checkInitContainer bool }{ { name: "BuildResource", taskTemplate: dummyContainerTaskTemplate(command, args), taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, true, ""), expectServiceAccount: serviceAccount, + checkInitContainer: false, }, { name: "BuildResource_PodTemplate", taskTemplate: dummyContainerTaskTemplateWithPodSpec(command, args), taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, true, ""), expectServiceAccount: podTemplateServiceAccount, + checkInitContainer: true, }, { name: "BuildResource_SecurityContext", taskTemplate: dummyContainerTaskTemplate(command, args), taskMetadata: dummyContainerTaskMetadata(containerResourceRequirements, nil, false, ""), expectServiceAccount: securityContextServiceAccount, + checkInitContainer: false, }, } for _, tc := range testCases { @@ -213,6 +224,11 @@ func TestContainerTaskExecutor_BuildResource(t *testing.T) { assert.Equal(t, command, j.Spec.Containers[0].Command) assert.Equal(t, []string{"test-data-reference"}, j.Spec.Containers[0].Args) + if tc.checkInitContainer { + assert.Equal(t, command, j.Spec.InitContainers[0].Command) + assert.Equal(t, []string{"test-data-reference"}, j.Spec.InitContainers[0].Args) + } + assert.Equal(t, tc.expectServiceAccount, j.Spec.ServiceAccountName) }) }