Skip to content

Commit

Permalink
Pod task revamped (flyteorg#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrina Rogan authored Mar 13, 2021
1 parent d110778 commit 4f22a59
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 32 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.0.0
github.com/aws/aws-sdk-go-v2/service/athena v1.0.0
github.com/coocood/freecache v1.1.1
github.com/flyteorg/flyteidl v0.18.15
github.com/flyteorg/flyteidl v0.18.17
github.com/flyteorg/flytestdlib v0.3.13
github.com/go-logr/zapr v0.4.0 // indirect
github.com/go-test/deep v1.0.7
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5Kwzbycv
github.com/fatih/color v1.10.0 h1:s36xzo75JdqLaaWoiEHk767eHiwo0598uUxyfiPkDsg=
github.com/fatih/color v1.10.0/go.mod h1:ELkj/draVOlAH/xkhN6mQ50Qd0MPOk5AAr3maGEBuJM=
github.com/fatih/structtag v1.2.0/go.mod h1:mBJUNpUnHmRKrKlQQlmCrh5PuhftFbNv8Ys4/aAZl94=
github.com/flyteorg/flyteidl v0.18.15 h1:sXrlwTRaRjQsXYMNrY/S930SKdKtu4XnpNFEu8I4tn4=
github.com/flyteorg/flyteidl v0.18.15/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI=
github.com/flyteorg/flyteidl v0.18.17 h1:74pPZ9PzITuzq+CgjMPb9EcFI5bVkf8mM5m4xmmlTmY=
github.com/flyteorg/flyteidl v0.18.17/go.mod h1:b5Fq4Z8a5b0mF6pEwTd48ufvikUGVkWSjZiMT0ZtqKI=
github.com/flyteorg/flytestdlib v0.3.13 h1:5ioA/q3ixlyqkFh5kDaHgmPyTP/AHtqq1K/TIbVLUzM=
github.com/flyteorg/flytestdlib v0.3.13/go.mod h1:Tz8JCECAbX6VWGwFT6cmEQ+RJpZ/6L9pswu3fzWs220=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk=
Expand Down
18 changes: 18 additions & 0 deletions go/tasks/pluginmachinery/utils/marshal_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,21 @@ func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) {
}
return structObj, nil
}

// Don't use this if the unmarshalled obj is a proto message.
func UnmarshalStructToObj(structObj *structpb.Struct, obj interface{}) error {
if structObj == nil {
return fmt.Errorf("nil Struct Object passed")
}

jsonObj, err := json.Marshal(structObj)
if err != nil {
return err
}

if err = json.Unmarshal(jsonObj, obj); err != nil {
return err
}

return nil
}
54 changes: 54 additions & 0 deletions go/tasks/pluginmachinery/utils/marshal_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package utils

import (
"encoding/json"
"testing"

"github.com/go-test/deep"
structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/stretchr/testify/assert"
v1 "k8s.io/api/core/v1"
)

func TestUnmarshalStructToObj(t *testing.T) {
t.Run("no nil structs allowed", func(t *testing.T) {
var podSpec v1.PodSpec
err := UnmarshalStructToObj(nil, &podSpec)
assert.EqualError(t, err, "nil Struct Object passed")
})
podSpec := v1.PodSpec{
Containers: []v1.Container{
{
Name: "a container",
},
{
Name: "another container",
},
},
}

b, err := json.Marshal(podSpec)
if err != nil {
t.Fatal(err)
}

structObj := &structpb.Struct{}
if err := json.Unmarshal(b, structObj); err != nil {
t.Fatal(err)
}

t.Run("no nil pointers as obj allowed", func(t *testing.T) {
var nilPodspec *v1.PodSpec
err := UnmarshalStructToObj(structObj, nilPodspec)
assert.EqualError(t, err, "json: Unmarshal(nil *v1.PodSpec)")
})

t.Run("happy case", func(t *testing.T) {
var podSpecObj v1.PodSpec
err := UnmarshalStructToObj(structObj, &podSpecObj)
assert.NoError(t, err)
if diff := deep.Equal(podSpecObj, podSpec); diff != nil {
t.Errorf("UnmarshalStructToObj() got = %v, want %v, diff: %v", podSpecObj, podSpec, diff)
}
})
}
68 changes: 39 additions & 29 deletions go/tasks/plugins/k8s/sidecar/sidecar.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package sidecar

