Skip to content

Commit

Permalink
fix(sdk.v2): Fix a couple of ParallelFor related bugs. Fixes #6383, f…
Browse files Browse the repository at this point in the history
…ixes #6628 (#6643)

* fix a couple of loop related bugs

* add release note
  • Loading branch information
chensun authored Sep 29, 2021
1 parent 5c89d51 commit b466f59
Show file tree
Hide file tree
Showing 16 changed files with 1,833 additions and 1,102 deletions.
2 changes: 2 additions & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

## Bug Fixes and Other Changes

* Fix a couple of bugs that affect nested loops and conditions in v2. [\#6643](https://github.com/kubeflow/pipelines/pull/6643)

## Documentation Updates

# 1.8.3
Expand Down
75 changes: 51 additions & 24 deletions sdk/python/kfp/dsl/_for_loop.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re
from typing import List, Union, Dict, Text, Any, Tuple, Optional
from typing import Any, Dict, List, Optional, Tuple, Union

from kfp import dsl
from kfp.dsl import _pipeline_param

ItemList = List[Union[int, float, str, Dict[Text, Any]]]
ItemList = List[Union[int, float, str, Dict[str, Any]]]


class LoopArguments(dsl.PipelineParam):
Expand All @@ -20,15 +21,15 @@ class LoopArguments(dsl.PipelineParam):
LEGAL_SUBVAR_NAME_REGEX = re.compile(r'[a-zA-Z_][0-9a-zA-Z_]*')

@classmethod
def _subvar_name_is_legal(cls, proposed_variable_name: Text):
def _subvar_name_is_legal(cls, proposed_variable_name: str):
return re.match(cls.LEGAL_SUBVAR_NAME_REGEX,
proposed_variable_name) is not None

def __init__(self,
items: Union[ItemList, dsl.PipelineParam],
code: Text,
name_override: Optional[Text] = None,
op_name: Optional[Text] = None,
code: str,
name_override: Optional[str] = None,
op_name: Optional[str] = None,
*args,
**kwargs):
"""LoopArguments represent the set of items to loop over in a
Expand All @@ -43,7 +44,7 @@ def __init__(self,
variable name.
code: A unique code used to identify these loop arguments. Should
match the code for the ParallelFor ops_group which created these
_LoopArguments. This prevents parameter name collisions.
LoopArguments. This prevents parameter name collisions.
"""
if name_override is None:
super().__init__(name=self._make_name(code), *args, **kwargs)
Expand Down Expand Up @@ -76,7 +77,11 @@ def __init__(self,
setattr(
self, subvar_name,
LoopArgumentVariable(
self.name, subvar_name, loop_args_op_name=self.op_name))
loop_args_name=self.name,
this_variable_name=subvar_name,
loop_args_op_name=self.op_name,
loop_args=self,
))

self.items_or_pipeline_param = items
self.referenced_subvar_names = []
Expand All @@ -97,7 +102,11 @@ def __getattr__(self, item):
# of time
self.referenced_subvar_names.append(item)
return LoopArgumentVariable(
self.name, item, loop_args_op_name=self.op_name)
loop_args_name=self.name,
this_variable_name=item,
loop_args_op_name=self.op_name,
loop_args=self,
)

def to_list_for_task_yaml(self):
if isinstance(self.items_or_pipeline_param, (list, tuple)):
Expand All @@ -108,15 +117,15 @@ def to_list_for_task_yaml(self):
'not pipeline param items.')

@classmethod
def _make_name(cls, code: Text):
def _make_name(cls, code: str):
"""Make a name for this parameter.
Code is a
"""
return '{}-{}'.format(cls.LOOP_ITEM_PARAM_NAME_BASE, code)

@classmethod
def name_is_loop_argument(cls, param_name: Text) -> bool:
def name_is_loop_argument(cls, param_name: str) -> bool:
"""Return True if the given parameter name looks like a loop argument.
Either it came from a withItems loop item or withParams loop
Expand All @@ -126,17 +135,25 @@ def name_is_loop_argument(cls, param_name: Text) -> bool:
or cls.name_is_withparams_loop_argument(param_name)

@classmethod
def name_is_withitems_loop_argument(cls, param_name: Text) -> bool:
def name_is_withitems_loop_argument(cls, param_name: str) -> bool:
"""Return True if the given parameter name looks like it came from a
loop arguments parameter."""
return (cls.LOOP_ITEM_PARAM_NAME_BASE + '-') in param_name

@classmethod
def name_is_withparams_loop_argument(cls, param_name: Text) -> bool:
def name_is_withparams_loop_argument(cls, param_name: str) -> bool:
"""Return True if the given parameter name looks like it came from a
withParams loop item."""
return ('-' + cls.LOOP_ITEM_NAME_BASE) in param_name

@classmethod
def remove_loop_item_base_name(cls, param_name: str) -> str:
"""Removes the last LOOP_ITEM_NAME_BASE from the end of param name."""
if ('-' + cls.LOOP_ITEM_NAME_BASE) in param_name:
# Split from the right, so that it handles multi-level nested args.
return param_name.rsplit('-' + cls.LOOP_ITEM_NAME_BASE, 1)[0]
return param_name


class LoopArgumentVariable(dsl.PipelineParam):
"""Represents a subvariable for loop arguments.
Expand All @@ -148,31 +165,42 @@ class LoopArgumentVariable(dsl.PipelineParam):

