diff --git a/go.mod b/go.mod index 4a826a9917..157098a6a5 100644 --- a/go.mod +++ b/go.mod @@ -13,7 +13,7 @@ require ( github.com/aws/aws-sdk-go-v2/config v1.0.0 github.com/aws/aws-sdk-go-v2/service/athena v1.0.0 github.com/coocood/freecache v1.1.1 - github.com/flyteorg/flyteidl v0.18.15 + github.com/flyteorg/flyteidl v0.18.17 github.com/flyteorg/flytestdlib v0.3.13 github.com/go-logr/zapr v0.4.0 // indirect github.com/go-test/deep v1.0.7 diff --git a/go.sum b/go.sum index 9fd1c04b87..7fe008fb26 100644 --- a/go.sum +++ b/go.sum @@ -228,8 +228,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg= github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM= github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94= -github.com/flyteorg/flyteidl v0.18.15 h1:sXrlwTRaRjQsXYMNrY/S930SKdKtu4XnpNFEu8I4tn4= -github.com/flyteorg/flyteidl v0.18.15/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI= +github.com/flyteorg/flyteidl v0.18.17 h1:74pPZ9PzITuzq+CgjMPb9EcFI5bVkf8mM5m4xmmlTmY= +github.com/flyteorg/flyteidl v0.18.17/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI= github.com/flyteorg/flytestdlib v0.3.13 h1:5ioA/q3ixlyqkFh5kDaHgmPyTP/AHtqq1K/TIbVLUzM= github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220= github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= diff --git a/go/tasks/pluginmachinery/utils/marshal_utils.go b/go/tasks/pluginmachinery/utils/marshal_utils.go index ceb563b5d0..51e36fe395 100755 --- a/go/tasks/pluginmachinery/utils/marshal_utils.go +++ b/go/tasks/pluginmachinery/utils/marshal_utils.go @@ -68,3 +68,21 @@ func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { } return structObj, nil } + +// Don't use this if the unmarshalled obj is a proto message. +func UnmarshalStructToObj(structObj *structpb.Struct, obj interface{}) error { + if structObj == nil { + return fmt.Errorf("nil Struct Object passed") + } + + jsonObj, err := json.Marshal(structObj) + if err != nil { + return err + } + + if err = json.Unmarshal(jsonObj, obj); err != nil { + return err + } + + return nil +} diff --git a/go/tasks/pluginmachinery/utils/marshal_utils_test.go b/go/tasks/pluginmachinery/utils/marshal_utils_test.go new file mode 100644 index 0000000000..abe1b7d2a2 --- /dev/null +++ b/go/tasks/pluginmachinery/utils/marshal_utils_test.go @@ -0,0 +1,54 @@ +package utils + +import ( + "encoding/json" + "testing" + + "github.com/go-test/deep" + structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/stretchr/testify/assert" + v1 "k8s.io/api/core/v1" +) + +func TestUnmarshalStructToObj(t *testing.T) { + t.Run("no nil structs allowed", func(t *testing.T) { + var podSpec v1.PodSpec + err := UnmarshalStructToObj(nil, &podSpec) + assert.EqualError(t, err, "nil Struct Object passed") + }) + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "a container", + }, + { + Name: "another container", + }, + }, + } + + b, err := json.Marshal(podSpec) + if err != nil { + t.Fatal(err) + } + + structObj := &structpb.Struct{} + if err := json.Unmarshal(b, structObj); err != nil { + t.Fatal(err) + } + + t.Run("no nil pointers as obj allowed", func(t *testing.T) { + var nilPodspec *v1.PodSpec + err := UnmarshalStructToObj(structObj, nilPodspec) + assert.EqualError(t, err, "json: Unmarshal(nil *v1.PodSpec)") + }) + + t.Run("happy case", func(t *testing.T) { + var podSpecObj v1.PodSpec + err := UnmarshalStructToObj(structObj, &podSpecObj) + assert.NoError(t, err) + if diff := deep.Equal(podSpecObj, podSpec); diff != nil { + t.Errorf("UnmarshalStructToObj() got = %v, want %v, diff: %v", podSpecObj, podSpec, diff) + } + }) +} diff --git a/go/tasks/plugins/k8s/sidecar/sidecar.go b/go/tasks/plugins/k8s/sidecar/sidecar.go index 643c0360de..fecec11697 100755 --- a/go/tasks/plugins/k8s/sidecar/sidecar.go +++ b/go/tasks/plugins/k8s/sidecar/sidecar.go @@ -2,10 +2,10 @@ package sidecar import ( "context" - "encoding/json" "fmt" - structpb "github.com/golang/protobuf/ptypes/struct" + "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils" + "sigs.k8s.io/controller-runtime/pkg/client" "github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template" @@ -22,7 +22,7 @@ import ( const ( sidecarTaskType = "sidecar" - primaryContainerKey = "primary" + primaryContainerKey = "primary_container_name" ) type sidecarResourceHandler struct{} @@ -73,45 +73,55 @@ type sidecarJob struct { PrimaryContainerName string } -func unmarshalSidecarCustom(structObj *structpb.Struct, sidecarJob *sidecarJob) error { - if structObj == nil { - return fmt.Errorf("nil Struct Object passed") - } - - jsonObj, err := json.Marshal(structObj) - if err != nil { - return err - } - - if err = json.Unmarshal(jsonObj, sidecarJob); err != nil { - return err - } - - return nil -} - func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) { - sidecarJob := sidecarJob{} + var podSpec k8sv1.PodSpec + var primaryContainerName string + task, err := taskCtx.TaskReader().Read(ctx) if err != nil { return nil, errors.Errorf(errors.BadTaskSpecification, "TaskSpecification cannot be read, Err: [%v]", err.Error()) } - err = unmarshalSidecarCustom(task.GetCustom(), &sidecarJob) - if err != nil { - return nil, errors.Errorf(errors.BadTaskSpecification, - "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + if task.TaskTypeVersion == 0 { + sidecarJob := sidecarJob{} + err := utils.UnmarshalStructToObj(task.GetCustom(), &sidecarJob) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error()) + } + if sidecarJob.PodSpec == nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification, nil PodSpec [%v]", task.GetCustom()) + } + podSpec = *sidecarJob.PodSpec + primaryContainerName = sidecarJob.PrimaryContainerName + } else { + err := utils.UnmarshalStructToObj(task.GetCustom(), &podSpec) + if err != nil { + return nil, errors.Errorf(errors.BadTaskSpecification, + "Unable to unmarshal task custom [%v], Err: [%v]", task.GetCustom(), err.Error()) + } + if len(task.GetConfig()) == 0 { + return nil, errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", primaryContainerKey) + } + var ok bool + primaryContainerName, ok = task.GetConfig()[primaryContainerKey] + if !ok { + return nil, errors.Errorf(errors.BadTaskSpecification, + "invalid TaskSpecification, config missing [%s] key in [%v]", primaryContainerKey, task.GetConfig()) + } } - pod := flytek8s.BuildPodWithSpec(sidecarJob.PodSpec) + pod := flytek8s.BuildPodWithSpec(&podSpec) // Set the restart policy to *not* inherit from the default so that a completed pod doesn't get caught in a // CrashLoopBackoff after the initial job completion. pod.Spec.RestartPolicy = k8sv1.RestartPolicyNever - // We want to Also update the serviceAccount to the serviceaccount of the workflow + // We want to also update the serviceAccount to the serviceaccount of the workflow pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount() - pod, err = validateAndFinalizePod(ctx, taskCtx, sidecarJob.PrimaryContainerName, *pod) + pod, err = validateAndFinalizePod(ctx, taskCtx, primaryContainerName, *pod) if err != nil { return nil, err } @@ -120,7 +130,7 @@ func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx plugins pod.Annotations = make(map[string]string, 1) } - pod.Annotations[primaryContainerKey] = sidecarJob.PrimaryContainerName + pod.Annotations[primaryContainerKey] = primaryContainerName return pod, nil } diff --git a/go/tasks/plugins/k8s/sidecar/sidecar_test.go b/go/tasks/plugins/k8s/sidecar/sidecar_test.go index f8cd2b58e2..00ca6c9e83 100755 --- a/go/tasks/plugins/k8s/sidecar/sidecar_test.go +++ b/go/tasks/plugins/k8s/sidecar/sidecar_test.go @@ -112,6 +112,160 @@ func getDummySidecarTaskContext(taskTemplate *core.TaskTemplate, resources *v1.R return taskCtx } +func TestBuildSidecarResource_TaskType1(t *testing.T) { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "primary container", + Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"}, + Resources: v1.ResourceRequirements{ + Limits: v1.ResourceList{ + "cpu": resource.MustParse("2"), + "memory": resource.MustParse("200Mi"), + }, + Requests: v1.ResourceList{ + "cpu": resource.MustParse("1"), + "memory": resource.MustParse("100Mi"), + }, + }, + VolumeMounts: []v1.VolumeMount{ + { + Name: "volume mount", + }, + }, + }, + { + Name: "secondary container", + }, + }, + Volumes: []v1.Volume{ + { + Name: "dshm", + }, + }, + Tolerations: []v1.Toleration{ + { + Key: "my toleration key", + Value: "my toleration value", + }, + }, + } + + b, err := json.Marshal(podSpec) + if err != nil { + t.Fatal(err) + } + + structObj := &structpb.Struct{} + if err := json.Unmarshal(b, structObj); err != nil { + t.Fatal(err) + } + + task := core.TaskTemplate{ + Custom: structObj, + TaskTypeVersion: 1, + Config: map[string]string{ + primaryContainerKey: "primary container", + }, + } + + tolGPU := v1.Toleration{ + Key: "flyte/gpu", + Value: "dedicated", + Operator: v1.TolerationOpEqual, + Effect: v1.TaintEffectNoSchedule, + } + + tolStorage := v1.Toleration{ + Key: "storage", + Value: "dedicated", + Operator: v1.TolerationOpExists, + Effect: v1.TaintEffectNoSchedule, + } + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + ResourceTolerations: map[v1.ResourceName][]v1.Toleration{ + v1.ResourceStorage: {tolStorage}, + ResourceNvidiaGPU: {tolGPU}, + }, + DefaultCPURequest: "1024m", + DefaultMemoryRequest: "1024Mi", + })) + handler := &sidecarResourceHandler{} + taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) + res, err := handler.BuildResource(context.TODO(), taskCtx) + assert.Nil(t, err) + assert.EqualValues(t, map[string]string{ + primaryContainerKey: "primary container", + }, res.GetAnnotations()) + + // Assert volumes & volume mounts are preserved + assert.Len(t, res.(*v1.Pod).Spec.Volumes, 1) + assert.Equal(t, "dshm", res.(*v1.Pod).Spec.Volumes[0].Name) + + assert.Len(t, res.(*v1.Pod).Spec.Containers[0].VolumeMounts, 1) + assert.Equal(t, "volume mount", res.(*v1.Pod).Spec.Containers[0].VolumeMounts[0].Name) + + // Assert user-specified tolerations don't get overridden + assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 1) + for _, tol := range res.(*v1.Pod).Spec.Tolerations { + if tol.Key == "my toleration key" { + assert.Equal(t, tol.Value, "my toleration value") + } else { + t.Fatalf("unexpected toleration [%+v]", tol) + } + } + +} + +func TestBuildSideResource_TaskType1_InvalidSpec(t *testing.T) { + podSpec := v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "primary container", + }, + { + Name: "secondary container", + }, + }, + } + + b, err := json.Marshal(podSpec) + if err != nil { + t.Fatal(err) + } + + structObj := &structpb.Struct{} + if err := json.Unmarshal(b, structObj); err != nil { + t.Fatal(err) + } + + task := core.TaskTemplate{ + Custom: structObj, + TaskTypeVersion: 1, + } + + assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{ + ResourceTolerations: map[v1.ResourceName][]v1.Toleration{ + v1.ResourceStorage: {}, + ResourceNvidiaGPU: {}, + }, + DefaultCPURequest: "1024m", + DefaultMemoryRequest: "1024Mi", + })) + handler := &sidecarResourceHandler{} + taskCtx := getDummySidecarTaskContext(&task, resourceRequirements) + _, err = handler.BuildResource(context.TODO(), taskCtx) + assert.EqualError(t, err, "[BadTaskSpecification] invalid TaskSpecification, config needs to be non-empty and include missing [primary_container_name] key") + + task.Config = map[string]string{ + "foo": "bar", + } + taskCtx = getDummySidecarTaskContext(&task, resourceRequirements) + _, err = handler.BuildResource(context.TODO(), taskCtx) + assert.EqualError(t, err, "[BadTaskSpecification] invalid TaskSpecification, config missing [primary_container_name] key in [map[foo:bar]]") + +} + func TestBuildSidecarResource(t *testing.T) { dir, err := os.Getwd() if err != nil {