Skip to content

Commit

Permalink
Bug fix in Sagemaker plugin (flyteorg#110)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ketan Umare authored Aug 11, 2020
1 parent 8dd7a24 commit 99fe14b
Show file tree
Hide file tree
Showing 10 changed files with 593 additions and 22 deletions.
11 changes: 2 additions & 9 deletions go/tasks/plugins/array/awsbatch/job_definition.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
pluginErrors "github.com/lyft/flyteplugins/go/tasks/errors"
"github.com/lyft/flyteplugins/go/tasks/plugins/array/awsbatch/config"
arrayCore "github.com/lyft/flyteplugins/go/tasks/plugins/array/core"
awsUtils "github.com/lyft/flyteplugins/go/tasks/plugins/awsutils"
"github.com/lyft/flytestdlib/errors"
"github.com/lyft/flytestdlib/logger"

Expand All @@ -23,14 +24,6 @@ func getContainerImage(_ context.Context, task *core.TaskTemplate) string {
return ""
}

func getRole(_ context.Context, roleAnnotationKey string, annotations map[string]string) string {
if len(roleAnnotationKey) > 0 {
return annotations[roleAnnotationKey]
}

return ""
}

var urlRegex = regexp.MustCompile(`^(?:([^/]+)/)?(?:([^/]+)/)*?([^@:/]+)(?:[@:][^/]+)?$`)

// Gets the repository part of the container image url
Expand All @@ -57,7 +50,7 @@ func EnsureJobDefinition(ctx context.Context, tCtx pluginCore.TaskExecutionConte
return nil, errors.Errorf(pluginErrors.BadTaskSpecification, "Tasktemplate does not contain a container image.")
}

role := getRole(ctx, cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata().GetAnnotations())
role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, tCtx.TaskExecutionMetadata().GetAnnotations())