import (
"context"
"encoding/json"
"fmt"

structpb "github.com/golang/protobuf/ptypes/struct"
"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/utils"

"sigs.k8s.io/controller-runtime/pkg/client"

"github.com/flyteorg/flyteplugins/go/tasks/pluginmachinery/core/template"
Expand All @@ -22,7 +22,7 @@ import (

const (
sidecarTaskType = "sidecar"
primaryContainerKey = "primary"
primaryContainerKey = "primary_container_name"
)

type sidecarResourceHandler struct{}
Expand Down Expand Up @@ -73,45 +73,55 @@ type sidecarJob struct {
PrimaryContainerName string
}

func unmarshalSidecarCustom(structObj *structpb.Struct, sidecarJob *sidecarJob) error {
if structObj == nil {
return fmt.Errorf("nil Struct Object passed")
}

jsonObj, err := json.Marshal(structObj)
if err != nil {
return err
}

if err = json.Unmarshal(jsonObj, sidecarJob); err != nil {
return err
}

return nil
}

func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCore.TaskExecutionContext) (client.Object, error) {
sidecarJob := sidecarJob{}
var podSpec k8sv1.PodSpec
var primaryContainerName string

task, err := taskCtx.TaskReader().Read(ctx)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"TaskSpecification cannot be read, Err: [%v]", err.Error())
}
err = unmarshalSidecarCustom(task.GetCustom(), &sidecarJob)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error())
if task.TaskTypeVersion == 0 {
sidecarJob := sidecarJob{}
err := utils.UnmarshalStructToObj(task.GetCustom(), &sidecarJob)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"invalid TaskSpecification [%v], Err: [%v]", task.GetCustom(), err.Error())
}
if sidecarJob.PodSpec == nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"invalid TaskSpecification, nil PodSpec [%v]", task.GetCustom())
}
podSpec = *sidecarJob.PodSpec
primaryContainerName = sidecarJob.PrimaryContainerName
} else {
err := utils.UnmarshalStructToObj(task.GetCustom(), &podSpec)
if err != nil {
return nil, errors.Errorf(errors.BadTaskSpecification,
"Unable to unmarshal task custom [%v], Err: [%v]", task.GetCustom(), err.Error())
}
if len(task.GetConfig()) == 0 {
return nil, errors.Errorf(errors.BadTaskSpecification,
"invalid TaskSpecification, config needs to be non-empty and include missing [%s] key", primaryContainerKey)
}
var ok bool
primaryContainerName, ok = task.GetConfig()[primaryContainerKey]
if !ok {
return nil, errors.Errorf(errors.BadTaskSpecification,
"invalid TaskSpecification, config missing [%s] key in [%v]", primaryContainerKey, task.GetConfig())
}
}

pod := flytek8s.BuildPodWithSpec(sidecarJob.PodSpec)
pod := flytek8s.BuildPodWithSpec(&podSpec)
// Set the restart policy to *not* inherit from the default so that a completed pod doesn't get caught in a
// CrashLoopBackoff after the initial job completion.
pod.Spec.RestartPolicy = k8sv1.RestartPolicyNever

// We want to Also update the serviceAccount to the serviceaccount of the workflow
// We want to also update the serviceAccount to the serviceaccount of the workflow
pod.Spec.ServiceAccountName = taskCtx.TaskExecutionMetadata().GetK8sServiceAccount()

pod, err = validateAndFinalizePod(ctx, taskCtx, sidecarJob.PrimaryContainerName, *pod)
pod, err = validateAndFinalizePod(ctx, taskCtx, primaryContainerName, *pod)
if err != nil {
return nil, err
}
Expand All @@ -120,7 +130,7 @@ func (sidecarResourceHandler) BuildResource(ctx context.Context, taskCtx plugins
pod.Annotations = make(map[string]string, 1)
}

pod.Annotations[primaryContainerKey] = sidecarJob.PrimaryContainerName
pod.Annotations[primaryContainerKey] = primaryContainerName

return pod, nil
}
Expand Down
154 changes: 154 additions & 0 deletions go/tasks/plugins/k8s/sidecar/sidecar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,160 @@ func getDummySidecarTaskContext(taskTemplate *core.TaskTemplate, resources *v1.R
return taskCtx
}