def __init__(
self,
loop_args_name: Text,
this_variable_name: Text,
loop_args_op_name: Text,
loop_args_name: str,
this_variable_name: str,
loop_args_op_name: Optional[str],
# For backward compatible, add loop_args as an optional argument.
# Ideally, this should replace loop_args_name and loop_args_op_name.
loop_args: Optional[LoopArguments] = None,
):
"""
If the user ran:
with dsl.ParallelFor([{'a': 1, 'b': 2}, {'a': 3, 'b': 4}]) as item:
...
Then there's be one _LoopArgumentsVariable for 'a' and another for 'b'.
Then there's be one LoopArgumentsVariable for 'a' and another for 'b'.
Args:
loop_args_name: the name of the _LoopArguments object that this is
loop_args_name: The name of the LoopArguments object that this is
a subvariable to.
this_variable_name: the name of this subvariable, which is the name
this_variable_name: The name of this subvariable, which is the name
of the dict key that spawned this subvariable.
loop_args_op_name: The name of the op that produced the loop arguments.
loop_args: Optional; The LoopArguments object this subvariable is based on.
"""
super().__init__(
name=self.get_name(
loop_args_name=loop_args_name,
this_variable_name=this_variable_name),
op_name=loop_args_op_name,
)
self.loop_args = loop_args

@property
def items_or_pipeline_param(
self) -> Union[ItemList, _pipeline_param.PipelineParam]:
return self.loop_args.items_or_pipeline_param

@classmethod
def get_name(cls, loop_args_name: Text, this_variable_name: Text) -> Text:
def get_name(cls, loop_args_name: str, this_variable_name: str) -> str:
"""Get the name.
Args:
Expand All @@ -187,15 +215,14 @@ def get_name(cls, loop_args_name: Text, this_variable_name: Text) -> Text:
this_variable_name)

@classmethod
def name_is_loop_arguments_variable(cls, param_name: Text) -> bool:
def name_is_loop_arguments_variable(cls, param_name: str) -> bool:
"""Return True if the given parameter name looks like it came from a
LoopArgumentsVariable."""
return re.match('.+%s.+' % cls.SUBVAR_NAME_DELIMITER,
param_name) is not None

@classmethod
def parse_loop_args_name_and_this_var_name(cls,
t: Text) -> Tuple[Text, Text]:
def parse_loop_args_name_and_this_var_name(cls, t: str) -> Tuple[str, str]:
"""Get the loop arguments param name and this subvariable name from the
given parameter name."""
m = re.match(
Expand All @@ -208,7 +235,7 @@ def parse_loop_args_name_and_this_var_name(cls,
)['this_var_name']

@classmethod
def get_subvar_name(cls, t: Text) -> Text:
def get_subvar_name(cls, t: str) -> str:
"""Get the subvariable name from a given LoopArgumentsVariable
parameter name."""
out = cls.parse_loop_args_name_and_this_var_name(t)
Expand Down
22 changes: 12 additions & 10 deletions sdk/python/kfp/dsl/component_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from typing import List, Optional, Tuple, Union

from kfp.components import _structures as structures
from kfp.dsl import _for_loop
from kfp.dsl import _pipeline_param
from kfp.dsl import dsl_utils
from kfp.dsl import _for_loop, _pipeline_param, dsl_utils
from kfp.pipeline_spec import pipeline_spec_pb2
from kfp.v2.components.types import type_utils

Expand Down Expand Up @@ -191,13 +189,16 @@ def build_task_inputs_spec(
"""
for param in pipeline_params or []:

param_name, subvar_name = _exclude_loop_arguments_variables(param)
param_full_name, subvar_name = _exclude_loop_arguments_variables(param)
input_name = additional_input_name_for_pipelineparam(param.full_name)

param_name = param.name
if subvar_name:
task_spec.inputs.parameters[
input_name].parameter_expression_selector = (
'parseJson(string_value)["{}"]'.format(subvar_name))
param_name = _for_loop.LoopArguments.remove_loop_item_base_name(
_exclude_loop_arguments_variables(param_name)[0])

if type_utils.is_parameter_type(param.param_type):
if param.op_name and dsl_utils.sanitize_task_name(
Expand All @@ -207,12 +208,13 @@ def build_task_inputs_spec(
dsl_utils.sanitize_task_name(param.op_name))
task_spec.inputs.parameters[
input_name].task_output_parameter.output_parameter_key = (
param.name)
param_name)
else:
task_spec.inputs.parameters[
input_name].component_input_parameter = (
param_name if is_parent_component_root else
additional_input_name_for_pipelineparam(param_name))
param_full_name if is_parent_component_root else
additional_input_name_for_pipelineparam(param_full_name)
)
else:
if param.op_name and dsl_utils.sanitize_task_name(
param.op_name) in tasks_in_current_dag:
Expand All @@ -221,11 +223,12 @@ def build_task_inputs_spec(
dsl_utils.sanitize_task_name(param.op_name))
task_spec.inputs.artifacts[
input_name].task_output_artifact.output_artifact_key = (
param.name)
param_name)
else:
task_spec.inputs.artifacts[
input_name].component_input_artifact = (
param_name if is_parent_component_root else input_name)
param_full_name
if is_parent_component_root else input_name)


def update_task_inputs_spec(
Expand Down Expand Up @@ -345,7 +348,6 @@ def update_task_inputs_spec(
component_input_parameter = (
additional_input_name_for_pipelineparam(
component_input_parameter))

assert component_input_parameter in parent_component_inputs.parameters, \
'component_input_parameter: {} not found. All inputs: {}'.format(
component_input_parameter, parent_component_inputs)
Expand Down
Loading

0 comments on commit b466f59

Please sign in to comment.