cacheKey := definition.NewCacheKey(role, containerImage)
if existingArn, found := definitionCache.Get(cacheKey); found {
Expand Down
11 changes: 11 additions & 0 deletions go/tasks/plugins/awsutils/awsutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package awsutils

import "context"

func GetRole(_ context.Context, roleAnnotationKey string, annotations map[string]string) string {
if len(roleAnnotationKey) > 0 {
return annotations[roleAnnotationKey]
}

return ""
}
9 changes: 5 additions & 4 deletions go/tasks/plugins/k8s/sagemaker/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const sagemakerConfigSectionKey = "sagemaker"

var (
defaultConfig = Config{
RoleArn: "default",
RoleArn: "default_role",
Region: "us-east-1",
// https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html
PrebuiltAlgorithms: []PrebuiltAlgorithmConfig{
Expand All @@ -19,8 +19,8 @@ var (
Region: "us-east-1",
VersionConfigs: []VersionConfig{
{
Version: "0.91",
Image: "811284229777.dkr.ecr.us-east-1.amazonaws.com/xgboost:latest",
Version: "0.90",
Image: "683313688378.dkr.ecr.us-east-1.amazonaws.com/sagemaker-xgboost:0.90-2-cpu-py3",
},
},
},
Expand All @@ -32,10 +32,11 @@ var (
sagemakerConfigSection = pluginsConfig.MustRegisterSubSection(sagemakerConfigSectionKey, &defaultConfig)
)

//Sagemaker plugin configs
// Sagemaker plugin configs
type Config struct {
RoleArn string `json:"roleArn" pflag:",The role the SageMaker plugin uses to communicate with the SageMaker service"`
Region string `json:"region" pflag:",The AWS region the SageMaker plugin communicates to"`
RoleAnnotationKey string `json:"roleAnnotationKey" pflag:",Map key to use to lookup role from task annotations."`
PrebuiltAlgorithms []PrebuiltAlgorithmConfig `json:"prebuiltAlgorithms" pflag:"-,A List of PrebuiltAlgorithm configs"`
}
type PrebuiltAlgorithmConfig struct {
Expand Down
46 changes: 44 additions & 2 deletions go/tasks/plugins/k8s/sagemaker/sagemaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strings"
"time"

awsUtils "github.com/lyft/flyteplugins/go/tasks/plugins/awsutils"

hpojobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/hyperparametertuningjob"
trainingjobController "github.com/aws/amazon-sagemaker-operator-for-k8s/controllers/trainingjob"
"github.com/lyft/flytestdlib/logger"
Expand Down Expand Up @@ -65,6 +67,9 @@ func (m awsSagemakerPlugin) BuildResourceForTrainingJob(
if err != nil {
return nil, errors.Wrapf(err, "invalid TrainingJob task specification: not able to unmarshal the custom field to [%s]", m.TaskType)
}
if sagemakerTrainingJob.GetTrainingJobResourceConfig() == nil {
return nil, errors.Errorf("Required field [TrainingJobResourceConfig] of the TrainingJob does not exist")
}

taskInput, err := taskCtx.InputReader().Get(ctx)
if err != nil {
Expand All @@ -78,10 +83,16 @@ func (m awsSagemakerPlugin) BuildResourceForTrainingJob(
if !ok {
return nil, errors.Errorf("Required input not specified: [train]")
}
if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil {
return nil, errors.Errorf("[train] Input is required and should be of Type [Scalar.Blob]")
}
validatePathLiteral, ok := inputLiterals["validation"]
if !ok {
return nil, errors.Errorf("Required input not specified: [validation]")
}
if validatePathLiteral.GetScalar() == nil || validatePathLiteral.GetScalar().GetBlob() == nil {
return nil, errors.Errorf("[validation] Input is required and should be of Type [Scalar.Blob]")
}
staticHyperparamsLiteral, ok := inputLiterals["static_hyperparameters"]
if !ok {
return nil, errors.Errorf("Required input not specified: [static_hyperparameters]")
Expand All @@ -106,6 +117,9 @@ func (m awsSagemakerPlugin) BuildResourceForTrainingJob(

cfg := config.GetSagemakerConfig()

if sagemakerTrainingJob.GetAlgorithmSpecification() == nil {
return nil, errors.Errorf("Required field [AlgorithmSpecification] does not exist")
}
var metricDefinitions []commonv1.MetricDefinition
idlMetricDefinitions := sagemakerTrainingJob.GetAlgorithmSpecification().GetMetricDefinitions()
for _, md := range idlMetricDefinitions {
Expand All @@ -120,6 +134,10 @@ func (m awsSagemakerPlugin) BuildResourceForTrainingJob(

inputModeString := strings.Title(strings.ToLower(sagemakerTrainingJob.GetAlgorithmSpecification().GetInputMode().String()))

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations())
if role == "" {
role = cfg.RoleArn
}
trainingJob := &trainingjobv1.TrainingJob{
Spec: trainingjobv1.TrainingJobSpec{
AlgorithmSpecification: &commonv1.AlgorithmSpecification{
Expand Down Expand Up @@ -168,7 +186,7 @@ func (m awsSagemakerPlugin) BuildResourceForTrainingJob(
VolumeSizeInGB: ToInt64Ptr(sagemakerTrainingJob.GetTrainingJobResourceConfig().GetVolumeSizeInGb()),
VolumeKmsKeyId: ToStringPtr(""), // TODO: Not yet supported. Need to add to proto and flytekit in the future
},
RoleArn: ToStringPtr(cfg.RoleArn),
RoleArn: ToStringPtr(role),
Region: ToStringPtr(cfg.Region),
StoppingCondition: &commonv1.StoppingCondition{
MaxRuntimeInSeconds: ToInt64Ptr(86400), // TODO: decide how to coordinate this and Flyte's timeout
Expand Down Expand Up @@ -199,6 +217,15 @@ func (m awsSagemakerPlugin) BuildResourceForHyperparameterTuningJob(
if err != nil {
return nil, errors.Wrapf(err, "invalid HyperparameterTuningJob task specification: not able to unmarshal the custom field to [%s]", hyperparameterTuningJobTaskType)
}
if sagemakerHPOJob.GetTrainingJob() == nil {
return nil, errors.Errorf("Required field [TrainingJob] of the HyperparameterTuningJob does not exist")
}
if sagemakerHPOJob.GetTrainingJob().GetAlgorithmSpecification() == nil {
return nil, errors.Errorf("Required field [AlgorithmSpecification] of the HyperparameterTuningJob's underlying TrainingJob does not exist")
}
if sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig() == nil {
return nil, errors.Errorf("Required field [TrainingJobResourceConfig] of the HyperparameterTuningJob's underlying TrainingJob does not exist")
}

taskInput, err := taskCtx.InputReader().Get(ctx)
if err != nil {
Expand All @@ -212,10 +239,16 @@ func (m awsSagemakerPlugin) BuildResourceForHyperparameterTuningJob(
if !ok {
return nil, errors.Errorf("Required input not specified: [train]")
}
if trainPathLiteral.GetScalar() == nil || trainPathLiteral.GetScalar().GetBlob() == nil {
return nil, errors.Errorf("[train] Input is required and should be of Type [Scalar.Blob]")
}
validatePathLiteral, ok := inputLiterals["validation"]
if !ok {
return nil, errors.Errorf("Required input not specified: [validation]")
}
if validatePathLiteral.GetScalar() == nil || validatePathLiteral.GetScalar().GetBlob() == nil {
return nil, errors.Errorf("[validation] Input is required and should be of Type [Scalar.Blob]")
}
staticHyperparamsLiteral, ok := inputLiterals["static_hyperparameters"]
if !ok {
return nil, errors.Errorf("Required input not specified: [static_hyperparameters]")
Expand All @@ -240,6 +273,10 @@ func (m awsSagemakerPlugin) BuildResourceForHyperparameterTuningJob(
return nil, errors.Wrapf(err, "failed to convert hyperparameter tuning job config literal to spec type")
}

if hpoJobConfig.GetTuningObjective() == nil {
return nil, errors.Errorf("Required field [TuningObjective] does not exist")
}

// Deleting the conflicting static hyperparameters: if a hyperparameter exist in both the map of static hyperparameter
// and the map of the tunable hyperparameter inside the Hyperparameter Tuning Job Config, we delete the entry
// in the static map and let the one in the map of the tunable hyperparameters take precedence
Expand Down Expand Up @@ -279,6 +316,11 @@ func (m awsSagemakerPlugin) BuildResourceForHyperparameterTuningJob(
tuningObjectiveTypeString := strings.Title(strings.ToLower(hpoJobConfig.GetTuningObjective().GetObjectiveType().String()))
trainingJobEarlyStoppingTypeString := strings.Title(strings.ToLower(hpoJobConfig.TrainingJobEarlyStoppingType.String()))

role := awsUtils.GetRole(ctx, cfg.RoleAnnotationKey, taskCtx.TaskExecutionMetadata().GetAnnotations())
if role == "" {
role = cfg.RoleArn
}

hpoJob := &hpojobv1.HyperparameterTuningJob{
Spec: hpojobv1.HyperparameterTuningJobSpec{
HyperParameterTuningJobName: &taskName,
Expand Down Expand Up @@ -336,7 +378,7 @@ func (m awsSagemakerPlugin) BuildResourceForHyperparameterTuningJob(
VolumeSizeInGB: ToInt64Ptr(sagemakerHPOJob.GetTrainingJob().GetTrainingJobResourceConfig().GetVolumeSizeInGb()),
VolumeKmsKeyId: ToStringPtr(""), // TODO: Not yet supported. Need to add to proto and flytekit in the future
},
RoleArn: ToStringPtr(cfg.RoleArn),
RoleArn: ToStringPtr(role),
StoppingCondition: &commonv1.StoppingCondition{
MaxRuntimeInSeconds: ToInt64Ptr(86400),
MaxWaitTimeInSeconds: nil,
Expand Down
Loading

0 comments on commit 99fe14b

Please sign in to comment.