func TestBuildSidecarResource_TaskType1(t *testing.T) {
podSpec := v1.PodSpec{
Containers: []v1.Container{
{
Name: "primary container",
Args: []string{"pyflyte-execute", "--task-module", "tests.flytekit.unit.sdk.tasks.test_sidecar_tasks", "--task-name", "simple_sidecar_task", "--inputs", "{{.input}}", "--output-prefix", "{{.outputPrefix}}"},
Resources: v1.ResourceRequirements{
Limits: v1.ResourceList{
"cpu": resource.MustParse("2"),
"memory": resource.MustParse("200Mi"),
},
Requests: v1.ResourceList{
"cpu": resource.MustParse("1"),
"memory": resource.MustParse("100Mi"),
},
},
VolumeMounts: []v1.VolumeMount{
{
Name: "volume mount",
},
},
},
{
Name: "secondary container",
},
},
Volumes: []v1.Volume{
{
Name: "dshm",
},
},
Tolerations: []v1.Toleration{
{
Key: "my toleration key",
Value: "my toleration value",
},
},
}

b, err := json.Marshal(podSpec)
if err != nil {
t.Fatal(err)
}

structObj := &structpb.Struct{}
if err := json.Unmarshal(b, structObj); err != nil {
t.Fatal(err)
}

task := core.TaskTemplate{
Custom: structObj,
TaskTypeVersion: 1,
Config: map[string]string{
primaryContainerKey: "primary container",
},
}

tolGPU := v1.Toleration{
Key: "flyte/gpu",
Value: "dedicated",
Operator: v1.TolerationOpEqual,
Effect: v1.TaintEffectNoSchedule,
}

tolStorage := v1.Toleration{
Key: "storage",
Value: "dedicated",
Operator: v1.TolerationOpExists,
Effect: v1.TaintEffectNoSchedule,
}
assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
ResourceTolerations: map[v1.ResourceName][]v1.Toleration{
v1.ResourceStorage: {tolStorage},
ResourceNvidiaGPU: {tolGPU},
},
DefaultCPURequest: "1024m",
DefaultMemoryRequest: "1024Mi",
}))
handler := &sidecarResourceHandler{}
taskCtx := getDummySidecarTaskContext(&task, resourceRequirements)
res, err := handler.BuildResource(context.TODO(), taskCtx)
assert.Nil(t, err)
assert.EqualValues(t, map[string]string{
primaryContainerKey: "primary container",
}, res.GetAnnotations())

// Assert volumes & volume mounts are preserved
assert.Len(t, res.(*v1.Pod).Spec.Volumes, 1)
assert.Equal(t, "dshm", res.(*v1.Pod).Spec.Volumes[0].Name)

assert.Len(t, res.(*v1.Pod).Spec.Containers[0].VolumeMounts, 1)
assert.Equal(t, "volume mount", res.(*v1.Pod).Spec.Containers[0].VolumeMounts[0].Name)

// Assert user-specified tolerations don't get overridden
assert.Len(t, res.(*v1.Pod).Spec.Tolerations, 1)
for _, tol := range res.(*v1.Pod).Spec.Tolerations {
if tol.Key == "my toleration key" {
assert.Equal(t, tol.Value, "my toleration value")
} else {
t.Fatalf("unexpected toleration [%+v]", tol)
}
}

}

func TestBuildSideResource_TaskType1_InvalidSpec(t *testing.T) {
podSpec := v1.PodSpec{
Containers: []v1.Container{
{
Name: "primary container",
},
{
Name: "secondary container",
},
},
}

b, err := json.Marshal(podSpec)
if err != nil {
t.Fatal(err)
}

structObj := &structpb.Struct{}
if err := json.Unmarshal(b, structObj); err != nil {
t.Fatal(err)
}

task := core.TaskTemplate{
Custom: structObj,
TaskTypeVersion: 1,
}

assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
ResourceTolerations: map[v1.ResourceName][]v1.Toleration{
v1.ResourceStorage: {},
ResourceNvidiaGPU: {},
},
DefaultCPURequest: "1024m",
DefaultMemoryRequest: "1024Mi",
}))
handler := &sidecarResourceHandler{}
taskCtx := getDummySidecarTaskContext(&task, resourceRequirements)
_, err = handler.BuildResource(context.TODO(), taskCtx)
assert.EqualError(t, err, "[BadTaskSpecification] invalid TaskSpecification, config needs to be non-empty and include missing [primary_container_name] key")

task.Config = map[string]string{
"foo": "bar",
}
taskCtx = getDummySidecarTaskContext(&task, resourceRequirements)
_, err = handler.BuildResource(context.TODO(), taskCtx)
assert.EqualError(t, err, "[BadTaskSpecification] invalid TaskSpecification, config missing [primary_container_name] key in [map[foo:bar]]")

}

func TestBuildSidecarResource(t *testing.T) {
dir, err := os.Getwd()
if err != nil {
Expand Down

0 comments on commit 4f22a59

Please sign in to comment.