From e5b6ccbe185282144572cb7b6bd0695f2f79515e Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Tue, 10 Mar 2020 16:22:03 -0700 Subject: [PATCH 01/26] Presto plugin executor --- go.mod | 2 +- go.sum | 4 +- go/tasks/plugins/command/command_client.go | 13 + .../plugins/presto/client/presto_client.go | 67 +++ .../plugins/presto/client/presto_status.go | 41 ++ go/tasks/plugins/presto/config/config.go | 67 +++ go/tasks/plugins/presto/execution_state.go | 530 ++++++++++++++++++ go/tasks/plugins/presto/executions_cache.go | 165 ++++++ go/tasks/plugins/presto/executor.go | 166 ++++++ go/tasks/plugins/presto/executor_metrics.go | 33 ++ 10 files changed, 1085 insertions(+), 3 deletions(-) create mode 100644 go/tasks/plugins/command/command_client.go create mode 100644 go/tasks/plugins/presto/client/presto_client.go create mode 100644 go/tasks/plugins/presto/client/presto_status.go create mode 100644 go/tasks/plugins/presto/config/config.go create mode 100644 go/tasks/plugins/presto/execution_state.go create mode 100644 go/tasks/plugins/presto/executions_cache.go create mode 100644 go/tasks/plugins/presto/executor.go create mode 100644 go/tasks/plugins/presto/executor_metrics.go diff --git a/go.mod b/go.mod index e2952d643..7a2a8ae70 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,7 @@ require ( github.com/golang/protobuf v1.3.3 github.com/googleapis/gnostic v0.4.1 // indirect github.com/hashicorp/golang-lru v0.5.4 - github.com/lyft/flyteidl v0.17.6 + github.com/lyft/flyteidl v0.17.9 github.com/lyft/flytestdlib v0.3.2 github.com/magiconair/properties v1.8.1 github.com/mitchellh/mapstructure v1.1.2 diff --git a/go.sum b/go.sum index 0e2ec3d71..ca8dc526a 100644 --- a/go.sum +++ b/go.sum @@ -296,8 +296,8 @@ github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0 h1:NGL46+1RYcCXb3sShp0nQq github.com/lyft/api v0.0.0-20191031200350-b49a72c274e0/go.mod h1:/L5qH+AD540e7Cetbui1tuJeXdmNhO8jM6VkXeDdDhQ= github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f h1:PGuAMDzAen0AulUfaEhNQMYmUpa41pAVo3zHI+GJsCM= github.com/lyft/apimachinery v0.0.0-20191031200210-047e3ea32d7f/go.mod h1:llRdnznGEAqC3DcNm6yEj472xaFVfLM7hnYofMb12tQ= -github.com/lyft/flyteidl v0.17.6 h1:O0qpT6ya45e/92+E84uGOYa0ZsaFoE5ZfPoyJ6e1bEQ= -github.com/lyft/flyteidl v0.17.6/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= +github.com/lyft/flyteidl v0.17.9 h1:JXT9PovHqS9V3YN74x9zWT0kvIEL48c2uNoujF1KMes= +github.com/lyft/flyteidl v0.17.9/go.mod h1:/zQXxuHO11u/saxTTZc8oYExIGEShXB+xCB1/F1Cu20= github.com/lyft/flytestdlib v0.3.0 h1:nIkX4MlyYdcLLzaF35RI2P5BhARt+qMgHoFto8eVNzU= github.com/lyft/flytestdlib v0.3.0/go.mod h1:LJPPJlkFj+wwVWMrQT3K5JZgNhZi2mULsCG4ZYhinhU= github.com/lyft/flytestdlib v0.3.2 h1:bY6Y+Fg6Jdc7zY4GAYuR7t2hjWwynIdmRvtLcRNaGnw= diff --git a/go/tasks/plugins/command/command_client.go b/go/tasks/plugins/command/command_client.go new file mode 100644 index 000000000..9e22be992 --- /dev/null +++ b/go/tasks/plugins/command/command_client.go @@ -0,0 +1,13 @@ +package command + +import ( + "context" +) + +type CommandStatus string + +type CommandClient interface { + ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) + KillCommand(ctx context.Context, commandID string) error + GetCommandStatus(ctx context.Context, commandID string) (CommandStatus, error) +} diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go new file mode 100644 index 000000000..bd1ac30b9 --- /dev/null +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -0,0 +1,67 @@ +package client + +import ( + //"bytes" + "context" + "github.com/lyft/flyteplugins/go/tasks/plugins/command" + "net/http" + "net/url" + + "time" + + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" +) + +const ( + httpRequestTimeoutSecs = 30 + + AcceptHeaderKey = "Accept" + ContentTypeHeaderKey = "Content-Type" + ContentTypeJSON = "application/json" + ContentTypeTextPlain = "text/plain" + PrestoCatalogHeader = "X-Presto-Catalog" + PrestoRoutingGroupHeader = "X-Presto-Routing-Group" + PrestoSchemaHeader = "X-Presto-Schema" + PrestoSourceHeader = "X-Presto-Source" + PrestoUserHeader = "X-Presto-User" +) + +type prestoClient struct { + client *http.Client + environment *url.URL +} + +type PrestoExecuteArgs struct { + RoutingGroup string `json:"routing_group, omitempty"` + Catalog string `json:"catalog, omitempty"` + Schema string `json:"schema, omitempty"` + Source string `json:"source, omitempty"` +} +type PrestoExecuteResponse struct { + Id string + Status command.CommandStatus + NextUri string +} + +func (p *prestoClient) ExecuteCommand( + ctx context.Context, + queryStr string, + extraArgs interface{}) (interface{}, error) { + + return PrestoExecuteResponse{}, nil +} + +func (p *prestoClient) KillCommand(ctx context.Context, commandID string) error { + return nil +} + +func (p *prestoClient) GetCommandStatus(ctx context.Context, commandId string) (command.CommandStatus, error) { + return PrestoStatusUnknown, nil +} + +func NewPrestoClient(cfg *config.Config) command.CommandClient { + return &prestoClient{ + client: &http.Client{Timeout: httpRequestTimeoutSecs * time.Second}, + environment: cfg.Environment.ResolveReference(&cfg.Environment.URL), + } +} diff --git a/go/tasks/plugins/presto/client/presto_status.go b/go/tasks/plugins/presto/client/presto_status.go new file mode 100644 index 000000000..eeff09361 --- /dev/null +++ b/go/tasks/plugins/presto/client/presto_status.go @@ -0,0 +1,41 @@ +package client + +import ( + "context" + "github.com/lyft/flyteplugins/go/tasks/plugins/command" + "github.com/lyft/flytestdlib/logger" + "strings" +) + +// This type is meant only to encapsulate the response coming from Presto as a type, it is +// not meant to be stored locally. + +const ( + PrestoStatusUnknown command.CommandStatus = "UNKNOWN" + PrestoStatusQueued command.CommandStatus = "QUEUED" + PrestoStatusRunning command.CommandStatus = "RUNNING" + PrestoStatusFinished command.CommandStatus = "FINISHED" + PrestoStatusError command.CommandStatus = "FAILED" + PrestoStatusCancelled command.CommandStatus = "CANCELLED" +) + +var PrestoStatuses = map[command.CommandStatus]struct{}{ + PrestoStatusUnknown: {}, + PrestoStatusQueued: {}, + PrestoStatusRunning: {}, + PrestoStatusFinished: {}, + PrestoStatusError: {}, + PrestoStatusCancelled: {}, +} + +func NewPrestoStatus(ctx context.Context, state string) command.CommandStatus { + upperCased := strings.ToUpper(state) + if strings.Contains(upperCased, "FAILED") { + return PrestoStatusError + } else if _, ok := PrestoStatuses[command.CommandStatus(upperCased)]; ok { + return command.CommandStatus(upperCased) + } else { + logger.Warnf(ctx, "Invalid Presto Status found: %v", state) + return PrestoStatusUnknown + } +} diff --git a/go/tasks/plugins/presto/config/config.go b/go/tasks/plugins/presto/config/config.go new file mode 100644 index 000000000..1b3e141d4 --- /dev/null +++ b/go/tasks/plugins/presto/config/config.go @@ -0,0 +1,67 @@ +package config + +//go:generate pflags Config --default-var=defaultConfig + +import ( + "context" + "net/url" + + "github.com/lyft/flytestdlib/config" + "github.com/lyft/flytestdlib/logger" + + pluginsConfig "github.com/lyft/flyteplugins/go/tasks/config" +) + +const prestoConfigSectionKey = "presto" + +func UrlMustParse(s string) config.URL { + r, err := url.Parse(s) + if err != nil { + logger.Panicf(context.TODO(), "Bad Presto URL Specified as default, error: %s", err) + } + if r == nil { + logger.Panicf(context.TODO(), "Nil Presto URL specified.", err) + } + return config.URL{URL: *r} +} + +type RoutingGroupConfig struct { + Name string `json:"primaryLabel" pflag:",The name of a given Presto routing group"` + Limit int `json:"limit" pflag:",Resource quota (in the number of outstanding requests) of the routing group"` + ProjectScopeQuotaProportionCap float64 `json:"projectScopeQuotaProportionCap" pflag:",A floating point number between 0 and 1, specifying the maximum proportion of quotas allowed to allocate to a project in the routing group"` + NamespaceScopeQuotaProportionCap float64 `json:"namespaceScopeQuotaProportionCap" pflag:",A floating point number between 0 and 1, specifying the maximum proportion of quotas allowed to allocate to a namespace in the routing group"` +} + +var ( + defaultConfig = Config{ + Environment: UrlMustParse("https://prestoproxy-internal.lyft.net:443"), + DefaultRoutingGroup: "adhoc", + Workers: 15, + LruCacheSize: 2000, + AwsS3ShardFormatter: "s3://lyft-modelbuilder/{}/", + AwsS3ShardCount: 2, + RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 250}}, + } + + prestoConfigSection = pluginsConfig.MustRegisterSubSection(prestoConfigSectionKey, &defaultConfig) +) + +// Presto plugin configs +type Config struct { + Environment config.URL `json:"endpoint" pflag:",Endpoint for Presto to use"` + DefaultRoutingGroup string `json:"defaultRoutingGroup" pflag:",Default Presto routing group"` + Workers int `json:"workers" pflag:",Number of parallel workers to refresh the cache"` + LruCacheSize int `json:"lruCacheSize" pflag:",Size of the AutoRefreshCache"` + AwsS3ShardFormatter string `json:"awsS3ShardFormatter" pflag:", S3 bucket prefix where Presto results will be stored"` + AwsS3ShardCount int `json:"awsS3ShardStringLength" pflag:", Number of characters for the S3 bucket shard prefix"` + RoutingGroupConfigs []RoutingGroupConfig `json:"clusterConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` +} + +// Retrieves the current config value or default. +func GetPrestoConfig() *Config { + return prestoConfigSection.GetConfig().(*Config) +} + +func SetPrestoConfig(cfg *Config) error { + return prestoConfigSection.SetConfig(cfg) +} diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go new file mode 100644 index 000000000..230b1e798 --- /dev/null +++ b/go/tasks/plugins/presto/execution_state.go @@ -0,0 +1,530 @@ +package presto + +import ( + "context" + "crypto/rand" + "fmt" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + "strings" + + "time" + + "github.com/lyft/flytestdlib/cache" + + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + + "github.com/lyft/flyteplugins/go/tasks/errors" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flytestdlib/logger" + + "github.com/lyft/flyteplugins/go/tasks/plugins/command" +) + +type ExecutionPhase int + +const ( + PhaseNotStarted ExecutionPhase = iota + PhaseQueued // resource manager token gotten + PhaseSubmitted // Sent off to Presto + PhaseQuerySucceeded + PhaseQueryFailed +) + +func (p ExecutionPhase) String() string { + switch p { + case PhaseNotStarted: + return "PhaseNotStarted" + case PhaseQueued: + return "PhaseQueued" + case PhaseSubmitted: + return "PhaseSubmitted" + case PhaseQuerySucceeded: + return "PhaseQuerySucceeded" + case PhaseQueryFailed: + return "PhaseQueryFailed" + } + return "Bad Presto execution phase" +} + +type ExecutionState struct { + Phase ExecutionPhase + + // This will store the command ID from Presto + CommandId string `json:"command_id,omitempty"` + URI string `json:"uri,omitempty"` + + CurrentPrestoQuery PrestoQuery `json:"current_presto_query, omitempty"` + QueryCount int `json:"query_count,omitempty"` + + // This number keeps track of the number of failures within the sync function. Without this, what happens in + // the sync function is entirely opaque. Note that this field is completely orthogonal to Flyte system/node/task + // level retries, just errors from hitting the Presto API, inside the sync loop + SyncFailureCount int `json:"sync_failure_count,omitempty"` + + // In kicking off the Presto command, this is the number of failures + CreationFailureCount int `json:"creation_failure_count,omitempty"` + + // The time the execution first requests for an allocation token + AllocationTokenRequestStartTime time.Time `json:"allocation_token_request_start_time,omitempty"` +} + +type PrestoQuery struct { + Statement string `json:"statement, omitempty"` + ExtraArgs client.PrestoExecuteArgs `json:"extra_args, omitempty"` + TempTableName string `json:"temp_table_name, omitempty"` + ExternalTableName string `json:"external_table_name, omitempty"` + ExternalLocation string `json:"external_location"` +} + +// This is the main state iteration +func HandleExecutionState( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState, + prestoClient command.CommandClient, + executionsCache cache.AutoRefresh, + metrics PrestoExecutorMetrics) (ExecutionState, error) { + + var transformError error + var newState ExecutionState + + switch currentState.Phase { + case PhaseNotStarted: + newState, transformError = GetAllocationToken(ctx, tCtx, currentState, metrics) + + case PhaseQueued: + prestoQuery, err := GetNextQuery(ctx, tCtx, currentState) + if err != nil { + return ExecutionState{}, err + } + currentState.CurrentPrestoQuery = prestoQuery + newState, transformError = KickOffQuery(ctx, tCtx, currentState, prestoClient, executionsCache) + + case PhaseSubmitted: + newState, transformError = MonitorQuery(ctx, tCtx, currentState, executionsCache) + + case PhaseQuerySucceeded: + if currentState.QueryCount < 4 { + // If there are still Presto statements to execute, increment the query count, reset the phase to get a new + // allocation token, and continue executing the remaining statements + currentState.QueryCount += 1 + currentState.Phase = PhaseQueued + } + newState = currentState + transformError = nil + + case PhaseQueryFailed: + newState = currentState + transformError = nil + } + + return newState, transformError +} + +func GetAllocationToken( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState, + metric PrestoExecutorMetrics) (ExecutionState, error) { + + newState := ExecutionState{} + uniqueId := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + + routingGroup, err := composeResourceNamespaceWithRoutingGroup(ctx, tCtx) + if err != nil { + return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when requesting allocation token %s", uniqueId) + } + + resourceConstraintsSpec := createResourceConstraintsSpec(ctx, tCtx, routingGroup) + + allocationStatus, err := tCtx.ResourceManager().AllocateResource(ctx, routingGroup, uniqueId, resourceConstraintsSpec) + if err != nil { + logger.Errorf(ctx, "Resource manager failed for TaskExecId [%s] token [%s]. error %s", + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueId, err) + return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error requesting allocation token %s", uniqueId) + } + logger.Infof(ctx, "Allocation result for [%s] is [%s]", uniqueId, allocationStatus) + + // Emitting the duration this execution has been waiting for a token allocation + if currentState.AllocationTokenRequestStartTime.IsZero() { + newState.AllocationTokenRequestStartTime = time.Now() + } else { + newState.AllocationTokenRequestStartTime = currentState.AllocationTokenRequestStartTime + } + waitTime := time.Since(newState.AllocationTokenRequestStartTime) + metric.ResourceWaitTime.Observe(waitTime.Seconds()) + + if allocationStatus == core.AllocationStatusGranted { + newState.Phase = PhaseQueued + } else if allocationStatus == core.AllocationStatusExhausted { + newState.Phase = PhaseNotStarted + } else if allocationStatus == core.AllocationStatusNamespaceQuotaExceeded { + newState.Phase = PhaseNotStarted + } else { + return newState, errors.Errorf(errors.ResourceManagerFailure, "Got bad allocation result [%s] for token [%s]", + allocationStatus, uniqueId) + } + + return newState, nil +} + +func composeResourceNamespaceWithRoutingGroup(ctx context.Context, tCtx core.TaskExecutionContext) (core.ResourceNamespace, error) { + routingGroup, _, _, _, err := GetQueryInfo(ctx, tCtx) + if err != nil { + return "", err + } + clusterPrimaryLabel := resolveRoutingGroup(ctx, routingGroup) + return core.ResourceNamespace(clusterPrimaryLabel), nil +} + +// This function is the link between the output written by the SDK, and the execution side. It extracts the query +// out of the task template. +func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( + routingGroup string, + catalog string, + schema string, + statement string, + err error) { + + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return "", "", "", "", err + } + + prestoQuery := plugins.PrestoQuery{} + if err := utils.UnmarshalStruct(taskTemplate.GetCustom(), &prestoQuery); err != nil { + return "", "", "", "", err + } + + if err := validatePrestoStatement(prestoQuery); err != nil { + return "", "", "", "", err + } + + routingGroup = prestoQuery.RoutingGroup + catalog = prestoQuery.Catalog + schema = prestoQuery.Schema + statement = prestoQuery.Statement + + logger.Debugf(ctx, "QueryInfo: query: [%v], routingGroup: [%v], catalog: [%v], schema: [%v]", statement, routingGroup, catalog, schema) + return +} + +func validatePrestoStatement(prestoJob plugins.PrestoQuery) error { + if prestoJob.Statement == "" { + return errors.Errorf(errors.BadTaskSpecification, + "Query could not be found. Please ensure that you are at least on Flytekit version 0.3.0 or later.") + } + return nil +} + +func resolveRoutingGroup(ctx context.Context, routingGroup string) string { + prestoCfg := config.GetPrestoConfig() + + if routingGroup == "" { + logger.Debugf(ctx, "Input routing group is an empty string; falling back to using the default routing group [%v]", prestoCfg.DefaultRoutingGroup) + return prestoCfg.DefaultRoutingGroup + } + + for _, routingGroupCfg := range prestoCfg.RoutingGroupConfigs { + if routingGroup == routingGroupCfg.Name { + logger.Debugf(ctx, "Found the Presto routing group: [%v]", routingGroupCfg.Name) + return routingGroup + } + } + + logger.Debugf(ctx, "Cannot find the routing group [%v] in configmap; "+ + "falling back to using the default routing group [%v]", routingGroup, prestoCfg.DefaultRoutingGroup) + return prestoCfg.DefaultRoutingGroup +} + +func createResourceConstraintsSpec(ctx context.Context, _ core.TaskExecutionContext, routingGroup core.ResourceNamespace) core.ResourceConstraintsSpec { + cfg := config.GetPrestoConfig() + constraintsSpec := core.ResourceConstraintsSpec{ + ProjectScopeResourceConstraint: nil, + NamespaceScopeResourceConstraint: nil, + } + if cfg.RoutingGroupConfigs == nil { + logger.Infof(ctx, "No routing group config is found. Returning an empty resource constraints spec") + return constraintsSpec + } + for _, routingGroupCfg := range cfg.RoutingGroupConfigs { + if routingGroupCfg.Name == string(routingGroup) { + constraintsSpec.ProjectScopeResourceConstraint = &core.ResourceConstraint{Value: int64(float64(routingGroupCfg.Limit) * routingGroupCfg.ProjectScopeQuotaProportionCap)} + constraintsSpec.NamespaceScopeResourceConstraint = &core.ResourceConstraint{Value: int64(float64(routingGroupCfg.Limit) * routingGroupCfg.NamespaceScopeQuotaProportionCap)} + break + } + } + logger.Infof(ctx, "Created a resource constraints spec: [%v]", constraintsSpec) + return constraintsSpec +} + +func GetNextQuery( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState) (PrestoQuery, error) { + + switch currentState.QueryCount { + case 0: + tempTableName := generateRandomString(32) + routingGroup, catalog, schema, statement, err := GetQueryInfo(ctx, tCtx) + if err != nil { + return PrestoQuery{}, err + } + + statement = fmt.Sprintf(`CREATE TABLE hive.flyte_temporary_tables.%s_temp AS %s`, tempTableName, statement) + + prestoQuery := PrestoQuery{ + Statement: statement, + ExtraArgs: client.PrestoExecuteArgs{ + RoutingGroup: resolveRoutingGroup(ctx, routingGroup), + Catalog: catalog, + Schema: schema, + }, + TempTableName: tempTableName + "_temp", + ExternalTableName: tempTableName + "_external", + } + + return prestoQuery, nil + + case 1: + cfg := config.GetPrestoConfig() + externalLocation := getExternalLocation(cfg.AwsS3ShardFormatter, cfg.AwsS3ShardCount) + + statement := fmt.Sprintf(` +CREATE TABLE hive.flyte_temporary_tables.%s (LIKE hive.flyte_temporary_tables.%s)" +WITH (format = 'PARQUET', external_location = '%s')`, + currentState.CurrentPrestoQuery.ExternalTableName, + currentState.CurrentPrestoQuery.TempTableName, + externalLocation, + ) + currentState.CurrentPrestoQuery.Statement = statement + return currentState.CurrentPrestoQuery, nil + + case 2: + statement := fmt.Sprintf(` +INSERT INTO hive.flyte_temporary_tables.%s +SELECT * +FROM hive.flyte_temporary_tables.%s`, + currentState.CurrentPrestoQuery.ExternalTableName, + currentState.CurrentPrestoQuery.TempTableName, + ) + currentState.CurrentPrestoQuery.Statement = statement + return currentState.CurrentPrestoQuery, nil + + case 3: + statement := fmt.Sprintf(`DROP TABLE %s`, currentState.CurrentPrestoQuery.TempTableName) + currentState.CurrentPrestoQuery.Statement = statement + return currentState.CurrentPrestoQuery, nil + + case 4: + statement := fmt.Sprintf(`DROP TABLE %s`, currentState.CurrentPrestoQuery.ExternalTableName) + currentState.CurrentPrestoQuery.Statement = statement + return currentState.CurrentPrestoQuery, nil + + default: + return currentState.CurrentPrestoQuery, nil + } +} + +func generateRandomString(length int) string { + const letters = "0123456789abcdefghijklmnopqrstuvwxyz" + bytes, err := generateRandomBytes(length) + if err != nil { + return "" + } + for i, b := range bytes { + bytes[i] = letters[b%byte(len(letters))] + } + return string(bytes) +} + +func generateRandomBytes(length int) ([]byte, error) { + b := make([]byte, length) + _, err := rand.Read(b) + // Note that err == nil only if we read len(b) bytes. + if err != nil { + return nil, err + } + + return b, nil +} + +func getExternalLocation(shardFormatter string, shardLength int) string { + shardCount := strings.Count(shardFormatter, "{}") + for i := 0; i < shardCount; i++ { + shardFormatter = strings.Replace(shardFormatter, "{}", generateRandomString(shardLength), 1) + } + + return shardFormatter + generateRandomString(32) + "/" +} + +func KickOffQuery( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState, + prestoClient command.CommandClient, + cache cache.AutoRefresh) (ExecutionState, error) { + + uniqueId := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + + statement := currentState.CurrentPrestoQuery.Statement + extraArgs := currentState.CurrentPrestoQuery.ExtraArgs + + response, err := prestoClient.ExecuteCommand(ctx, statement, extraArgs) + if err != nil { + // If we failed, we'll keep the NotStarted state + currentState.CreationFailureCount = currentState.CreationFailureCount + 1 + logger.Warnf(ctx, "Error creating Presto query for %s, failure counts %d. Error: %s", uniqueId, currentState.CreationFailureCount, err) + } else { + executeResponse := response.(client.PrestoExecuteResponse) + + // If we succeed, then store the command id returned from Presto, and update our state. Also, add to the + // AutoRefreshCache so we start getting updates for its status. + commandId := executeResponse.Id + logger.Infof(ctx, "Created Presto Id [%s] for token %s", commandId, uniqueId) + currentState.CommandId = commandId + currentState.Phase = PhaseSubmitted + currentState.URI = executeResponse.NextUri + + executionStateCacheItem := ExecutionStateCacheItem{ + ExecutionState: currentState, + Id: uniqueId, + } + + // The first time we put it in the cache, we know it won't have succeeded so we don't need to look at it + _, err := cache.GetOrCreate(uniqueId, executionStateCacheItem) + if err != nil { + // This means that our cache has fundamentally broken... return a system error + logger.Errorf(ctx, "Cache failed to GetOrCreate for execution [%s] cache key [%s], owner [%s]. Error %s", + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueId, + tCtx.TaskExecutionMetadata().GetOwnerReference(), err) + return currentState, err + } + } + + return currentState, nil +} + +func MonitorQuery( + ctx context.Context, + tCtx core.TaskExecutionContext, + currentState ExecutionState, + cache cache.AutoRefresh) (ExecutionState, error) { + + uniqueId := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + executionStateCacheItem := ExecutionStateCacheItem{ + ExecutionState: currentState, + Id: uniqueId, + } + + cachedItem, err := cache.GetOrCreate(uniqueId, executionStateCacheItem) + if err != nil { + // This means that our cache has fundamentally broken... return a system error + logger.Errorf(ctx, "Cache is broken on execution [%s] cache key [%s], owner [%s]. Error %s", + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueId, + tCtx.TaskExecutionMetadata().GetOwnerReference(), err) + return currentState, errors.Wrapf(errors.CacheFailed, err, "Error when GetOrCreate while monitoring") + } + + cachedExecutionState, ok := cachedItem.(ExecutionStateCacheItem) + if !ok { + logger.Errorf(ctx, "Error casting cache object into ExecutionState") + return currentState, errors.Errorf(errors.CacheFailed, "Failed to cast [%v]", cachedItem) + } + + // If there were updates made to the state, we'll have picked them up automatically. Nothing more to do. + return cachedExecutionState.ExecutionState, nil +} + +func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { + var phaseInfo core.PhaseInfo + t := time.Now() + + switch state.Phase { + case PhaseNotStarted: + phaseInfo = core.PhaseInfoNotReady(t, core.DefaultPhaseVersion, "Haven't received allocation token") + case PhaseQueued: + // TODO: Turn into config + if state.CreationFailureCount > 5 { + phaseInfo = core.PhaseInfoRetryableFailure("PrestoFailure", "Too many creation attempts", nil) + } else { + phaseInfo = core.PhaseInfoQueued(t, uint32(state.CreationFailureCount), "Waiting for Presto launch") + } + case PhaseSubmitted: + phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion, ConstructTaskInfo(state)) + + case PhaseQuerySucceeded: + phaseInfo = core.PhaseInfoSuccess(ConstructTaskInfo(state)) + + case PhaseQueryFailed: + phaseInfo = core.PhaseInfoFailure(errors.DownstreamSystemError, "Query failed", ConstructTaskInfo(state)) + } + + return phaseInfo +} + +func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { + logs := make([]*idlCore.TaskLog, 0, 1) + t := time.Now() + if e.CommandId != "" { + logs = append(logs, ConstructTaskLog(e)) + return &core.TaskInfo{ + Logs: logs, + OccurredAt: &t, + } + } + + return nil +} + +func ConstructTaskLog(e ExecutionState) *idlCore.TaskLog { + return &idlCore.TaskLog{ + Name: fmt.Sprintf("Status: %s [%s]", e.Phase, e.CommandId), + MessageFormat: idlCore.TaskLog_UNKNOWN, + Uri: e.URI, + } +} + +func Abort(ctx context.Context, currentState ExecutionState, client command.CommandClient) error { + // Cancel Presto query if non-terminal state + if !InTerminalState(currentState) && currentState.CommandId != "" { + err := client.KillCommand(ctx, currentState.CommandId) + if err != nil { + logger.Errorf(ctx, "Error terminating Presto command in Finalize [%s]", err) + return err + } + } + return nil +} + +func Finalize(ctx context.Context, tCtx core.TaskExecutionContext, _ ExecutionState) error { + // Release allocation token + uniqueId := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + routingGroup, err := composeResourceNamespaceWithRoutingGroup(ctx, tCtx) + if err != nil { + return errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when releasing allocation token %s", uniqueId) + } + + err = tCtx.ResourceManager().ReleaseResource(ctx, routingGroup, uniqueId) + + if err != nil { + logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", uniqueId, err) + return err + } + return nil +} + +func InTerminalState(e ExecutionState) bool { + return e.Phase == PhaseQuerySucceeded || e.Phase == PhaseQueryFailed +} + +func IsNotYetSubmitted(e ExecutionState) bool { + if e.Phase == PhaseNotStarted || e.Phase == PhaseQueued { + return true + } + return false +} diff --git a/go/tasks/plugins/presto/executions_cache.go b/go/tasks/plugins/presto/executions_cache.go new file mode 100644 index 000000000..5f2a1e304 --- /dev/null +++ b/go/tasks/plugins/presto/executions_cache.go @@ -0,0 +1,165 @@ +package presto + +import ( + "context" + "github.com/lyft/flyteplugins/go/tasks/plugins/command" + "time" + + "k8s.io/client-go/util/workqueue" + + "github.com/lyft/flytestdlib/cache" + + "github.com/lyft/flyteplugins/go/tasks/errors" + stdErrors "github.com/lyft/flytestdlib/errors" + + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client3" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" +) + +const ResyncDuration = 3 * time.Second + +const ( + BadPrestoReturnCodeError stdErrors.ErrorCode = "PRESTO_RETURNED_UNKNOWN" +) + +type PrestoExecutionsCache struct { + cache.AutoRefresh + prestoClient command.CommandClient + scope promutils.Scope + cfg *config.Config +} + +func NewPrestoExecutionsCache( + ctx context.Context, + prestoClient command.CommandClient, + cfg *config.Config, + scope promutils.Scope) (PrestoExecutionsCache, error) { + + q := PrestoExecutionsCache{ + prestoClient: prestoClient, + scope: scope, + cfg: cfg, + } + autoRefreshCache, err := cache.NewAutoRefreshCache("presto", q.SyncPrestoQuery, workqueue.DefaultControllerRateLimiter(), ResyncDuration, cfg.Workers, cfg.LruCacheSize, scope) + if err != nil { + logger.Errorf(ctx, "Could not create AutoRefreshCache in PrestoExecutor. [%s]", err) + return q, errors.Wrapf(errors.CacheFailed, err, "Error creating AutoRefreshCache") + } + q.AutoRefresh = autoRefreshCache + return q, nil +} + +type ExecutionStateCacheItem struct { + ExecutionState + + // This ID is the cache key and so will need to be unique across all objects in the cache (it will probably be + // unique across all of Flyte) and needs to be deterministic. + // This will also be used as the allocation token for now. + Id string `json:"id"` +} + +func (e ExecutionStateCacheItem) ID() string { + return e.Id +} + +// This basically grab an updated status from the Presto API and stores it in the cache +// All other handling should be in the synchronous loop. +func (p *PrestoExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache.Batch) ( + updatedBatch []cache.ItemSyncResponse, err error) { + + resp := make([]cache.ItemSyncResponse, 0, len(batch)) + for _, query := range batch { + // Cast the item back to the thing we want to work with. + executionStateCacheItem, ok := query.GetItem().(ExecutionStateCacheItem) + if !ok { + logger.Errorf(ctx, "Sync loop - Error casting cache object into ExecutionState") + return nil, errors.Errorf(errors.CacheFailed, "Failed to cast [%v]", batch[0].GetID()) + } + + if executionStateCacheItem.CommandId == "" { + logger.Warnf(ctx, "Sync loop - CommandID is blank for [%s] skipping", executionStateCacheItem.Id) + resp = append(resp, cache.ItemSyncResponse{ + ID: query.GetID(), + Item: query.GetItem(), + Action: cache.Unchanged, + }) + + continue + } + + logger.Debugf(ctx, "Sync loop - processing Presto job [%s] - cache key [%s]", + executionStateCacheItem.CommandId, executionStateCacheItem.Id) + + if InTerminalState(executionStateCacheItem.ExecutionState) { + logger.Debugf(ctx, "Sync loop - Presto id [%s] in terminal state [%s]", + executionStateCacheItem.CommandId, executionStateCacheItem.Id) + + resp = append(resp, cache.ItemSyncResponse{ + ID: query.GetID(), + Item: query.GetItem(), + Action: cache.Unchanged, + }) + + continue + } + + // Get an updated status from Presto + logger.Debugf(ctx, "Querying Presto for %s - %s", executionStateCacheItem.CommandId, executionStateCacheItem.Id) + commandStatus, err := p.prestoClient.GetCommandStatus(ctx, executionStateCacheItem.CommandId) + if err != nil { + logger.Errorf(ctx, "Error from Presto command %s", executionStateCacheItem.CommandId) + executionStateCacheItem.SyncFailureCount++ + // Make sure we don't return nil for the first argument, because that deletes it from the cache. + resp = append(resp, cache.ItemSyncResponse{ + ID: query.GetID(), + Item: executionStateCacheItem, + Action: cache.Update, + }) + + continue + } + + newExecutionPhase, err := PrestoStatusToExecutionPhase(commandStatus) + if err != nil { + return nil, err + } + + if newExecutionPhase > executionStateCacheItem.Phase { + logger.Infof(ctx, "Moving ExecutionPhase for %s %s from %s to %s", executionStateCacheItem.CommandId, + executionStateCacheItem.Id, executionStateCacheItem.Phase, newExecutionPhase) + + executionStateCacheItem.Phase = newExecutionPhase + + resp = append(resp, cache.ItemSyncResponse{ + ID: query.GetID(), + Item: executionStateCacheItem, + Action: cache.Update, + }) + } + } + + return resp, nil +} + +// We need some way to translate results we get from Presto, into a plugin phase +func PrestoStatusToExecutionPhase(s command.CommandStatus) (ExecutionPhase, error) { + switch s { + case client.PrestoStatusFinished: + return PhaseQuerySucceeded, nil + case client.PrestoStatusCancelled: + return PhaseQueryFailed, nil + case client.PrestoStatusFailed: + return PhaseQueryFailed, nil + case client.PrestoStatusQueued: + return PhaseSubmitted, nil + case client.PrestoStatusRunning: + return PhaseSubmitted, nil + case client.PrestoStatusUnknown: + return PhaseQueryFailed, errors.Errorf(BadPrestoReturnCodeError, "Presto returned status Unknown") + default: + return PhaseQueryFailed, errors.Errorf(BadPrestoReturnCodeError, "default fallthrough case") + } +} diff --git a/go/tasks/plugins/presto/executor.go b/go/tasks/plugins/presto/executor.go new file mode 100644 index 000000000..6e476908e --- /dev/null +++ b/go/tasks/plugins/presto/executor.go @@ -0,0 +1,166 @@ +package presto + +import ( + "context" + "github.com/lyft/flyteplugins/go/tasks/plugins/command" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + + "github.com/lyft/flytestdlib/cache" + + "github.com/lyft/flyteplugins/go/tasks/errors" + pluginMachinery "github.com/lyft/flyteplugins/go/tasks/pluginmachinery" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + "github.com/lyft/flytestdlib/logger" + "github.com/lyft/flytestdlib/promutils" +) + +// This is the name of this plugin effectively. In Flyte plugin configuration, use this string to enable this plugin. +const prestoExecutorId = "presto-executor" + +// Version of the custom state this plugin stores. Useful for backwards compatibility if you one day need to update +// the structure of the stored state +const pluginStateVersion = 0 + +const prestoTaskType = "presto" // This needs to match the type defined in Flytekit constants.py + +type PrestoExecutor struct { + id string + metrics PrestoExecutorMetrics + prestoClient command.CommandClient + executionsCache cache.AutoRefresh + cfg *config.Config +} + +func (p PrestoExecutor) GetID() string { + return p.id +} + +func (p PrestoExecutor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { + incomingState := ExecutionState{} + + // We assume here that the first time this function is called, the custom state we get back is whatever we passed in, + // namely the zero-value of our struct. + if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { + logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state when handling [%s] [%s]", + p.id, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) + return core.UnknownTransition, errors.Wrapf(errors.CorruptedPluginState, err, + "Failed to unmarshal custom state in Handle") + } + + // Do what needs to be done, and give this function everything it needs to do its job properly + outgoingState, transformError := HandleExecutionState(ctx, tCtx, incomingState, p.prestoClient, p.executionsCache, p.metrics) + + // Return if there was an error + if transformError != nil { + return core.UnknownTransition, transformError + } + + // If no error, then infer the new Phase from the various states + phaseInfo := MapExecutionStateToPhaseInfo(outgoingState) + + if err := tCtx.PluginStateWriter().Put(pluginStateVersion, outgoingState); err != nil { + return core.UnknownTransition, err + } + + return core.DoTransitionType(core.TransitionTypeBarrier, phaseInfo), nil +} + +func (p PrestoExecutor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error { + incomingState := ExecutionState{} + if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { + logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state in Finalize [%s] Err [%s]", + p.id, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) + return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to unmarshal custom state in Finalize") + } + + return Abort(ctx, incomingState, p.prestoClient) +} + +func (p PrestoExecutor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { + incomingState := ExecutionState{} + if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { + logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state in Finalize [%s] Err [%s]", + p.id, tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName(), err) + return errors.Wrapf(errors.CorruptedPluginState, err, "Failed to unmarshal custom state in Finalize") + } + + return Finalize(ctx, tCtx, incomingState) +} + +func (p PrestoExecutor) GetProperties() core.PluginProperties { + return core.PluginProperties{} +} + +func PrestoExecutorLoader(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { + cfg := config.GetPrestoConfig() + return InitializePrestoExecutor(ctx, iCtx, cfg, BuildResourceConfig(cfg), client.NewPrestoClient(cfg)) +} + +func BuildResourceConfig(cfg *config.Config) map[string]int { + resourceConfig := make(map[string]int, len(cfg.RoutingGroupConfigs)) + + for _, routingGroupCfg := range cfg.RoutingGroupConfigs { + resourceConfig[routingGroupCfg.Name] = routingGroupCfg.Limit + } + return resourceConfig +} + +func InitializePrestoExecutor( + ctx context.Context, + iCtx core.SetupContext, + cfg *config.Config, + resourceConfig map[string]int, + prestoClient command.CommandClient) (core.Plugin, error) { + logger.Infof(ctx, "Initializing a Presto executor with a resource config [%v]", resourceConfig) + q, err := NewPrestoExecutor(ctx, cfg, prestoClient, iCtx.MetricsScope()) + if err != nil { + logger.Errorf(ctx, "Failed to create a new PrestoExecutor due to error: [%v]", err) + return nil, err + } + + for routingGroupName, routingGroupLimit := range resourceConfig { + logger.Infof(ctx, "Registering resource quota for cluster [%v]", routingGroupName) + if err := iCtx.ResourceRegistrar().RegisterResourceQuota(ctx, core.ResourceNamespace(routingGroupName), routingGroupLimit); err != nil { + logger.Errorf(ctx, "Resource quota registration for [%v] failed due to error [%v]", routingGroupName, err) + return nil, err + } + } + + return q, nil +} + +func NewPrestoExecutor( + ctx context.Context, + cfg *config.Config, + prestoClient command.CommandClient, + scope promutils.Scope) (PrestoExecutor, error) { + executionsAutoRefreshCache, err := NewPrestoExecutionsCache(ctx, prestoClient, cfg, scope.NewSubScope(prestoTaskType)) + if err != nil { + logger.Errorf(ctx, "Failed to create AutoRefreshCache in PrestoExecutor Setup. Error: %v", err) + return PrestoExecutor{}, err + } + + err = executionsAutoRefreshCache.Start(ctx) + if err != nil { + logger.Errorf(ctx, "Failed to start AutoRefreshCache. Error: %v", err) + } + + return PrestoExecutor{ + id: prestoExecutorId, + cfg: cfg, + metrics: getPrestoExecutorMetrics(scope), + prestoClient: prestoClient, + executionsCache: executionsAutoRefreshCache, + }, nil +} + +func init() { + pluginMachinery.PluginRegistry().RegisterCorePlugin( + core.PluginEntry{ + ID: prestoExecutorId, + RegisteredTaskTypes: []core.TaskType{prestoTaskType}, + LoadPlugin: PrestoExecutorLoader, + IsDefault: false, + }) +} diff --git a/go/tasks/plugins/presto/executor_metrics.go b/go/tasks/plugins/presto/executor_metrics.go new file mode 100644 index 000000000..08fc483c5 --- /dev/null +++ b/go/tasks/plugins/presto/executor_metrics.go @@ -0,0 +1,33 @@ +package presto + +import ( + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "github.com/prometheus/client_golang/prometheus" +) + +type PrestoExecutorMetrics struct { + Scope promutils.Scope + ReleaseResourceFailed labeled.Counter + AllocationGranted labeled.Counter + AllocationNotGranted labeled.Counter + ResourceWaitTime prometheus.Summary +} + +var ( + tokenAgeObjectives = map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001, 1.0: 0.0} +) + +func getPrestoExecutorMetrics(scope promutils.Scope) PrestoExecutorMetrics { + return PrestoExecutorMetrics{ + Scope: scope, + ReleaseResourceFailed: labeled.NewCounter("released_resource_failed", + "Error releasing allocation token", scope), + AllocationGranted: labeled.NewCounter("allocation_granted", + "Allocation request granted", scope), + AllocationNotGranted: labeled.NewCounter("allocation_not_granted", + "Allocation request did not fail but not granted", scope), + ResourceWaitTime: scope.MustNewSummaryWithOptions("resource_wait_time", "Duration the execution has been waiting for a resource allocation token", + promutils.SummaryOptions{Objectives: tokenAgeObjectives}), + } +} From c0f71da371798aaf662ae2dd7543f320ca27c4ff Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Tue, 10 Mar 2020 20:30:02 -0700 Subject: [PATCH 02/26] mockery --- go/tasks/plugins/presto/client/presto_client.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index bd1ac30b9..922e68ada 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -26,6 +26,8 @@ const ( PrestoUserHeader = "X-Presto-User" ) +//go:generate mockery -all -case=snake + type prestoClient struct { client *http.Client environment *url.URL From 90478fddb24e2c3083167ba5e3eae6b89f640652 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 11 Mar 2020 13:33:54 -0700 Subject: [PATCH 03/26] add unit tests --- go/tasks/plugins/command/command_client.go | 2 + .../plugins/command/mocks/command_client.go | 128 ++++++ .../plugins/presto/client/presto_client.go | 2 - .../plugins/presto/config/config_flags.go | 51 ++ .../presto/config/config_flags_test.go | 234 ++++++++++ go/tasks/plugins/presto/execution_state.go | 9 +- .../plugins/presto/execution_state_test.go | 434 ++++++++++++++++++ go/tasks/plugins/presto/test_helpers.go | 112 +++++ 8 files changed, 965 insertions(+), 7 deletions(-) create mode 100644 go/tasks/plugins/command/mocks/command_client.go create mode 100755 go/tasks/plugins/presto/config/config_flags.go create mode 100755 go/tasks/plugins/presto/config/config_flags_test.go create mode 100644 go/tasks/plugins/presto/execution_state_test.go create mode 100644 go/tasks/plugins/presto/test_helpers.go diff --git a/go/tasks/plugins/command/command_client.go b/go/tasks/plugins/command/command_client.go index 9e22be992..abd000782 100644 --- a/go/tasks/plugins/command/command_client.go +++ b/go/tasks/plugins/command/command_client.go @@ -6,6 +6,8 @@ import ( type CommandStatus string +//go:generate mockery -all -case=snake + type CommandClient interface { ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) KillCommand(ctx context.Context, commandID string) error diff --git a/go/tasks/plugins/command/mocks/command_client.go b/go/tasks/plugins/command/mocks/command_client.go new file mode 100644 index 000000000..9e72c1dcb --- /dev/null +++ b/go/tasks/plugins/command/mocks/command_client.go @@ -0,0 +1,128 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + command "github.com/lyft/flyteplugins/go/tasks/plugins/command" + + mock "github.com/stretchr/testify/mock" +) + +// CommandClient is an autogenerated mock type for the CommandClient type +type CommandClient struct { + mock.Mock +} + +type CommandClient_ExecuteCommand struct { + *mock.Call +} + +func (_m CommandClient_ExecuteCommand) Return(_a0 interface{}, _a1 error) *CommandClient_ExecuteCommand { + return &CommandClient_ExecuteCommand{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *CommandClient) OnExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) *CommandClient_ExecuteCommand { + c := _m.On("ExecuteCommand", ctx, commandStr, extraArgs) + return &CommandClient_ExecuteCommand{Call: c} +} + +func (_m *CommandClient) OnExecuteCommandMatch(matchers ...interface{}) *CommandClient_ExecuteCommand { + c := _m.On("ExecuteCommand", matchers...) + return &CommandClient_ExecuteCommand{Call: c} +} + +// ExecuteCommand provides a mock function with given fields: ctx, commandStr, extraArgs +func (_m *CommandClient) ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) { + ret := _m.Called(ctx, commandStr, extraArgs) + + var r0 interface{} + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) interface{}); ok { + r0 = rf(ctx, commandStr, extraArgs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, interface{}) error); ok { + r1 = rf(ctx, commandStr, extraArgs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type CommandClient_GetCommandStatus struct { + *mock.Call +} + +func (_m CommandClient_GetCommandStatus) Return(_a0 command.CommandStatus, _a1 error) *CommandClient_GetCommandStatus { + return &CommandClient_GetCommandStatus{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *CommandClient) OnGetCommandStatus(ctx context.Context, commandID string) *CommandClient_GetCommandStatus { + c := _m.On("GetCommandStatus", ctx, commandID) + return &CommandClient_GetCommandStatus{Call: c} +} + +func (_m *CommandClient) OnGetCommandStatusMatch(matchers ...interface{}) *CommandClient_GetCommandStatus { + c := _m.On("GetCommandStatus", matchers...) + return &CommandClient_GetCommandStatus{Call: c} +} + +// GetCommandStatus provides a mock function with given fields: ctx, commandID +func (_m *CommandClient) GetCommandStatus(ctx context.Context, commandID string) (command.CommandStatus, error) { + ret := _m.Called(ctx, commandID) + + var r0 command.CommandStatus + if rf, ok := ret.Get(0).(func(context.Context, string) command.CommandStatus); ok { + r0 = rf(ctx, commandID) + } else { + r0 = ret.Get(0).(command.CommandStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, commandID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type CommandClient_KillCommand struct { + *mock.Call +} + +func (_m CommandClient_KillCommand) Return(_a0 error) *CommandClient_KillCommand { + return &CommandClient_KillCommand{Call: _m.Call.Return(_a0)} +} + +func (_m *CommandClient) OnKillCommand(ctx context.Context, commandID string) *CommandClient_KillCommand { + c := _m.On("KillCommand", ctx, commandID) + return &CommandClient_KillCommand{Call: c} +} + +func (_m *CommandClient) OnKillCommandMatch(matchers ...interface{}) *CommandClient_KillCommand { + c := _m.On("KillCommand", matchers...) + return &CommandClient_KillCommand{Call: c} +} + +// KillCommand provides a mock function with given fields: ctx, commandID +func (_m *CommandClient) KillCommand(ctx context.Context, commandID string) error { + ret := _m.Called(ctx, commandID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, commandID) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index 922e68ada..bd1ac30b9 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -26,8 +26,6 @@ const ( PrestoUserHeader = "X-Presto-User" ) -//go:generate mockery -all -case=snake - type prestoClient struct { client *http.Client environment *url.URL diff --git a/go/tasks/plugins/presto/config/config_flags.go b/go/tasks/plugins/presto/config/config_flags.go new file mode 100755 index 000000000..ffabb5b00 --- /dev/null +++ b/go/tasks/plugins/presto/config/config_flags.go @@ -0,0 +1,51 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "reflect" + + "fmt" + + "github.com/spf13/pflag" +) + +// If v is a pointer, it will get its element value or the zero value of the element type. +// If v is not a pointer, it will return it as is. +func (Config) elemValueOrNil(v interface{}) interface{} { + if t := reflect.TypeOf(v); t.Kind() == reflect.Ptr { + if reflect.ValueOf(v).IsNil() { + return reflect.Zero(t.Elem()).Interface() + } else { + return reflect.ValueOf(v).Interface() + } + } else if v == nil { + return reflect.Zero(t).Interface() + } + + return v +} + +func (Config) mustMarshalJSON(v json.Marshaler) string { + raw, err := v.MarshalJSON() + if err != nil { + panic(err) + } + + return string(raw) +} + +// GetPFlagSet will return strongly types pflags for all fields in Config and its nested types. The format of the +// flags is json-name.json-sub-name... etc. +func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { + cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "endpoint"), defaultConfig.Environment.String(), "Endpoint for Presto to use") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultRoutingGroup"), defaultConfig.DefaultRoutingGroup, "Default Presto routing group") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "workers"), defaultConfig.Workers, "Number of parallel workers to refresh the cache") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "lruCacheSize"), defaultConfig.LruCacheSize, "Size of the AutoRefreshCache") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "awsS3ShardFormatter"), defaultConfig.AwsS3ShardFormatter, " S3 bucket prefix where Presto results will be stored") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "awsS3ShardStringLength"), defaultConfig.AwsS3ShardCount, " Number of characters for the S3 bucket shard prefix") + return cmdFlags +} diff --git a/go/tasks/plugins/presto/config/config_flags_test.go b/go/tasks/plugins/presto/config/config_flags_test.go new file mode 100755 index 000000000..6c6cf67a6 --- /dev/null +++ b/go/tasks/plugins/presto/config/config_flags_test.go @@ -0,0 +1,234 @@ +// Code generated by go generate; DO NOT EDIT. +// This file was generated by robots. + +package config + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "testing" + + "github.com/mitchellh/mapstructure" + "github.com/stretchr/testify/assert" +) + +var dereferencableKindsConfig = map[reflect.Kind]struct{}{ + reflect.Array: {}, reflect.Chan: {}, reflect.Map: {}, reflect.Ptr: {}, reflect.Slice: {}, +} + +// Checks if t is a kind that can be dereferenced to get its underlying type. +func canGetElementConfig(t reflect.Kind) bool { + _, exists := dereferencableKindsConfig[t] + return exists +} + +// This decoder hook tests types for json unmarshaling capability. If implemented, it uses json unmarshal to build the +// object. Otherwise, it'll just pass on the original data. +func jsonUnmarshalerHookConfig(_, to reflect.Type, data interface{}) (interface{}, error) { + unmarshalerType := reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + if to.Implements(unmarshalerType) || reflect.PtrTo(to).Implements(unmarshalerType) || + (canGetElementConfig(to.Kind()) && to.Elem().Implements(unmarshalerType)) { + + raw, err := json.Marshal(data) + if err != nil { + fmt.Printf("Failed to marshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + res := reflect.New(to).Interface() + err = json.Unmarshal(raw, &res) + if err != nil { + fmt.Printf("Failed to umarshal Data: %v. Error: %v. Skipping jsonUnmarshalHook", data, err) + return data, nil + } + + return res, nil + } + + return data, nil +} + +func decode_Config(input, result interface{}) error { + config := &mapstructure.DecoderConfig{ + TagName: "json", + WeaklyTypedInput: true, + Result: result, + DecodeHook: mapstructure.ComposeDecodeHookFunc( + mapstructure.StringToTimeDurationHookFunc(), + mapstructure.StringToSliceHookFunc(","), + jsonUnmarshalerHookConfig, + ), + } + + decoder, err := mapstructure.NewDecoder(config) + if err != nil { + return err + } + + return decoder.Decode(input) +} + +func join_Config(arr interface{}, sep string) string { + listValue := reflect.ValueOf(arr) + strs := make([]string, 0, listValue.Len()) + for i := 0; i < listValue.Len(); i++ { + strs = append(strs, fmt.Sprintf("%v", listValue.Index(i))) + } + + return strings.Join(strs, sep) +} + +func testDecodeJson_Config(t *testing.T, val, result interface{}) { + assert.NoError(t, decode_Config(val, result)) +} + +func testDecodeSlice_Config(t *testing.T, vStringSlice, result interface{}) { + assert.NoError(t, decode_Config(vStringSlice, result)) +} + +func TestConfig_GetPFlagSet(t *testing.T) { + val := Config{} + cmdFlags := val.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) +} + +func TestConfig_SetFlags(t *testing.T) { + actual := Config{} + cmdFlags := actual.GetPFlagSet("") + assert.True(t, cmdFlags.HasFlags()) + + t.Run("Test_endpoint", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("endpoint"); err == nil { + assert.Equal(t, string(defaultConfig.Environment.String()), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.Environment.String() + + cmdFlags.Set("endpoint", testValue) + if vString, err := cmdFlags.GetString("endpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Environment) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_defaultRoutingGroup", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("defaultRoutingGroup"); err == nil { + assert.Equal(t, string(defaultConfig.DefaultRoutingGroup), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("defaultRoutingGroup", testValue) + if vString, err := cmdFlags.GetString("defaultRoutingGroup"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DefaultRoutingGroup) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_workers", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("workers"); err == nil { + assert.Equal(t, int(defaultConfig.Workers), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("workers", testValue) + if vInt, err := cmdFlags.GetInt("workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Workers) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_lruCacheSize", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("lruCacheSize"); err == nil { + assert.Equal(t, int(defaultConfig.LruCacheSize), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("lruCacheSize", testValue) + if vInt, err := cmdFlags.GetInt("lruCacheSize"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.LruCacheSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_awsS3ShardFormatter", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("awsS3ShardFormatter"); err == nil { + assert.Equal(t, string(defaultConfig.AwsS3ShardFormatter), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("awsS3ShardFormatter", testValue) + if vString, err := cmdFlags.GetString("awsS3ShardFormatter"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AwsS3ShardFormatter) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_awsS3ShardStringLength", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("awsS3ShardStringLength"); err == nil { + assert.Equal(t, int(defaultConfig.AwsS3ShardCount), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("awsS3ShardStringLength", testValue) + if vInt, err := cmdFlags.GetInt("awsS3ShardStringLength"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.AwsS3ShardCount) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) +} diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 230b1e798..2cc380baf 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -177,7 +177,7 @@ func composeResourceNamespaceWithRoutingGroup(ctx context.Context, tCtx core.Tas if err != nil { return "", err } - clusterPrimaryLabel := resolveRoutingGroup(ctx, routingGroup) + clusterPrimaryLabel := resolveRoutingGroup(ctx, routingGroup, config.GetPrestoConfig()) return core.ResourceNamespace(clusterPrimaryLabel), nil } @@ -221,9 +221,7 @@ func validatePrestoStatement(prestoJob plugins.PrestoQuery) error { return nil } -func resolveRoutingGroup(ctx context.Context, routingGroup string) string { - prestoCfg := config.GetPrestoConfig() - +func resolveRoutingGroup(ctx context.Context, routingGroup string, prestoCfg *config.Config) string { if routingGroup == "" { logger.Debugf(ctx, "Input routing group is an empty string; falling back to using the default routing group [%v]", prestoCfg.DefaultRoutingGroup) return prestoCfg.DefaultRoutingGroup @@ -269,6 +267,7 @@ func GetNextQuery( switch currentState.QueryCount { case 0: + prestoCfg := config.GetPrestoConfig() tempTableName := generateRandomString(32) routingGroup, catalog, schema, statement, err := GetQueryInfo(ctx, tCtx) if err != nil { @@ -280,7 +279,7 @@ func GetNextQuery( prestoQuery := PrestoQuery{ Statement: statement, ExtraArgs: client.PrestoExecuteArgs{ - RoutingGroup: resolveRoutingGroup(ctx, routingGroup), + RoutingGroup: resolveRoutingGroup(ctx, routingGroup, prestoCfg), Catalog: catalog, Schema: schema, }, diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go new file mode 100644 index 000000000..5d3a92eb1 --- /dev/null +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -0,0 +1,434 @@ +package presto + +import ( + "context" + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/command/mocks" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + mocks2 "github.com/lyft/flytestdlib/cache/mocks" + "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flytestdlib/promutils" + "github.com/lyft/flytestdlib/promutils/labeled" + "net/url" + "testing" + "time" + + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + pluginsCoreMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" +) + +func init() { + labeled.SetMetricKeys(contextutils.NamespaceKey) +} + +func TestInTerminalState(t *testing.T) { + var stateTests = []struct { + phase ExecutionPhase + isTerminal bool + }{ + {phase: PhaseNotStarted, isTerminal: false}, + {phase: PhaseQueued, isTerminal: false}, + {phase: PhaseSubmitted, isTerminal: false}, + {phase: PhaseQuerySucceeded, isTerminal: true}, + {phase: PhaseQueryFailed, isTerminal: true}, + } + + for _, tt := range stateTests { + t.Run(tt.phase.String(), func(t *testing.T) { + e := ExecutionState{Phase: tt.phase} + res := InTerminalState(e) + assert.Equal(t, tt.isTerminal, res) + }) + } +} + +func TestIsNotYetSubmitted(t *testing.T) { + var stateTests = []struct { + phase ExecutionPhase + isNotYetSubmitted bool + }{ + {phase: PhaseNotStarted, isNotYetSubmitted: true}, + {phase: PhaseQueued, isNotYetSubmitted: true}, + {phase: PhaseSubmitted, isNotYetSubmitted: false}, + {phase: PhaseQuerySucceeded, isNotYetSubmitted: false}, + {phase: PhaseQueryFailed, isNotYetSubmitted: false}, + } + + for _, tt := range stateTests { + t.Run(tt.phase.String(), func(t *testing.T) { + e := ExecutionState{Phase: tt.phase} + res := IsNotYetSubmitted(e) + assert.Equal(t, tt.isNotYetSubmitted, res) + }) + } +} + +func TestGetQueryInfo(t *testing.T) { + ctx := context.Background() + + taskTemplate := GetSingleHiveQueryTaskTemplate() + mockTaskReader := &mocks.TaskReader{} + mockTaskReader.On("Read", mock.Anything).Return(&taskTemplate, nil) + + mockTaskExecutionContext := mocks.TaskExecutionContext{} + mockTaskExecutionContext.On("TaskReader").Return(mockTaskReader) + + taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} + taskMetadata.On("GetNamespace").Return("myproject-staging") + taskMetadata.On("GetLabels").Return(map[string]string{"sample": "label"}) + mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) + + routingGroup, catalog, schema, statement, err := GetQueryInfo(ctx, &mockTaskExecutionContext) + assert.NoError(t, err) + assert.Equal(t, "adhoc", routingGroup) + assert.Equal(t, "hive", catalog) + assert.Equal(t, "city", schema) + assert.Equal(t, "select * from hive.city.fact_airport_sessions limit 10", statement) +} + +func TestValidatePrestoStatement(t *testing.T) { + prestoQuery := plugins.PrestoQuery{ + RoutingGroup: "adhoc", + Catalog: "hive", + Schema: "city", + Statement: "", + } + err := validatePrestoStatement(prestoQuery) + assert.Error(t, err) +} + +func TestConstructTaskLog(t *testing.T) { + expected := "https://prestoproxy-internal.lyft.net:443" + u, err := url.Parse(expected) + assert.NoError(t, err) + taskLog := ConstructTaskLog(ExecutionState{CommandId: "123", URI: u.String()}) + assert.Equal(t, expected, taskLog.Uri) +} + +func TestConstructTaskInfo(t *testing.T) { + empty := ConstructTaskInfo(ExecutionState{}) + assert.Nil(t, empty) + + expected := "https://prestoproxy-internal.lyft.net:443" + u, err := url.Parse(expected) + assert.NoError(t, err) + + e := ExecutionState{ + Phase: PhaseQuerySucceeded, + CommandId: "123", + SyncFailureCount: 0, + URI: u.String(), + } + + taskInfo := ConstructTaskInfo(e) + assert.Equal(t, "https://prestoproxy-internal.lyft.net:443", taskInfo.Logs[0].Uri) +} + +func TestMapExecutionStateToPhaseInfo(t *testing.T) { + t.Run("NotStarted", func(t *testing.T) { + e := ExecutionState{ + Phase: PhaseNotStarted, + } + phaseInfo := MapExecutionStateToPhaseInfo(e) + assert.Equal(t, core.PhaseNotReady, phaseInfo.Phase()) + }) + + t.Run("Queued", func(t *testing.T) { + e := ExecutionState{ + Phase: PhaseQueued, + CreationFailureCount: 0, + } + phaseInfo := MapExecutionStateToPhaseInfo(e) + assert.Equal(t, core.PhaseQueued, phaseInfo.Phase()) + + e = ExecutionState{ + Phase: PhaseQueued, + CreationFailureCount: 100, + } + phaseInfo = MapExecutionStateToPhaseInfo(e) + assert.Equal(t, core.PhaseRetryableFailure, phaseInfo.Phase()) + + }) + + t.Run("Submitted", func(t *testing.T) { + e := ExecutionState{ + Phase: PhaseSubmitted, + } + phaseInfo := MapExecutionStateToPhaseInfo(e) + assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) + }) +} + +func TestGetAllocationToken(t *testing.T) { + ctx := context.Background() + + t.Run("allocation granted", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusGranted, nil) + + mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: time.Now()} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, PhaseQueued, state.Phase) + }) + + t.Run("exhausted", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusExhausted, nil) + + mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: time.Now()} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, PhaseNotStarted, state.Phase) + }) + + t.Run("namespace exhausted", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusNamespaceQuotaExceeded, nil) + + mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: time.Now()} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, PhaseNotStarted, state.Phase) + }) + + t.Run("Request start time, if empty in current state, should be set", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusNamespaceQuotaExceeded, nil) + + mockCurrentState := ExecutionState{} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, state.AllocationTokenRequestStartTime.IsZero(), false) + }) + + t.Run("Request start time, if already set in current state, should be maintained", func(t *testing.T) { + tCtx := GetMockTaskExecutionContext() + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("AllocateResource", mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return(core.AllocationStatusGranted, nil) + + startTime := time.Now() + mockCurrentState := ExecutionState{AllocationTokenRequestStartTime: startTime} + mockMetrics := getPrestoExecutorMetrics(promutils.NewTestScope()) + state, err := GetAllocationToken(ctx, tCtx, mockCurrentState, mockMetrics) + assert.NoError(t, err) + assert.Equal(t, state.AllocationTokenRequestStartTime.IsZero(), false) + assert.Equal(t, state.AllocationTokenRequestStartTime, startTime) + }) +} + +func TestAbort(t *testing.T) { + ctx := context.Background() + + t.Run("Terminate called when not in terminal state", func(t *testing.T) { + var x = false + //mockQubole := &quboleMocks.QuboleClient{} + mockPresto := &prestoMocks.CommandClient{} + //mockQubole.On("KillCommand", mock.Anything, mock.MatchedBy(func(commandId string) bool { + // return commandId == "123456" + //}), mock.Anything).Run(func(_ mock.Arguments) { + // x = true + //}).Return(nil) + mockPresto.On("KillCommand", mock.Anything, mock.MatchedBy(func(commandId string) bool { + return commandId == "123456" + }), mock.Anything).Run(func(_ mock.Arguments) { + x = true + }).Return(nil) + + err := Abort(ctx, ExecutionState{Phase: PhaseSubmitted, CommandId: "123456"}, mockPresto) + assert.NoError(t, err) + assert.True(t, x) + }) + + t.Run("Terminate not called when in terminal state", func(t *testing.T) { + var x = false + //mockQubole := &quboleMocks.QuboleClient{} + mockPresto := &prestoMocks.CommandClient{} + mockPresto.On("KillCommand", mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { + x = true + }).Return(nil) + + err := Abort(ctx, ExecutionState{Phase: PhaseQuerySucceeded, CommandId: "123456",}, mockPresto) + assert.NoError(t, err) + assert.False(t, x) + }) +} + +func TestFinalize(t *testing.T) { + // Test that Finalize releases resources + ctx := context.Background() + tCtx := GetMockTaskExecutionContext() + state := ExecutionState{} + var called = false + mockResourceManager := tCtx.ResourceManager() + x := mockResourceManager.(*mocks.ResourceManager) + x.On("ReleaseResource", mock.Anything, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { + called = true + }).Return(nil) + + err := Finalize(ctx, tCtx, state) + assert.NoError(t, err) + assert.True(t, called) +} + +func TestMonitorQuery(t *testing.T) { + ctx := context.Background() + tCtx := GetMockTaskExecutionContext() + state := ExecutionState{ + Phase: PhaseSubmitted, + } + var getOrCreateCalled = false + mockCache := &mocks2.AutoRefresh{} + mockCache.OnGetOrCreateMatch("my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", mock.Anything).Return(ExecutionStateCacheItem{ + ExecutionState: ExecutionState{Phase: PhaseQuerySucceeded}, + Id: "my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", + }, nil).Run(func(_ mock.Arguments) { + getOrCreateCalled = true + }) + + newState, err := MonitorQuery(ctx, tCtx, state, mockCache) + assert.NoError(t, err) + assert.True(t, getOrCreateCalled) + assert.Equal(t, PhaseQuerySucceeded, newState.Phase) +} + +func TestKickOffQuery(t *testing.T) { + ctx := context.Background() + tCtx := GetMockTaskExecutionContext() + + var prestoCalled = false + + prestoExecuteResponse := client.PrestoExecuteResponse{ + Id: "1234567", + Status: client.PrestoStatusQueued, + } + mockPresto := &prestoMocks.CommandClient{} + mockPresto.OnExecuteCommandMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, + mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { + prestoCalled = true + }).Return(prestoExecuteResponse, nil) + var getOrCreateCalled = false + mockCache := &mocks2.AutoRefresh{} + mockCache.OnGetOrCreate(mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { + getOrCreateCalled = true + }).Return(ExecutionStateCacheItem{}, nil) + + state := ExecutionState{} + newState, err := KickOffQuery(ctx, tCtx, state, mockPresto, mockCache) + assert.NoError(t, err) + assert.Equal(t, PhaseSubmitted, newState.Phase) + assert.Equal(t, "1234567", newState.CommandId) + assert.True(t, getOrCreateCalled) + assert.True(t, prestoCalled) +} + +func createMockPrestoCfg() *config.Config { + return &config.Config{ + Environment: config.UrlMustParse("https://prestoproxy-internal.lyft.net:443"), + DefaultRoutingGroup: "adhoc", + Workers: 15, + LruCacheSize: 2000, + AwsS3ShardFormatter: "s3://lyft-modelbuilder/{}/", + AwsS3ShardCount: 2, + RoutingGroupConfigs: []config.RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, + } +} + +func Test_mapLabelToPrimaryLabel(t *testing.T) { + ctx := context.TODO() + mockPrestoCfg := createMockPrestoCfg() + + type args struct { + ctx context.Context + routingGroup string + prestoCfg *config.Config + } + tests := []struct { + name string + args args + want string + }{ + {name: "Routing group is found in configs", args: args{ctx: ctx, routingGroup: "etl", prestoCfg: mockPrestoCfg}, want: "etl"}, + {name: "Use routing group default when not found in configs", args: args{ctx: ctx, routingGroup: "test", prestoCfg: mockPrestoCfg}, want: mockPrestoCfg.DefaultRoutingGroup}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equal(t, tt.want, resolveRoutingGroup(tt.args.ctx, tt.args.routingGroup, tt.args.prestoCfg)) + }) + } +} + +func createMockTaskExecutionContextWithProjectDomain(project string, domain string) *mocks.TaskExecutionContext { + mockTaskExecutionContext := mocks.TaskExecutionContext{} + taskExecID := &pluginsCoreMocks.TaskExecutionID{} + taskExecID.OnGetID().Return(idlCore.TaskExecutionIdentifier{ + NodeExecutionId: &idlCore.NodeExecutionIdentifier{ExecutionId: &idlCore.WorkflowExecutionIdentifier{ + Project: project, + Domain: domain, + Name: "random name", + }}, + }) + + taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} + taskMetadata.OnGetTaskExecutionID().Return(taskExecID) + mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) + return &mockTaskExecutionContext +} + +//func Test_getClusterPrimaryLabel(t *testing.T) { +// ctx := context.TODO() +// err := config.SetQuboleConfig(createMockPrestoCfg()) +// assert.Nil(t, err) +// +// type args struct { +// ctx context.Context +// tCtx core.TaskExecutionContext +// clusterLabelOverride string +// } +// tests := []struct { +// name string +// args args +// want string +// }{ +// {name: "Override is not empty + override has NO existing mapping + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain Z"), clusterLabelOverride: "AAAA"}, want: "primary B"}, +// {name: "Override is not empty + override has NO existing mapping + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain blah"), clusterLabelOverride: "blh"}, want: DefaultClusterPrimaryLabel}, +// {name: "Override is not empty + override has an existing mapping + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project blah", "domain blah"), clusterLabelOverride: "C-prod"}, want: "primary C"}, +// {name: "Override is not empty + override has an existing mapping + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain A"), clusterLabelOverride: "C-prod"}, want: "primary C"}, +// {name: "Override is empty + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain X"), clusterLabelOverride: ""}, want: "primary A"}, +// {name: "Override is empty + project-domain has an existing mapping2", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain Z"), clusterLabelOverride: ""}, want: "primary B"}, +// {name: "Override is empty + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain blah"), clusterLabelOverride: ""}, want: DefaultClusterPrimaryLabel}, +// {name: "Override is empty + project-domain has NO existing mapping2", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project blah", "domain X"), clusterLabelOverride: ""}, want: DefaultClusterPrimaryLabel}, +// } +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// if got := getClusterPrimaryLabel(tt.args.ctx, tt.args.tCtx, tt.args.clusterLabelOverride); got != tt.want { +// t.Errorf("getClusterPrimaryLabel() = %v, want %v", got, tt.want) +// } +// }) +// } +//} diff --git a/go/tasks/plugins/presto/test_helpers.go b/go/tasks/plugins/presto/test_helpers.go new file mode 100644 index 000000000..22509b1ed --- /dev/null +++ b/go/tasks/plugins/presto/test_helpers.go @@ -0,0 +1,112 @@ +package presto + +import ( + structpb "github.com/golang/protobuf/ptypes/struct" + idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" + coreMock "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + ioMock "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/io/mocks" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/utils" + "github.com/lyft/flytestdlib/storage" + "github.com/stretchr/testify/mock" + v1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/api/resource" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" +) + +func GetSingleHiveQueryTaskTemplate() idlCore.TaskTemplate { + prestoQuery := plugins.PrestoQuery{ + RoutingGroup: "adhoc", + Catalog: "hive", + Schema: "city", + Statement: "select * from hive.city.fact_airport_sessions limit 10", + } + stObj := &structpb.Struct{} + _ = utils.MarshalStruct(&prestoQuery, stObj) + tt := idlCore.TaskTemplate{ + Type: "presto", + Custom: stObj, + Id: &idlCore.Identifier{ + Name: "sample_presto_task_test_name", + Project: "flyteplugins", + Version: "1", + ResourceType: idlCore.ResourceType_TASK, + }, + } + + return tt +} + +var resourceRequirements = &v1.ResourceRequirements{ + Limits: v1.ResourceList{ + v1.ResourceCPU: resource.MustParse("1024m"), + v1.ResourceStorage: resource.MustParse("100M"), + }, +} + +func GetMockTaskExecutionMetadata() core.TaskExecutionMetadata { + taskMetadata := &coreMock.TaskExecutionMetadata{} + taskMetadata.On("GetNamespace").Return("test-namespace") + taskMetadata.On("GetAnnotations").Return(map[string]string{"annotation-1": "val1"}) + taskMetadata.On("GetLabels").Return(map[string]string{"label-1": "val1"}) + taskMetadata.On("GetOwnerReference").Return(metav1.OwnerReference{ + Kind: "node", + Name: "blah", + }) + taskMetadata.On("GetK8sServiceAccount").Return("service-account") + taskMetadata.On("GetOwnerID").Return(types.NamespacedName{ + Namespace: "test-namespace", + Name: "test-owner-name", + }) + + tID := &coreMock.TaskExecutionID{} + tID.On("GetID").Return(idlCore.TaskExecutionIdentifier{ + NodeExecutionId: &idlCore.NodeExecutionIdentifier{ + ExecutionId: &idlCore.WorkflowExecutionIdentifier{ + Name: "my_wf_exec_name", + Project: "my_wf_exec_project", + Domain: "my_wf_exec_domain", + }, + }, + }) + tID.On("GetGeneratedName").Return("my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name") + taskMetadata.On("GetTaskExecutionID").Return(tID) + + to := &coreMock.TaskOverrides{} + to.On("GetResources").Return(resourceRequirements) + taskMetadata.On("GetOverrides").Return(to) + + return taskMetadata +} + +func GetMockTaskExecutionContext() core.TaskExecutionContext { + tt := GetSingleHiveQueryTaskTemplate() + + dummyTaskMetadata := GetMockTaskExecutionMetadata() + taskCtx := &coreMock.TaskExecutionContext{} + inputReader := &ioMock.InputReader{} + inputReader.On("GetInputPath").Return(storage.DataReference("test-data-reference")) + inputReader.On("Get", mock.Anything).Return(&idlCore.LiteralMap{}, nil) + taskCtx.On("InputReader").Return(inputReader) + + outputReader := &ioMock.OutputWriter{} + outputReader.On("GetOutputPath").Return(storage.DataReference("/data/outputs.pb")) + outputReader.On("GetOutputPrefixPath").Return(storage.DataReference("/data/")) + taskCtx.On("OutputWriter").Return(outputReader) + + taskReader := &coreMock.TaskReader{} + taskReader.On("Read", mock.Anything).Return(&tt, nil) + taskCtx.On("TaskReader").Return(taskReader) + + resourceManager := &coreMock.ResourceManager{} + taskCtx.On("ResourceManager").Return(resourceManager) + + taskCtx.On("TaskExecutionMetadata").Return(dummyTaskMetadata) + mockSecretManager := &coreMock.SecretManager{} + mockSecretManager.On("Get", mock.Anything, mock.Anything).Return("fake key", nil) + taskCtx.On("SecretManager").Return(mockSecretManager) + + return taskCtx +} From 0211489768efc30bfc3f8c7df05ea3ac4632cd26 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 11 Mar 2020 14:49:22 -0700 Subject: [PATCH 04/26] more unit tests --- .../plugins/presto/executions_cache_test.go | 90 +++++++++++++++++++ 1 file changed, 90 insertions(+) create mode 100644 go/tasks/plugins/presto/executions_cache_test.go diff --git a/go/tasks/plugins/presto/executions_cache_test.go b/go/tasks/plugins/presto/executions_cache_test.go new file mode 100644 index 000000000..235dbabea --- /dev/null +++ b/go/tasks/plugins/presto/executions_cache_test.go @@ -0,0 +1,90 @@ +package presto + +import ( + "context" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/command/mocks" + "testing" + + "github.com/lyft/flytestdlib/cache" + cacheMocks "github.com/lyft/flytestdlib/cache/mocks" + + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + + "github.com/lyft/flytestdlib/promutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { + ctx := context.Background() + + t.Run("terminal state return unchanged", func(t *testing.T) { + mockCache := &cacheMocks.AutoRefresh{} + mockPresto := &prestoMocks.CommandClient{} + testScope := promutils.NewTestScope() + + p := PrestoExecutionsCache{ + AutoRefresh: mockCache, + prestoClient: mockPresto, + scope: testScope, + cfg: config.GetPrestoConfig(), + } + + state := ExecutionState{ + Phase: PhaseQuerySucceeded, + } + cacheItem := ExecutionStateCacheItem{ + ExecutionState: state, + Id: "some-id", + } + + iw := &cacheMocks.ItemWrapper{} + iw.OnGetItem().Return(cacheItem) + iw.OnGetID().Return("some-id") + + newCacheItem, err := p.SyncPrestoQuery(ctx, []cache.ItemWrapper{iw}) + assert.NoError(t, err) + assert.Equal(t, cache.Unchanged, newCacheItem[0].Action) + assert.Equal(t, cacheItem, newCacheItem[0].Item) + }) + + t.Run("move to success", func(t *testing.T) { + mockCache := &cacheMocks.AutoRefresh{} + mockPresto := &prestoMocks.CommandClient{} + mockSecretManager := &mocks.SecretManager{} + mockSecretManager.OnGetMatch(mock.Anything, mock.Anything).Return("fake key", nil) + + testScope := promutils.NewTestScope() + + p := PrestoExecutionsCache{ + AutoRefresh: mockCache, + prestoClient: mockPresto, + scope: testScope, + cfg: config.GetPrestoConfig(), + } + + state := ExecutionState{ + CommandId: "123456", + Phase: PhaseSubmitted, + } + cacheItem := ExecutionStateCacheItem{ + ExecutionState: state, + Id: "some-id", + } + mockPresto.OnGetCommandStatusMatch(mock.Anything, mock.MatchedBy(func(commandId string) bool { + return commandId == state.CommandId + }), mock.Anything).Return(client.PrestoStatusFinished, nil) + + iw := &cacheMocks.ItemWrapper{} + iw.OnGetItem().Return(cacheItem) + iw.OnGetID().Return("some-id") + + newCacheItem, err := p.SyncPrestoQuery(ctx, []cache.ItemWrapper{iw}) + newExecutionState := newCacheItem[0].Item.(ExecutionStateCacheItem) + assert.NoError(t, err) + assert.Equal(t, cache.Update, newCacheItem[0].Action) + assert.Equal(t, PhaseQuerySucceeded, newExecutionState.Phase) + }) +} From 526c3ac5536c9e7d1d3033e8c41bcb2066195abd Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 11 Mar 2020 15:03:46 -0700 Subject: [PATCH 05/26] update to correct import --- go/tasks/plugins/presto/client/presto_status.go | 6 +++--- go/tasks/plugins/presto/executions_cache.go | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go/tasks/plugins/presto/client/presto_status.go b/go/tasks/plugins/presto/client/presto_status.go index eeff09361..fa46cfe97 100644 --- a/go/tasks/plugins/presto/client/presto_status.go +++ b/go/tasks/plugins/presto/client/presto_status.go @@ -15,7 +15,7 @@ const ( PrestoStatusQueued command.CommandStatus = "QUEUED" PrestoStatusRunning command.CommandStatus = "RUNNING" PrestoStatusFinished command.CommandStatus = "FINISHED" - PrestoStatusError command.CommandStatus = "FAILED" + PrestoStatusFailed command.CommandStatus = "FAILED" PrestoStatusCancelled command.CommandStatus = "CANCELLED" ) @@ -24,14 +24,14 @@ var PrestoStatuses = map[command.CommandStatus]struct{}{ PrestoStatusQueued: {}, PrestoStatusRunning: {}, PrestoStatusFinished: {}, - PrestoStatusError: {}, + PrestoStatusFailed: {}, PrestoStatusCancelled: {}, } func NewPrestoStatus(ctx context.Context, state string) command.CommandStatus { upperCased := strings.ToUpper(state) if strings.Contains(upperCased, "FAILED") { - return PrestoStatusError + return PrestoStatusFailed } else if _, ok := PrestoStatuses[command.CommandStatus(upperCased)]; ok { return command.CommandStatus(upperCased) } else { diff --git a/go/tasks/plugins/presto/executions_cache.go b/go/tasks/plugins/presto/executions_cache.go index 5f2a1e304..685565a40 100644 --- a/go/tasks/plugins/presto/executions_cache.go +++ b/go/tasks/plugins/presto/executions_cache.go @@ -12,7 +12,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/errors" stdErrors "github.com/lyft/flytestdlib/errors" - "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client3" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" "github.com/lyft/flytestdlib/logger" From ebcf04a6bd5a3eace86047adafcfe2882c2da4ad Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 11 Mar 2020 15:51:33 -0700 Subject: [PATCH 06/26] fix lint --- .../{command => cmd}/command_client.go | 2 +- .../{command => cmd}/mocks/command_client.go | 2 +- .../plugins/presto/client/presto_client.go | 42 ++++----- .../plugins/presto/client/presto_status.go | 23 +++-- go/tasks/plugins/presto/config/config.go | 4 +- go/tasks/plugins/presto/execution_state.go | 92 +++++++++---------- .../plugins/presto/execution_state_test.go | 69 ++------------ go/tasks/plugins/presto/executions_cache.go | 42 ++++----- .../plugins/presto/executions_cache_test.go | 14 +-- go/tasks/plugins/presto/executor.go | 42 ++++----- go/tasks/plugins/presto/executor_metrics.go | 6 +- 11 files changed, 142 insertions(+), 196 deletions(-) rename go/tasks/plugins/{command => cmd}/command_client.go (95%) rename go/tasks/plugins/{command => cmd}/mocks/command_client.go (98%) diff --git a/go/tasks/plugins/command/command_client.go b/go/tasks/plugins/cmd/command_client.go similarity index 95% rename from go/tasks/plugins/command/command_client.go rename to go/tasks/plugins/cmd/command_client.go index abd000782..7c9ab3bc6 100644 --- a/go/tasks/plugins/command/command_client.go +++ b/go/tasks/plugins/cmd/command_client.go @@ -1,4 +1,4 @@ -package command +package cmd import ( "context" diff --git a/go/tasks/plugins/command/mocks/command_client.go b/go/tasks/plugins/cmd/mocks/command_client.go similarity index 98% rename from go/tasks/plugins/command/mocks/command_client.go rename to go/tasks/plugins/cmd/mocks/command_client.go index 9e72c1dcb..46769f674 100644 --- a/go/tasks/plugins/command/mocks/command_client.go +++ b/go/tasks/plugins/cmd/mocks/command_client.go @@ -5,7 +5,7 @@ package mocks import ( context "context" - command "github.com/lyft/flyteplugins/go/tasks/plugins/command" + command "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" mock "github.com/stretchr/testify/mock" ) diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index bd1ac30b9..8d800435b 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -1,9 +1,8 @@ package client import ( - //"bytes" "context" - "github.com/lyft/flyteplugins/go/tasks/plugins/command" + "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" "net/http" "net/url" @@ -14,16 +13,15 @@ import ( const ( httpRequestTimeoutSecs = 30 - - AcceptHeaderKey = "Accept" - ContentTypeHeaderKey = "Content-Type" - ContentTypeJSON = "application/json" - ContentTypeTextPlain = "text/plain" - PrestoCatalogHeader = "X-Presto-Catalog" - PrestoRoutingGroupHeader = "X-Presto-Routing-Group" - PrestoSchemaHeader = "X-Presto-Schema" - PrestoSourceHeader = "X-Presto-Source" - PrestoUserHeader = "X-Presto-User" + //AcceptHeaderKey = "Accept" + //ContentTypeHeaderKey = "Content-Type" + //ContentTypeJSON = "application/json" + //ContentTypeTextPlain = "text/plain" + //PrestoCatalogHeader = "X-Presto-Catalog" + //PrestoRoutingGroupHeader = "X-Presto-Routing-Group" + //PrestoSchemaHeader = "X-Presto-Schema" + //PrestoSourceHeader = "X-Presto-Source" + //PrestoUserHeader = "X-Presto-User" ) type prestoClient struct { @@ -32,15 +30,15 @@ type prestoClient struct { } type PrestoExecuteArgs struct { - RoutingGroup string `json:"routing_group, omitempty"` - Catalog string `json:"catalog, omitempty"` - Schema string `json:"schema, omitempty"` - Source string `json:"source, omitempty"` + RoutingGroup string `json:"routing_group,omitempty"` + Catalog string `json:"catalog,omitempty"` + Schema string `json:"schema,omitempty"` + Source string `json:"source,omitempty"` } type PrestoExecuteResponse struct { - Id string - Status command.CommandStatus - NextUri string + ID string + Status cmd.CommandStatus + NextURI string } func (p *prestoClient) ExecuteCommand( @@ -55,11 +53,11 @@ func (p *prestoClient) KillCommand(ctx context.Context, commandID string) error return nil } -func (p *prestoClient) GetCommandStatus(ctx context.Context, commandId string) (command.CommandStatus, error) { - return PrestoStatusUnknown, nil +func (p *prestoClient) GetCommandStatus(ctx context.Context, commandID string) (cmd.CommandStatus, error) { + return NewPrestoStatus(ctx, "UNKNOWN"), nil } -func NewPrestoClient(cfg *config.Config) command.CommandClient { +func NewPrestoClient(cfg *config.Config) cmd.CommandClient { return &prestoClient{ client: &http.Client{Timeout: httpRequestTimeoutSecs * time.Second}, environment: cfg.Environment.ResolveReference(&cfg.Environment.URL), diff --git a/go/tasks/plugins/presto/client/presto_status.go b/go/tasks/plugins/presto/client/presto_status.go index fa46cfe97..3153c6fc7 100644 --- a/go/tasks/plugins/presto/client/presto_status.go +++ b/go/tasks/plugins/presto/client/presto_status.go @@ -2,24 +2,23 @@ package client import ( "context" - "github.com/lyft/flyteplugins/go/tasks/plugins/command" + "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" "github.com/lyft/flytestdlib/logger" "strings" ) // This type is meant only to encapsulate the response coming from Presto as a type, it is // not meant to be stored locally. - const ( - PrestoStatusUnknown command.CommandStatus = "UNKNOWN" - PrestoStatusQueued command.CommandStatus = "QUEUED" - PrestoStatusRunning command.CommandStatus = "RUNNING" - PrestoStatusFinished command.CommandStatus = "FINISHED" - PrestoStatusFailed command.CommandStatus = "FAILED" - PrestoStatusCancelled command.CommandStatus = "CANCELLED" + PrestoStatusUnknown cmd.CommandStatus = "UNKNOWN" + PrestoStatusQueued cmd.CommandStatus = "QUEUED" + PrestoStatusRunning cmd.CommandStatus = "RUNNING" + PrestoStatusFinished cmd.CommandStatus = "FINISHED" + PrestoStatusFailed cmd.CommandStatus = "FAILED" + PrestoStatusCancelled cmd.CommandStatus = "CANCELLED" ) -var PrestoStatuses = map[command.CommandStatus]struct{}{ +var PrestoStatuses = map[cmd.CommandStatus]struct{}{ PrestoStatusUnknown: {}, PrestoStatusQueued: {}, PrestoStatusRunning: {}, @@ -28,12 +27,12 @@ var PrestoStatuses = map[command.CommandStatus]struct{}{ PrestoStatusCancelled: {}, } -func NewPrestoStatus(ctx context.Context, state string) command.CommandStatus { +func NewPrestoStatus(ctx context.Context, state string) cmd.CommandStatus { upperCased := strings.ToUpper(state) if strings.Contains(upperCased, "FAILED") { return PrestoStatusFailed - } else if _, ok := PrestoStatuses[command.CommandStatus(upperCased)]; ok { - return command.CommandStatus(upperCased) + } else if _, ok := PrestoStatuses[cmd.CommandStatus(upperCased)]; ok { + return cmd.CommandStatus(upperCased) } else { logger.Warnf(ctx, "Invalid Presto Status found: %v", state) return PrestoStatusUnknown diff --git a/go/tasks/plugins/presto/config/config.go b/go/tasks/plugins/presto/config/config.go index 1b3e141d4..5190ebee3 100644 --- a/go/tasks/plugins/presto/config/config.go +++ b/go/tasks/plugins/presto/config/config.go @@ -14,7 +14,7 @@ import ( const prestoConfigSectionKey = "presto" -func UrlMustParse(s string) config.URL { +func URLMustParse(s string) config.URL { r, err := url.Parse(s) if err != nil { logger.Panicf(context.TODO(), "Bad Presto URL Specified as default, error: %s", err) @@ -34,7 +34,7 @@ type RoutingGroupConfig struct { var ( defaultConfig = Config{ - Environment: UrlMustParse("https://prestoproxy-internal.lyft.net:443"), + Environment: URLMustParse("https://prestoproxy-internal.lyft.net:443"), DefaultRoutingGroup: "adhoc", Workers: 15, LruCacheSize: 2000, diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 2cc380baf..a7d58ae4c 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -21,7 +21,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" "github.com/lyft/flytestdlib/logger" - "github.com/lyft/flyteplugins/go/tasks/plugins/command" + "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" ) type ExecutionPhase int @@ -54,11 +54,11 @@ type ExecutionState struct { Phase ExecutionPhase // This will store the command ID from Presto - CommandId string `json:"command_id,omitempty"` + CommandID string `json:"command_id,omitempty"` URI string `json:"uri,omitempty"` - CurrentPrestoQuery PrestoQuery `json:"current_presto_query, omitempty"` - QueryCount int `json:"query_count,omitempty"` + CurrentPrestoQuery Query `json:"current_presto_query,omitempty"` + QueryCount int `json:"query_count,omitempty"` // This number keeps track of the number of failures within the sync function. Without this, what happens in // the sync function is entirely opaque. Note that this field is completely orthogonal to Flyte system/node/task @@ -72,11 +72,11 @@ type ExecutionState struct { AllocationTokenRequestStartTime time.Time `json:"allocation_token_request_start_time,omitempty"` } -type PrestoQuery struct { - Statement string `json:"statement, omitempty"` - ExtraArgs client.PrestoExecuteArgs `json:"extra_args, omitempty"` - TempTableName string `json:"temp_table_name, omitempty"` - ExternalTableName string `json:"external_table_name, omitempty"` +type Query struct { + Statement string `json:"statement,omitempty"` + ExtraArgs client.PrestoExecuteArgs `json:"extra_args,omitempty"` + TempTableName string `json:"temp_table_name,omitempty"` + ExternalTableName string `json:"external_table_name,omitempty"` ExternalLocation string `json:"external_location"` } @@ -85,9 +85,9 @@ func HandleExecutionState( ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, - prestoClient command.CommandClient, + prestoClient cmd.CommandClient, executionsCache cache.AutoRefresh, - metrics PrestoExecutorMetrics) (ExecutionState, error) { + metrics ExecutorMetrics) (ExecutionState, error) { var transformError error var newState ExecutionState @@ -111,7 +111,7 @@ func HandleExecutionState( if currentState.QueryCount < 4 { // If there are still Presto statements to execute, increment the query count, reset the phase to get a new // allocation token, and continue executing the remaining statements - currentState.QueryCount += 1 + currentState.QueryCount++ currentState.Phase = PhaseQueued } newState = currentState @@ -129,25 +129,25 @@ func GetAllocationToken( ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, - metric PrestoExecutorMetrics) (ExecutionState, error) { + metric ExecutorMetrics) (ExecutionState, error) { newState := ExecutionState{} - uniqueId := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() routingGroup, err := composeResourceNamespaceWithRoutingGroup(ctx, tCtx) if err != nil { - return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when requesting allocation token %s", uniqueId) + return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when requesting allocation token %s", uniqueID) } resourceConstraintsSpec := createResourceConstraintsSpec(ctx, tCtx, routingGroup) - allocationStatus, err := tCtx.ResourceManager().AllocateResource(ctx, routingGroup, uniqueId, resourceConstraintsSpec) + allocationStatus, err := tCtx.ResourceManager().AllocateResource(ctx, routingGroup, uniqueID, resourceConstraintsSpec) if err != nil { logger.Errorf(ctx, "Resource manager failed for TaskExecId [%s] token [%s]. error %s", - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueId, err) - return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error requesting allocation token %s", uniqueId) + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueID, err) + return newState, errors.Wrapf(errors.ResourceManagerFailure, err, "Error requesting allocation token %s", uniqueID) } - logger.Infof(ctx, "Allocation result for [%s] is [%s]", uniqueId, allocationStatus) + logger.Infof(ctx, "Allocation result for [%s] is [%s]", uniqueID, allocationStatus) // Emitting the duration this execution has been waiting for a token allocation if currentState.AllocationTokenRequestStartTime.IsZero() { @@ -166,7 +166,7 @@ func GetAllocationToken( newState.Phase = PhaseNotStarted } else { return newState, errors.Errorf(errors.ResourceManagerFailure, "Got bad allocation result [%s] for token [%s]", - allocationStatus, uniqueId) + allocationStatus, uniqueID) } return newState, nil @@ -263,7 +263,7 @@ func createResourceConstraintsSpec(ctx context.Context, _ core.TaskExecutionCont func GetNextQuery( ctx context.Context, tCtx core.TaskExecutionContext, - currentState ExecutionState) (PrestoQuery, error) { + currentState ExecutionState) (Query, error) { switch currentState.QueryCount { case 0: @@ -271,12 +271,12 @@ func GetNextQuery( tempTableName := generateRandomString(32) routingGroup, catalog, schema, statement, err := GetQueryInfo(ctx, tCtx) if err != nil { - return PrestoQuery{}, err + return Query{}, err } statement = fmt.Sprintf(`CREATE TABLE hive.flyte_temporary_tables.%s_temp AS %s`, tempTableName, statement) - prestoQuery := PrestoQuery{ + prestoQuery := Query{ Statement: statement, ExtraArgs: client.PrestoExecuteArgs{ RoutingGroup: resolveRoutingGroup(ctx, routingGroup, prestoCfg), @@ -365,10 +365,10 @@ func KickOffQuery( ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, - prestoClient command.CommandClient, + prestoClient cmd.CommandClient, cache cache.AutoRefresh) (ExecutionState, error) { - uniqueId := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() statement := currentState.CurrentPrestoQuery.Statement extraArgs := currentState.CurrentPrestoQuery.ExtraArgs @@ -377,29 +377,29 @@ func KickOffQuery( if err != nil { // If we failed, we'll keep the NotStarted state currentState.CreationFailureCount = currentState.CreationFailureCount + 1 - logger.Warnf(ctx, "Error creating Presto query for %s, failure counts %d. Error: %s", uniqueId, currentState.CreationFailureCount, err) + logger.Warnf(ctx, "Error creating Presto query for %s, failure counts %d. Error: %s", uniqueID, currentState.CreationFailureCount, err) } else { executeResponse := response.(client.PrestoExecuteResponse) // If we succeed, then store the command id returned from Presto, and update our state. Also, add to the // AutoRefreshCache so we start getting updates for its status. - commandId := executeResponse.Id - logger.Infof(ctx, "Created Presto Id [%s] for token %s", commandId, uniqueId) - currentState.CommandId = commandId + commandID := executeResponse.ID + logger.Infof(ctx, "Created Presto ID [%s] for token %s", commandID, uniqueID) + currentState.CommandID = commandID currentState.Phase = PhaseSubmitted - currentState.URI = executeResponse.NextUri + currentState.URI = executeResponse.NextURI executionStateCacheItem := ExecutionStateCacheItem{ ExecutionState: currentState, - Id: uniqueId, + Identifier: uniqueID, } // The first time we put it in the cache, we know it won't have succeeded so we don't need to look at it - _, err := cache.GetOrCreate(uniqueId, executionStateCacheItem) + _, err := cache.GetOrCreate(uniqueID, executionStateCacheItem) if err != nil { // This means that our cache has fundamentally broken... return a system error logger.Errorf(ctx, "Cache failed to GetOrCreate for execution [%s] cache key [%s], owner [%s]. Error %s", - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueId, + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueID, tCtx.TaskExecutionMetadata().GetOwnerReference(), err) return currentState, err } @@ -414,17 +414,17 @@ func MonitorQuery( currentState ExecutionState, cache cache.AutoRefresh) (ExecutionState, error) { - uniqueId := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() executionStateCacheItem := ExecutionStateCacheItem{ ExecutionState: currentState, - Id: uniqueId, + Identifier: uniqueID, } - cachedItem, err := cache.GetOrCreate(uniqueId, executionStateCacheItem) + cachedItem, err := cache.GetOrCreate(uniqueID, executionStateCacheItem) if err != nil { // This means that our cache has fundamentally broken... return a system error logger.Errorf(ctx, "Cache is broken on execution [%s] cache key [%s], owner [%s]. Error %s", - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueId, + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueID, tCtx.TaskExecutionMetadata().GetOwnerReference(), err) return currentState, errors.Wrapf(errors.CacheFailed, err, "Error when GetOrCreate while monitoring") } @@ -469,7 +469,7 @@ func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { logs := make([]*idlCore.TaskLog, 0, 1) t := time.Now() - if e.CommandId != "" { + if e.CommandID != "" { logs = append(logs, ConstructTaskLog(e)) return &core.TaskInfo{ Logs: logs, @@ -482,16 +482,16 @@ func ConstructTaskInfo(e ExecutionState) *core.TaskInfo { func ConstructTaskLog(e ExecutionState) *idlCore.TaskLog { return &idlCore.TaskLog{ - Name: fmt.Sprintf("Status: %s [%s]", e.Phase, e.CommandId), + Name: fmt.Sprintf("Status: %s [%s]", e.Phase, e.CommandID), MessageFormat: idlCore.TaskLog_UNKNOWN, Uri: e.URI, } } -func Abort(ctx context.Context, currentState ExecutionState, client command.CommandClient) error { +func Abort(ctx context.Context, currentState ExecutionState, client cmd.CommandClient) error { // Cancel Presto query if non-terminal state - if !InTerminalState(currentState) && currentState.CommandId != "" { - err := client.KillCommand(ctx, currentState.CommandId) + if !InTerminalState(currentState) && currentState.CommandID != "" { + err := client.KillCommand(ctx, currentState.CommandID) if err != nil { logger.Errorf(ctx, "Error terminating Presto command in Finalize [%s]", err) return err @@ -502,16 +502,16 @@ func Abort(ctx context.Context, currentState ExecutionState, client command.Comm func Finalize(ctx context.Context, tCtx core.TaskExecutionContext, _ ExecutionState) error { // Release allocation token - uniqueId := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() routingGroup, err := composeResourceNamespaceWithRoutingGroup(ctx, tCtx) if err != nil { - return errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when releasing allocation token %s", uniqueId) + return errors.Wrapf(errors.ResourceManagerFailure, err, "Error getting query info when releasing allocation token %s", uniqueID) } - err = tCtx.ResourceManager().ReleaseResource(ctx, routingGroup, uniqueId) + err = tCtx.ResourceManager().ReleaseResource(ctx, routingGroup, uniqueID) if err != nil { - logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", uniqueId, err) + logger.Errorf(ctx, "Error releasing allocation token [%s] in Finalize [%s]", uniqueID, err) return err } return nil diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index 5d3a92eb1..877a29fd7 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -2,9 +2,8 @@ package presto import ( "context" - idlCore "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/command/mocks" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/cmd/mocks" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" mocks2 "github.com/lyft/flytestdlib/cache/mocks" @@ -108,7 +107,7 @@ func TestConstructTaskLog(t *testing.T) { expected := "https://prestoproxy-internal.lyft.net:443" u, err := url.Parse(expected) assert.NoError(t, err) - taskLog := ConstructTaskLog(ExecutionState{CommandId: "123", URI: u.String()}) + taskLog := ConstructTaskLog(ExecutionState{CommandID: "123", URI: u.String()}) assert.Equal(t, expected, taskLog.Uri) } @@ -122,7 +121,7 @@ func TestConstructTaskInfo(t *testing.T) { e := ExecutionState{ Phase: PhaseQuerySucceeded, - CommandId: "123", + CommandID: "123", SyncFailureCount: 0, URI: u.String(), } @@ -260,7 +259,7 @@ func TestAbort(t *testing.T) { x = true }).Return(nil) - err := Abort(ctx, ExecutionState{Phase: PhaseSubmitted, CommandId: "123456"}, mockPresto) + err := Abort(ctx, ExecutionState{Phase: PhaseSubmitted, CommandID: "123456"}, mockPresto) assert.NoError(t, err) assert.True(t, x) }) @@ -273,7 +272,7 @@ func TestAbort(t *testing.T) { x = true }).Return(nil) - err := Abort(ctx, ExecutionState{Phase: PhaseQuerySucceeded, CommandId: "123456",}, mockPresto) + err := Abort(ctx, ExecutionState{Phase: PhaseQuerySucceeded, CommandID: "123456",}, mockPresto) assert.NoError(t, err) assert.False(t, x) }) @@ -306,7 +305,7 @@ func TestMonitorQuery(t *testing.T) { mockCache := &mocks2.AutoRefresh{} mockCache.OnGetOrCreateMatch("my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", mock.Anything).Return(ExecutionStateCacheItem{ ExecutionState: ExecutionState{Phase: PhaseQuerySucceeded}, - Id: "my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", + Identifier: "my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", }, nil).Run(func(_ mock.Arguments) { getOrCreateCalled = true }) @@ -324,7 +323,7 @@ func TestKickOffQuery(t *testing.T) { var prestoCalled = false prestoExecuteResponse := client.PrestoExecuteResponse{ - Id: "1234567", + ID: "1234567", Status: client.PrestoStatusQueued, } mockPresto := &prestoMocks.CommandClient{} @@ -342,14 +341,14 @@ func TestKickOffQuery(t *testing.T) { newState, err := KickOffQuery(ctx, tCtx, state, mockPresto, mockCache) assert.NoError(t, err) assert.Equal(t, PhaseSubmitted, newState.Phase) - assert.Equal(t, "1234567", newState.CommandId) + assert.Equal(t, "1234567", newState.CommandID) assert.True(t, getOrCreateCalled) assert.True(t, prestoCalled) } func createMockPrestoCfg() *config.Config { return &config.Config{ - Environment: config.UrlMustParse("https://prestoproxy-internal.lyft.net:443"), + Environment: config.URLMustParse("https://prestoproxy-internal.lyft.net:443"), DefaultRoutingGroup: "adhoc", Workers: 15, LruCacheSize: 2000, @@ -382,53 +381,3 @@ func Test_mapLabelToPrimaryLabel(t *testing.T) { }) } } - -func createMockTaskExecutionContextWithProjectDomain(project string, domain string) *mocks.TaskExecutionContext { - mockTaskExecutionContext := mocks.TaskExecutionContext{} - taskExecID := &pluginsCoreMocks.TaskExecutionID{} - taskExecID.OnGetID().Return(idlCore.TaskExecutionIdentifier{ - NodeExecutionId: &idlCore.NodeExecutionIdentifier{ExecutionId: &idlCore.WorkflowExecutionIdentifier{ - Project: project, - Domain: domain, - Name: "random name", - }}, - }) - - taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} - taskMetadata.OnGetTaskExecutionID().Return(taskExecID) - mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) - return &mockTaskExecutionContext -} - -//func Test_getClusterPrimaryLabel(t *testing.T) { -// ctx := context.TODO() -// err := config.SetQuboleConfig(createMockPrestoCfg()) -// assert.Nil(t, err) -// -// type args struct { -// ctx context.Context -// tCtx core.TaskExecutionContext -// clusterLabelOverride string -// } -// tests := []struct { -// name string -// args args -// want string -// }{ -// {name: "Override is not empty + override has NO existing mapping + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain Z"), clusterLabelOverride: "AAAA"}, want: "primary B"}, -// {name: "Override is not empty + override has NO existing mapping + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain blah"), clusterLabelOverride: "blh"}, want: DefaultClusterPrimaryLabel}, -// {name: "Override is not empty + override has an existing mapping + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project blah", "domain blah"), clusterLabelOverride: "C-prod"}, want: "primary C"}, -// {name: "Override is not empty + override has an existing mapping + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain A"), clusterLabelOverride: "C-prod"}, want: "primary C"}, -// {name: "Override is empty + project-domain has an existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain X"), clusterLabelOverride: ""}, want: "primary A"}, -// {name: "Override is empty + project-domain has an existing mapping2", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain Z"), clusterLabelOverride: ""}, want: "primary B"}, -// {name: "Override is empty + project-domain has NO existing mapping", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project A", "domain blah"), clusterLabelOverride: ""}, want: DefaultClusterPrimaryLabel}, -// {name: "Override is empty + project-domain has NO existing mapping2", args: args{ctx: ctx, tCtx: createMockTaskExecutionContextWithProjectDomain("project blah", "domain X"), clusterLabelOverride: ""}, want: DefaultClusterPrimaryLabel}, -// } -// for _, tt := range tests { -// t.Run(tt.name, func(t *testing.T) { -// if got := getClusterPrimaryLabel(tt.args.ctx, tt.args.tCtx, tt.args.clusterLabelOverride); got != tt.want { -// t.Errorf("getClusterPrimaryLabel() = %v, want %v", got, tt.want) -// } -// }) -// } -//} diff --git a/go/tasks/plugins/presto/executions_cache.go b/go/tasks/plugins/presto/executions_cache.go index 685565a40..3126df1c3 100644 --- a/go/tasks/plugins/presto/executions_cache.go +++ b/go/tasks/plugins/presto/executions_cache.go @@ -2,7 +2,7 @@ package presto import ( "context" - "github.com/lyft/flyteplugins/go/tasks/plugins/command" + "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" "time" "k8s.io/client-go/util/workqueue" @@ -25,27 +25,27 @@ const ( BadPrestoReturnCodeError stdErrors.ErrorCode = "PRESTO_RETURNED_UNKNOWN" ) -type PrestoExecutionsCache struct { +type ExecutionsCache struct { cache.AutoRefresh - prestoClient command.CommandClient + prestoClient cmd.CommandClient scope promutils.Scope cfg *config.Config } func NewPrestoExecutionsCache( ctx context.Context, - prestoClient command.CommandClient, + prestoClient cmd.CommandClient, cfg *config.Config, - scope promutils.Scope) (PrestoExecutionsCache, error) { + scope promutils.Scope) (ExecutionsCache, error) { - q := PrestoExecutionsCache{ + q := ExecutionsCache{ prestoClient: prestoClient, scope: scope, cfg: cfg, } autoRefreshCache, err := cache.NewAutoRefreshCache("presto", q.SyncPrestoQuery, workqueue.DefaultControllerRateLimiter(), ResyncDuration, cfg.Workers, cfg.LruCacheSize, scope) if err != nil { - logger.Errorf(ctx, "Could not create AutoRefreshCache in PrestoExecutor. [%s]", err) + logger.Errorf(ctx, "Could not create AutoRefreshCache in Executor. [%s]", err) return q, errors.Wrapf(errors.CacheFailed, err, "Error creating AutoRefreshCache") } q.AutoRefresh = autoRefreshCache @@ -58,16 +58,16 @@ type ExecutionStateCacheItem struct { // This ID is the cache key and so will need to be unique across all objects in the cache (it will probably be // unique across all of Flyte) and needs to be deterministic. // This will also be used as the allocation token for now. - Id string `json:"id"` + Identifier string `json:"id"` } func (e ExecutionStateCacheItem) ID() string { - return e.Id + return e.Identifier } // This basically grab an updated status from the Presto API and stores it in the cache // All other handling should be in the synchronous loop. -func (p *PrestoExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache.Batch) ( +func (p *ExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache.Batch) ( updatedBatch []cache.ItemSyncResponse, err error) { resp := make([]cache.ItemSyncResponse, 0, len(batch)) @@ -79,8 +79,8 @@ func (p *PrestoExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache return nil, errors.Errorf(errors.CacheFailed, "Failed to cast [%v]", batch[0].GetID()) } - if executionStateCacheItem.CommandId == "" { - logger.Warnf(ctx, "Sync loop - CommandID is blank for [%s] skipping", executionStateCacheItem.Id) + if executionStateCacheItem.CommandID == "" { + logger.Warnf(ctx, "Sync loop - CommandID is blank for [%s] skipping", executionStateCacheItem.Identifier) resp = append(resp, cache.ItemSyncResponse{ ID: query.GetID(), Item: query.GetItem(), @@ -91,11 +91,11 @@ func (p *PrestoExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache } logger.Debugf(ctx, "Sync loop - processing Presto job [%s] - cache key [%s]", - executionStateCacheItem.CommandId, executionStateCacheItem.Id) + executionStateCacheItem.CommandID, executionStateCacheItem.Identifier) if InTerminalState(executionStateCacheItem.ExecutionState) { logger.Debugf(ctx, "Sync loop - Presto id [%s] in terminal state [%s]", - executionStateCacheItem.CommandId, executionStateCacheItem.Id) + executionStateCacheItem.CommandID, executionStateCacheItem.Identifier) resp = append(resp, cache.ItemSyncResponse{ ID: query.GetID(), @@ -107,10 +107,10 @@ func (p *PrestoExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache } // Get an updated status from Presto - logger.Debugf(ctx, "Querying Presto for %s - %s", executionStateCacheItem.CommandId, executionStateCacheItem.Id) - commandStatus, err := p.prestoClient.GetCommandStatus(ctx, executionStateCacheItem.CommandId) + logger.Debugf(ctx, "Querying Presto for %s - %s", executionStateCacheItem.CommandID, executionStateCacheItem.Identifier) + commandStatus, err := p.prestoClient.GetCommandStatus(ctx, executionStateCacheItem.CommandID) if err != nil { - logger.Errorf(ctx, "Error from Presto command %s", executionStateCacheItem.CommandId) + logger.Errorf(ctx, "Error from Presto command %s", executionStateCacheItem.CommandID) executionStateCacheItem.SyncFailureCount++ // Make sure we don't return nil for the first argument, because that deletes it from the cache. resp = append(resp, cache.ItemSyncResponse{ @@ -122,14 +122,14 @@ func (p *PrestoExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache continue } - newExecutionPhase, err := PrestoStatusToExecutionPhase(commandStatus) + newExecutionPhase, err := StatusToExecutionPhase(commandStatus) if err != nil { return nil, err } if newExecutionPhase > executionStateCacheItem.Phase { - logger.Infof(ctx, "Moving ExecutionPhase for %s %s from %s to %s", executionStateCacheItem.CommandId, - executionStateCacheItem.Id, executionStateCacheItem.Phase, newExecutionPhase) + logger.Infof(ctx, "Moving ExecutionPhase for %s %s from %s to %s", executionStateCacheItem.CommandID, + executionStateCacheItem.Identifier, executionStateCacheItem.Phase, newExecutionPhase) executionStateCacheItem.Phase = newExecutionPhase @@ -145,7 +145,7 @@ func (p *PrestoExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache } // We need some way to translate results we get from Presto, into a plugin phase -func PrestoStatusToExecutionPhase(s command.CommandStatus) (ExecutionPhase, error) { +func StatusToExecutionPhase(s cmd.CommandStatus) (ExecutionPhase, error) { switch s { case client.PrestoStatusFinished: return PhaseQuerySucceeded, nil diff --git a/go/tasks/plugins/presto/executions_cache_test.go b/go/tasks/plugins/presto/executions_cache_test.go index 235dbabea..aeb61d664 100644 --- a/go/tasks/plugins/presto/executions_cache_test.go +++ b/go/tasks/plugins/presto/executions_cache_test.go @@ -2,7 +2,7 @@ package presto import ( "context" - prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/command/mocks" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/cmd/mocks" "testing" "github.com/lyft/flytestdlib/cache" @@ -25,7 +25,7 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { mockPresto := &prestoMocks.CommandClient{} testScope := promutils.NewTestScope() - p := PrestoExecutionsCache{ + p := ExecutionsCache{ AutoRefresh: mockCache, prestoClient: mockPresto, scope: testScope, @@ -37,7 +37,7 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { } cacheItem := ExecutionStateCacheItem{ ExecutionState: state, - Id: "some-id", + Identifier: "some-id", } iw := &cacheMocks.ItemWrapper{} @@ -58,7 +58,7 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { testScope := promutils.NewTestScope() - p := PrestoExecutionsCache{ + p := ExecutionsCache{ AutoRefresh: mockCache, prestoClient: mockPresto, scope: testScope, @@ -66,15 +66,15 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { } state := ExecutionState{ - CommandId: "123456", + CommandID: "123456", Phase: PhaseSubmitted, } cacheItem := ExecutionStateCacheItem{ ExecutionState: state, - Id: "some-id", + Identifier: "some-id", } mockPresto.OnGetCommandStatusMatch(mock.Anything, mock.MatchedBy(func(commandId string) bool { - return commandId == state.CommandId + return commandId == state.CommandID }), mock.Anything).Return(client.PrestoStatusFinished, nil) iw := &cacheMocks.ItemWrapper{} diff --git a/go/tasks/plugins/presto/executor.go b/go/tasks/plugins/presto/executor.go index 6e476908e..38ad487af 100644 --- a/go/tasks/plugins/presto/executor.go +++ b/go/tasks/plugins/presto/executor.go @@ -2,7 +2,7 @@ package presto import ( "context" - "github.com/lyft/flyteplugins/go/tasks/plugins/command" + "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" "github.com/lyft/flytestdlib/cache" @@ -16,7 +16,7 @@ import ( ) // This is the name of this plugin effectively. In Flyte plugin configuration, use this string to enable this plugin. -const prestoExecutorId = "presto-executor" +const prestoExecutorID = "presto-executor" // Version of the custom state this plugin stores. Useful for backwards compatibility if you one day need to update // the structure of the stored state @@ -24,19 +24,19 @@ const pluginStateVersion = 0 const prestoTaskType = "presto" // This needs to match the type defined in Flytekit constants.py -type PrestoExecutor struct { +type Executor struct { id string - metrics PrestoExecutorMetrics - prestoClient command.CommandClient + metrics ExecutorMetrics + prestoClient cmd.CommandClient executionsCache cache.AutoRefresh cfg *config.Config } -func (p PrestoExecutor) GetID() string { +func (p Executor) GetID() string { return p.id } -func (p PrestoExecutor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { +func (p Executor) Handle(ctx context.Context, tCtx core.TaskExecutionContext) (core.Transition, error) { incomingState := ExecutionState{} // We assume here that the first time this function is called, the custom state we get back is whatever we passed in, @@ -66,7 +66,7 @@ func (p PrestoExecutor) Handle(ctx context.Context, tCtx core.TaskExecutionConte return core.DoTransitionType(core.TransitionTypeBarrier, phaseInfo), nil } -func (p PrestoExecutor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error { +func (p Executor) Abort(ctx context.Context, tCtx core.TaskExecutionContext) error { incomingState := ExecutionState{} if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state in Finalize [%s] Err [%s]", @@ -77,7 +77,7 @@ func (p PrestoExecutor) Abort(ctx context.Context, tCtx core.TaskExecutionContex return Abort(ctx, incomingState, p.prestoClient) } -func (p PrestoExecutor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { +func (p Executor) Finalize(ctx context.Context, tCtx core.TaskExecutionContext) error { incomingState := ExecutionState{} if _, err := tCtx.PluginStateReader().Get(&incomingState); err != nil { logger.Errorf(ctx, "Plugin %s failed to unmarshal custom state in Finalize [%s] Err [%s]", @@ -88,11 +88,11 @@ func (p PrestoExecutor) Finalize(ctx context.Context, tCtx core.TaskExecutionCon return Finalize(ctx, tCtx, incomingState) } -func (p PrestoExecutor) GetProperties() core.PluginProperties { +func (p Executor) GetProperties() core.PluginProperties { return core.PluginProperties{} } -func PrestoExecutorLoader(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { +func ExecutorLoader(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { cfg := config.GetPrestoConfig() return InitializePrestoExecutor(ctx, iCtx, cfg, BuildResourceConfig(cfg), client.NewPrestoClient(cfg)) } @@ -111,11 +111,11 @@ func InitializePrestoExecutor( iCtx core.SetupContext, cfg *config.Config, resourceConfig map[string]int, - prestoClient command.CommandClient) (core.Plugin, error) { + prestoClient cmd.CommandClient) (core.Plugin, error) { logger.Infof(ctx, "Initializing a Presto executor with a resource config [%v]", resourceConfig) q, err := NewPrestoExecutor(ctx, cfg, prestoClient, iCtx.MetricsScope()) if err != nil { - logger.Errorf(ctx, "Failed to create a new PrestoExecutor due to error: [%v]", err) + logger.Errorf(ctx, "Failed to create a new Executor due to error: [%v]", err) return nil, err } @@ -133,12 +133,12 @@ func InitializePrestoExecutor( func NewPrestoExecutor( ctx context.Context, cfg *config.Config, - prestoClient command.CommandClient, - scope promutils.Scope) (PrestoExecutor, error) { + prestoClient cmd.CommandClient, + scope promutils.Scope) (Executor, error) { executionsAutoRefreshCache, err := NewPrestoExecutionsCache(ctx, prestoClient, cfg, scope.NewSubScope(prestoTaskType)) if err != nil { - logger.Errorf(ctx, "Failed to create AutoRefreshCache in PrestoExecutor Setup. Error: %v", err) - return PrestoExecutor{}, err + logger.Errorf(ctx, "Failed to create AutoRefreshCache in Executor Setup. Error: %v", err) + return Executor{}, err } err = executionsAutoRefreshCache.Start(ctx) @@ -146,8 +146,8 @@ func NewPrestoExecutor( logger.Errorf(ctx, "Failed to start AutoRefreshCache. Error: %v", err) } - return PrestoExecutor{ - id: prestoExecutorId, + return Executor{ + id: prestoExecutorID, cfg: cfg, metrics: getPrestoExecutorMetrics(scope), prestoClient: prestoClient, @@ -158,9 +158,9 @@ func NewPrestoExecutor( func init() { pluginMachinery.PluginRegistry().RegisterCorePlugin( core.PluginEntry{ - ID: prestoExecutorId, + ID: prestoExecutorID, RegisteredTaskTypes: []core.TaskType{prestoTaskType}, - LoadPlugin: PrestoExecutorLoader, + LoadPlugin: ExecutorLoader, IsDefault: false, }) } diff --git a/go/tasks/plugins/presto/executor_metrics.go b/go/tasks/plugins/presto/executor_metrics.go index 08fc483c5..235a840eb 100644 --- a/go/tasks/plugins/presto/executor_metrics.go +++ b/go/tasks/plugins/presto/executor_metrics.go @@ -6,7 +6,7 @@ import ( "github.com/prometheus/client_golang/prometheus" ) -type PrestoExecutorMetrics struct { +type ExecutorMetrics struct { Scope promutils.Scope ReleaseResourceFailed labeled.Counter AllocationGranted labeled.Counter @@ -18,8 +18,8 @@ var ( tokenAgeObjectives = map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001, 1.0: 0.0} ) -func getPrestoExecutorMetrics(scope promutils.Scope) PrestoExecutorMetrics { - return PrestoExecutorMetrics{ +func getPrestoExecutorMetrics(scope promutils.Scope) ExecutorMetrics { + return ExecutorMetrics{ Scope: scope, ReleaseResourceFailed: labeled.NewCounter("released_resource_failed", "Error releasing allocation token", scope), From 39a3a3c78046a1fc0ef271fd304ba07c35160581 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 11 Mar 2020 17:30:28 -0700 Subject: [PATCH 07/26] linting --- go/tasks/plugins/cmd/mocks/command_client.go | 128 ------------------ .../plugins/presto/client/presto_client.go | 9 +- .../plugins/presto/client/presto_status.go | 25 ++-- go/tasks/plugins/presto/execution_state.go | 11 +- .../plugins/presto/execution_state_test.go | 26 ++-- go/tasks/plugins/presto/executions_cache.go | 9 +- .../plugins/presto/executions_cache_test.go | 7 +- go/tasks/plugins/presto/executor.go | 9 +- go/tasks/plugins/svc/mocks/service_client.go | 127 +++++++++++++++++ .../service_client.go} | 4 +- 10 files changed, 178 insertions(+), 177 deletions(-) delete mode 100644 go/tasks/plugins/cmd/mocks/command_client.go create mode 100644 go/tasks/plugins/svc/mocks/service_client.go rename go/tasks/plugins/{cmd/command_client.go => svc/service_client.go} (88%) diff --git a/go/tasks/plugins/cmd/mocks/command_client.go b/go/tasks/plugins/cmd/mocks/command_client.go deleted file mode 100644 index 46769f674..000000000 --- a/go/tasks/plugins/cmd/mocks/command_client.go +++ /dev/null @@ -1,128 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - command "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" - - mock "github.com/stretchr/testify/mock" -) - -// CommandClient is an autogenerated mock type for the CommandClient type -type CommandClient struct { - mock.Mock -} - -type CommandClient_ExecuteCommand struct { - *mock.Call -} - -func (_m CommandClient_ExecuteCommand) Return(_a0 interface{}, _a1 error) *CommandClient_ExecuteCommand { - return &CommandClient_ExecuteCommand{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *CommandClient) OnExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) *CommandClient_ExecuteCommand { - c := _m.On("ExecuteCommand", ctx, commandStr, extraArgs) - return &CommandClient_ExecuteCommand{Call: c} -} - -func (_m *CommandClient) OnExecuteCommandMatch(matchers ...interface{}) *CommandClient_ExecuteCommand { - c := _m.On("ExecuteCommand", matchers...) - return &CommandClient_ExecuteCommand{Call: c} -} - -// ExecuteCommand provides a mock function with given fields: ctx, commandStr, extraArgs -func (_m *CommandClient) ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) { - ret := _m.Called(ctx, commandStr, extraArgs) - - var r0 interface{} - if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) interface{}); ok { - r0 = rf(ctx, commandStr, extraArgs) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(interface{}) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, interface{}) error); ok { - r1 = rf(ctx, commandStr, extraArgs) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type CommandClient_GetCommandStatus struct { - *mock.Call -} - -func (_m CommandClient_GetCommandStatus) Return(_a0 command.CommandStatus, _a1 error) *CommandClient_GetCommandStatus { - return &CommandClient_GetCommandStatus{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *CommandClient) OnGetCommandStatus(ctx context.Context, commandID string) *CommandClient_GetCommandStatus { - c := _m.On("GetCommandStatus", ctx, commandID) - return &CommandClient_GetCommandStatus{Call: c} -} - -func (_m *CommandClient) OnGetCommandStatusMatch(matchers ...interface{}) *CommandClient_GetCommandStatus { - c := _m.On("GetCommandStatus", matchers...) - return &CommandClient_GetCommandStatus{Call: c} -} - -// GetCommandStatus provides a mock function with given fields: ctx, commandID -func (_m *CommandClient) GetCommandStatus(ctx context.Context, commandID string) (command.CommandStatus, error) { - ret := _m.Called(ctx, commandID) - - var r0 command.CommandStatus - if rf, ok := ret.Get(0).(func(context.Context, string) command.CommandStatus); ok { - r0 = rf(ctx, commandID) - } else { - r0 = ret.Get(0).(command.CommandStatus) - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, commandID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type CommandClient_KillCommand struct { - *mock.Call -} - -func (_m CommandClient_KillCommand) Return(_a0 error) *CommandClient_KillCommand { - return &CommandClient_KillCommand{Call: _m.Call.Return(_a0)} -} - -func (_m *CommandClient) OnKillCommand(ctx context.Context, commandID string) *CommandClient_KillCommand { - c := _m.On("KillCommand", ctx, commandID) - return &CommandClient_KillCommand{Call: c} -} - -func (_m *CommandClient) OnKillCommandMatch(matchers ...interface{}) *CommandClient_KillCommand { - c := _m.On("KillCommand", matchers...) - return &CommandClient_KillCommand{Call: c} -} - -// KillCommand provides a mock function with given fields: ctx, commandID -func (_m *CommandClient) KillCommand(ctx context.Context, commandID string) error { - ret := _m.Called(ctx, commandID) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, commandID) - } else { - r0 = ret.Error(0) - } - - return r0 -} diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index 8d800435b..f01af0805 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -2,10 +2,11 @@ package client import ( "context" - "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" "net/http" "net/url" + "github.com/lyft/flyteplugins/go/tasks/plugins/svc" + "time" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" @@ -37,7 +38,7 @@ type PrestoExecuteArgs struct { } type PrestoExecuteResponse struct { ID string - Status cmd.CommandStatus + Status svc.CommandStatus NextURI string } @@ -53,11 +54,11 @@ func (p *prestoClient) KillCommand(ctx context.Context, commandID string) error return nil } -func (p *prestoClient) GetCommandStatus(ctx context.Context, commandID string) (cmd.CommandStatus, error) { +func (p *prestoClient) GetCommandStatus(ctx context.Context, commandID string) (svc.CommandStatus, error) { return NewPrestoStatus(ctx, "UNKNOWN"), nil } -func NewPrestoClient(cfg *config.Config) cmd.CommandClient { +func NewPrestoClient(cfg *config.Config) svc.ServiceClient { return &prestoClient{ client: &http.Client{Timeout: httpRequestTimeoutSecs * time.Second}, environment: cfg.Environment.ResolveReference(&cfg.Environment.URL), diff --git a/go/tasks/plugins/presto/client/presto_status.go b/go/tasks/plugins/presto/client/presto_status.go index 3153c6fc7..78eb6957f 100644 --- a/go/tasks/plugins/presto/client/presto_status.go +++ b/go/tasks/plugins/presto/client/presto_status.go @@ -2,23 +2,24 @@ package client import ( "context" - "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" - "github.com/lyft/flytestdlib/logger" "strings" + + "github.com/lyft/flyteplugins/go/tasks/plugins/svc" + "github.com/lyft/flytestdlib/logger" ) // This type is meant only to encapsulate the response coming from Presto as a type, it is // not meant to be stored locally. const ( - PrestoStatusUnknown cmd.CommandStatus = "UNKNOWN" - PrestoStatusQueued cmd.CommandStatus = "QUEUED" - PrestoStatusRunning cmd.CommandStatus = "RUNNING" - PrestoStatusFinished cmd.CommandStatus = "FINISHED" - PrestoStatusFailed cmd.CommandStatus = "FAILED" - PrestoStatusCancelled cmd.CommandStatus = "CANCELLED" + PrestoStatusUnknown svc.CommandStatus = "UNKNOWN" + PrestoStatusQueued svc.CommandStatus = "QUEUED" + PrestoStatusRunning svc.CommandStatus = "RUNNING" + PrestoStatusFinished svc.CommandStatus = "FINISHED" + PrestoStatusFailed svc.CommandStatus = "FAILED" + PrestoStatusCancelled svc.CommandStatus = "CANCELLED" ) -var PrestoStatuses = map[cmd.CommandStatus]struct{}{ +var PrestoStatuses = map[svc.CommandStatus]struct{}{ PrestoStatusUnknown: {}, PrestoStatusQueued: {}, PrestoStatusRunning: {}, @@ -27,12 +28,12 @@ var PrestoStatuses = map[cmd.CommandStatus]struct{}{ PrestoStatusCancelled: {}, } -func NewPrestoStatus(ctx context.Context, state string) cmd.CommandStatus { +func NewPrestoStatus(ctx context.Context, state string) svc.CommandStatus { upperCased := strings.ToUpper(state) if strings.Contains(upperCased, "FAILED") { return PrestoStatusFailed - } else if _, ok := PrestoStatuses[cmd.CommandStatus(upperCased)]; ok { - return cmd.CommandStatus(upperCased) + } else if _, ok := PrestoStatuses[svc.CommandStatus(upperCased)]; ok { + return svc.CommandStatus(upperCased) } else { logger.Warnf(ctx, "Invalid Presto Status found: %v", state) return PrestoStatusUnknown diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index a7d58ae4c..9a4802b66 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -4,9 +4,10 @@ import ( "context" "crypto/rand" "fmt" - "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" "strings" + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + "time" "github.com/lyft/flytestdlib/cache" @@ -21,7 +22,7 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" "github.com/lyft/flytestdlib/logger" - "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" + "github.com/lyft/flyteplugins/go/tasks/plugins/svc" ) type ExecutionPhase int @@ -85,7 +86,7 @@ func HandleExecutionState( ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, - prestoClient cmd.CommandClient, + prestoClient svc.ServiceClient, executionsCache cache.AutoRefresh, metrics ExecutorMetrics) (ExecutionState, error) { @@ -365,7 +366,7 @@ func KickOffQuery( ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, - prestoClient cmd.CommandClient, + prestoClient svc.ServiceClient, cache cache.AutoRefresh) (ExecutionState, error) { uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() @@ -488,7 +489,7 @@ func ConstructTaskLog(e ExecutionState) *idlCore.TaskLog { } } -func Abort(ctx context.Context, currentState ExecutionState, client cmd.CommandClient) error { +func Abort(ctx context.Context, currentState ExecutionState, client svc.ServiceClient) error { // Cancel Presto query if non-terminal state if !InTerminalState(currentState) && currentState.CommandID != "" { err := client.KillCommand(ctx, currentState.CommandID) diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index 877a29fd7..8beab5252 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -2,17 +2,18 @@ package presto import ( "context" + "net/url" + "testing" + "time" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/cmd/mocks" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/svc/mocks" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" mocks2 "github.com/lyft/flytestdlib/cache/mocks" "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" - "net/url" - "testing" - "time" "github.com/lyft/flyteidl/gen/pb-go/flyteidl/plugins" @@ -246,13 +247,8 @@ func TestAbort(t *testing.T) { t.Run("Terminate called when not in terminal state", func(t *testing.T) { var x = false - //mockQubole := &quboleMocks.QuboleClient{} - mockPresto := &prestoMocks.CommandClient{} - //mockQubole.On("KillCommand", mock.Anything, mock.MatchedBy(func(commandId string) bool { - // return commandId == "123456" - //}), mock.Anything).Run(func(_ mock.Arguments) { - // x = true - //}).Return(nil) + + mockPresto := &prestoMocks.ServiceClient{} mockPresto.On("KillCommand", mock.Anything, mock.MatchedBy(func(commandId string) bool { return commandId == "123456" }), mock.Anything).Run(func(_ mock.Arguments) { @@ -266,13 +262,13 @@ func TestAbort(t *testing.T) { t.Run("Terminate not called when in terminal state", func(t *testing.T) { var x = false - //mockQubole := &quboleMocks.QuboleClient{} - mockPresto := &prestoMocks.CommandClient{} + + mockPresto := &prestoMocks.ServiceClient{} mockPresto.On("KillCommand", mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { x = true }).Return(nil) - err := Abort(ctx, ExecutionState{Phase: PhaseQuerySucceeded, CommandID: "123456",}, mockPresto) + err := Abort(ctx, ExecutionState{Phase: PhaseQuerySucceeded, CommandID: "123456"}, mockPresto) assert.NoError(t, err) assert.False(t, x) }) @@ -326,7 +322,7 @@ func TestKickOffQuery(t *testing.T) { ID: "1234567", Status: client.PrestoStatusQueued, } - mockPresto := &prestoMocks.CommandClient{} + mockPresto := &prestoMocks.ServiceClient{} mockPresto.OnExecuteCommandMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { prestoCalled = true diff --git a/go/tasks/plugins/presto/executions_cache.go b/go/tasks/plugins/presto/executions_cache.go index 3126df1c3..2190040c4 100644 --- a/go/tasks/plugins/presto/executions_cache.go +++ b/go/tasks/plugins/presto/executions_cache.go @@ -2,9 +2,10 @@ package presto import ( "context" - "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" "time" + "github.com/lyft/flyteplugins/go/tasks/plugins/svc" + "k8s.io/client-go/util/workqueue" "github.com/lyft/flytestdlib/cache" @@ -27,14 +28,14 @@ const ( type ExecutionsCache struct { cache.AutoRefresh - prestoClient cmd.CommandClient + prestoClient svc.ServiceClient scope promutils.Scope cfg *config.Config } func NewPrestoExecutionsCache( ctx context.Context, - prestoClient cmd.CommandClient, + prestoClient svc.ServiceClient, cfg *config.Config, scope promutils.Scope) (ExecutionsCache, error) { @@ -145,7 +146,7 @@ func (p *ExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache.Batch } // We need some way to translate results we get from Presto, into a plugin phase -func StatusToExecutionPhase(s cmd.CommandStatus) (ExecutionPhase, error) { +func StatusToExecutionPhase(s svc.CommandStatus) (ExecutionPhase, error) { switch s { case client.PrestoStatusFinished: return PhaseQuerySucceeded, nil diff --git a/go/tasks/plugins/presto/executions_cache_test.go b/go/tasks/plugins/presto/executions_cache_test.go index aeb61d664..71c6c83ca 100644 --- a/go/tasks/plugins/presto/executions_cache_test.go +++ b/go/tasks/plugins/presto/executions_cache_test.go @@ -2,9 +2,10 @@ package presto import ( "context" - prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/cmd/mocks" "testing" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/svc/mocks" + "github.com/lyft/flytestdlib/cache" cacheMocks "github.com/lyft/flytestdlib/cache/mocks" @@ -22,7 +23,7 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { t.Run("terminal state return unchanged", func(t *testing.T) { mockCache := &cacheMocks.AutoRefresh{} - mockPresto := &prestoMocks.CommandClient{} + mockPresto := &prestoMocks.ServiceClient{} testScope := promutils.NewTestScope() p := ExecutionsCache{ @@ -52,7 +53,7 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { t.Run("move to success", func(t *testing.T) { mockCache := &cacheMocks.AutoRefresh{} - mockPresto := &prestoMocks.CommandClient{} + mockPresto := &prestoMocks.ServiceClient{} mockSecretManager := &mocks.SecretManager{} mockSecretManager.OnGetMatch(mock.Anything, mock.Anything).Return("fake key", nil) diff --git a/go/tasks/plugins/presto/executor.go b/go/tasks/plugins/presto/executor.go index 38ad487af..cdb6fcf30 100644 --- a/go/tasks/plugins/presto/executor.go +++ b/go/tasks/plugins/presto/executor.go @@ -2,7 +2,8 @@ package presto import ( "context" - "github.com/lyft/flyteplugins/go/tasks/plugins/cmd" + + "github.com/lyft/flyteplugins/go/tasks/plugins/svc" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" "github.com/lyft/flytestdlib/cache" @@ -27,7 +28,7 @@ const prestoTaskType = "presto" // This needs to match the type defined in Flyte type Executor struct { id string metrics ExecutorMetrics - prestoClient cmd.CommandClient + prestoClient svc.ServiceClient executionsCache cache.AutoRefresh cfg *config.Config } @@ -111,7 +112,7 @@ func InitializePrestoExecutor( iCtx core.SetupContext, cfg *config.Config, resourceConfig map[string]int, - prestoClient cmd.CommandClient) (core.Plugin, error) { + prestoClient svc.ServiceClient) (core.Plugin, error) { logger.Infof(ctx, "Initializing a Presto executor with a resource config [%v]", resourceConfig) q, err := NewPrestoExecutor(ctx, cfg, prestoClient, iCtx.MetricsScope()) if err != nil { @@ -133,7 +134,7 @@ func InitializePrestoExecutor( func NewPrestoExecutor( ctx context.Context, cfg *config.Config, - prestoClient cmd.CommandClient, + prestoClient svc.ServiceClient, scope promutils.Scope) (Executor, error) { executionsAutoRefreshCache, err := NewPrestoExecutionsCache(ctx, prestoClient, cfg, scope.NewSubScope(prestoTaskType)) if err != nil { diff --git a/go/tasks/plugins/svc/mocks/service_client.go b/go/tasks/plugins/svc/mocks/service_client.go new file mode 100644 index 000000000..19d351038 --- /dev/null +++ b/go/tasks/plugins/svc/mocks/service_client.go @@ -0,0 +1,127 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + svc "github.com/lyft/flyteplugins/go/tasks/plugins/svc" + mock "github.com/stretchr/testify/mock" +) + +// ServiceClient is an autogenerated mock type for the ServiceClient type +type ServiceClient struct { + mock.Mock +} + +type ServiceClient_ExecuteCommand struct { + *mock.Call +} + +func (_m ServiceClient_ExecuteCommand) Return(_a0 interface{}, _a1 error) *ServiceClient_ExecuteCommand { + return &ServiceClient_ExecuteCommand{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ServiceClient) OnExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) *ServiceClient_ExecuteCommand { + c := _m.On("ExecuteCommand", ctx, commandStr, extraArgs) + return &ServiceClient_ExecuteCommand{Call: c} +} + +func (_m *ServiceClient) OnExecuteCommandMatch(matchers ...interface{}) *ServiceClient_ExecuteCommand { + c := _m.On("ExecuteCommand", matchers...) + return &ServiceClient_ExecuteCommand{Call: c} +} + +// ExecuteCommand provides a mock function with given fields: ctx, commandStr, extraArgs +func (_m *ServiceClient) ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) { + ret := _m.Called(ctx, commandStr, extraArgs) + + var r0 interface{} + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) interface{}); ok { + r0 = rf(ctx, commandStr, extraArgs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, interface{}) error); ok { + r1 = rf(ctx, commandStr, extraArgs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ServiceClient_GetCommandStatus struct { + *mock.Call +} + +func (_m ServiceClient_GetCommandStatus) Return(_a0 svc.CommandStatus, _a1 error) *ServiceClient_GetCommandStatus { + return &ServiceClient_GetCommandStatus{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *ServiceClient) OnGetCommandStatus(ctx context.Context, commandID string) *ServiceClient_GetCommandStatus { + c := _m.On("GetCommandStatus", ctx, commandID) + return &ServiceClient_GetCommandStatus{Call: c} +} + +func (_m *ServiceClient) OnGetCommandStatusMatch(matchers ...interface{}) *ServiceClient_GetCommandStatus { + c := _m.On("GetCommandStatus", matchers...) + return &ServiceClient_GetCommandStatus{Call: c} +} + +// GetCommandStatus provides a mock function with given fields: ctx, commandID +func (_m *ServiceClient) GetCommandStatus(ctx context.Context, commandID string) (svc.CommandStatus, error) { + ret := _m.Called(ctx, commandID) + + var r0 svc.CommandStatus + if rf, ok := ret.Get(0).(func(context.Context, string) svc.CommandStatus); ok { + r0 = rf(ctx, commandID) + } else { + r0 = ret.Get(0).(svc.CommandStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, commandID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type ServiceClient_KillCommand struct { + *mock.Call +} + +func (_m ServiceClient_KillCommand) Return(_a0 error) *ServiceClient_KillCommand { + return &ServiceClient_KillCommand{Call: _m.Call.Return(_a0)} +} + +func (_m *ServiceClient) OnKillCommand(ctx context.Context, commandID string) *ServiceClient_KillCommand { + c := _m.On("KillCommand", ctx, commandID) + return &ServiceClient_KillCommand{Call: c} +} + +func (_m *ServiceClient) OnKillCommandMatch(matchers ...interface{}) *ServiceClient_KillCommand { + c := _m.On("KillCommand", matchers...) + return &ServiceClient_KillCommand{Call: c} +} + +// KillCommand provides a mock function with given fields: ctx, commandID +func (_m *ServiceClient) KillCommand(ctx context.Context, commandID string) error { + ret := _m.Called(ctx, commandID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, commandID) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/go/tasks/plugins/cmd/command_client.go b/go/tasks/plugins/svc/service_client.go similarity index 88% rename from go/tasks/plugins/cmd/command_client.go rename to go/tasks/plugins/svc/service_client.go index 7c9ab3bc6..f99d7e660 100644 --- a/go/tasks/plugins/cmd/command_client.go +++ b/go/tasks/plugins/svc/service_client.go @@ -1,4 +1,4 @@ -package cmd +package svc import ( "context" @@ -8,7 +8,7 @@ type CommandStatus string //go:generate mockery -all -case=snake -type CommandClient interface { +type ServiceClient interface { ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) KillCommand(ctx context.Context, commandID string) error GetCommandStatus(ctx context.Context, commandID string) (CommandStatus, error) From 237922122cf0bba8c0d4be044f622a59993f2ebe Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 11 Mar 2020 17:43:48 -0700 Subject: [PATCH 08/26] last linting? --- go/tasks/plugins/presto/execution_state.go | 10 ++++------ go/tasks/plugins/presto/execution_state_test.go | 2 +- go/tasks/plugins/presto/executor.go | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 9a4802b66..939537227 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -295,7 +295,7 @@ func GetNextQuery( externalLocation := getExternalLocation(cfg.AwsS3ShardFormatter, cfg.AwsS3ShardCount) statement := fmt.Sprintf(` -CREATE TABLE hive.flyte_temporary_tables.%s (LIKE hive.flyte_temporary_tables.%s)" +CREATE TABLE hive.flyte_temporary_tables.%s (LIKE hive.flyte_temporary_tables.%s) WITH (format = 'PARQUET', external_location = '%s')`, currentState.CurrentPrestoQuery.ExternalTableName, currentState.CurrentPrestoQuery.TempTableName, @@ -305,13 +305,11 @@ WITH (format = 'PARQUET', external_location = '%s')`, return currentState.CurrentPrestoQuery, nil case 2: - statement := fmt.Sprintf(` + statement := ` INSERT INTO hive.flyte_temporary_tables.%s SELECT * -FROM hive.flyte_temporary_tables.%s`, - currentState.CurrentPrestoQuery.ExternalTableName, - currentState.CurrentPrestoQuery.TempTableName, - ) +FROM hive.flyte_temporary_tables.%s` + statement = fmt.Sprintf(statement, currentState.CurrentPrestoQuery.ExternalTableName, currentState.CurrentPrestoQuery.TempTableName) currentState.CurrentPrestoQuery.Statement = statement return currentState.CurrentPrestoQuery, nil diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index 8beab5252..8db0beb76 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -7,9 +7,9 @@ import ( "time" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" - prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/svc/mocks" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/svc/mocks" mocks2 "github.com/lyft/flytestdlib/cache/mocks" "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/promutils" diff --git a/go/tasks/plugins/presto/executor.go b/go/tasks/plugins/presto/executor.go index cdb6fcf30..4d2a54e54 100644 --- a/go/tasks/plugins/presto/executor.go +++ b/go/tasks/plugins/presto/executor.go @@ -3,8 +3,8 @@ package presto import ( "context" - "github.com/lyft/flyteplugins/go/tasks/plugins/svc" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + "github.com/lyft/flyteplugins/go/tasks/plugins/svc" "github.com/lyft/flytestdlib/cache" From c87dbd5da3e0146e8fa13cf7500eff26978af1f8 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 11 Mar 2020 18:12:19 -0700 Subject: [PATCH 09/26] minor changes --- go/tasks/plugins/presto/config/config.go | 2 +- go/tasks/plugins/presto/execution_state.go | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/go/tasks/plugins/presto/config/config.go b/go/tasks/plugins/presto/config/config.go index 5190ebee3..a0d550fea 100644 --- a/go/tasks/plugins/presto/config/config.go +++ b/go/tasks/plugins/presto/config/config.go @@ -40,7 +40,7 @@ var ( LruCacheSize: 2000, AwsS3ShardFormatter: "s3://lyft-modelbuilder/{}/", AwsS3ShardCount: 2, - RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 250}}, + RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, } prestoConfigSection = pluginsConfig.MustRegisterSubSection(prestoConfigSectionKey, &defaultConfig) diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 939537227..07307b78b 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -446,7 +446,6 @@ func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { case PhaseNotStarted: phaseInfo = core.PhaseInfoNotReady(t, core.DefaultPhaseVersion, "Haven't received allocation token") case PhaseQueued: - // TODO: Turn into config if state.CreationFailureCount > 5 { phaseInfo = core.PhaseInfoRetryableFailure("PrestoFailure", "Too many creation attempts", nil) } else { From 7ba7cc7bf844836f559623015c072310b410f220 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 11 Mar 2020 18:16:31 -0700 Subject: [PATCH 10/26] expanded comment on state machine logic --- go/tasks/plugins/presto/execution_state.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 07307b78b..77fd4befc 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -110,8 +110,9 @@ func HandleExecutionState( case PhaseQuerySucceeded: if currentState.QueryCount < 4 { - // If there are still Presto statements to execute, increment the query count, reset the phase to get a new - // allocation token, and continue executing the remaining statements + // If there are still Presto statements to execute, increment the query count, reset the phase to 'queued' + // and continue executing the remaining statements. In this case, we won't request another allocation token + // as the 5 statements that get executed are all considered to be part of the same "query" currentState.QueryCount++ currentState.Phase = PhaseQueued } From 4baecf2ab706e4919f213f6d232a699de911b668 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Thu, 12 Mar 2020 16:48:01 -0700 Subject: [PATCH 11/26] PR feedback changes --- go/tasks/pluginmachinery/core/phase.go | 4 + .../presto/client/mocks/presto_client.go | 128 ++++++++++++++++++ .../plugins/presto/client/presto_client.go | 35 +++-- .../plugins/presto/client/presto_status.go | 23 ++-- go/tasks/plugins/presto/config/config.go | 33 +++-- .../plugins/presto/config/config_flags.go | 11 +- .../presto/config/config_flags_test.go | 122 +++++++++++++---- go/tasks/plugins/presto/execution_state.go | 67 +++------ .../plugins/presto/execution_state_test.go | 22 +-- go/tasks/plugins/presto/executions_cache.go | 10 +- .../plugins/presto/executions_cache_test.go | 6 +- go/tasks/plugins/presto/executor.go | 9 +- .../{test_helpers.go => helpers_test.go} | 0 go/tasks/plugins/svc/mocks/service_client.go | 127 ----------------- go/tasks/plugins/svc/service_client.go | 15 -- 15 files changed, 331 insertions(+), 281 deletions(-) create mode 100644 go/tasks/plugins/presto/client/mocks/presto_client.go rename go/tasks/plugins/presto/{test_helpers.go => helpers_test.go} (100%) delete mode 100644 go/tasks/plugins/svc/mocks/service_client.go delete mode 100644 go/tasks/plugins/svc/service_client.go diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index 470cf1368..ca8fad8fe 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -180,6 +180,10 @@ func PhaseInfoSuccess(info *TaskInfo) PhaseInfo { return phaseInfo(PhaseSuccess, DefaultPhaseVersion, nil, info) } +func PhaseInfoSuccessWithVersion(version uint32, info *TaskInfo) PhaseInfo { + return phaseInfo(PhaseSuccess, version, nil, info) +} + func PhaseInfoFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason}, info) } diff --git a/go/tasks/plugins/presto/client/mocks/presto_client.go b/go/tasks/plugins/presto/client/mocks/presto_client.go new file mode 100644 index 000000000..3494940e5 --- /dev/null +++ b/go/tasks/plugins/presto/client/mocks/presto_client.go @@ -0,0 +1,128 @@ +// Code generated by mockery v1.0.1. DO NOT EDIT. + +package mocks + +import ( + context "context" + + client "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + + mock "github.com/stretchr/testify/mock" +) + +// PrestoClient is an autogenerated mock type for the PrestoClient type +type PrestoClient struct { + mock.Mock +} + +type PrestoClient_ExecuteCommand struct { + *mock.Call +} + +func (_m PrestoClient_ExecuteCommand) Return(_a0 interface{}, _a1 error) *PrestoClient_ExecuteCommand { + return &PrestoClient_ExecuteCommand{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *PrestoClient) OnExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) *PrestoClient_ExecuteCommand { + c := _m.On("ExecuteCommand", ctx, commandStr, extraArgs) + return &PrestoClient_ExecuteCommand{Call: c} +} + +func (_m *PrestoClient) OnExecuteCommandMatch(matchers ...interface{}) *PrestoClient_ExecuteCommand { + c := _m.On("ExecuteCommand", matchers...) + return &PrestoClient_ExecuteCommand{Call: c} +} + +// ExecuteCommand provides a mock function with given fields: ctx, commandStr, extraArgs +func (_m *PrestoClient) ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) { + ret := _m.Called(ctx, commandStr, extraArgs) + + var r0 interface{} + if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) interface{}); ok { + r0 = rf(ctx, commandStr, extraArgs) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, interface{}) error); ok { + r1 = rf(ctx, commandStr, extraArgs) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type PrestoClient_GetCommandStatus struct { + *mock.Call +} + +func (_m PrestoClient_GetCommandStatus) Return(_a0 client.PrestoStatus, _a1 error) *PrestoClient_GetCommandStatus { + return &PrestoClient_GetCommandStatus{Call: _m.Call.Return(_a0, _a1)} +} + +func (_m *PrestoClient) OnGetCommandStatus(ctx context.Context, commandID string) *PrestoClient_GetCommandStatus { + c := _m.On("GetCommandStatus", ctx, commandID) + return &PrestoClient_GetCommandStatus{Call: c} +} + +func (_m *PrestoClient) OnGetCommandStatusMatch(matchers ...interface{}) *PrestoClient_GetCommandStatus { + c := _m.On("GetCommandStatus", matchers...) + return &PrestoClient_GetCommandStatus{Call: c} +} + +// GetCommandStatus provides a mock function with given fields: ctx, commandID +func (_m *PrestoClient) GetCommandStatus(ctx context.Context, commandID string) (client.PrestoStatus, error) { + ret := _m.Called(ctx, commandID) + + var r0 client.PrestoStatus + if rf, ok := ret.Get(0).(func(context.Context, string) client.PrestoStatus); ok { + r0 = rf(ctx, commandID) + } else { + r0 = ret.Get(0).(client.PrestoStatus) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, commandID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type PrestoClient_KillCommand struct { + *mock.Call +} + +func (_m PrestoClient_KillCommand) Return(_a0 error) *PrestoClient_KillCommand { + return &PrestoClient_KillCommand{Call: _m.Call.Return(_a0)} +} + +func (_m *PrestoClient) OnKillCommand(ctx context.Context, commandID string) *PrestoClient_KillCommand { + c := _m.On("KillCommand", ctx, commandID) + return &PrestoClient_KillCommand{Call: c} +} + +func (_m *PrestoClient) OnKillCommandMatch(matchers ...interface{}) *PrestoClient_KillCommand { + c := _m.On("KillCommand", matchers...) + return &PrestoClient_KillCommand{Call: c} +} + +// KillCommand provides a mock function with given fields: ctx, commandID +func (_m *PrestoClient) KillCommand(ctx context.Context, commandID string) error { + ret := _m.Called(ctx, commandID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, commandID) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index f01af0805..f1f29c538 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -5,8 +5,6 @@ import ( "net/http" "net/url" - "github.com/lyft/flyteplugins/go/tasks/plugins/svc" - "time" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" @@ -14,35 +12,34 @@ import ( const ( httpRequestTimeoutSecs = 30 - //AcceptHeaderKey = "Accept" - //ContentTypeHeaderKey = "Content-Type" - //ContentTypeJSON = "application/json" - //ContentTypeTextPlain = "text/plain" - //PrestoCatalogHeader = "X-Presto-Catalog" - //PrestoRoutingGroupHeader = "X-Presto-Routing-Group" - //PrestoSchemaHeader = "X-Presto-Schema" - //PrestoSourceHeader = "X-Presto-Source" - //PrestoUserHeader = "X-Presto-User" ) -type prestoClient struct { +type noopPrestoClient struct { client *http.Client environment *url.URL } type PrestoExecuteArgs struct { - RoutingGroup string `json:"routing_group,omitempty"` + RoutingGroup string `json:"routingGroup,omitempty"` Catalog string `json:"catalog,omitempty"` Schema string `json:"schema,omitempty"` Source string `json:"source,omitempty"` } type PrestoExecuteResponse struct { ID string - Status svc.CommandStatus + Status PrestoStatus NextURI string } -func (p *prestoClient) ExecuteCommand( +//go:generate mockery -all -case=snake + +type PrestoClient interface { + ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) + KillCommand(ctx context.Context, commandID string) error + GetCommandStatus(ctx context.Context, commandID string) (PrestoStatus, error) +} + +func (p noopPrestoClient) ExecuteCommand( ctx context.Context, queryStr string, extraArgs interface{}) (interface{}, error) { @@ -50,16 +47,16 @@ func (p *prestoClient) ExecuteCommand( return PrestoExecuteResponse{}, nil } -func (p *prestoClient) KillCommand(ctx context.Context, commandID string) error { +func (p noopPrestoClient) KillCommand(ctx context.Context, commandID string) error { return nil } -func (p *prestoClient) GetCommandStatus(ctx context.Context, commandID string) (svc.CommandStatus, error) { +func (p noopPrestoClient) GetCommandStatus(ctx context.Context, commandID string) (PrestoStatus, error) { return NewPrestoStatus(ctx, "UNKNOWN"), nil } -func NewPrestoClient(cfg *config.Config) svc.ServiceClient { - return &prestoClient{ +func NewNoopPrestoClient(cfg *config.Config) PrestoClient { + return &noopPrestoClient{ client: &http.Client{Timeout: httpRequestTimeoutSecs * time.Second}, environment: cfg.Environment.ResolveReference(&cfg.Environment.URL), } diff --git a/go/tasks/plugins/presto/client/presto_status.go b/go/tasks/plugins/presto/client/presto_status.go index 78eb6957f..db09afab2 100644 --- a/go/tasks/plugins/presto/client/presto_status.go +++ b/go/tasks/plugins/presto/client/presto_status.go @@ -4,22 +4,23 @@ import ( "context" "strings" - "github.com/lyft/flyteplugins/go/tasks/plugins/svc" "github.com/lyft/flytestdlib/logger" ) +type PrestoStatus string + // This type is meant only to encapsulate the response coming from Presto as a type, it is // not meant to be stored locally. const ( - PrestoStatusUnknown svc.CommandStatus = "UNKNOWN" - PrestoStatusQueued svc.CommandStatus = "QUEUED" - PrestoStatusRunning svc.CommandStatus = "RUNNING" - PrestoStatusFinished svc.CommandStatus = "FINISHED" - PrestoStatusFailed svc.CommandStatus = "FAILED" - PrestoStatusCancelled svc.CommandStatus = "CANCELLED" + PrestoStatusUnknown PrestoStatus = "UNKNOWN" + PrestoStatusQueued PrestoStatus = "QUEUED" + PrestoStatusRunning PrestoStatus = "RUNNING" + PrestoStatusFinished PrestoStatus = "FINISHED" + PrestoStatusFailed PrestoStatus = "FAILED" + PrestoStatusCancelled PrestoStatus = "CANCELLED" ) -var PrestoStatuses = map[svc.CommandStatus]struct{}{ +var PrestoStatuses = map[PrestoStatus]struct{}{ PrestoStatusUnknown: {}, PrestoStatusQueued: {}, PrestoStatusRunning: {}, @@ -28,12 +29,12 @@ var PrestoStatuses = map[svc.CommandStatus]struct{}{ PrestoStatusCancelled: {}, } -func NewPrestoStatus(ctx context.Context, state string) svc.CommandStatus { +func NewPrestoStatus(ctx context.Context, state string) PrestoStatus { upperCased := strings.ToUpper(state) if strings.Contains(upperCased, "FAILED") { return PrestoStatusFailed - } else if _, ok := PrestoStatuses[svc.CommandStatus(upperCased)]; ok { - return svc.CommandStatus(upperCased) + } else if _, ok := PrestoStatuses[PrestoStatus(upperCased)]; ok { + return PrestoStatus(upperCased) } else { logger.Warnf(ctx, "Invalid Presto Status found: %v", state) return PrestoStatusUnknown diff --git a/go/tasks/plugins/presto/config/config.go b/go/tasks/plugins/presto/config/config.go index a0d550fea..b91eb22ec 100644 --- a/go/tasks/plugins/presto/config/config.go +++ b/go/tasks/plugins/presto/config/config.go @@ -5,6 +5,7 @@ package config import ( "context" "net/url" + "time" "github.com/lyft/flytestdlib/config" "github.com/lyft/flytestdlib/logger" @@ -26,21 +27,34 @@ func URLMustParse(s string) config.URL { } type RoutingGroupConfig struct { - Name string `json:"primaryLabel" pflag:",The name of a given Presto routing group"` + Name string `json:"name" pflag:",The name of a given Presto routing group"` Limit int `json:"limit" pflag:",Resource quota (in the number of outstanding requests) of the routing group"` ProjectScopeQuotaProportionCap float64 `json:"projectScopeQuotaProportionCap" pflag:",A floating point number between 0 and 1, specifying the maximum proportion of quotas allowed to allocate to a project in the routing group"` NamespaceScopeQuotaProportionCap float64 `json:"namespaceScopeQuotaProportionCap" pflag:",A floating point number between 0 and 1, specifying the maximum proportion of quotas allowed to allocate to a namespace in the routing group"` } +type RateLimiter struct { + Name string `json:"name" pflag:",The name of the rate limiter"` + SyncPeriod config.Duration `json:"syncPeriod" pflag:",The duration to wait before the cache is refreshed again"` + Workers int `json:"workers" pflag:",Number of parallel workers to refresh the cache"` + LruCacheSize int `json:"lruCacheSize" pflag:",Size of the cache"` + MetricScope string `json:"metricScope" pflag:",The prefix in Prometheus used to track metrics related to Presto"` +} + var ( defaultConfig = Config{ - Environment: URLMustParse("https://prestoproxy-internal.lyft.net:443"), + Environment: URLMustParse(""), DefaultRoutingGroup: "adhoc", - Workers: 15, - LruCacheSize: 2000, - AwsS3ShardFormatter: "s3://lyft-modelbuilder/{}/", + AwsS3ShardFormatter: "", AwsS3ShardCount: 2, RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, + RateLimiter: RateLimiter{ + Name: "presto", + SyncPeriod: config.Duration{Duration: 3 * time.Second}, + Workers: 15, + LruCacheSize: 2000, + MetricScope: "presto", + }, } prestoConfigSection = pluginsConfig.MustRegisterSubSection(prestoConfigSectionKey, &defaultConfig) @@ -48,13 +62,12 @@ var ( // Presto plugin configs type Config struct { - Environment config.URL `json:"endpoint" pflag:",Endpoint for Presto to use"` + Environment config.URL `json:"environment" pflag:",Environment endpoint for Presto to use"` DefaultRoutingGroup string `json:"defaultRoutingGroup" pflag:",Default Presto routing group"` - Workers int `json:"workers" pflag:",Number of parallel workers to refresh the cache"` - LruCacheSize int `json:"lruCacheSize" pflag:",Size of the AutoRefreshCache"` AwsS3ShardFormatter string `json:"awsS3ShardFormatter" pflag:", S3 bucket prefix where Presto results will be stored"` - AwsS3ShardCount int `json:"awsS3ShardStringLength" pflag:", Number of characters for the S3 bucket shard prefix"` - RoutingGroupConfigs []RoutingGroupConfig `json:"clusterConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` + AwsS3ShardCount int `json:"awsS3ShardCount" pflag:", Number of characters for the S3 bucket shard prefix"` + RoutingGroupConfigs []RoutingGroupConfig `json:"routingGroupConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` + RateLimiter RateLimiter `json:"rateLimiter" pflag:"Rate limiter config"` } // Retrieves the current config value or default. diff --git a/go/tasks/plugins/presto/config/config_flags.go b/go/tasks/plugins/presto/config/config_flags.go index ffabb5b00..c83e30834 100755 --- a/go/tasks/plugins/presto/config/config_flags.go +++ b/go/tasks/plugins/presto/config/config_flags.go @@ -41,11 +41,14 @@ func (Config) mustMarshalJSON(v json.Marshaler) string { // flags is json-name.json-sub-name... etc. func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "endpoint"), defaultConfig.Environment.String(), "Endpoint for Presto to use") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "environment"), defaultConfig.Environment.String(), "Environment endpoint for Presto to use") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultRoutingGroup"), defaultConfig.DefaultRoutingGroup, "Default Presto routing group") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "workers"), defaultConfig.Workers, "Number of parallel workers to refresh the cache") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "lruCacheSize"), defaultConfig.LruCacheSize, "Size of the AutoRefreshCache") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "awsS3ShardFormatter"), defaultConfig.AwsS3ShardFormatter, " S3 bucket prefix where Presto results will be stored") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "awsS3ShardStringLength"), defaultConfig.AwsS3ShardCount, " Number of characters for the S3 bucket shard prefix") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "awsS3ShardCount"), defaultConfig.AwsS3ShardCount, " Number of characters for the S3 bucket shard prefix") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.name"), defaultConfig.RateLimiter.Name, "The name of the rate limiter") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.syncPeriod"), defaultConfig.RateLimiter.SyncPeriod.String(), "The duration to wait before the cache is refreshed again") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "rateLimiter.workers"), defaultConfig.RateLimiter.Workers, "Number of parallel workers to refresh the cache") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "rateLimiter.lruCacheSize"), defaultConfig.RateLimiter.LruCacheSize, "Size of the cache") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.metricScope"), defaultConfig.RateLimiter.MetricScope, "The prefix in Prometheus used to track metrics related to Presto") return cmdFlags } diff --git a/go/tasks/plugins/presto/config/config_flags_test.go b/go/tasks/plugins/presto/config/config_flags_test.go index 6c6cf67a6..41bead819 100755 --- a/go/tasks/plugins/presto/config/config_flags_test.go +++ b/go/tasks/plugins/presto/config/config_flags_test.go @@ -99,10 +99,10 @@ func TestConfig_SetFlags(t *testing.T) { cmdFlags := actual.GetPFlagSet("") assert.True(t, cmdFlags.HasFlags()) - t.Run("Test_endpoint", func(t *testing.T) { + t.Run("Test_environment", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vString, err := cmdFlags.GetString("endpoint"); err == nil { + if vString, err := cmdFlags.GetString("environment"); err == nil { assert.Equal(t, string(defaultConfig.Environment.String()), vString) } else { assert.FailNow(t, err.Error()) @@ -112,8 +112,8 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := defaultConfig.Environment.String() - cmdFlags.Set("endpoint", testValue) - if vString, err := cmdFlags.GetString("endpoint"); err == nil { + cmdFlags.Set("environment", testValue) + if vString, err := cmdFlags.GetString("environment"); err == nil { testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.Environment) } else { @@ -143,11 +143,11 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_workers", func(t *testing.T) { + t.Run("Test_awsS3ShardFormatter", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("workers"); err == nil { - assert.Equal(t, int(defaultConfig.Workers), vInt) + if vString, err := cmdFlags.GetString("awsS3ShardFormatter"); err == nil { + assert.Equal(t, string(defaultConfig.AwsS3ShardFormatter), vString) } else { assert.FailNow(t, err.Error()) } @@ -156,20 +156,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("workers", testValue) - if vInt, err := cmdFlags.GetInt("workers"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.Workers) + cmdFlags.Set("awsS3ShardFormatter", testValue) + if vString, err := cmdFlags.GetString("awsS3ShardFormatter"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AwsS3ShardFormatter) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_lruCacheSize", func(t *testing.T) { + t.Run("Test_awsS3ShardCount", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("lruCacheSize"); err == nil { - assert.Equal(t, int(defaultConfig.LruCacheSize), vInt) + if vInt, err := cmdFlags.GetInt("awsS3ShardCount"); err == nil { + assert.Equal(t, int(defaultConfig.AwsS3ShardCount), vInt) } else { assert.FailNow(t, err.Error()) } @@ -178,20 +178,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("lruCacheSize", testValue) - if vInt, err := cmdFlags.GetInt("lruCacheSize"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.LruCacheSize) + cmdFlags.Set("awsS3ShardCount", testValue) + if vInt, err := cmdFlags.GetInt("awsS3ShardCount"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.AwsS3ShardCount) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_awsS3ShardFormatter", func(t *testing.T) { + t.Run("Test_rateLimiter.name", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vString, err := cmdFlags.GetString("awsS3ShardFormatter"); err == nil { - assert.Equal(t, string(defaultConfig.AwsS3ShardFormatter), vString) + if vString, err := cmdFlags.GetString("rateLimiter.name"); err == nil { + assert.Equal(t, string(defaultConfig.RateLimiter.Name), vString) } else { assert.FailNow(t, err.Error()) } @@ -200,20 +200,42 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("awsS3ShardFormatter", testValue) - if vString, err := cmdFlags.GetString("awsS3ShardFormatter"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AwsS3ShardFormatter) + cmdFlags.Set("rateLimiter.name", testValue) + if vString, err := cmdFlags.GetString("rateLimiter.name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RateLimiter.Name) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_awsS3ShardStringLength", func(t *testing.T) { + t.Run("Test_rateLimiter.syncPeriod", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("awsS3ShardStringLength"); err == nil { - assert.Equal(t, int(defaultConfig.AwsS3ShardCount), vInt) + if vString, err := cmdFlags.GetString("rateLimiter.syncPeriod"); err == nil { + assert.Equal(t, string(defaultConfig.RateLimiter.SyncPeriod.String()), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := defaultConfig.RateLimiter.SyncPeriod.String() + + cmdFlags.Set("rateLimiter.syncPeriod", testValue) + if vString, err := cmdFlags.GetString("rateLimiter.syncPeriod"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RateLimiter.SyncPeriod) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_rateLimiter.workers", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("rateLimiter.workers"); err == nil { + assert.Equal(t, int(defaultConfig.RateLimiter.Workers), vInt) } else { assert.FailNow(t, err.Error()) } @@ -222,9 +244,53 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("awsS3ShardStringLength", testValue) - if vInt, err := cmdFlags.GetInt("awsS3ShardStringLength"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.AwsS3ShardCount) + cmdFlags.Set("rateLimiter.workers", testValue) + if vInt, err := cmdFlags.GetInt("rateLimiter.workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RateLimiter.Workers) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_rateLimiter.lruCacheSize", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vInt, err := cmdFlags.GetInt("rateLimiter.lruCacheSize"); err == nil { + assert.Equal(t, int(defaultConfig.RateLimiter.LruCacheSize), vInt) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("rateLimiter.lruCacheSize", testValue) + if vInt, err := cmdFlags.GetInt("rateLimiter.lruCacheSize"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RateLimiter.LruCacheSize) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_rateLimiter.metricScope", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("rateLimiter.metricScope"); err == nil { + assert.Equal(t, string(defaultConfig.RateLimiter.MetricScope), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("rateLimiter.metricScope", testValue) + if vString, err := cmdFlags.GetString("rateLimiter.metricScope"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RateLimiter.MetricScope) } else { assert.FailNow(t, err.Error()) diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 77fd4befc..5bfc65ed6 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -2,7 +2,9 @@ package presto import ( "context" - "crypto/rand" + + "k8s.io/apimachinery/pkg/util/rand" + "fmt" "strings" @@ -21,8 +23,6 @@ import ( "github.com/lyft/flyteplugins/go/tasks/errors" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" "github.com/lyft/flytestdlib/logger" - - "github.com/lyft/flyteplugins/go/tasks/plugins/svc" ) type ExecutionPhase int @@ -55,30 +55,30 @@ type ExecutionState struct { Phase ExecutionPhase // This will store the command ID from Presto - CommandID string `json:"command_id,omitempty"` + CommandID string `json:"commandId,omitempty"` URI string `json:"uri,omitempty"` - CurrentPrestoQuery Query `json:"current_presto_query,omitempty"` - QueryCount int `json:"query_count,omitempty"` + CurrentPrestoQuery Query `json:"currentPrestoQuery,omitempty"` + QueryCount int `json:"queryCount,omitempty"` // This number keeps track of the number of failures within the sync function. Without this, what happens in // the sync function is entirely opaque. Note that this field is completely orthogonal to Flyte system/node/task // level retries, just errors from hitting the Presto API, inside the sync loop - SyncFailureCount int `json:"sync_failure_count,omitempty"` + SyncFailureCount int `json:"syncFailureCount,omitempty"` // In kicking off the Presto command, this is the number of failures - CreationFailureCount int `json:"creation_failure_count,omitempty"` + CreationFailureCount int `json:"creationFailureCount,omitempty"` // The time the execution first requests for an allocation token - AllocationTokenRequestStartTime time.Time `json:"allocation_token_request_start_time,omitempty"` + AllocationTokenRequestStartTime time.Time `json:"allocationTokenRequestStartTime,omitempty"` } type Query struct { Statement string `json:"statement,omitempty"` - ExtraArgs client.PrestoExecuteArgs `json:"extra_args,omitempty"` - TempTableName string `json:"temp_table_name,omitempty"` - ExternalTableName string `json:"external_table_name,omitempty"` - ExternalLocation string `json:"external_location"` + ExtraArgs client.PrestoExecuteArgs `json:"extraArgs,omitempty"` + TempTableName string `json:"tempTableName,omitempty"` + ExternalTableName string `json:"externalTableName,omitempty"` + ExternalLocation string `json:"externalLocation"` } // This is the main state iteration @@ -86,7 +86,7 @@ func HandleExecutionState( ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, - prestoClient svc.ServiceClient, + prestoClient client.PrestoClient, executionsCache cache.AutoRefresh, metrics ExecutorMetrics) (ExecutionState, error) { @@ -270,7 +270,7 @@ func GetNextQuery( switch currentState.QueryCount { case 0: prestoCfg := config.GetPrestoConfig() - tempTableName := generateRandomString(32) + tempTableName := rand.String(32) routingGroup, catalog, schema, statement, err := GetQueryInfo(ctx, tCtx) if err != nil { return Query{}, err @@ -329,43 +329,20 @@ FROM hive.flyte_temporary_tables.%s` } } -func generateRandomString(length int) string { - const letters = "0123456789abcdefghijklmnopqrstuvwxyz" - bytes, err := generateRandomBytes(length) - if err != nil { - return "" - } - for i, b := range bytes { - bytes[i] = letters[b%byte(len(letters))] - } - return string(bytes) -} - -func generateRandomBytes(length int) ([]byte, error) { - b := make([]byte, length) - _, err := rand.Read(b) - // Note that err == nil only if we read len(b) bytes. - if err != nil { - return nil, err - } - - return b, nil -} - func getExternalLocation(shardFormatter string, shardLength int) string { shardCount := strings.Count(shardFormatter, "{}") for i := 0; i < shardCount; i++ { - shardFormatter = strings.Replace(shardFormatter, "{}", generateRandomString(shardLength), 1) + shardFormatter = strings.Replace(shardFormatter, "{}", rand.String(shardLength), 1) } - return shardFormatter + generateRandomString(32) + "/" + return shardFormatter + rand.String(32) + "/" } func KickOffQuery( ctx context.Context, tCtx core.TaskExecutionContext, currentState ExecutionState, - prestoClient svc.ServiceClient, + prestoClient client.PrestoClient, cache cache.AutoRefresh) (ExecutionState, error) { uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() @@ -450,13 +427,13 @@ func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { if state.CreationFailureCount > 5 { phaseInfo = core.PhaseInfoRetryableFailure("PrestoFailure", "Too many creation attempts", nil) } else { - phaseInfo = core.PhaseInfoQueued(t, uint32(state.CreationFailureCount), "Waiting for Presto launch") + phaseInfo = core.PhaseInfoQueued(t, uint32(state.QueryCount), "Waiting for Presto launch") } case PhaseSubmitted: - phaseInfo = core.PhaseInfoRunning(core.DefaultPhaseVersion, ConstructTaskInfo(state)) + phaseInfo = core.PhaseInfoRunning(uint32(state.QueryCount), ConstructTaskInfo(state)) case PhaseQuerySucceeded: - phaseInfo = core.PhaseInfoSuccess(ConstructTaskInfo(state)) + phaseInfo = core.PhaseInfoSuccessWithVersion(uint32(state.QueryCount), ConstructTaskInfo(state)) case PhaseQueryFailed: phaseInfo = core.PhaseInfoFailure(errors.DownstreamSystemError, "Query failed", ConstructTaskInfo(state)) @@ -487,7 +464,7 @@ func ConstructTaskLog(e ExecutionState) *idlCore.TaskLog { } } -func Abort(ctx context.Context, currentState ExecutionState, client svc.ServiceClient) error { +func Abort(ctx context.Context, currentState ExecutionState, client client.PrestoClient) error { // Cancel Presto query if non-terminal state if !InTerminalState(currentState) && currentState.CommandID != "" { err := client.KillCommand(ctx, currentState.CommandID) diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index 8db0beb76..4efdb5dcb 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -8,9 +8,10 @@ import ( "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client/mocks" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" - prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/svc/mocks" mocks2 "github.com/lyft/flytestdlib/cache/mocks" + stdConfig "github.com/lyft/flytestdlib/config" "github.com/lyft/flytestdlib/contextutils" "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" @@ -248,7 +249,7 @@ func TestAbort(t *testing.T) { t.Run("Terminate called when not in terminal state", func(t *testing.T) { var x = false - mockPresto := &prestoMocks.ServiceClient{} + mockPresto := &prestoMocks.PrestoClient{} mockPresto.On("KillCommand", mock.Anything, mock.MatchedBy(func(commandId string) bool { return commandId == "123456" }), mock.Anything).Run(func(_ mock.Arguments) { @@ -263,7 +264,7 @@ func TestAbort(t *testing.T) { t.Run("Terminate not called when in terminal state", func(t *testing.T) { var x = false - mockPresto := &prestoMocks.ServiceClient{} + mockPresto := &prestoMocks.PrestoClient{} mockPresto.On("KillCommand", mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { x = true }).Return(nil) @@ -322,7 +323,7 @@ func TestKickOffQuery(t *testing.T) { ID: "1234567", Status: client.PrestoStatusQueued, } - mockPresto := &prestoMocks.ServiceClient{} + mockPresto := &prestoMocks.PrestoClient{} mockPresto.OnExecuteCommandMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Run(func(_ mock.Arguments) { prestoCalled = true @@ -344,13 +345,18 @@ func TestKickOffQuery(t *testing.T) { func createMockPrestoCfg() *config.Config { return &config.Config{ - Environment: config.URLMustParse("https://prestoproxy-internal.lyft.net:443"), + Environment: config.URLMustParse(""), DefaultRoutingGroup: "adhoc", - Workers: 15, - LruCacheSize: 2000, - AwsS3ShardFormatter: "s3://lyft-modelbuilder/{}/", + AwsS3ShardFormatter: "", AwsS3ShardCount: 2, RoutingGroupConfigs: []config.RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, + RateLimiter: config.RateLimiter{ + Name: "presto", + SyncPeriod: stdConfig.Duration{Duration: 3 * time.Second}, + Workers: 15, + LruCacheSize: 2000, + MetricScope: "presto", + }, } } diff --git a/go/tasks/plugins/presto/executions_cache.go b/go/tasks/plugins/presto/executions_cache.go index 2190040c4..dcc10e304 100644 --- a/go/tasks/plugins/presto/executions_cache.go +++ b/go/tasks/plugins/presto/executions_cache.go @@ -4,8 +4,6 @@ import ( "context" "time" - "github.com/lyft/flyteplugins/go/tasks/plugins/svc" - "k8s.io/client-go/util/workqueue" "github.com/lyft/flytestdlib/cache" @@ -28,14 +26,14 @@ const ( type ExecutionsCache struct { cache.AutoRefresh - prestoClient svc.ServiceClient + prestoClient client.PrestoClient scope promutils.Scope cfg *config.Config } func NewPrestoExecutionsCache( ctx context.Context, - prestoClient svc.ServiceClient, + prestoClient client.PrestoClient, cfg *config.Config, scope promutils.Scope) (ExecutionsCache, error) { @@ -44,7 +42,7 @@ func NewPrestoExecutionsCache( scope: scope, cfg: cfg, } - autoRefreshCache, err := cache.NewAutoRefreshCache("presto", q.SyncPrestoQuery, workqueue.DefaultControllerRateLimiter(), ResyncDuration, cfg.Workers, cfg.LruCacheSize, scope) + autoRefreshCache, err := cache.NewAutoRefreshCache(cfg.RateLimiter.Name, q.SyncPrestoQuery, workqueue.DefaultControllerRateLimiter(), cfg.RateLimiter.SyncPeriod.Duration, cfg.RateLimiter.Workers, cfg.RateLimiter.LruCacheSize, scope.NewSubScope(cfg.RateLimiter.MetricScope)) if err != nil { logger.Errorf(ctx, "Could not create AutoRefreshCache in Executor. [%s]", err) return q, errors.Wrapf(errors.CacheFailed, err, "Error creating AutoRefreshCache") @@ -146,7 +144,7 @@ func (p *ExecutionsCache) SyncPrestoQuery(ctx context.Context, batch cache.Batch } // We need some way to translate results we get from Presto, into a plugin phase -func StatusToExecutionPhase(s svc.CommandStatus) (ExecutionPhase, error) { +func StatusToExecutionPhase(s client.PrestoStatus) (ExecutionPhase, error) { switch s { case client.PrestoStatusFinished: return PhaseQuerySucceeded, nil diff --git a/go/tasks/plugins/presto/executions_cache_test.go b/go/tasks/plugins/presto/executions_cache_test.go index 71c6c83ca..3f6114762 100644 --- a/go/tasks/plugins/presto/executions_cache_test.go +++ b/go/tasks/plugins/presto/executions_cache_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/svc/mocks" + prestoMocks "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client/mocks" "github.com/lyft/flytestdlib/cache" cacheMocks "github.com/lyft/flytestdlib/cache/mocks" @@ -23,7 +23,7 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { t.Run("terminal state return unchanged", func(t *testing.T) { mockCache := &cacheMocks.AutoRefresh{} - mockPresto := &prestoMocks.ServiceClient{} + mockPresto := &prestoMocks.PrestoClient{} testScope := promutils.NewTestScope() p := ExecutionsCache{ @@ -53,7 +53,7 @@ func TestPrestoExecutionsCache_SyncQuboleQuery(t *testing.T) { t.Run("move to success", func(t *testing.T) { mockCache := &cacheMocks.AutoRefresh{} - mockPresto := &prestoMocks.ServiceClient{} + mockPresto := &prestoMocks.PrestoClient{} mockSecretManager := &mocks.SecretManager{} mockSecretManager.OnGetMatch(mock.Anything, mock.Anything).Return("fake key", nil) diff --git a/go/tasks/plugins/presto/executor.go b/go/tasks/plugins/presto/executor.go index 4d2a54e54..011f70eec 100644 --- a/go/tasks/plugins/presto/executor.go +++ b/go/tasks/plugins/presto/executor.go @@ -4,7 +4,6 @@ import ( "context" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" - "github.com/lyft/flyteplugins/go/tasks/plugins/svc" "github.com/lyft/flytestdlib/cache" @@ -28,7 +27,7 @@ const prestoTaskType = "presto" // This needs to match the type defined in Flyte type Executor struct { id string metrics ExecutorMetrics - prestoClient svc.ServiceClient + prestoClient client.PrestoClient executionsCache cache.AutoRefresh cfg *config.Config } @@ -95,7 +94,7 @@ func (p Executor) GetProperties() core.PluginProperties { func ExecutorLoader(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { cfg := config.GetPrestoConfig() - return InitializePrestoExecutor(ctx, iCtx, cfg, BuildResourceConfig(cfg), client.NewPrestoClient(cfg)) + return InitializePrestoExecutor(ctx, iCtx, cfg, BuildResourceConfig(cfg), client.NewNoopPrestoClient(cfg)) } func BuildResourceConfig(cfg *config.Config) map[string]int { @@ -112,7 +111,7 @@ func InitializePrestoExecutor( iCtx core.SetupContext, cfg *config.Config, resourceConfig map[string]int, - prestoClient svc.ServiceClient) (core.Plugin, error) { + prestoClient client.PrestoClient) (core.Plugin, error) { logger.Infof(ctx, "Initializing a Presto executor with a resource config [%v]", resourceConfig) q, err := NewPrestoExecutor(ctx, cfg, prestoClient, iCtx.MetricsScope()) if err != nil { @@ -134,7 +133,7 @@ func InitializePrestoExecutor( func NewPrestoExecutor( ctx context.Context, cfg *config.Config, - prestoClient svc.ServiceClient, + prestoClient client.PrestoClient, scope promutils.Scope) (Executor, error) { executionsAutoRefreshCache, err := NewPrestoExecutionsCache(ctx, prestoClient, cfg, scope.NewSubScope(prestoTaskType)) if err != nil { diff --git a/go/tasks/plugins/presto/test_helpers.go b/go/tasks/plugins/presto/helpers_test.go similarity index 100% rename from go/tasks/plugins/presto/test_helpers.go rename to go/tasks/plugins/presto/helpers_test.go diff --git a/go/tasks/plugins/svc/mocks/service_client.go b/go/tasks/plugins/svc/mocks/service_client.go deleted file mode 100644 index 19d351038..000000000 --- a/go/tasks/plugins/svc/mocks/service_client.go +++ /dev/null @@ -1,127 +0,0 @@ -// Code generated by mockery v1.0.1. DO NOT EDIT. - -package mocks - -import ( - context "context" - - svc "github.com/lyft/flyteplugins/go/tasks/plugins/svc" - mock "github.com/stretchr/testify/mock" -) - -// ServiceClient is an autogenerated mock type for the ServiceClient type -type ServiceClient struct { - mock.Mock -} - -type ServiceClient_ExecuteCommand struct { - *mock.Call -} - -func (_m ServiceClient_ExecuteCommand) Return(_a0 interface{}, _a1 error) *ServiceClient_ExecuteCommand { - return &ServiceClient_ExecuteCommand{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *ServiceClient) OnExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) *ServiceClient_ExecuteCommand { - c := _m.On("ExecuteCommand", ctx, commandStr, extraArgs) - return &ServiceClient_ExecuteCommand{Call: c} -} - -func (_m *ServiceClient) OnExecuteCommandMatch(matchers ...interface{}) *ServiceClient_ExecuteCommand { - c := _m.On("ExecuteCommand", matchers...) - return &ServiceClient_ExecuteCommand{Call: c} -} - -// ExecuteCommand provides a mock function with given fields: ctx, commandStr, extraArgs -func (_m *ServiceClient) ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) { - ret := _m.Called(ctx, commandStr, extraArgs) - - var r0 interface{} - if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) interface{}); ok { - r0 = rf(ctx, commandStr, extraArgs) - } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(interface{}) - } - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, interface{}) error); ok { - r1 = rf(ctx, commandStr, extraArgs) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type ServiceClient_GetCommandStatus struct { - *mock.Call -} - -func (_m ServiceClient_GetCommandStatus) Return(_a0 svc.CommandStatus, _a1 error) *ServiceClient_GetCommandStatus { - return &ServiceClient_GetCommandStatus{Call: _m.Call.Return(_a0, _a1)} -} - -func (_m *ServiceClient) OnGetCommandStatus(ctx context.Context, commandID string) *ServiceClient_GetCommandStatus { - c := _m.On("GetCommandStatus", ctx, commandID) - return &ServiceClient_GetCommandStatus{Call: c} -} - -func (_m *ServiceClient) OnGetCommandStatusMatch(matchers ...interface{}) *ServiceClient_GetCommandStatus { - c := _m.On("GetCommandStatus", matchers...) - return &ServiceClient_GetCommandStatus{Call: c} -} - -// GetCommandStatus provides a mock function with given fields: ctx, commandID -func (_m *ServiceClient) GetCommandStatus(ctx context.Context, commandID string) (svc.CommandStatus, error) { - ret := _m.Called(ctx, commandID) - - var r0 svc.CommandStatus - if rf, ok := ret.Get(0).(func(context.Context, string) svc.CommandStatus); ok { - r0 = rf(ctx, commandID) - } else { - r0 = ret.Get(0).(svc.CommandStatus) - } - - var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { - r1 = rf(ctx, commandID) - } else { - r1 = ret.Error(1) - } - - return r0, r1 -} - -type ServiceClient_KillCommand struct { - *mock.Call -} - -func (_m ServiceClient_KillCommand) Return(_a0 error) *ServiceClient_KillCommand { - return &ServiceClient_KillCommand{Call: _m.Call.Return(_a0)} -} - -func (_m *ServiceClient) OnKillCommand(ctx context.Context, commandID string) *ServiceClient_KillCommand { - c := _m.On("KillCommand", ctx, commandID) - return &ServiceClient_KillCommand{Call: c} -} - -func (_m *ServiceClient) OnKillCommandMatch(matchers ...interface{}) *ServiceClient_KillCommand { - c := _m.On("KillCommand", matchers...) - return &ServiceClient_KillCommand{Call: c} -} - -// KillCommand provides a mock function with given fields: ctx, commandID -func (_m *ServiceClient) KillCommand(ctx context.Context, commandID string) error { - ret := _m.Called(ctx, commandID) - - var r0 error - if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = rf(ctx, commandID) - } else { - r0 = ret.Error(0) - } - - return r0 -} diff --git a/go/tasks/plugins/svc/service_client.go b/go/tasks/plugins/svc/service_client.go deleted file mode 100644 index f99d7e660..000000000 --- a/go/tasks/plugins/svc/service_client.go +++ /dev/null @@ -1,15 +0,0 @@ -package svc - -import ( - "context" -) - -type CommandStatus string - -//go:generate mockery -all -case=snake - -type ServiceClient interface { - ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) - KillCommand(ctx context.Context, commandID string) error - GetCommandStatus(ctx context.Context, commandID string) (CommandStatus, error) -} From cd07b27db3b10d199c94925d4268ceede031e063 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Thu, 12 Mar 2020 18:37:07 -0700 Subject: [PATCH 12/26] PR feedback changes 2 --- .../presto/client/mocks/presto_client.go | 18 +++++++++--------- .../plugins/presto/client/presto_client.go | 13 +++++++++++-- go/tasks/plugins/presto/execution_state.go | 17 +++++++++++------ 3 files changed, 31 insertions(+), 17 deletions(-) diff --git a/go/tasks/plugins/presto/client/mocks/presto_client.go b/go/tasks/plugins/presto/client/mocks/presto_client.go index 3494940e5..c26487085 100644 --- a/go/tasks/plugins/presto/client/mocks/presto_client.go +++ b/go/tasks/plugins/presto/client/mocks/presto_client.go @@ -23,8 +23,8 @@ func (_m PrestoClient_ExecuteCommand) Return(_a0 interface{}, _a1 error) *Presto return &PrestoClient_ExecuteCommand{Call: _m.Call.Return(_a0, _a1)} } -func (_m *PrestoClient) OnExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) *PrestoClient_ExecuteCommand { - c := _m.On("ExecuteCommand", ctx, commandStr, extraArgs) +func (_m *PrestoClient) OnExecuteCommand(ctx context.Context, commandStr string, executeArgs client.PrestoExecuteArgs) *PrestoClient_ExecuteCommand { + c := _m.On("ExecuteCommand", ctx, commandStr, executeArgs) return &PrestoClient_ExecuteCommand{Call: c} } @@ -33,13 +33,13 @@ func (_m *PrestoClient) OnExecuteCommandMatch(matchers ...interface{}) *PrestoCl return &PrestoClient_ExecuteCommand{Call: c} } -// ExecuteCommand provides a mock function with given fields: ctx, commandStr, extraArgs -func (_m *PrestoClient) ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) { - ret := _m.Called(ctx, commandStr, extraArgs) +// ExecuteCommand provides a mock function with given fields: ctx, commandStr, executeArgs +func (_m *PrestoClient) ExecuteCommand(ctx context.Context, commandStr string, executeArgs client.PrestoExecuteArgs) (interface{}, error) { + ret := _m.Called(ctx, commandStr, executeArgs) var r0 interface{} - if rf, ok := ret.Get(0).(func(context.Context, string, interface{}) interface{}); ok { - r0 = rf(ctx, commandStr, extraArgs) + if rf, ok := ret.Get(0).(func(context.Context, string, client.PrestoExecuteArgs) interface{}); ok { + r0 = rf(ctx, commandStr, executeArgs) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(interface{}) @@ -47,8 +47,8 @@ func (_m *PrestoClient) ExecuteCommand(ctx context.Context, commandStr string, e } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, string, interface{}) error); ok { - r1 = rf(ctx, commandStr, extraArgs) + if rf, ok := ret.Get(1).(func(context.Context, string, client.PrestoExecuteArgs) error); ok { + r1 = rf(ctx, commandStr, executeArgs) } else { r1 = ret.Error(1) } diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index f1f29c538..1ac26da46 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -19,12 +19,15 @@ type noopPrestoClient struct { environment *url.URL } +// Contains information needed to execute a Presto query type PrestoExecuteArgs struct { RoutingGroup string `json:"routingGroup,omitempty"` Catalog string `json:"catalog,omitempty"` Schema string `json:"schema,omitempty"` Source string `json:"source,omitempty"` } + +// Representation of a response after submitting a query to Presto type PrestoExecuteResponse struct { ID string Status PrestoStatus @@ -33,16 +36,22 @@ type PrestoExecuteResponse struct { //go:generate mockery -all -case=snake +// Interface to interact with PrestoClient for Presto tasks type PrestoClient interface { - ExecuteCommand(ctx context.Context, commandStr string, extraArgs interface{}) (interface{}, error) + // Submits a query to Presto + ExecuteCommand(ctx context.Context, commandStr string, executeArgs PrestoExecuteArgs) (interface{}, error) + + // Cancels a currently running Presto query KillCommand(ctx context.Context, commandID string) error + + // Gets the status of a Presto query GetCommandStatus(ctx context.Context, commandID string) (PrestoStatus, error) } func (p noopPrestoClient) ExecuteCommand( ctx context.Context, queryStr string, - extraArgs interface{}) (interface{}, error) { + executeArgs PrestoExecuteArgs) (interface{}, error) { return PrestoExecuteResponse{}, nil } diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 5bfc65ed6..4d3e6c2be 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -56,10 +56,15 @@ type ExecutionState struct { // This will store the command ID from Presto CommandID string `json:"commandId,omitempty"` - URI string `json:"uri,omitempty"` + // This will have the nextUri from Presto which is used to advance the query forward + URI string `json:"uri,omitempty"` + + // This is the current Presto query (out of 5) needed to complete a Presto task CurrentPrestoQuery Query `json:"currentPrestoQuery,omitempty"` - QueryCount int `json:"queryCount,omitempty"` + + // Keeps track of which Presto query we are on. Its values range from 0-4 for the 5 queries that are needed + QueryCount int `json:"queryCount,omitempty"` // This number keeps track of the number of failures within the sync function. Without this, what happens in // the sync function is entirely opaque. Note that this field is completely orthogonal to Flyte system/node/task @@ -75,7 +80,7 @@ type ExecutionState struct { type Query struct { Statement string `json:"statement,omitempty"` - ExtraArgs client.PrestoExecuteArgs `json:"extraArgs,omitempty"` + ExecuteArgs client.PrestoExecuteArgs `json:"executeArgs,omitempty"` TempTableName string `json:"tempTableName,omitempty"` ExternalTableName string `json:"externalTableName,omitempty"` ExternalLocation string `json:"externalLocation"` @@ -280,7 +285,7 @@ func GetNextQuery( prestoQuery := Query{ Statement: statement, - ExtraArgs: client.PrestoExecuteArgs{ + ExecuteArgs: client.PrestoExecuteArgs{ RoutingGroup: resolveRoutingGroup(ctx, routingGroup, prestoCfg), Catalog: catalog, Schema: schema, @@ -348,9 +353,9 @@ func KickOffQuery( uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() statement := currentState.CurrentPrestoQuery.Statement - extraArgs := currentState.CurrentPrestoQuery.ExtraArgs + executeArgs := currentState.CurrentPrestoQuery.ExecuteArgs - response, err := prestoClient.ExecuteCommand(ctx, statement, extraArgs) + response, err := prestoClient.ExecuteCommand(ctx, statement, executeArgs) if err != nil { // If we failed, we'll keep the NotStarted state currentState.CreationFailureCount = currentState.CreationFailureCount + 1 From 574a3cc71a994232b143b833995e1924b63ff244 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Fri, 13 Mar 2020 10:34:36 -0700 Subject: [PATCH 13/26] PR feedback changes 3 --- go/tasks/pluginmachinery/core/phase.go | 4 ---- go/tasks/plugins/presto/execution_state.go | 8 +++++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/go/tasks/pluginmachinery/core/phase.go b/go/tasks/pluginmachinery/core/phase.go index ca8fad8fe..470cf1368 100644 --- a/go/tasks/pluginmachinery/core/phase.go +++ b/go/tasks/pluginmachinery/core/phase.go @@ -180,10 +180,6 @@ func PhaseInfoSuccess(info *TaskInfo) PhaseInfo { return phaseInfo(PhaseSuccess, DefaultPhaseVersion, nil, info) } -func PhaseInfoSuccessWithVersion(version uint32, info *TaskInfo) PhaseInfo { - return phaseInfo(PhaseSuccess, version, nil, info) -} - func PhaseInfoFailure(code, reason string, info *TaskInfo) PhaseInfo { return PhaseInfoFailed(PhasePermanentFailure, &core.ExecutionError{Code: code, Message: reason}, info) } diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 4d3e6c2be..aace8ad82 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -436,10 +436,12 @@ func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { } case PhaseSubmitted: phaseInfo = core.PhaseInfoRunning(uint32(state.QueryCount), ConstructTaskInfo(state)) - case PhaseQuerySucceeded: - phaseInfo = core.PhaseInfoSuccessWithVersion(uint32(state.QueryCount), ConstructTaskInfo(state)) - + if state.QueryCount < 4 { + phaseInfo = core.PhaseInfoRunning(uint32(state.QueryCount), ConstructTaskInfo(state)) + } else { + phaseInfo = core.PhaseInfoSuccess(ConstructTaskInfo(state)) + } case PhaseQueryFailed: phaseInfo = core.PhaseInfoFailure(errors.DownstreamSystemError, "Query failed", ConstructTaskInfo(state)) } From 352ac582f84b57b843585a0f0eaa29606470760f Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Sun, 15 Mar 2020 16:43:07 -0700 Subject: [PATCH 14/26] PR feedback changes 4 --- go/tasks/pluginmachinery/core/exec_context.go | 3 ++ .../presto/client/mocks/presto_client.go | 12 +++-- .../presto/client/noop_presto_client.go | 43 ++++++++++++++++++ .../plugins/presto/client/presto_client.go | 44 +------------------ go/tasks/plugins/presto/config/config.go | 4 -- .../plugins/presto/config/config_flags.go | 2 - .../presto/config/config_flags_test.go | 44 ------------------- go/tasks/plugins/presto/execution_state.go | 28 ++++-------- .../plugins/presto/execution_state_test.go | 4 +- 9 files changed, 62 insertions(+), 122 deletions(-) create mode 100644 go/tasks/plugins/presto/client/noop_presto_client.go diff --git a/go/tasks/pluginmachinery/core/exec_context.go b/go/tasks/pluginmachinery/core/exec_context.go index 4212320c2..0545f008f 100644 --- a/go/tasks/pluginmachinery/core/exec_context.go +++ b/go/tasks/pluginmachinery/core/exec_context.go @@ -57,6 +57,9 @@ type TaskExecutionContext interface { // Returns a handle to the Task events recorder, which get stored in the Admin. EventsRecorder() EventsRecorder + + //// Returns a reference to a data location + //DataLocation() io.OutputWriter } // A simple fire-and-forget func diff --git a/go/tasks/plugins/presto/client/mocks/presto_client.go b/go/tasks/plugins/presto/client/mocks/presto_client.go index c26487085..ad732e91b 100644 --- a/go/tasks/plugins/presto/client/mocks/presto_client.go +++ b/go/tasks/plugins/presto/client/mocks/presto_client.go @@ -19,7 +19,7 @@ type PrestoClient_ExecuteCommand struct { *mock.Call } -func (_m PrestoClient_ExecuteCommand) Return(_a0 interface{}, _a1 error) *PrestoClient_ExecuteCommand { +func (_m PrestoClient_ExecuteCommand) Return(_a0 client.PrestoExecuteResponse, _a1 error) *PrestoClient_ExecuteCommand { return &PrestoClient_ExecuteCommand{Call: _m.Call.Return(_a0, _a1)} } @@ -34,16 +34,14 @@ func (_m *PrestoClient) OnExecuteCommandMatch(matchers ...interface{}) *PrestoCl } // ExecuteCommand provides a mock function with given fields: ctx, commandStr, executeArgs -func (_m *PrestoClient) ExecuteCommand(ctx context.Context, commandStr string, executeArgs client.PrestoExecuteArgs) (interface{}, error) { +func (_m *PrestoClient) ExecuteCommand(ctx context.Context, commandStr string, executeArgs client.PrestoExecuteArgs) (client.PrestoExecuteResponse, error) { ret := _m.Called(ctx, commandStr, executeArgs) - var r0 interface{} - if rf, ok := ret.Get(0).(func(context.Context, string, client.PrestoExecuteArgs) interface{}); ok { + var r0 client.PrestoExecuteResponse + if rf, ok := ret.Get(0).(func(context.Context, string, client.PrestoExecuteArgs) client.PrestoExecuteResponse); ok { r0 = rf(ctx, commandStr, executeArgs) } else { - if ret.Get(0) != nil { - r0 = ret.Get(0).(interface{}) - } + r0 = ret.Get(0).(client.PrestoExecuteResponse) } var r1 error diff --git a/go/tasks/plugins/presto/client/noop_presto_client.go b/go/tasks/plugins/presto/client/noop_presto_client.go new file mode 100644 index 000000000..98facaf3d --- /dev/null +++ b/go/tasks/plugins/presto/client/noop_presto_client.go @@ -0,0 +1,43 @@ +package client + +import ( + "context" + "net/http" + "net/url" + + "time" + + "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" +) + +const ( + httpRequestTimeoutSecs = 30 +) + +type noopPrestoClient struct { + client *http.Client + environment *url.URL +} + +func (p noopPrestoClient) ExecuteCommand( + ctx context.Context, + queryStr string, + executeArgs PrestoExecuteArgs) (PrestoExecuteResponse, error) { + + return PrestoExecuteResponse{}, nil +} + +func (p noopPrestoClient) KillCommand(ctx context.Context, commandID string) error { + return nil +} + +func (p noopPrestoClient) GetCommandStatus(ctx context.Context, commandID string) (PrestoStatus, error) { + return NewPrestoStatus(ctx, "UNKNOWN"), nil +} + +func NewNoopPrestoClient(cfg *config.Config) PrestoClient { + return &noopPrestoClient{ + client: &http.Client{Timeout: httpRequestTimeoutSecs * time.Second}, + environment: cfg.Environment.ResolveReference(&cfg.Environment.URL), + } +} diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index 1ac26da46..79523634e 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -1,23 +1,6 @@ package client -import ( - "context" - "net/http" - "net/url" - - "time" - - "github.com/lyft/flyteplugins/go/tasks/plugins/presto/config" -) - -const ( - httpRequestTimeoutSecs = 30 -) - -type noopPrestoClient struct { - client *http.Client - environment *url.URL -} +import "context" // Contains information needed to execute a Presto query type PrestoExecuteArgs struct { @@ -39,7 +22,7 @@ type PrestoExecuteResponse struct { // Interface to interact with PrestoClient for Presto tasks type PrestoClient interface { // Submits a query to Presto - ExecuteCommand(ctx context.Context, commandStr string, executeArgs PrestoExecuteArgs) (interface{}, error) + ExecuteCommand(ctx context.Context, commandStr string, executeArgs PrestoExecuteArgs) (PrestoExecuteResponse, error) // Cancels a currently running Presto query KillCommand(ctx context.Context, commandID string) error @@ -47,26 +30,3 @@ type PrestoClient interface { // Gets the status of a Presto query GetCommandStatus(ctx context.Context, commandID string) (PrestoStatus, error) } - -func (p noopPrestoClient) ExecuteCommand( - ctx context.Context, - queryStr string, - executeArgs PrestoExecuteArgs) (interface{}, error) { - - return PrestoExecuteResponse{}, nil -} - -func (p noopPrestoClient) KillCommand(ctx context.Context, commandID string) error { - return nil -} - -func (p noopPrestoClient) GetCommandStatus(ctx context.Context, commandID string) (PrestoStatus, error) { - return NewPrestoStatus(ctx, "UNKNOWN"), nil -} - -func NewNoopPrestoClient(cfg *config.Config) PrestoClient { - return &noopPrestoClient{ - client: &http.Client{Timeout: httpRequestTimeoutSecs * time.Second}, - environment: cfg.Environment.ResolveReference(&cfg.Environment.URL), - } -} diff --git a/go/tasks/plugins/presto/config/config.go b/go/tasks/plugins/presto/config/config.go index b91eb22ec..c2a6bd743 100644 --- a/go/tasks/plugins/presto/config/config.go +++ b/go/tasks/plugins/presto/config/config.go @@ -45,8 +45,6 @@ var ( defaultConfig = Config{ Environment: URLMustParse(""), DefaultRoutingGroup: "adhoc", - AwsS3ShardFormatter: "", - AwsS3ShardCount: 2, RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, RateLimiter: RateLimiter{ Name: "presto", @@ -64,8 +62,6 @@ var ( type Config struct { Environment config.URL `json:"environment" pflag:",Environment endpoint for Presto to use"` DefaultRoutingGroup string `json:"defaultRoutingGroup" pflag:",Default Presto routing group"` - AwsS3ShardFormatter string `json:"awsS3ShardFormatter" pflag:", S3 bucket prefix where Presto results will be stored"` - AwsS3ShardCount int `json:"awsS3ShardCount" pflag:", Number of characters for the S3 bucket shard prefix"` RoutingGroupConfigs []RoutingGroupConfig `json:"routingGroupConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` RateLimiter RateLimiter `json:"rateLimiter" pflag:"Rate limiter config"` } diff --git a/go/tasks/plugins/presto/config/config_flags.go b/go/tasks/plugins/presto/config/config_flags.go index c83e30834..66601f842 100755 --- a/go/tasks/plugins/presto/config/config_flags.go +++ b/go/tasks/plugins/presto/config/config_flags.go @@ -43,8 +43,6 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) cmdFlags.String(fmt.Sprintf("%v%v", prefix, "environment"), defaultConfig.Environment.String(), "Environment endpoint for Presto to use") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultRoutingGroup"), defaultConfig.DefaultRoutingGroup, "Default Presto routing group") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "awsS3ShardFormatter"), defaultConfig.AwsS3ShardFormatter, " S3 bucket prefix where Presto results will be stored") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "awsS3ShardCount"), defaultConfig.AwsS3ShardCount, " Number of characters for the S3 bucket shard prefix") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.name"), defaultConfig.RateLimiter.Name, "The name of the rate limiter") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.syncPeriod"), defaultConfig.RateLimiter.SyncPeriod.String(), "The duration to wait before the cache is refreshed again") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "rateLimiter.workers"), defaultConfig.RateLimiter.Workers, "Number of parallel workers to refresh the cache") diff --git a/go/tasks/plugins/presto/config/config_flags_test.go b/go/tasks/plugins/presto/config/config_flags_test.go index 41bead819..f714e29b8 100755 --- a/go/tasks/plugins/presto/config/config_flags_test.go +++ b/go/tasks/plugins/presto/config/config_flags_test.go @@ -143,50 +143,6 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_awsS3ShardFormatter", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("awsS3ShardFormatter"); err == nil { - assert.Equal(t, string(defaultConfig.AwsS3ShardFormatter), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) - - t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("awsS3ShardFormatter", testValue) - if vString, err := cmdFlags.GetString("awsS3ShardFormatter"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.AwsS3ShardFormatter) - - } else { - assert.FailNow(t, err.Error()) - } - }) - }) - t.Run("Test_awsS3ShardCount", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("awsS3ShardCount"); err == nil { - assert.Equal(t, int(defaultConfig.AwsS3ShardCount), vInt) - } else { - assert.FailNow(t, err.Error()) - } - }) - - t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("awsS3ShardCount", testValue) - if vInt, err := cmdFlags.GetInt("awsS3ShardCount"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.AwsS3ShardCount) - - } else { - assert.FailNow(t, err.Error()) - } - }) - }) t.Run("Test_rateLimiter.name", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index aace8ad82..9a00c8366 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -6,7 +6,6 @@ import ( "k8s.io/apimachinery/pkg/util/rand" "fmt" - "strings" "github.com/lyft/flyteplugins/go/tasks/plugins/presto/client" @@ -297,9 +296,7 @@ func GetNextQuery( return prestoQuery, nil case 1: - cfg := config.GetPrestoConfig() - externalLocation := getExternalLocation(cfg.AwsS3ShardFormatter, cfg.AwsS3ShardCount) - + externalLocation := "TODO - use sandbox ref" statement := fmt.Sprintf(` CREATE TABLE hive.flyte_temporary_tables.%s (LIKE hive.flyte_temporary_tables.%s) WITH (format = 'PARQUET', external_location = '%s')`, @@ -334,15 +331,6 @@ FROM hive.flyte_temporary_tables.%s` } } -func getExternalLocation(shardFormatter string, shardLength int) string { - shardCount := strings.Count(shardFormatter, "{}") - for i := 0; i < shardCount; i++ { - shardFormatter = strings.Replace(shardFormatter, "{}", rand.String(shardLength), 1) - } - - return shardFormatter + rand.String(32) + "/" -} - func KickOffQuery( ctx context.Context, tCtx core.TaskExecutionContext, @@ -361,15 +349,13 @@ func KickOffQuery( currentState.CreationFailureCount = currentState.CreationFailureCount + 1 logger.Warnf(ctx, "Error creating Presto query for %s, failure counts %d. Error: %s", uniqueID, currentState.CreationFailureCount, err) } else { - executeResponse := response.(client.PrestoExecuteResponse) - // If we succeed, then store the command id returned from Presto, and update our state. Also, add to the // AutoRefreshCache so we start getting updates for its status. - commandID := executeResponse.ID + commandID := response.ID logger.Infof(ctx, "Created Presto ID [%s] for token %s", commandID, uniqueID) currentState.CommandID = commandID currentState.Phase = PhaseSubmitted - currentState.URI = executeResponse.NextURI + currentState.URI = response.NextURI executionStateCacheItem := ExecutionStateCacheItem{ ExecutionState: currentState, @@ -421,6 +407,8 @@ func MonitorQuery( return cachedExecutionState.ExecutionState, nil } +// The 'PhaseInfoRunning' occurs 15 times (3 for each of the 5 Presto queries that get run for every Presto task) which +// are differentiated by the version (1-15) func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { var phaseInfo core.PhaseInfo t := time.Now() @@ -432,13 +420,13 @@ func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { if state.CreationFailureCount > 5 { phaseInfo = core.PhaseInfoRetryableFailure("PrestoFailure", "Too many creation attempts", nil) } else { - phaseInfo = core.PhaseInfoQueued(t, uint32(state.QueryCount), "Waiting for Presto launch") + phaseInfo = core.PhaseInfoRunning(uint32(3*state.QueryCount+1), ConstructTaskInfo(state)) } case PhaseSubmitted: - phaseInfo = core.PhaseInfoRunning(uint32(state.QueryCount), ConstructTaskInfo(state)) + phaseInfo = core.PhaseInfoRunning(uint32(3*state.QueryCount+2), ConstructTaskInfo(state)) case PhaseQuerySucceeded: if state.QueryCount < 4 { - phaseInfo = core.PhaseInfoRunning(uint32(state.QueryCount), ConstructTaskInfo(state)) + phaseInfo = core.PhaseInfoRunning(uint32(3*state.QueryCount+3), ConstructTaskInfo(state)) } else { phaseInfo = core.PhaseInfoSuccess(ConstructTaskInfo(state)) } diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index 4efdb5dcb..44b99dabc 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -147,7 +147,7 @@ func TestMapExecutionStateToPhaseInfo(t *testing.T) { CreationFailureCount: 0, } phaseInfo := MapExecutionStateToPhaseInfo(e) - assert.Equal(t, core.PhaseQueued, phaseInfo.Phase()) + assert.Equal(t, core.PhaseRunning, phaseInfo.Phase()) e = ExecutionState{ Phase: PhaseQueued, @@ -347,8 +347,6 @@ func createMockPrestoCfg() *config.Config { return &config.Config{ Environment: config.URLMustParse(""), DefaultRoutingGroup: "adhoc", - AwsS3ShardFormatter: "", - AwsS3ShardCount: 2, RoutingGroupConfigs: []config.RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, RateLimiter: config.RateLimiter{ Name: "presto", From f48673f36255bba96ebef6d431a695717c44138b Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Sun, 15 Mar 2020 17:18:55 -0700 Subject: [PATCH 15/26] add user to execute args --- go/tasks/plugins/presto/client/presto_client.go | 1 + go/tasks/plugins/presto/execution_state.go | 13 +++++++++++++ 2 files changed, 14 insertions(+) diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index 79523634e..7019fef42 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -8,6 +8,7 @@ type PrestoExecuteArgs struct { Catalog string `json:"catalog,omitempty"` Schema string `json:"schema,omitempty"` Source string `json:"source,omitempty"` + User string `json:"user,omitempty"` } // Representation of a response after submitting a query to Presto diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 9a00c8366..62253aa88 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -3,6 +3,8 @@ package presto import ( "context" + "github.com/lyft/flytestdlib/contextutils" + "k8s.io/apimachinery/pkg/util/rand" "fmt" @@ -32,6 +34,8 @@ const ( PhaseSubmitted // Sent off to Presto PhaseQuerySucceeded PhaseQueryFailed + + PrincipalContextKey contextutils.Key = "principal" ) func (p ExecutionPhase) String() string { @@ -288,6 +292,7 @@ func GetNextQuery( RoutingGroup: resolveRoutingGroup(ctx, routingGroup, prestoCfg), Catalog: catalog, Schema: schema, + User: getUser(ctx), }, TempTableName: tempTableName + "_temp", ExternalTableName: tempTableName + "_external", @@ -331,6 +336,14 @@ FROM hive.flyte_temporary_tables.%s` } } +func getUser(ctx context.Context) string { + principalContextUser := ctx.Value(PrincipalContextKey) + if principalContextUser != nil { + return fmt.Sprintf("%v", principalContextUser) + } + return "" +} + func KickOffQuery( ctx context.Context, tCtx core.TaskExecutionContext, From 4dff686ef8b6518203e332780e818b05ec3b450a Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Sun, 15 Mar 2020 22:37:22 -0700 Subject: [PATCH 16/26] resource reg --- go/tasks/plugins/presto/executor.go | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/go/tasks/plugins/presto/executor.go b/go/tasks/plugins/presto/executor.go index 011f70eec..e52ec8842 100644 --- a/go/tasks/plugins/presto/executor.go +++ b/go/tasks/plugins/presto/executor.go @@ -94,35 +94,25 @@ func (p Executor) GetProperties() core.PluginProperties { func ExecutorLoader(ctx context.Context, iCtx core.SetupContext) (core.Plugin, error) { cfg := config.GetPrestoConfig() - return InitializePrestoExecutor(ctx, iCtx, cfg, BuildResourceConfig(cfg), client.NewNoopPrestoClient(cfg)) -} - -func BuildResourceConfig(cfg *config.Config) map[string]int { - resourceConfig := make(map[string]int, len(cfg.RoutingGroupConfigs)) - - for _, routingGroupCfg := range cfg.RoutingGroupConfigs { - resourceConfig[routingGroupCfg.Name] = routingGroupCfg.Limit - } - return resourceConfig + return InitializePrestoExecutor(ctx, iCtx, cfg, client.NewNoopPrestoClient(cfg)) } func InitializePrestoExecutor( ctx context.Context, iCtx core.SetupContext, cfg *config.Config, - resourceConfig map[string]int, prestoClient client.PrestoClient) (core.Plugin, error) { - logger.Infof(ctx, "Initializing a Presto executor with a resource config [%v]", resourceConfig) + logger.Infof(ctx, "Initializing a Presto executo") q, err := NewPrestoExecutor(ctx, cfg, prestoClient, iCtx.MetricsScope()) if err != nil { logger.Errorf(ctx, "Failed to create a new Executor due to error: [%v]", err) return nil, err } - for routingGroupName, routingGroupLimit := range resourceConfig { - logger.Infof(ctx, "Registering resource quota for cluster [%v]", routingGroupName) - if err := iCtx.ResourceRegistrar().RegisterResourceQuota(ctx, core.ResourceNamespace(routingGroupName), routingGroupLimit); err != nil { - logger.Errorf(ctx, "Resource quota registration for [%v] failed due to error [%v]", routingGroupName, err) + for _, routingGroup := range cfg.RoutingGroupConfigs { + logger.Infof(ctx, "Registering resource quota for routing group [%v]", routingGroup.Name) + if err := iCtx.ResourceRegistrar().RegisterResourceQuota(ctx, core.ResourceNamespace(routingGroup.Name), routingGroup.Limit); err != nil { + logger.Errorf(ctx, "Resource quota registration for [%v] failed due to error [%v]", routingGroup.Name, err) return nil, err } } From 2898c2c322ce4924aab079b944926c32edeb6dbe Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Mon, 16 Mar 2020 08:37:50 -0700 Subject: [PATCH 17/26] update status --- go/tasks/plugins/presto/client/presto_status.go | 4 ++-- go/tasks/plugins/presto/execution_state_test.go | 2 +- go/tasks/plugins/presto/executions_cache.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/go/tasks/plugins/presto/client/presto_status.go b/go/tasks/plugins/presto/client/presto_status.go index db09afab2..a71964e0e 100644 --- a/go/tasks/plugins/presto/client/presto_status.go +++ b/go/tasks/plugins/presto/client/presto_status.go @@ -13,7 +13,7 @@ type PrestoStatus string // not meant to be stored locally. const ( PrestoStatusUnknown PrestoStatus = "UNKNOWN" - PrestoStatusQueued PrestoStatus = "QUEUED" + PrestoStatusWaiting PrestoStatus = "WAITING" PrestoStatusRunning PrestoStatus = "RUNNING" PrestoStatusFinished PrestoStatus = "FINISHED" PrestoStatusFailed PrestoStatus = "FAILED" @@ -22,7 +22,7 @@ const ( var PrestoStatuses = map[PrestoStatus]struct{}{ PrestoStatusUnknown: {}, - PrestoStatusQueued: {}, + PrestoStatusWaiting: {}, PrestoStatusRunning: {}, PrestoStatusFinished: {}, PrestoStatusFailed: {}, diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index 44b99dabc..01616f6d2 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -321,7 +321,7 @@ func TestKickOffQuery(t *testing.T) { prestoExecuteResponse := client.PrestoExecuteResponse{ ID: "1234567", - Status: client.PrestoStatusQueued, + Status: client.PrestoStatusWaiting, } mockPresto := &prestoMocks.PrestoClient{} mockPresto.OnExecuteCommandMatch(mock.Anything, mock.Anything, mock.Anything, mock.Anything, diff --git a/go/tasks/plugins/presto/executions_cache.go b/go/tasks/plugins/presto/executions_cache.go index dcc10e304..47ac57804 100644 --- a/go/tasks/plugins/presto/executions_cache.go +++ b/go/tasks/plugins/presto/executions_cache.go @@ -152,7 +152,7 @@ func StatusToExecutionPhase(s client.PrestoStatus) (ExecutionPhase, error) { return PhaseQueryFailed, nil case client.PrestoStatusFailed: return PhaseQueryFailed, nil - case client.PrestoStatusQueued: + case client.PrestoStatusWaiting: return PhaseSubmitted, nil case client.PrestoStatusRunning: return PhaseSubmitted, nil From 23b3eb246d9c7c49fe3f284ebe21a6609273d6fa Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 18 Mar 2020 15:18:40 -0700 Subject: [PATCH 18/26] e2e chages --- .../plugins/presto/client/presto_client.go | 2 + .../plugins/presto/client/presto_status.go | 2 - go/tasks/plugins/presto/config/config.go | 2 + .../plugins/presto/config/config_flags.go | 1 + .../presto/config/config_flags_test.go | 22 +++++ go/tasks/plugins/presto/execution_state.go | 88 ++++++++++++++----- .../plugins/presto/execution_state_test.go | 2 +- go/tasks/plugins/presto/executions_cache.go | 3 - go/tasks/plugins/presto/executor.go | 6 +- go/tasks/plugins/presto/executor_metrics.go | 15 ++-- 10 files changed, 105 insertions(+), 38 deletions(-) diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index 7019fef42..34bb32c4d 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -2,6 +2,8 @@ package client import "context" +type PrestoStatus string + // Contains information needed to execute a Presto query type PrestoExecuteArgs struct { RoutingGroup string `json:"routingGroup,omitempty"` diff --git a/go/tasks/plugins/presto/client/presto_status.go b/go/tasks/plugins/presto/client/presto_status.go index a71964e0e..8979b953e 100644 --- a/go/tasks/plugins/presto/client/presto_status.go +++ b/go/tasks/plugins/presto/client/presto_status.go @@ -7,8 +7,6 @@ import ( "github.com/lyft/flytestdlib/logger" ) -type PrestoStatus string - // This type is meant only to encapsulate the response coming from Presto as a type, it is // not meant to be stored locally. const ( diff --git a/go/tasks/plugins/presto/config/config.go b/go/tasks/plugins/presto/config/config.go index c2a6bd743..37cbfe5ba 100644 --- a/go/tasks/plugins/presto/config/config.go +++ b/go/tasks/plugins/presto/config/config.go @@ -45,6 +45,7 @@ var ( defaultConfig = Config{ Environment: URLMustParse(""), DefaultRoutingGroup: "adhoc", + DefaultUser: "flyte-default-user@lyft.com", RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, RateLimiter: RateLimiter{ Name: "presto", @@ -62,6 +63,7 @@ var ( type Config struct { Environment config.URL `json:"environment" pflag:",Environment endpoint for Presto to use"` DefaultRoutingGroup string `json:"defaultRoutingGroup" pflag:",Default Presto routing group"` + DefaultUser string `json:"defaultUser" pflag:",Default Presto user"` RoutingGroupConfigs []RoutingGroupConfig `json:"routingGroupConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` RateLimiter RateLimiter `json:"rateLimiter" pflag:"Rate limiter config"` } diff --git a/go/tasks/plugins/presto/config/config_flags.go b/go/tasks/plugins/presto/config/config_flags.go index 66601f842..1a0a2e44d 100755 --- a/go/tasks/plugins/presto/config/config_flags.go +++ b/go/tasks/plugins/presto/config/config_flags.go @@ -43,6 +43,7 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags := pflag.NewFlagSet("Config", pflag.ExitOnError) cmdFlags.String(fmt.Sprintf("%v%v", prefix, "environment"), defaultConfig.Environment.String(), "Environment endpoint for Presto to use") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultRoutingGroup"), defaultConfig.DefaultRoutingGroup, "Default Presto routing group") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultUser"), defaultConfig.DefaultUser, "Default Presto user") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.name"), defaultConfig.RateLimiter.Name, "The name of the rate limiter") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.syncPeriod"), defaultConfig.RateLimiter.SyncPeriod.String(), "The duration to wait before the cache is refreshed again") cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "rateLimiter.workers"), defaultConfig.RateLimiter.Workers, "Number of parallel workers to refresh the cache") diff --git a/go/tasks/plugins/presto/config/config_flags_test.go b/go/tasks/plugins/presto/config/config_flags_test.go index f714e29b8..21df51a13 100755 --- a/go/tasks/plugins/presto/config/config_flags_test.go +++ b/go/tasks/plugins/presto/config/config_flags_test.go @@ -143,6 +143,28 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_defaultUser", func(t *testing.T) { + t.Run("DefaultValue", func(t *testing.T) { + // Test that default value is set properly + if vString, err := cmdFlags.GetString("defaultUser"); err == nil { + assert.Equal(t, string(defaultConfig.DefaultUser), vString) + } else { + assert.FailNow(t, err.Error()) + } + }) + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("defaultUser", testValue) + if vString, err := cmdFlags.GetString("defaultUser"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.DefaultUser) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) t.Run("Test_rateLimiter.name", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 62253aa88..4b900c9d9 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -2,8 +2,9 @@ package presto import ( "context" + "strings" - "github.com/lyft/flytestdlib/contextutils" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" "k8s.io/apimachinery/pkg/util/rand" @@ -24,6 +25,8 @@ import ( "github.com/lyft/flyteplugins/go/tasks/errors" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core" "github.com/lyft/flytestdlib/logger" + + pb "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" ) type ExecutionPhase int @@ -34,8 +37,6 @@ const ( PhaseSubmitted // Sent off to Presto PhaseQuerySucceeded PhaseQueryFailed - - PrincipalContextKey contextutils.Key = "principal" ) func (p ExecutionPhase) String() string { @@ -66,6 +67,9 @@ type ExecutionState struct { // This is the current Presto query (out of 5) needed to complete a Presto task CurrentPrestoQuery Query `json:"currentPrestoQuery,omitempty"` + // This is an id to keep track of the current query. Every query's id should be unique for caching purposes + CurrentPrestoQueryUUID string `json:"currentPrestoQueryUUID,omitempty"` + // Keeps track of which Presto query we are on. Its values range from 0-4 for the 5 queries that are needed QueryCount int `json:"queryCount,omitempty"` @@ -121,11 +125,12 @@ func HandleExecutionState( // If there are still Presto statements to execute, increment the query count, reset the phase to 'queued' // and continue executing the remaining statements. In this case, we won't request another allocation token // as the 5 statements that get executed are all considered to be part of the same "query" - currentState.QueryCount++ currentState.Phase = PhaseQueued + } else { + transformError = writeOutput(ctx, tCtx, currentState.CurrentPrestoQuery.ExternalLocation) } + currentState.QueryCount++ newState = currentState - transformError = nil case PhaseQueryFailed: newState = currentState @@ -284,7 +289,7 @@ func GetNextQuery( return Query{}, err } - statement = fmt.Sprintf(`CREATE TABLE hive.flyte_temporary_tables.%s_temp AS %s`, tempTableName, statement) + statement = fmt.Sprintf(`CREATE TABLE hive.flyte_temporary_tables."%s_temp" AS %s`, tempTableName, statement) prestoQuery := Query{ Statement: statement, @@ -292,7 +297,8 @@ func GetNextQuery( RoutingGroup: resolveRoutingGroup(ctx, routingGroup, prestoCfg), Catalog: catalog, Schema: schema, - User: getUser(ctx), + Source: "flyte", + User: getUser(ctx, prestoCfg.DefaultUser), }, TempTableName: tempTableName + "_temp", ExternalTableName: tempTableName + "_external", @@ -301,33 +307,35 @@ func GetNextQuery( return prestoQuery, nil case 1: - externalLocation := "TODO - use sandbox ref" + // TODO + externalLocation := getExternalLocation("s3://lyft-modelbuilder/{}/", 2) statement := fmt.Sprintf(` -CREATE TABLE hive.flyte_temporary_tables.%s (LIKE hive.flyte_temporary_tables.%s) +CREATE TABLE hive.flyte_temporary_tables."%s" (LIKE hive.flyte_temporary_tables."%s") WITH (format = 'PARQUET', external_location = '%s')`, currentState.CurrentPrestoQuery.ExternalTableName, currentState.CurrentPrestoQuery.TempTableName, externalLocation, ) currentState.CurrentPrestoQuery.Statement = statement + currentState.CurrentPrestoQuery.ExternalLocation = externalLocation return currentState.CurrentPrestoQuery, nil case 2: statement := ` -INSERT INTO hive.flyte_temporary_tables.%s +INSERT INTO hive.flyte_temporary_tables."%s" SELECT * -FROM hive.flyte_temporary_tables.%s` +FROM hive.flyte_temporary_tables."%s"` statement = fmt.Sprintf(statement, currentState.CurrentPrestoQuery.ExternalTableName, currentState.CurrentPrestoQuery.TempTableName) currentState.CurrentPrestoQuery.Statement = statement return currentState.CurrentPrestoQuery, nil case 3: - statement := fmt.Sprintf(`DROP TABLE %s`, currentState.CurrentPrestoQuery.TempTableName) + statement := fmt.Sprintf(`DROP TABLE hive.flyte_temporary_tables."%s"`, currentState.CurrentPrestoQuery.TempTableName) currentState.CurrentPrestoQuery.Statement = statement return currentState.CurrentPrestoQuery, nil case 4: - statement := fmt.Sprintf(`DROP TABLE %s`, currentState.CurrentPrestoQuery.ExternalTableName) + statement := fmt.Sprintf(`DROP TABLE hive.flyte_temporary_tables."%s"`, currentState.CurrentPrestoQuery.ExternalTableName) currentState.CurrentPrestoQuery.Statement = statement return currentState.CurrentPrestoQuery, nil @@ -336,12 +344,21 @@ FROM hive.flyte_temporary_tables.%s` } } -func getUser(ctx context.Context) string { - principalContextUser := ctx.Value(PrincipalContextKey) +func getExternalLocation(shardFormatter string, shardLength int) string { + shardCount := strings.Count(shardFormatter, "{}") + for i := 0; i < shardCount; i++ { + shardFormatter = strings.Replace(shardFormatter, "{}", rand.String(shardLength), 1) + } + + return shardFormatter + rand.String(32) + "/" +} + +func getUser(ctx context.Context, defaultUser string) string { + principalContextUser := ctx.Value("principal") if principalContextUser != nil { return fmt.Sprintf("%v", principalContextUser) } - return "" + return defaultUser } func KickOffQuery( @@ -351,7 +368,7 @@ func KickOffQuery( prestoClient client.PrestoClient, cache cache.AutoRefresh) (ExecutionState, error) { - uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + "_" + rand.String(32) statement := currentState.CurrentPrestoQuery.Statement executeArgs := currentState.CurrentPrestoQuery.ExecuteArgs @@ -369,6 +386,7 @@ func KickOffQuery( currentState.CommandID = commandID currentState.Phase = PhaseSubmitted currentState.URI = response.NextURI + currentState.CurrentPrestoQueryUUID = uniqueID executionStateCacheItem := ExecutionStateCacheItem{ ExecutionState: currentState, @@ -395,17 +413,17 @@ func MonitorQuery( currentState ExecutionState, cache cache.AutoRefresh) (ExecutionState, error) { - uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + uniqueQueryID := currentState.CurrentPrestoQueryUUID executionStateCacheItem := ExecutionStateCacheItem{ ExecutionState: currentState, - Identifier: uniqueID, + Identifier: uniqueQueryID, } - cachedItem, err := cache.GetOrCreate(uniqueID, executionStateCacheItem) + cachedItem, err := cache.GetOrCreate(uniqueQueryID, executionStateCacheItem) if err != nil { // This means that our cache has fundamentally broken... return a system error logger.Errorf(ctx, "Cache is broken on execution [%s] cache key [%s], owner [%s]. Error %s", - tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueID, + tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetID(), uniqueQueryID, tCtx.TaskExecutionMetadata().GetOwnerReference(), err) return currentState, errors.Wrapf(errors.CacheFailed, err, "Error when GetOrCreate while monitoring") } @@ -420,6 +438,32 @@ func MonitorQuery( return cachedExecutionState.ExecutionState, nil } +func writeOutput(ctx context.Context, tCtx core.TaskExecutionContext, externalLocation string) error { + taskTemplate, err := tCtx.TaskReader().Read(ctx) + if err != nil { + return err + } + + results := taskTemplate.Interface.Outputs.Variables["results"] + + return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader( + &pb.LiteralMap{ + Literals: map[string]*pb.Literal{ + "results": { + Value: &pb.Literal_Scalar{ + Scalar: &pb.Scalar{Value: &pb.Scalar_Schema{ + Schema: &pb.Schema{ + Uri: externalLocation, + Type: results.GetType().GetSchema(), + }, + }, + }, + }, + }, + }, + }, nil)) +} + // The 'PhaseInfoRunning' occurs 15 times (3 for each of the 5 Presto queries that get run for every Presto task) which // are differentiated by the version (1-15) func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { @@ -438,7 +482,7 @@ func MapExecutionStateToPhaseInfo(state ExecutionState) core.PhaseInfo { case PhaseSubmitted: phaseInfo = core.PhaseInfoRunning(uint32(3*state.QueryCount+2), ConstructTaskInfo(state)) case PhaseQuerySucceeded: - if state.QueryCount < 4 { + if state.QueryCount < 5 { phaseInfo = core.PhaseInfoRunning(uint32(3*state.QueryCount+3), ConstructTaskInfo(state)) } else { phaseInfo = core.PhaseInfoSuccess(ConstructTaskInfo(state)) diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index 01616f6d2..d36579434 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -300,7 +300,7 @@ func TestMonitorQuery(t *testing.T) { } var getOrCreateCalled = false mockCache := &mocks2.AutoRefresh{} - mockCache.OnGetOrCreateMatch("my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", mock.Anything).Return(ExecutionStateCacheItem{ + mockCache.OnGetOrCreateMatch(mock.AnythingOfType("string"), mock.Anything).Return(ExecutionStateCacheItem{ ExecutionState: ExecutionState{Phase: PhaseQuerySucceeded}, Identifier: "my_wf_exec_project:my_wf_exec_domain:my_wf_exec_name", }, nil).Run(func(_ mock.Arguments) { diff --git a/go/tasks/plugins/presto/executions_cache.go b/go/tasks/plugins/presto/executions_cache.go index 47ac57804..e56560314 100644 --- a/go/tasks/plugins/presto/executions_cache.go +++ b/go/tasks/plugins/presto/executions_cache.go @@ -2,7 +2,6 @@ package presto import ( "context" - "time" "k8s.io/client-go/util/workqueue" @@ -18,8 +17,6 @@ import ( "github.com/lyft/flytestdlib/promutils" ) -const ResyncDuration = 3 * time.Second - const ( BadPrestoReturnCodeError stdErrors.ErrorCode = "PRESTO_RETURNED_UNKNOWN" ) diff --git a/go/tasks/plugins/presto/executor.go b/go/tasks/plugins/presto/executor.go index e52ec8842..745ac7817 100644 --- a/go/tasks/plugins/presto/executor.go +++ b/go/tasks/plugins/presto/executor.go @@ -16,7 +16,7 @@ import ( ) // This is the name of this plugin effectively. In Flyte plugin configuration, use this string to enable this plugin. -const prestoExecutorID = "presto-executor" +const prestoPluginID = "presto" // Version of the custom state this plugin stores. Useful for backwards compatibility if you one day need to update // the structure of the stored state @@ -137,7 +137,7 @@ func NewPrestoExecutor( } return Executor{ - id: prestoExecutorID, + id: prestoPluginID, cfg: cfg, metrics: getPrestoExecutorMetrics(scope), prestoClient: prestoClient, @@ -148,7 +148,7 @@ func NewPrestoExecutor( func init() { pluginMachinery.PluginRegistry().RegisterCorePlugin( core.PluginEntry{ - ID: prestoExecutorID, + ID: prestoPluginID, RegisteredTaskTypes: []core.TaskType{prestoTaskType}, LoadPlugin: ExecutorLoader, IsDefault: false, diff --git a/go/tasks/plugins/presto/executor_metrics.go b/go/tasks/plugins/presto/executor_metrics.go index 235a840eb..bb1fca77b 100644 --- a/go/tasks/plugins/presto/executor_metrics.go +++ b/go/tasks/plugins/presto/executor_metrics.go @@ -21,13 +21,14 @@ var ( func getPrestoExecutorMetrics(scope promutils.Scope) ExecutorMetrics { return ExecutorMetrics{ Scope: scope, - ReleaseResourceFailed: labeled.NewCounter("released_resource_failed", - "Error releasing allocation token", scope), - AllocationGranted: labeled.NewCounter("allocation_granted", - "Allocation request granted", scope), - AllocationNotGranted: labeled.NewCounter("allocation_not_granted", - "Allocation request did not fail but not granted", scope), - ResourceWaitTime: scope.MustNewSummaryWithOptions("resource_wait_time", "Duration the execution has been waiting for a resource allocation token", + ReleaseResourceFailed: labeled.NewCounter("presto_released_resource_failed", + "Error releasing allocation token for Presto", scope), + AllocationGranted: labeled.NewCounter("presto_allocation_granted", + "Allocation request granted for Presto", scope), + AllocationNotGranted: labeled.NewCounter("presto_allocation_not_granted", + "Allocation request did not fail but not granted for Presto", scope), + ResourceWaitTime: scope.MustNewSummaryWithOptions("presto_resource_wait_time", + "Duration the execution has been waiting for a resource allocation token for Presto", promutils.SummaryOptions{Objectives: tokenAgeObjectives}), } } From a3810acb0c28da7a94d10618bc0e13081034ba8e Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Wed, 18 Mar 2020 15:51:36 -0700 Subject: [PATCH 19/26] update sync period --- go/tasks/plugins/presto/config/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go/tasks/plugins/presto/config/config.go b/go/tasks/plugins/presto/config/config.go index 37cbfe5ba..ecc092bd2 100644 --- a/go/tasks/plugins/presto/config/config.go +++ b/go/tasks/plugins/presto/config/config.go @@ -49,7 +49,7 @@ var ( RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, RateLimiter: RateLimiter{ Name: "presto", - SyncPeriod: config.Duration{Duration: 3 * time.Second}, + SyncPeriod: config.Duration{Duration: 5 * time.Second}, Workers: 15, LruCacheSize: 2000, MetricScope: "presto", From 8a1954afb637d79f2f16fed1fcc66e0be6ad5d56 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Thu, 19 Mar 2020 18:02:41 -0700 Subject: [PATCH 20/26] add input interpolator --- go/tasks/plugins/presto/execution_state.go | 12 ++ .../presto/utils/input_interpolator.go | 137 ++++++++++++++++++ 2 files changed, 149 insertions(+) create mode 100644 go/tasks/plugins/presto/utils/input_interpolator.go diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 4b900c9d9..f1fad5796 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -4,6 +4,8 @@ import ( "context" "strings" + presto "github.com/lyft/flyteplugins/go/tasks/plugins/presto/utils" + "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" "k8s.io/apimachinery/pkg/util/rand" @@ -289,6 +291,16 @@ func GetNextQuery( return Query{}, err } + inputs, err := tCtx.InputReader().Get(ctx) + if err != nil { + return Query{}, err + } + + statement, routingGroup, catalog, schema, err = presto.InterpolateInputs(ctx, *inputs, statement, routingGroup, catalog, schema) + if err != nil { + return Query{}, err + } + statement = fmt.Sprintf(`CREATE TABLE hive.flyte_temporary_tables."%s_temp" AS %s`, tempTableName, statement) prestoQuery := Query{ diff --git a/go/tasks/plugins/presto/utils/input_interpolator.go b/go/tasks/plugins/presto/utils/input_interpolator.go new file mode 100644 index 000000000..f9d27f8c9 --- /dev/null +++ b/go/tasks/plugins/presto/utils/input_interpolator.go @@ -0,0 +1,137 @@ +package presto + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strings" + + "github.com/golang/protobuf/ptypes" + pb "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" + "github.com/lyft/flytestdlib/logger" +) + +// Matches any pair of open/close mustaches that contains a variable inside (e.g. `{{ abc }}`) +var inputMustacheRegex = regexp.MustCompile(`{{\s*[^\s]+\s*}}`) + +// Matches the variable content inside a pair of double mustaches except for spaces and the mustaches themselves +var inputVarNameRegex = regexp.MustCompile(`([^{{\s}}]+)`) + +// This will interpolate any input variable mustaches that were assigned to the statement, routingGroup, catalog, and +// schema with the values of the input variables provided for the user. +// +// For example, if we have a Presto task like: +// +// presto_task = SdkPrestoTask( +// query="SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ds}}' LIMIT 10", +// output_schema=schema.schema_type, +// routing_group="{{ routing_group }}", +// catalog="hive", +// schema="city", +// task_inputs=inputs(ds=Types.String, routing_group=Types.String), +// ) +// +// Then this function takes care of replace '{{ds}}' and '{{ routing_group }}' with the appropriate input values. +func InterpolateInputs( + ctx context.Context, + inputs pb.LiteralMap, + statement string, + routingGroup string, + catalog string, + schema string) (string, string, string, string, error) { + + inputsAsStrings, err := literalMapToStringMap(ctx, inputs) + if err != nil { + return "", "", "", "", err + } + + // Remove implicit inputs from the rest of the inputs used for interpolation + delete(inputsAsStrings, "implicit_routing_group") + delete(inputsAsStrings, "implicit_catalog") + delete(inputsAsStrings, "implicit_schema") + + statement = interpolate(inputsAsStrings, statement) + routingGroup = interpolate(inputsAsStrings, routingGroup) + catalog = interpolate(inputsAsStrings, catalog) + schema = interpolate(inputsAsStrings, schema) + + return statement, routingGroup, catalog, schema, nil +} + +func interpolate(inputs map[string]string, s string) string { + mustacheInputs := inputMustacheRegex.FindAllString(s, -1) + for _, inputMustache := range mustacheInputs { + inputVarName := inputVarNameRegex.FindString(inputMustache) + inputVarReplacement := inputs[inputVarName] + s = strings.Replace(s, inputMustache, inputVarReplacement, -1) + } + + return s +} + +func literalMapToStringMap(ctx context.Context, literalMap pb.LiteralMap) (map[string]string, error) { + stringMap := map[string]string{} + + for k, v := range literalMap.Literals { + serializedLiteral, err := serializeLiteral(ctx, v) + if err != nil { + return nil, err + } + stringMap[k] = serializedLiteral + } + + return stringMap, nil +} + +func serializeLiteral(ctx context.Context, l *pb.Literal) (string, error) { + switch o := l.Value.(type) { + case *pb.Literal_Collection: + res := make([]string, 0, len(o.Collection.Literals)) + for _, sub := range o.Collection.Literals { + s, err := serializeLiteral(ctx, sub) + if err != nil { + return "", err + } + + res = append(res, s) + } + + return fmt.Sprintf("[%v]", strings.Join(res, ",")), nil + case *pb.Literal_Scalar: + return serializeLiteralScalar(o.Scalar) + default: + logger.Debugf(ctx, "received unexpected primitive type") + return "", fmt.Errorf("received an unexpected primitive type [%v]", reflect.TypeOf(l.Value)) + } +} + +func serializeLiteralScalar(l *pb.Scalar) (string, error) { + switch o := l.Value.(type) { + case *pb.Scalar_Primitive: + return serializePrimitive(o.Primitive) + case *pb.Scalar_Blob: + return o.Blob.Uri, nil + default: + return "", fmt.Errorf("received an unexpected scalar type [%v]", reflect.TypeOf(l.Value)) + } +} + +func serializePrimitive(p *pb.Primitive) (string, error) { + switch o := p.Value.(type) { + case *pb.Primitive_Integer: + return fmt.Sprintf("%v", o.Integer), nil + case *pb.Primitive_Boolean: + return fmt.Sprintf("%v", o.Boolean), nil + case *pb.Primitive_Datetime: + return ptypes.TimestampString(o.Datetime), nil + case *pb.Primitive_Duration: + return o.Duration.String(), nil + case *pb.Primitive_FloatValue: + return fmt.Sprintf("%v", o.FloatValue), nil + case *pb.Primitive_StringValue: + return o.StringValue, nil + default: + return "", fmt.Errorf("received an unexpected primitive type [%v]", reflect.TypeOf(p.Value)) + } +} From 859d50f0ff6815530a01d1a64b375dc49e3e6637 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Thu, 19 Mar 2020 18:10:27 -0700 Subject: [PATCH 21/26] changes --- go/tasks/pluginmachinery/core/exec_context.go | 3 --- go/tasks/plugins/presto/execution_state.go | 20 +++++++++---------- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/go/tasks/pluginmachinery/core/exec_context.go b/go/tasks/pluginmachinery/core/exec_context.go index 0545f008f..4212320c2 100644 --- a/go/tasks/pluginmachinery/core/exec_context.go +++ b/go/tasks/pluginmachinery/core/exec_context.go @@ -57,9 +57,6 @@ type TaskExecutionContext interface { // Returns a handle to the Task events recorder, which get stored in the Admin. EventsRecorder() EventsRecorder - - //// Returns a reference to a data location - //DataLocation() io.OutputWriter } // A simple fire-and-forget func diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index f1fad5796..ba15ea126 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -226,6 +226,16 @@ func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( schema = prestoQuery.Schema statement = prestoQuery.Statement + inputs, err := tCtx.InputReader().Get(ctx) + if err != nil { + return "", "", "", "", err + } + + statement, routingGroup, catalog, schema, err = presto.InterpolateInputs(ctx, *inputs, statement, routingGroup, catalog, schema) + if err != nil { + return "", "", "", "", err + } + logger.Debugf(ctx, "QueryInfo: query: [%v], routingGroup: [%v], catalog: [%v], schema: [%v]", statement, routingGroup, catalog, schema) return } @@ -291,16 +301,6 @@ func GetNextQuery( return Query{}, err } - inputs, err := tCtx.InputReader().Get(ctx) - if err != nil { - return Query{}, err - } - - statement, routingGroup, catalog, schema, err = presto.InterpolateInputs(ctx, *inputs, statement, routingGroup, catalog, schema) - if err != nil { - return Query{}, err - } - statement = fmt.Sprintf(`CREATE TABLE hive.flyte_temporary_tables."%s_temp" AS %s`, tempTableName, statement) prestoQuery := Query{ From 177e7f9fac970ad35c0c77da17f0e262385ebf59 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Thu, 19 Mar 2020 22:17:02 -0700 Subject: [PATCH 22/26] changes 2 --- go/tasks/plugins/presto/execution_state.go | 23 ++++++++---------- .../plugins/presto/execution_state_test.go | 24 ------------------- go/tasks/plugins/presto/helpers_test.go | 4 ++-- .../presto/utils/input_interpolator.go | 6 ++--- 4 files changed, 15 insertions(+), 42 deletions(-) diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index ba15ea126..459913583 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -200,13 +200,7 @@ func composeResourceNamespaceWithRoutingGroup(ctx context.Context, tCtx core.Tas // This function is the link between the output written by the SDK, and the execution side. It extracts the query // out of the task template. -func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( - routingGroup string, - catalog string, - schema string, - statement string, - err error) { - +func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) (string, string, string, string, error) { taskTemplate, err := tCtx.TaskReader().Read(ctx) if err != nil { return "", "", "", "", err @@ -221,23 +215,23 @@ func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) ( return "", "", "", "", err } - routingGroup = prestoQuery.RoutingGroup - catalog = prestoQuery.Catalog - schema = prestoQuery.Schema - statement = prestoQuery.Statement + routingGroup := prestoQuery.RoutingGroup + catalog := prestoQuery.Catalog + schema := prestoQuery.Schema + statement := prestoQuery.Statement inputs, err := tCtx.InputReader().Get(ctx) if err != nil { return "", "", "", "", err } - statement, routingGroup, catalog, schema, err = presto.InterpolateInputs(ctx, *inputs, statement, routingGroup, catalog, schema) + routingGroup, catalog, schema, statement, err = presto.InterpolateInputs(ctx, *inputs, routingGroup, catalog, schema, statement) if err != nil { return "", "", "", "", err } logger.Debugf(ctx, "QueryInfo: query: [%v], routingGroup: [%v], catalog: [%v], schema: [%v]", statement, routingGroup, catalog, schema) - return + return routingGroup, catalog, schema, statement, err } func validatePrestoStatement(prestoJob plugins.PrestoQuery) error { @@ -380,6 +374,9 @@ func KickOffQuery( prestoClient client.PrestoClient, cache cache.AutoRefresh) (ExecutionState, error) { + // For the caching id, we can't rely simply on the task execution id since we have to run 5 consecutive queries and + // the ids used for each of these has to be unique. Because of this, we append a random postfix to the task + // execution id. uniqueID := tCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() + "_" + rand.String(32) statement := currentState.CurrentPrestoQuery.Statement diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index d36579434..2ca5ed2b1 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -22,7 +22,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" - pluginsCoreMocks "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/core/mocks" ) func init() { @@ -71,29 +70,6 @@ func TestIsNotYetSubmitted(t *testing.T) { } } -func TestGetQueryInfo(t *testing.T) { - ctx := context.Background() - - taskTemplate := GetSingleHiveQueryTaskTemplate() - mockTaskReader := &mocks.TaskReader{} - mockTaskReader.On("Read", mock.Anything).Return(&taskTemplate, nil) - - mockTaskExecutionContext := mocks.TaskExecutionContext{} - mockTaskExecutionContext.On("TaskReader").Return(mockTaskReader) - - taskMetadata := &pluginsCoreMocks.TaskExecutionMetadata{} - taskMetadata.On("GetNamespace").Return("myproject-staging") - taskMetadata.On("GetLabels").Return(map[string]string{"sample": "label"}) - mockTaskExecutionContext.On("TaskExecutionMetadata").Return(taskMetadata) - - routingGroup, catalog, schema, statement, err := GetQueryInfo(ctx, &mockTaskExecutionContext) - assert.NoError(t, err) - assert.Equal(t, "adhoc", routingGroup) - assert.Equal(t, "hive", catalog) - assert.Equal(t, "city", schema) - assert.Equal(t, "select * from hive.city.fact_airport_sessions limit 10", statement) -} - func TestValidatePrestoStatement(t *testing.T) { prestoQuery := plugins.PrestoQuery{ RoutingGroup: "adhoc", diff --git a/go/tasks/plugins/presto/helpers_test.go b/go/tasks/plugins/presto/helpers_test.go index 22509b1ed..300b78e14 100644 --- a/go/tasks/plugins/presto/helpers_test.go +++ b/go/tasks/plugins/presto/helpers_test.go @@ -16,7 +16,7 @@ import ( "k8s.io/apimachinery/pkg/types" ) -func GetSingleHiveQueryTaskTemplate() idlCore.TaskTemplate { +func GetPrestoQueryTaskTemplate() idlCore.TaskTemplate { prestoQuery := plugins.PrestoQuery{ RoutingGroup: "adhoc", Catalog: "hive", @@ -82,7 +82,7 @@ func GetMockTaskExecutionMetadata() core.TaskExecutionMetadata { } func GetMockTaskExecutionContext() core.TaskExecutionContext { - tt := GetSingleHiveQueryTaskTemplate() + tt := GetPrestoQueryTaskTemplate() dummyTaskMetadata := GetMockTaskExecutionMetadata() taskCtx := &coreMock.TaskExecutionContext{} diff --git a/go/tasks/plugins/presto/utils/input_interpolator.go b/go/tasks/plugins/presto/utils/input_interpolator.go index f9d27f8c9..8e5ea9bcd 100644 --- a/go/tasks/plugins/presto/utils/input_interpolator.go +++ b/go/tasks/plugins/presto/utils/input_interpolator.go @@ -36,10 +36,10 @@ var inputVarNameRegex = regexp.MustCompile(`([^{{\s}}]+)`) func InterpolateInputs( ctx context.Context, inputs pb.LiteralMap, - statement string, routingGroup string, catalog string, - schema string) (string, string, string, string, error) { + schema string, + statement string) (string, string, string, string, error) { inputsAsStrings, err := literalMapToStringMap(ctx, inputs) if err != nil { @@ -56,7 +56,7 @@ func InterpolateInputs( catalog = interpolate(inputsAsStrings, catalog) schema = interpolate(inputsAsStrings, schema) - return statement, routingGroup, catalog, schema, nil + return routingGroup, catalog, schema, statement, nil } func interpolate(inputs map[string]string, s string) string { From 76d0e965889564938a821517fc88d790c348e982 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Fri, 20 Mar 2020 12:00:16 -0700 Subject: [PATCH 23/26] prefix implicit inputs with __ --- go/tasks/plugins/presto/utils/input_interpolator.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/go/tasks/plugins/presto/utils/input_interpolator.go b/go/tasks/plugins/presto/utils/input_interpolator.go index 8e5ea9bcd..c3b039706 100644 --- a/go/tasks/plugins/presto/utils/input_interpolator.go +++ b/go/tasks/plugins/presto/utils/input_interpolator.go @@ -47,9 +47,9 @@ func InterpolateInputs( } // Remove implicit inputs from the rest of the inputs used for interpolation - delete(inputsAsStrings, "implicit_routing_group") - delete(inputsAsStrings, "implicit_catalog") - delete(inputsAsStrings, "implicit_schema") + delete(inputsAsStrings, "__implicit_routing_group") + delete(inputsAsStrings, "__implicit_catalog") + delete(inputsAsStrings, "__implicit_schema") statement = interpolate(inputsAsStrings, statement) routingGroup = interpolate(inputsAsStrings, routingGroup) From 7a5f2d60fbdf7d9948d1e3e96f1f0bed5a978750 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Fri, 20 Mar 2020 12:00:51 -0700 Subject: [PATCH 24/26] comments --- go/tasks/plugins/presto/client/presto_status.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/go/tasks/plugins/presto/client/presto_status.go b/go/tasks/plugins/presto/client/presto_status.go index 8979b953e..d5e772f6c 100644 --- a/go/tasks/plugins/presto/client/presto_status.go +++ b/go/tasks/plugins/presto/client/presto_status.go @@ -29,6 +29,9 @@ var PrestoStatuses = map[PrestoStatus]struct{}{ func NewPrestoStatus(ctx context.Context, state string) PrestoStatus { upperCased := strings.ToUpper(state) + + // Presto has different failure modes so this maps them all to a single Failure on the + // Flyte side if strings.Contains(upperCased, "FAILED") { return PrestoStatusFailed } else if _, ok := PrestoStatuses[PrestoStatus(upperCased)]; ok { From 78ac9ff993f7b60eac250755aabe19dfdfb48264 Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Mon, 23 Mar 2020 17:29:05 -0700 Subject: [PATCH 25/26] more feedback --- .../plugins/presto/client/presto_client.go | 6 +- go/tasks/plugins/presto/config/config.go | 12 +- .../plugins/presto/config/config_flags.go | 9 +- .../presto/config/config_flags_test.go | 72 ++++----- go/tasks/plugins/presto/execution_state.go | 22 ++- .../plugins/presto/execution_state_test.go | 3 +- go/tasks/plugins/presto/executions_cache.go | 2 +- go/tasks/plugins/presto/helpers_test.go | 1 + .../presto/utils/input_interpolator.go | 137 ------------------ 9 files changed, 50 insertions(+), 214 deletions(-) delete mode 100644 go/tasks/plugins/presto/utils/input_interpolator.go diff --git a/go/tasks/plugins/presto/client/presto_client.go b/go/tasks/plugins/presto/client/presto_client.go index 34bb32c4d..fb8812e2c 100644 --- a/go/tasks/plugins/presto/client/presto_client.go +++ b/go/tasks/plugins/presto/client/presto_client.go @@ -15,9 +15,9 @@ type PrestoExecuteArgs struct { // Representation of a response after submitting a query to Presto type PrestoExecuteResponse struct { - ID string - Status PrestoStatus - NextURI string + ID string `json:"id,omitempty"` + Status PrestoStatus `json:"status,omitempty"` + NextURI string `json:"nextUri,omitempty"` } //go:generate mockery -all -case=snake diff --git a/go/tasks/plugins/presto/config/config.go b/go/tasks/plugins/presto/config/config.go index ecc092bd2..21f2ff917 100644 --- a/go/tasks/plugins/presto/config/config.go +++ b/go/tasks/plugins/presto/config/config.go @@ -33,26 +33,24 @@ type RoutingGroupConfig struct { NamespaceScopeQuotaProportionCap float64 `json:"namespaceScopeQuotaProportionCap" pflag:",A floating point number between 0 and 1, specifying the maximum proportion of quotas allowed to allocate to a namespace in the routing group"` } -type RateLimiter struct { +type RefreshCacheConfig struct { Name string `json:"name" pflag:",The name of the rate limiter"` SyncPeriod config.Duration `json:"syncPeriod" pflag:",The duration to wait before the cache is refreshed again"` Workers int `json:"workers" pflag:",Number of parallel workers to refresh the cache"` LruCacheSize int `json:"lruCacheSize" pflag:",Size of the cache"` - MetricScope string `json:"metricScope" pflag:",The prefix in Prometheus used to track metrics related to Presto"` } var ( defaultConfig = Config{ Environment: URLMustParse(""), DefaultRoutingGroup: "adhoc", - DefaultUser: "flyte-default-user@lyft.com", + DefaultUser: "flyte-default-user", RoutingGroupConfigs: []RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, - RateLimiter: RateLimiter{ + RefreshCacheConfig: RefreshCacheConfig{ Name: "presto", SyncPeriod: config.Duration{Duration: 5 * time.Second}, Workers: 15, - LruCacheSize: 2000, - MetricScope: "presto", + LruCacheSize: 10000, }, } @@ -65,7 +63,7 @@ type Config struct { DefaultRoutingGroup string `json:"defaultRoutingGroup" pflag:",Default Presto routing group"` DefaultUser string `json:"defaultUser" pflag:",Default Presto user"` RoutingGroupConfigs []RoutingGroupConfig `json:"routingGroupConfigs" pflag:"-,A list of cluster configs. Each of the configs corresponds to a service cluster"` - RateLimiter RateLimiter `json:"rateLimiter" pflag:"Rate limiter config"` + RefreshCacheConfig RefreshCacheConfig `json:"refreshCacheConfig" pflag:"Rate limiter config"` } // Retrieves the current config value or default. diff --git a/go/tasks/plugins/presto/config/config_flags.go b/go/tasks/plugins/presto/config/config_flags.go index 1a0a2e44d..c4200c63c 100755 --- a/go/tasks/plugins/presto/config/config_flags.go +++ b/go/tasks/plugins/presto/config/config_flags.go @@ -44,10 +44,9 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.String(fmt.Sprintf("%v%v", prefix, "environment"), defaultConfig.Environment.String(), "Environment endpoint for Presto to use") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultRoutingGroup"), defaultConfig.DefaultRoutingGroup, "Default Presto routing group") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "defaultUser"), defaultConfig.DefaultUser, "Default Presto user") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.name"), defaultConfig.RateLimiter.Name, "The name of the rate limiter") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.syncPeriod"), defaultConfig.RateLimiter.SyncPeriod.String(), "The duration to wait before the cache is refreshed again") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "rateLimiter.workers"), defaultConfig.RateLimiter.Workers, "Number of parallel workers to refresh the cache") - cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "rateLimiter.lruCacheSize"), defaultConfig.RateLimiter.LruCacheSize, "Size of the cache") - cmdFlags.String(fmt.Sprintf("%v%v", prefix, "rateLimiter.metricScope"), defaultConfig.RateLimiter.MetricScope, "The prefix in Prometheus used to track metrics related to Presto") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.name"), defaultConfig.RefreshCacheConfig.Name, "The name of the rate limiter") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.syncPeriod"), defaultConfig.RefreshCacheConfig.SyncPeriod.String(), "The duration to wait before the cache is refreshed again") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.workers"), defaultConfig.RefreshCacheConfig.Workers, "Number of parallel workers to refresh the cache") + cmdFlags.Int(fmt.Sprintf("%v%v", prefix, "refreshCacheConfig.lruCacheSize"), defaultConfig.RefreshCacheConfig.LruCacheSize, "Size of the cache") return cmdFlags } diff --git a/go/tasks/plugins/presto/config/config_flags_test.go b/go/tasks/plugins/presto/config/config_flags_test.go index 21df51a13..00820c7be 100755 --- a/go/tasks/plugins/presto/config/config_flags_test.go +++ b/go/tasks/plugins/presto/config/config_flags_test.go @@ -165,11 +165,11 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) - t.Run("Test_rateLimiter.name", func(t *testing.T) { + t.Run("Test_refreshCacheConfig.name", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vString, err := cmdFlags.GetString("rateLimiter.name"); err == nil { - assert.Equal(t, string(defaultConfig.RateLimiter.Name), vString) + if vString, err := cmdFlags.GetString("refreshCacheConfig.name"); err == nil { + assert.Equal(t, string(defaultConfig.RefreshCacheConfig.Name), vString) } else { assert.FailNow(t, err.Error()) } @@ -178,42 +178,42 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("rateLimiter.name", testValue) - if vString, err := cmdFlags.GetString("rateLimiter.name"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RateLimiter.Name) + cmdFlags.Set("refreshCacheConfig.name", testValue) + if vString, err := cmdFlags.GetString("refreshCacheConfig.name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RefreshCacheConfig.Name) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_rateLimiter.syncPeriod", func(t *testing.T) { + t.Run("Test_refreshCacheConfig.syncPeriod", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vString, err := cmdFlags.GetString("rateLimiter.syncPeriod"); err == nil { - assert.Equal(t, string(defaultConfig.RateLimiter.SyncPeriod.String()), vString) + if vString, err := cmdFlags.GetString("refreshCacheConfig.syncPeriod"); err == nil { + assert.Equal(t, string(defaultConfig.RefreshCacheConfig.SyncPeriod.String()), vString) } else { assert.FailNow(t, err.Error()) } }) t.Run("Override", func(t *testing.T) { - testValue := defaultConfig.RateLimiter.SyncPeriod.String() + testValue := defaultConfig.RefreshCacheConfig.SyncPeriod.String() - cmdFlags.Set("rateLimiter.syncPeriod", testValue) - if vString, err := cmdFlags.GetString("rateLimiter.syncPeriod"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RateLimiter.SyncPeriod) + cmdFlags.Set("refreshCacheConfig.syncPeriod", testValue) + if vString, err := cmdFlags.GetString("refreshCacheConfig.syncPeriod"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RefreshCacheConfig.SyncPeriod) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_rateLimiter.workers", func(t *testing.T) { + t.Run("Test_refreshCacheConfig.workers", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("rateLimiter.workers"); err == nil { - assert.Equal(t, int(defaultConfig.RateLimiter.Workers), vInt) + if vInt, err := cmdFlags.GetInt("refreshCacheConfig.workers"); err == nil { + assert.Equal(t, int(defaultConfig.RefreshCacheConfig.Workers), vInt) } else { assert.FailNow(t, err.Error()) } @@ -222,20 +222,20 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("rateLimiter.workers", testValue) - if vInt, err := cmdFlags.GetInt("rateLimiter.workers"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RateLimiter.Workers) + cmdFlags.Set("refreshCacheConfig.workers", testValue) + if vInt, err := cmdFlags.GetInt("refreshCacheConfig.workers"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RefreshCacheConfig.Workers) } else { assert.FailNow(t, err.Error()) } }) }) - t.Run("Test_rateLimiter.lruCacheSize", func(t *testing.T) { + t.Run("Test_refreshCacheConfig.lruCacheSize", func(t *testing.T) { t.Run("DefaultValue", func(t *testing.T) { // Test that default value is set properly - if vInt, err := cmdFlags.GetInt("rateLimiter.lruCacheSize"); err == nil { - assert.Equal(t, int(defaultConfig.RateLimiter.LruCacheSize), vInt) + if vInt, err := cmdFlags.GetInt("refreshCacheConfig.lruCacheSize"); err == nil { + assert.Equal(t, int(defaultConfig.RefreshCacheConfig.LruCacheSize), vInt) } else { assert.FailNow(t, err.Error()) } @@ -244,31 +244,9 @@ func TestConfig_SetFlags(t *testing.T) { t.Run("Override", func(t *testing.T) { testValue := "1" - cmdFlags.Set("rateLimiter.lruCacheSize", testValue) - if vInt, err := cmdFlags.GetInt("rateLimiter.lruCacheSize"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RateLimiter.LruCacheSize) - - } else { - assert.FailNow(t, err.Error()) - } - }) - }) - t.Run("Test_rateLimiter.metricScope", func(t *testing.T) { - t.Run("DefaultValue", func(t *testing.T) { - // Test that default value is set properly - if vString, err := cmdFlags.GetString("rateLimiter.metricScope"); err == nil { - assert.Equal(t, string(defaultConfig.RateLimiter.MetricScope), vString) - } else { - assert.FailNow(t, err.Error()) - } - }) - - t.Run("Override", func(t *testing.T) { - testValue := "1" - - cmdFlags.Set("rateLimiter.metricScope", testValue) - if vString, err := cmdFlags.GetString("rateLimiter.metricScope"); err == nil { - testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RateLimiter.MetricScope) + cmdFlags.Set("refreshCacheConfig.lruCacheSize", testValue) + if vInt, err := cmdFlags.GetInt("refreshCacheConfig.lruCacheSize"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vInt), &actual.RefreshCacheConfig.LruCacheSize) } else { assert.FailNow(t, err.Error()) diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 459913583..9dd0052c3 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -4,8 +4,6 @@ import ( "context" "strings" - presto "github.com/lyft/flyteplugins/go/tasks/plugins/presto/utils" - "github.com/lyft/flyteplugins/go/tasks/pluginmachinery/ioutils" "k8s.io/apimachinery/pkg/util/rand" @@ -215,20 +213,20 @@ func GetQueryInfo(ctx context.Context, tCtx core.TaskExecutionContext) (string, return "", "", "", "", err } - routingGroup := prestoQuery.RoutingGroup - catalog := prestoQuery.Catalog - schema := prestoQuery.Schema - statement := prestoQuery.Statement - - inputs, err := tCtx.InputReader().Get(ctx) + outputs, err := utils.ReplaceTemplateCommandArgs(ctx, []string{ + prestoQuery.RoutingGroup, + prestoQuery.Catalog, + prestoQuery.Schema, + prestoQuery.Statement, + }, tCtx.InputReader(), tCtx.OutputWriter()) if err != nil { return "", "", "", "", err } - routingGroup, catalog, schema, statement, err = presto.InterpolateInputs(ctx, *inputs, routingGroup, catalog, schema, statement) - if err != nil { - return "", "", "", "", err - } + routingGroup := outputs[0] + catalog := outputs[1] + schema := outputs[2] + statement := outputs[3] logger.Debugf(ctx, "QueryInfo: query: [%v], routingGroup: [%v], catalog: [%v], schema: [%v]", statement, routingGroup, catalog, schema) return routingGroup, catalog, schema, statement, err diff --git a/go/tasks/plugins/presto/execution_state_test.go b/go/tasks/plugins/presto/execution_state_test.go index 2ca5ed2b1..2f0facb68 100644 --- a/go/tasks/plugins/presto/execution_state_test.go +++ b/go/tasks/plugins/presto/execution_state_test.go @@ -324,12 +324,11 @@ func createMockPrestoCfg() *config.Config { Environment: config.URLMustParse(""), DefaultRoutingGroup: "adhoc", RoutingGroupConfigs: []config.RoutingGroupConfig{{Name: "adhoc", Limit: 250}, {Name: "etl", Limit: 100}}, - RateLimiter: config.RateLimiter{ + RefreshCacheConfig: config.RefreshCacheConfig{ Name: "presto", SyncPeriod: stdConfig.Duration{Duration: 3 * time.Second}, Workers: 15, LruCacheSize: 2000, - MetricScope: "presto", }, } } diff --git a/go/tasks/plugins/presto/executions_cache.go b/go/tasks/plugins/presto/executions_cache.go index e56560314..a26c6b3a2 100644 --- a/go/tasks/plugins/presto/executions_cache.go +++ b/go/tasks/plugins/presto/executions_cache.go @@ -39,7 +39,7 @@ func NewPrestoExecutionsCache( scope: scope, cfg: cfg, } - autoRefreshCache, err := cache.NewAutoRefreshCache(cfg.RateLimiter.Name, q.SyncPrestoQuery, workqueue.DefaultControllerRateLimiter(), cfg.RateLimiter.SyncPeriod.Duration, cfg.RateLimiter.Workers, cfg.RateLimiter.LruCacheSize, scope.NewSubScope(cfg.RateLimiter.MetricScope)) + autoRefreshCache, err := cache.NewAutoRefreshCache(cfg.RefreshCacheConfig.Name, q.SyncPrestoQuery, workqueue.DefaultControllerRateLimiter(), cfg.RefreshCacheConfig.SyncPeriod.Duration, cfg.RefreshCacheConfig.Workers, cfg.RefreshCacheConfig.LruCacheSize, scope) if err != nil { logger.Errorf(ctx, "Could not create AutoRefreshCache in Executor. [%s]", err) return q, errors.Wrapf(errors.CacheFailed, err, "Error creating AutoRefreshCache") diff --git a/go/tasks/plugins/presto/helpers_test.go b/go/tasks/plugins/presto/helpers_test.go index 300b78e14..f5bc6a235 100644 --- a/go/tasks/plugins/presto/helpers_test.go +++ b/go/tasks/plugins/presto/helpers_test.go @@ -89,6 +89,7 @@ func GetMockTaskExecutionContext() core.TaskExecutionContext { inputReader := &ioMock.InputReader{} inputReader.On("GetInputPath").Return(storage.DataReference("test-data-reference")) inputReader.On("Get", mock.Anything).Return(&idlCore.LiteralMap{}, nil) + inputReader.On("GetInputPrefixPath").Return(storage.DataReference("/data")) taskCtx.On("InputReader").Return(inputReader) outputReader := &ioMock.OutputWriter{} diff --git a/go/tasks/plugins/presto/utils/input_interpolator.go b/go/tasks/plugins/presto/utils/input_interpolator.go deleted file mode 100644 index c3b039706..000000000 --- a/go/tasks/plugins/presto/utils/input_interpolator.go +++ /dev/null @@ -1,137 +0,0 @@ -package presto - -import ( - "context" - "fmt" - "reflect" - "regexp" - "strings" - - "github.com/golang/protobuf/ptypes" - pb "github.com/lyft/flyteidl/gen/pb-go/flyteidl/core" - "github.com/lyft/flytestdlib/logger" -) - -// Matches any pair of open/close mustaches that contains a variable inside (e.g. `{{ abc }}`) -var inputMustacheRegex = regexp.MustCompile(`{{\s*[^\s]+\s*}}`) - -// Matches the variable content inside a pair of double mustaches except for spaces and the mustaches themselves -var inputVarNameRegex = regexp.MustCompile(`([^{{\s}}]+)`) - -// This will interpolate any input variable mustaches that were assigned to the statement, routingGroup, catalog, and -// schema with the values of the input variables provided for the user. -// -// For example, if we have a Presto task like: -// -// presto_task = SdkPrestoTask( -// query="SELECT * FROM hive.city.fact_airport_sessions WHERE ds = '{{ds}}' LIMIT 10", -// output_schema=schema.schema_type, -// routing_group="{{ routing_group }}", -// catalog="hive", -// schema="city", -// task_inputs=inputs(ds=Types.String, routing_group=Types.String), -// ) -// -// Then this function takes care of replace '{{ds}}' and '{{ routing_group }}' with the appropriate input values. -func InterpolateInputs( - ctx context.Context, - inputs pb.LiteralMap, - routingGroup string, - catalog string, - schema string, - statement string) (string, string, string, string, error) { - - inputsAsStrings, err := literalMapToStringMap(ctx, inputs) - if err != nil { - return "", "", "", "", err - } - - // Remove implicit inputs from the rest of the inputs used for interpolation - delete(inputsAsStrings, "__implicit_routing_group") - delete(inputsAsStrings, "__implicit_catalog") - delete(inputsAsStrings, "__implicit_schema") - - statement = interpolate(inputsAsStrings, statement) - routingGroup = interpolate(inputsAsStrings, routingGroup) - catalog = interpolate(inputsAsStrings, catalog) - schema = interpolate(inputsAsStrings, schema) - - return routingGroup, catalog, schema, statement, nil -} - -func interpolate(inputs map[string]string, s string) string { - mustacheInputs := inputMustacheRegex.FindAllString(s, -1) - for _, inputMustache := range mustacheInputs { - inputVarName := inputVarNameRegex.FindString(inputMustache) - inputVarReplacement := inputs[inputVarName] - s = strings.Replace(s, inputMustache, inputVarReplacement, -1) - } - - return s -} - -func literalMapToStringMap(ctx context.Context, literalMap pb.LiteralMap) (map[string]string, error) { - stringMap := map[string]string{} - - for k, v := range literalMap.Literals { - serializedLiteral, err := serializeLiteral(ctx, v) - if err != nil { - return nil, err - } - stringMap[k] = serializedLiteral - } - - return stringMap, nil -} - -func serializeLiteral(ctx context.Context, l *pb.Literal) (string, error) { - switch o := l.Value.(type) { - case *pb.Literal_Collection: - res := make([]string, 0, len(o.Collection.Literals)) - for _, sub := range o.Collection.Literals { - s, err := serializeLiteral(ctx, sub) - if err != nil { - return "", err - } - - res = append(res, s) - } - - return fmt.Sprintf("[%v]", strings.Join(res, ",")), nil - case *pb.Literal_Scalar: - return serializeLiteralScalar(o.Scalar) - default: - logger.Debugf(ctx, "received unexpected primitive type") - return "", fmt.Errorf("received an unexpected primitive type [%v]", reflect.TypeOf(l.Value)) - } -} - -func serializeLiteralScalar(l *pb.Scalar) (string, error) { - switch o := l.Value.(type) { - case *pb.Scalar_Primitive: - return serializePrimitive(o.Primitive) - case *pb.Scalar_Blob: - return o.Blob.Uri, nil - default: - return "", fmt.Errorf("received an unexpected scalar type [%v]", reflect.TypeOf(l.Value)) - } -} - -func serializePrimitive(p *pb.Primitive) (string, error) { - switch o := p.Value.(type) { - case *pb.Primitive_Integer: - return fmt.Sprintf("%v", o.Integer), nil - case *pb.Primitive_Boolean: - return fmt.Sprintf("%v", o.Boolean), nil - case *pb.Primitive_Datetime: - return ptypes.TimestampString(o.Datetime), nil - case *pb.Primitive_Duration: - return o.Duration.String(), nil - case *pb.Primitive_FloatValue: - return fmt.Sprintf("%v", o.FloatValue), nil - case *pb.Primitive_StringValue: - return o.StringValue, nil - default: - return "", fmt.Errorf("received an unexpected primitive type [%v]", reflect.TypeOf(p.Value)) - } -} From 3e87ce4ef7e3a7d7ca8baf5d7e5a7c4a5185478d Mon Sep 17 00:00:00 2001 From: lu4nm3 Date: Tue, 24 Mar 2020 13:13:30 -0700 Subject: [PATCH 26/26] edit metrics --- go/tasks/plugins/presto/execution_state.go | 2 -- go/tasks/plugins/presto/executor_metrics.go | 9 --------- 2 files changed, 11 deletions(-) diff --git a/go/tasks/plugins/presto/execution_state.go b/go/tasks/plugins/presto/execution_state.go index 9dd0052c3..6217b4b21 100644 --- a/go/tasks/plugins/presto/execution_state.go +++ b/go/tasks/plugins/presto/execution_state.go @@ -170,8 +170,6 @@ func GetAllocationToken( } else { newState.AllocationTokenRequestStartTime = currentState.AllocationTokenRequestStartTime } - waitTime := time.Since(newState.AllocationTokenRequestStartTime) - metric.ResourceWaitTime.Observe(waitTime.Seconds()) if allocationStatus == core.AllocationStatusGranted { newState.Phase = PhaseQueued diff --git a/go/tasks/plugins/presto/executor_metrics.go b/go/tasks/plugins/presto/executor_metrics.go index bb1fca77b..69538a757 100644 --- a/go/tasks/plugins/presto/executor_metrics.go +++ b/go/tasks/plugins/presto/executor_metrics.go @@ -3,7 +3,6 @@ package presto import ( "github.com/lyft/flytestdlib/promutils" "github.com/lyft/flytestdlib/promutils/labeled" - "github.com/prometheus/client_golang/prometheus" ) type ExecutorMetrics struct { @@ -11,13 +10,8 @@ type ExecutorMetrics struct { ReleaseResourceFailed labeled.Counter AllocationGranted labeled.Counter AllocationNotGranted labeled.Counter - ResourceWaitTime prometheus.Summary } -var ( - tokenAgeObjectives = map[float64]float64{0.5: 0.05, 0.9: 0.01, 0.99: 0.001, 1.0: 0.0} -) - func getPrestoExecutorMetrics(scope promutils.Scope) ExecutorMetrics { return ExecutorMetrics{ Scope: scope, @@ -27,8 +21,5 @@ func getPrestoExecutorMetrics(scope promutils.Scope) ExecutorMetrics { "Allocation request granted for Presto", scope), AllocationNotGranted: labeled.NewCounter("presto_allocation_not_granted", "Allocation request did not fail but not granted for Presto", scope), - ResourceWaitTime: scope.MustNewSummaryWithOptions("presto_resource_wait_time", - "Duration the execution has been waiting for a resource allocation token for Presto", - promutils.SummaryOptions{Objectives: tokenAgeObjectives}), } }