Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WithParams #2044

Merged
merged 20 commits into from
Sep 17, 2019
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 73 additions & 64 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import inspect
import tarfile
import zipfile
from typing import Set, List, Text, Dict
from typing import Set, List, Text, Dict, Tuple, Any, Union, Optional

import yaml
from kfp.dsl import _container_op, _for_loop
Expand Down Expand Up @@ -339,17 +339,16 @@ def _get_dependencies(self, pipeline, root_group, op_groups, opsgroups_groups, o
upstream_op_names.add(param.op_name)
upstream_op_names |= set(op.dependent_names)

for op_name in upstream_op_names:
for upstream_op_name in upstream_op_names:
# the dependent op could be either a BaseOp or an opsgroup
if op_name in pipeline.ops:
upstream_op = pipeline.ops[op_name]
elif op_name in opsgroups:
upstream_op = opsgroups[op_name]
if upstream_op_name in pipeline.ops:
upstream_op = pipeline.ops[upstream_op_name]
elif upstream_op_name in opsgroups:
upstream_op = opsgroups[upstream_op_name]
else:
raise ValueError('compiler cannot find the ' + op_name)
raise ValueError('compiler cannot find the ' + upstream_op_name)

upstream_groups, downstream_groups = \
self._get_uncommon_ancestors(op_groups, opsgroups_groups, upstream_op, op)
upstream_groups, downstream_groups = self._get_uncommon_ancestors(op_groups, opsgroups_groups, upstream_op, op)
dependencies[downstream_groups[0]].add(upstream_groups[0])

# Generate dependencies based on the recursive opsgroups
Expand Down Expand Up @@ -460,62 +459,32 @@ def _group_to_dag_template(self, group, inputs, outputs, dependencies):

# Generate arguments section for this task.
if inputs.get(sub_group.name, None):
arguments = []
for param_name, dependent_name in inputs[sub_group.name]:
if dependent_name:
# The value comes from an upstream sibling.
# Special handling for recursive subgroup: argument name comes from the existing opsgroup
if is_recursive_subgroup:
for index, input in enumerate(sub_group.inputs):
if param_name == self._pipelineparam_full_name(input):
break
referenced_input = sub_group.recursive_ref.inputs[index]
full_name = self._pipelineparam_full_name(referenced_input)
arguments.append({
'name': full_name,
'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
})
else:
arguments.append({
'name': param_name,
'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)
})
task['arguments'] = {'parameters': self.get_arguments_for_sub_group(sub_group, is_recursive_subgroup, inputs)}

# additional task modifications for withItems and withParam
if isinstance(sub_group, dsl.ParallelFor):
if sub_group.items_is_pipeline_param:
# these loop args are a 'withParam' rather than 'withItems'.
# i.e., rather than a static list, they are either the output of another task or were input
# as global pipeline parameters

pipeline_param = sub_group.loop_args
if pipeline_param.op_name is None:
withparam_value = '{{workflow.parameters.%s}}' % pipeline_param.name
else:
# The value comes from its parent.
# Special handling for recursive subgroup: argument name comes from the existing opsgroup
if is_recursive_subgroup:
for index, input in enumerate(sub_group.inputs):
if param_name == self._pipelineparam_full_name(input):
break
referenced_input = sub_group.recursive_ref.inputs[index]
full_name = self._pipelineparam_full_name(referenced_input)
arguments.append({
'name': full_name,
'value': '{{inputs.parameters.%s}}' % param_name
})
else:
if isinstance(sub_group, dsl.ParallelFor):
if sub_group.loop_args.name in param_name:
if _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable(param_name):
subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name(param_name)
value = '{{item.%s}}' % subvar_name
elif _for_loop.LoopArguments.name_is_loop_arguments(param_name):
value = '{{item}}'
else:
raise ValueError("Failed to match loop args with param. param_name: {}, ".format(param_name) +
"sub_group.loop_args.name: {}.".format(sub_group.loop_args.name))
else:
value = '{{inputs.parameters.%s}}' % param_name
task['withItems'] = sub_group.loop_args.to_list_for_task_yaml()
else:
value = '{{inputs.parameters.%s}}' % param_name
arguments.append({
'name': param_name,
'value': value,
})
arguments.sort(key=lambda x: x['name'])
task['arguments'] = {'parameters': arguments}

param_name = '%s-%s' % (pipeline_param.op_name, pipeline_param.name)
withparam_value = '{{tasks.%s.outputs.parameters.%s}}' % (pipeline_param.op_name, param_name)

# these loop args are the output of another task
if 'dependencies' not in task or task['dependencies'] is None:
task['dependencies'] = []
if pipeline_param.op_name not in task['dependencies']:
task['dependencies'].append(pipeline_param.op_name)

task['withParam'] = withparam_value
else:
task['withItems'] = sub_group.loop_args.to_list_for_task_yaml()

if isinstance(sub_group, dsl.ContainerOp) and sub_group.artifact_arguments:
artifact_argument_structs = []
for input_name, argument in sub_group.artifact_arguments.items():
Expand All @@ -532,6 +501,46 @@ def _group_to_dag_template(self, group, inputs, outputs, dependencies):
template['dag'] = {'tasks': tasks}
return template

def get_arguments_for_sub_group(
self,
sub_group: Union[OpsGroup, dsl._container_op.BaseOp],
is_recursive_subgroup: Optional[bool],
inputs: Dict[Text, Tuple[Text, Text]],
):
arguments = []
for param_name, dependent_name in inputs[sub_group.name]:
if is_recursive_subgroup:
for index, input in enumerate(sub_group.inputs):
if param_name == self._pipelineparam_full_name(input):
break
referenced_input = sub_group.recursive_ref.inputs[index]
argument_name = self._pipelineparam_full_name(referenced_input)
else:
argument_name = param_name

# default argument_value + special cases
argument_value = '{{inputs.parameters.%s}}' % param_name
if isinstance(sub_group, dsl.ParallelFor):
if sub_group.loop_args.name in param_name:
if _for_loop.LoopArgumentVariable.name_is_loop_arguments_variable(param_name):
subvar_name = _for_loop.LoopArgumentVariable.get_subvar_name(param_name)
argument_value = '{{item.%s}}' % subvar_name
elif _for_loop.LoopArguments.name_is_loop_arguments(param_name) or sub_group.items_is_pipeline_param:
argument_value = '{{item}}'
else:
raise ValueError("Failed to match loop args with parameter. param_name: {}, ".format(param_name))
elif dependent_name:
argument_value = '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name)

arguments.append({
'name': argument_name,
'value': argument_value,
})

arguments.sort(key=lambda x: x['name'])

return arguments

def _create_dag_templates(self, pipeline, op_transformers=None, op_to_templates_handler=None):
"""Create all groups and ops templates in the pipeline.

Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/dsl/_container_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,7 +970,7 @@ def __init__(
container_kwargs: Dict = None,
artifact_argument_paths: List[InputArgumentPath] = None,
file_outputs: Dict[str, str] = None,
output_artifact_paths : Dict[str, str]=None,
output_artifact_paths: Dict[str, str]=None,
artifact_location: V1alpha1ArtifactLocation=None,
is_exit_handler=False,
pvolumes: Dict[str, V1Volume] = None,
Expand Down
30 changes: 22 additions & 8 deletions sdk/python/kfp/dsl/_for_loop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from typing import List, Union, Dict, Text, Any, Tuple
from typing import List, Union, Dict, Text, Any, Tuple, Optional

from kfp import dsl

Expand All @@ -19,7 +19,7 @@ class LoopArguments(dsl.PipelineParam):
def _subvar_name_is_legal(cls, proposed_variable_name: Text):
return re.match(cls.LEGAL_SUBVAR_NAME_REGEX, proposed_variable_name) is not None

def __init__(self, items: ItemList, code: Text):
def __init__(self, items: Union[ItemList, dsl.PipelineParam], code: Text, name_override: Optional[Text]=None, op_name: Optional[Text]=None):
"""_LoopArguments represent the set of items to loop over in a ParallelFor loop. This class shoudn't be
instantiated by the user but rather is created by _ops_group.ParallelFor.

Expand All @@ -29,12 +29,15 @@ def __init__(self, items: ItemList, code: Text):
code: A unique code used to identify these loop arguments. Should match the code for the ParallelFor
ops_group which created these _LoopArguments. This prevents parameter name collissions.
"""
super().__init__(name=self._make_name(code))
if name_override is None:
super().__init__(name=self._make_name(code))
else:
super().__init__(name=name_override, op_name=op_name)

if not isinstance(items, (list, tuple)):
raise TypeError("Expected list or tuple, got {}.".format(type(items)))
if not isinstance(items, (list, tuple, dsl.PipelineParam)):
raise TypeError("Expected list, tuple, or PipelineParam, got {}.".format(type(items)))

if isinstance(items[0], dict):
if isinstance(items, list) and isinstance(items[0], dict):
subvar_names = set(items[0].keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does Argo resolve {{item.a}} when the key a is missing?

for item in items:
if not set(item.keys()) == subvar_names:
Expand All @@ -48,10 +51,21 @@ def __init__(self, items: ItemList, code: Text):
"name.".format(subvar_name))
setattr(self, subvar_name, LoopArgumentVariable(self.name, subvar_name))

self.items = items
self.items_or_pipeline_param = items
self.referenced_subvar_names = []

def __getattr__(self, item):
# this is being overridden so that we can access subvariables of the LoopArguments (i.e.: item.a) without
# knowing the subvariable names ahead of time
self.referenced_subvar_names.append(item)
return LoopArgumentVariable(self.name, item)

def to_list_for_task_yaml(self):
return self.items
if isinstance(self.items_or_pipeline_param, list):
return self.items_or_pipeline_param
else:
raise ValueError("You should only call this method on loop args which have list items, "
"not pipeline param items.")

@classmethod
def _make_name(cls, code: Text):
Expand Down
15 changes: 10 additions & 5 deletions sdk/python/kfp/dsl/_ops_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from typing import Union
import uuid

from kfp.dsl import _for_loop
from kfp.dsl import _for_loop, _pipeline_param

from . import _container_op
from . import _pipeline
from ._pipeline_param import ConditionOperator


class OpsGroup(object):
"""Represents a logical group of ops and group of OpsGroups.
Expand Down Expand Up @@ -93,6 +93,7 @@ def after(self, dependency):
self.dependencies.append(dependency)
return self


class ExitHandler(OpsGroup):
"""Represents an exit handler that is invoked upon exiting a group of ops.

Expand Down Expand Up @@ -168,13 +169,17 @@ class ParallelFor(OpsGroup):
def _get_unique_id_code():
return uuid.uuid4().hex[:_for_loop.LoopArguments.NUM_CODE_CHARS]

def __init__(self, loop_args: _for_loop.ItemList):
# random code to id this loop
def __init__(self, loop_args: Union[_for_loop.ItemList, _pipeline_param.PipelineParam]):
self.items_is_pipeline_param = isinstance(loop_args, _pipeline_param.PipelineParam)

# use a random code to uniquely identify this loop
code = self._get_unique_id_code()
group_name = 'for-loop-{}'.format(code)
super().__init__(self.TYPE_NAME, name=group_name)

if not isinstance(loop_args, _for_loop.LoopArguments):
self.items_is_pipeline_param = isinstance(loop_args, _pipeline_param.PipelineParam)
if not self.items_is_pipeline_param and not isinstance(loop_args, _for_loop.LoopArguments):
# we were passed a raw list, wrap it in loop args
loop_args = _for_loop.LoopArguments(loop_args, code)

self.loop_args = loop_args
Expand Down
8 changes: 7 additions & 1 deletion sdk/python/tests/compiler/compiler_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,6 @@ def init_container_pipeline():
init_container = init_containers[0]
self.assertEqual(init_container, {'image':'alpine:latest', 'command': ['echo', 'bye'], 'name': 'echo'})


def test_delete_resource_op(self):
"""Test a pipeline with a delete resource operation."""
from kubernetes import client as k8s
Expand Down Expand Up @@ -703,6 +702,13 @@ def some_pipeline():
self.assertIsNone(delete_op_template.get("failureCondition"))
self.assertDictEqual(delete_op_template.get("outputs"), {})

def test_withparam_global(self):
self._test_py_compile_yaml('withparam_global')

def test_withparam_output(self):
self._test_py_compile_yaml('withparam_output')

def test_py_input_artifact_raw_value(self):
"""Test pipeline input_artifact_raw_value."""
self._test_py_compile_yaml('input_artifact_raw_value')

Loading