diff --git a/sdk/python/kfp/components/_components.py b/sdk/python/kfp/components/_components.py index 9613bcad4ae8..9d7f035a1fde 100644 --- a/sdk/python/kfp/components/_components.py +++ b/sdk/python/kfp/components/_components.py @@ -19,15 +19,16 @@ 'load_component_from_file', ] +import copy import sys from collections import OrderedDict +from typing import Any, List, Mapping, NamedTuple, Sequence, Union from ._naming import _sanitize_file_name, _sanitize_python_function_name, generate_unique_name_conversion_table from ._yaml_utils import load_yaml from ._structures import ComponentSpec from ._structures import * from ._data_passing import serialize_value, type_name_to_type -from kfp.dsl import PipelineParam -from kfp.dsl.types import verify_type_compatibility + _default_component_name = 'Component' @@ -170,13 +171,71 @@ def _generate_output_file_name(port_name): return _outputs_dir + '/' + _sanitize_file_name(port_name) + '/' + _single_io_file_name -#Holds the transformation functions that are called each time TaskSpec instance is created from a component. If there are multiple handlers, the last one is used. -_created_task_transformation_handler = [] +def _react_to_incompatible_reference_type( + input_type, + argument_type, + input_name: str, +): + """Raises error for the case when the argument type is incompatible with the input type.""" + message = 'Argument with type "{}" was passed to the input "{}" that has type "{}".'.format(argument_type, input_name, input_type) + raise TypeError(message) + + +def _create_task_spec_from_component_and_arguments( + component_spec: ComponentSpec, + arguments: Mapping[str, Any], + component_ref: ComponentReference = None, +) -> TaskSpec: + """Constructs a TaskSpec object from component reference and arguments. + The function also checks the arguments types and serializes them.""" + if component_ref is None: + component_ref = ComponentReference(spec=component_spec) + else: + component_ref = copy.copy(component_ref) + component_ref.spec = component_spec + + # Not checking for missing or extra arguments since the dynamic factory function checks that + task_arguments = {} + for input_name, argument_value in arguments.items(): + input_type = component_spec._inputs_dict[input_name].type + + if isinstance(argument_value, (GraphInputArgument, TaskOutputArgument)): + # argument_value is a reference + if isinstance(argument_value, GraphInputArgument): + reference_type = argument_value.graph_input.type + elif isinstance(argument_value, TaskOutputArgument): + reference_type = argument_value.task_output.type + else: + reference_type = None + + if reference_type and input_type and reference_type != input_type: + _react_to_incompatible_reference_type(input_type, reference_type, input_name) + + task_arguments[input_name] = argument_value + else: + # argument_value is a constant value + serialized_argument_value = serialize_value(argument_value, input_type) + task_arguments[input_name] = serialized_argument_value + + task = TaskSpec( + component_ref=component_ref, + arguments=task_arguments, + ) + task._init_outputs() + + return task + +_default_container_task_constructor = _create_task_spec_from_component_and_arguments -#TODO: Move to the dsl.Pipeline context class -from . import _dsl_bridge -_created_task_transformation_handler.append(_dsl_bridge.create_container_op_from_task) +# Holds the function that constructs a task object based on ComponentSpec, arguments and ComponentReference. +# Framework authors can override this constructor function to construct different framework-specific task-like objects. +# The task object should have the task.outputs dictionary with keys corresponding to the ComponentSpec outputs. +# The default constructor creates and instance of the TaskSpec class. +_container_task_constructor = _default_container_task_constructor + + +_always_expand_graph_components = False class _DefaultValue: @@ -210,44 +269,33 @@ def _create_task_factory_from_component_spec(component_spec:ComponentSpec, compo component_ref.spec = component_spec def create_task_from_component_and_arguments(pythonic_arguments): - arguments = {} - # Not checking for missing or extra arguments since the dynamic factory function checks that - for argument_name, argument_value in pythonic_arguments.items(): - if isinstance(argument_value, _DefaultValue): # Skipping passing arguments for optional values that have not been overridden. - continue - input_name = pythonic_name_to_input_name[argument_name] - input_type = component_spec._inputs_dict[input_name].type - - if isinstance(argument_value, (GraphInputArgument, TaskOutputArgument, PipelineParam)): - # argument_value is a reference - - if isinstance(argument_value, PipelineParam): - reference_type = argument_value.param_type - argument_value = str(argument_value) - elif isinstance(argument_value, TaskOutputArgument): - reference_type = argument_value.task_output.type - else: - reference_type = None - - verify_type_compatibility(reference_type, input_type, 'Incompatible argument passed to the input "{}" of component "{}": '.format(input_name, component_spec.name)) - - arguments[input_name] = argument_value - else: - # argument_value is a constant value - serialized_argument_value = serialize_value(argument_value, input_type) - arguments[input_name] = serialized_argument_value - - task = TaskSpec( - component_ref=component_ref, + arguments = { + pythonic_name_to_input_name[argument_name]: argument_value + for argument_name, argument_value in pythonic_arguments.items() + if not isinstance(argument_value, _DefaultValue) # Skipping passing arguments for optional values that have not been overridden. + } + + if ( + isinstance(component_spec.implementation, GraphImplementation) + and ( + # When the container task constructor is not overriden, we just construct TaskSpec for both container and graph tasks. + # If the container task constructor is overriden, we should expand the graph components so that the override is called for all sub-tasks. + _container_task_constructor != _default_container_task_constructor + or _always_expand_graph_components + ) + ): + return _resolve_graph_task( + component_spec=component_spec, + arguments=arguments, + component_ref=component_ref, + ) + + task = _container_task_constructor( + component_spec=component_spec, arguments=arguments, + component_ref=component_ref, ) - task._init_outputs() - - if isinstance(component_spec.implementation, GraphImplementation): - return _resolve_graph_task(task, component_spec) - if _created_task_transformation_handler: - task = _created_task_transformation_handler[-1](task) return task import inspect @@ -284,14 +332,161 @@ def component_default_to_func_default(component_default: str, is_optional: bool) return task_factory -def _resolve_graph_task(graph_task: TaskSpec, graph_component_spec: ComponentSpec) -> TaskSpec: +_ResolvedCommandLineAndPaths = NamedTuple( + '_ResolvedCommandLineAndPaths', + [ + ('command', Sequence[str]), + ('args', Sequence[str]), + ('input_paths', Mapping[str, str]), + ('output_paths', Mapping[str, str]), + ('inputs_consumed_by_value', Mapping[str, str]), + ], +) + + +def _resolve_command_line_and_paths( + component_spec: ComponentSpec, + arguments: Mapping[str, str], + input_path_generator=_generate_input_file_name, + output_path_generator=_generate_output_file_name, + argument_serializer=serialize_value, +) -> _ResolvedCommandLineAndPaths: + """Resolves the command line argument placeholders. Also produces the maps of the generated inpuit/output paths.""" + argument_values = arguments + + if not isinstance(component_spec.implementation, ContainerImplementation): + raise TypeError('Only container components have command line to resolve') + + inputs_dict = {input_spec.name: input_spec for input_spec in component_spec.inputs or []} + container_spec = component_spec.implementation.container + + output_paths = OrderedDict() #Preserving the order to make the kubernetes output names deterministic + unconfigurable_output_paths = container_spec.file_outputs or {} + for output in component_spec.outputs or []: + if output.name in unconfigurable_output_paths: + output_paths[output.name] = unconfigurable_output_paths[output.name] + + input_paths = OrderedDict() + inputs_consumed_by_value = {} + + def expand_command_part(arg) -> Union[str, List[str], None]: + if arg is None: + return None + if isinstance(arg, (str, int, float, bool)): + return str(arg) + + if isinstance(arg, InputValuePlaceholder): + input_name = arg.input_name + input_spec = inputs_dict[input_name] + input_value = argument_values.get(input_name, None) + if input_value is not None: + serialized_argument = argument_serializer(input_value, input_spec.type) + inputs_consumed_by_value[input_name] = serialized_argument + return serialized_argument + else: + if input_spec.optional: + return None + else: + raise ValueError('No value provided for input {}'.format(input_name)) + + if isinstance(arg, InputPathPlaceholder): + input_name = arg.input_name + input_value = argument_values.get(input_name, None) + if input_value is not None: + input_path = input_path_generator(input_name) + input_paths[input_name] = input_path + return input_path + else: + input_spec = inputs_dict[input_name] + if input_spec.optional: + #Even when we support default values there is no need to check for a default here. + #In current execution flow (called by python task factory), the missing argument would be replaced with the default value by python itself. + return None + else: + raise ValueError('No value provided for input {}'.format(input_name)) + + elif isinstance(arg, OutputPathPlaceholder): + output_name = arg.output_name + output_filename = output_path_generator(output_name) + if arg.output_name in output_paths: + if output_paths[output_name] != output_filename: + raise ValueError('Conflicting output files specified for port {}: {} and {}'.format(output_name, output_paths[output_name], output_filename)) + else: + output_paths[output_name] = output_filename + + return output_filename + + elif isinstance(arg, ConcatPlaceholder): + expanded_argument_strings = expand_argument_list(arg.items) + return ''.join(expanded_argument_strings) + + elif isinstance(arg, IfPlaceholder): + arg = arg.if_structure + condition_result = expand_command_part(arg.condition) + from distutils.util import strtobool + condition_result_bool = condition_result and strtobool(condition_result) #Python gotcha: bool('False') == True; Need to use strtobool; Also need to handle None and [] + result_node = arg.then_value if condition_result_bool else arg.else_value + if result_node is None: + return [] + if isinstance(result_node, list): + expanded_result = expand_argument_list(result_node) + else: + expanded_result = expand_command_part(result_node) + return expanded_result + + elif isinstance(arg, IsPresentPlaceholder): + argument_is_present = argument_values.get(arg.input_name, None) is not None + return str(argument_is_present) + else: + raise TypeError('Unrecognized argument type: {}'.format(arg)) + + def expand_argument_list(argument_list): + expanded_list = [] + if argument_list is not None: + for part in argument_list: + expanded_part = expand_command_part(part) + if expanded_part is not None: + if isinstance(expanded_part, list): + expanded_list.extend(expanded_part) + else: + expanded_list.append(str(expanded_part)) + return expanded_list + + expanded_command = expand_argument_list(container_spec.command) + expanded_args = expand_argument_list(container_spec.args) + + return _ResolvedCommandLineAndPaths( + command=expanded_command, + args=expanded_args, + input_paths=input_paths, + output_paths=output_paths, + inputs_consumed_by_value=inputs_consumed_by_value, + ) + + +_ResolvedGraphTask = NamedTuple( + '_ResolvedGraphTask', + [ + ('component_spec', ComponentSpec), + ('component_ref', ComponentReference), + ('outputs', Mapping[str, Any]), + ('task_arguments', Mapping[str, Any]), + ], +) + + +def _resolve_graph_task( + component_spec: ComponentSpec, + arguments: Mapping[str, Any], + component_ref: ComponentReference = None, +) -> TaskSpec: from ..components import ComponentStore component_store = ComponentStore.default_store - graph = graph_component_spec.implementation.graph + graph = component_spec.implementation.graph - graph_input_arguments = {input.name: input.default for input in graph_component_spec.inputs if input.default is not None} - graph_input_arguments.update(graph_task.arguments) + graph_input_arguments = {input.name: input.default for input in component_spec.inputs if input.default is not None} + graph_input_arguments.update(arguments) outputs_of_tasks = {} def resolve_argument(argument): @@ -326,7 +521,10 @@ def resolve_argument(argument): resolved_graph_outputs = OrderedDict([(output_name, resolve_argument(argument)) for output_name, argument in graph.output_values.items()]) # For resolved graph component tasks task.outputs point to the actual tasks that originally produced the output that is later returned from the graph - graph_task.output_references = graph_task.outputs - graph_task.outputs = resolved_graph_outputs - + graph_task = _ResolvedGraphTask( + component_ref=component_ref, + component_spec=component_spec, + outputs = resolved_graph_outputs, + task_arguments=arguments, + ) return graph_task diff --git a/sdk/python/kfp/components/_dsl_bridge.py b/sdk/python/kfp/components/_dsl_bridge.py index 6b86295a8deb..bee03a122a15 100644 --- a/sdk/python/kfp/components/_dsl_bridge.py +++ b/sdk/python/kfp/components/_dsl_bridge.py @@ -13,150 +13,59 @@ # limitations under the License. import copy -from collections import OrderedDict -from typing import Mapping -from ._structures import ContainerImplementation, ConcatPlaceholder, IfPlaceholder, InputValuePlaceholder, InputPathPlaceholder, IsPresentPlaceholder, OutputPathPlaceholder, TaskSpec -from ._components import _generate_input_file_name, _generate_output_file_name, _default_component_name - -def create_container_op_from_task(task_spec: TaskSpec): - argument_values = task_spec.arguments - component_spec = task_spec.component_ref.spec - - if not isinstance(component_spec.implementation, ContainerImplementation): - raise TypeError('Only container component tasks can be converted to ContainerOp') - - inputs_dict = {input_spec.name: input_spec for input_spec in component_spec.inputs or []} - container_spec = component_spec.implementation.container - - output_paths = OrderedDict() #Preserving the order to make the kubernetes output names deterministic - unconfigurable_output_paths = container_spec.file_outputs or {} - for output in component_spec.outputs or []: - if output.name in unconfigurable_output_paths: - output_paths[output.name] = unconfigurable_output_paths[output.name] - - input_paths = OrderedDict() - artifact_arguments = OrderedDict() - - def expand_command_part(arg): #input values with original names - #(Union[str,Mapping[str, Any]]) -> Union[str,List[str]] - if arg is None: - return None - if isinstance(arg, (str, int, float, bool)): - return str(arg) - - if isinstance(arg, InputValuePlaceholder): - input_name = arg.input_name - input_value = argument_values.get(input_name, None) - if input_value is not None: - return str(input_value) - else: - input_spec = inputs_dict[input_name] - if input_spec.optional: - return None - else: - raise ValueError('No value provided for input {}'.format(input_name)) - - if isinstance(arg, InputPathPlaceholder): - input_name = arg.input_name - input_value = argument_values.get(input_name, None) - if input_value is not None: - input_path = _generate_input_file_name(input_name) - input_paths[input_name] = input_path - artifact_arguments[input_name] = input_value - return input_path - else: - input_spec = inputs_dict[input_name] - if input_spec.optional: - #Even when we support default values there is no need to check for a default here. - #In current execution flow (called by python task factory), the missing argument would be replaced with the default value by python itself. - return None - else: - raise ValueError('No value provided for input {}'.format(input_name)) - - elif isinstance(arg, OutputPathPlaceholder): - output_name = arg.output_name - output_filename = _generate_output_file_name(output_name) - if arg.output_name in output_paths: - if output_paths[output_name] != output_filename: - raise ValueError('Conflicting output files specified for port {}: {} and {}'.format(output_name, output_paths[output_name], output_filename)) - else: - output_paths[output_name] = output_filename - - return output_filename - - elif isinstance(arg, ConcatPlaceholder): - expanded_argument_strings = expand_argument_list(arg.items) - return ''.join(expanded_argument_strings) - - elif isinstance(arg, IfPlaceholder): - arg = arg.if_structure - condition_result = expand_command_part(arg.condition) - from distutils.util import strtobool - condition_result_bool = condition_result and strtobool(condition_result) #Python gotcha: bool('False') == True; Need to use strtobool; Also need to handle None and [] - result_node = arg.then_value if condition_result_bool else arg.else_value - if result_node is None: - return [] - if isinstance(result_node, list): - expanded_result = expand_argument_list(result_node) - else: - expanded_result = expand_command_part(result_node) - return expanded_result - - elif isinstance(arg, IsPresentPlaceholder): - argument_is_present = argument_values.get(arg.input_name, None) is not None - return str(argument_is_present) - else: - raise TypeError('Unrecognized argument type: {}'.format(arg)) - - def expand_argument_list(argument_list): - expanded_list = [] - if argument_list is not None: - for part in argument_list: - expanded_part = expand_command_part(part) - if expanded_part is not None: - if isinstance(expanded_part, list): - expanded_list.extend(expanded_part) - else: - expanded_list.append(str(expanded_part)) - return expanded_list - - expanded_command = expand_argument_list(container_spec.command) - expanded_args = expand_argument_list(container_spec.args) - - return _task_object_factory( - name=component_spec.name or _default_component_name, - container_image=container_spec.image, - command=expanded_command, - arguments=expanded_args, - input_paths=input_paths, - output_paths=output_paths, - artifact_arguments=artifact_arguments, - env=container_spec.env, +from typing import Any, Mapping +from ._structures import ComponentSpec, ComponentReference +from ._components import _default_component_name, _resolve_command_line_and_paths +from .. import dsl + + +def _create_container_op_from_component_and_arguments( + component_spec: ComponentSpec, + arguments: Mapping[str, Any], + component_ref: ComponentReference = None, +) -> 'dsl.ContainerOp': + # Check types of the reference arguments and serialize PipelineParams + arguments = arguments.copy() + for input_name, argument_value in arguments.items(): + if isinstance(argument_value, dsl.PipelineParam): + input_type = component_spec._inputs_dict[input_name].type + reference_type = argument_value.param_type + dsl.types.verify_type_compatibility(reference_type, input_type, 'Incompatible argument passed to the input "{}" of component "{}": '.format(input_name, component_spec.name)) + + arguments[input_name] = str(argument_value) + + resolved_cmd = _resolve_command_line_and_paths( component_spec=component_spec, + arguments=arguments, ) - -def _create_container_op_from_resolved_task(name:str, container_image:str, command=None, arguments=None, input_paths=None, artifact_arguments=None, output_paths=None, env : Mapping[str, str]=None, component_spec=None): - from .. import dsl - #Renaming outputs to conform with ContainerOp/Argo from ._naming import _sanitize_python_function_name, generate_unique_name_conversion_table - output_names = (output_paths or {}).keys() - output_name_to_kubernetes = generate_unique_name_conversion_table(output_names, _sanitize_python_function_name) - output_paths_for_container_op = {output_name_to_kubernetes[name]: path for name, path in output_paths.items()} + output_names = (resolved_cmd.output_paths or {}).keys() + output_name_to_python = generate_unique_name_conversion_table(output_names, _sanitize_python_function_name) + output_paths_for_container_op = {output_name_to_python[name]: path for name, path in resolved_cmd.output_paths.items()} + + container_spec = component_spec.implementation.container task = dsl.ContainerOp( - name=name, - image=container_image, - command=command, - arguments=arguments, + name=component_spec.name or _default_component_name, + image=container_spec.image, + command=resolved_cmd.command, + arguments=resolved_cmd.args, file_outputs=output_paths_for_container_op, - artifact_argument_paths=[dsl.InputArgumentPath(argument=artifact_arguments[input_name], input=input_name, path=path) for input_name, path in input_paths.items()], + artifact_argument_paths=[ + dsl.InputArgumentPath( + argument=arguments[input_name], + input=input_name, + path=path, + ) + for input_name, path in resolved_cmd.input_paths.items() + ], ) # Fixing ContainerOp output types if component_spec.outputs: for output in component_spec.outputs: - pythonic_name = output_name_to_kubernetes[output.name] + pythonic_name = output_name_to_python[output.name] if pythonic_name in task.outputs: task.outputs[pythonic_name].param_type = output.type @@ -164,9 +73,9 @@ def _create_container_op_from_resolved_task(name:str, container_image:str, comma component_meta.implementation = None task._set_metadata(component_meta) - if env: + if container_spec.env: from kubernetes import client as k8s_client - for name, value in env.items(): + for name, value in container_spec.env.items(): task.container.add_env_variable(k8s_client.V1EnvVar(name=name, value=value)) if component_spec.metadata: @@ -176,6 +85,3 @@ def _create_container_op_from_resolved_task(name:str, container_image:str, comma task.add_pod_label(key, value) return task - - -_task_object_factory=_create_container_op_from_resolved_task diff --git a/sdk/python/kfp/components/_python_to_graph_component.py b/sdk/python/kfp/components/_python_to_graph_component.py index 8a100ec70e49..53d55d3ca21a 100644 --- a/sdk/python/kfp/components/_python_to_graph_component.py +++ b/sdk/python/kfp/components/_python_to_graph_component.py @@ -74,7 +74,18 @@ def create_graph_component_spec_from_pipeline_func(pipeline_func: Callable, embe task_map = OrderedDict() #Preserving task order - def task_construction_handler(task: TaskSpec): + from ._components import _create_task_spec_from_component_and_arguments + def task_construction_handler( + component_spec, + arguments, + component_ref, + ): + task = _create_task_spec_from_component_and_arguments( + component_spec=component_spec, + arguments=arguments, + component_ref=component_ref, + ) + #Rewriting task ids so that they're same every time task_id = task.component_ref.spec.name or "Task" task_id = _make_name_unique_by_adding_index(task_id, task_map.keys(), ' ') @@ -94,12 +105,14 @@ def task_construction_handler(task: TaskSpec): try: #Setting the handler to fix and catch the tasks. - _components._created_task_transformation_handler.append(task_construction_handler) + # FIX: The handler only hooks container component creation + old_handler = _components._container_task_constructor + _components._container_task_constructor = task_construction_handler #Calling the pipeline_func with GraphInputArgument instances as arguments pipeline_func_result = pipeline_func(**pipeline_func_args) finally: - _components._created_task_transformation_handler.pop() + _components._container_task_constructor = old_handler # Getting graph outputs diff --git a/sdk/python/kfp/dsl/_pipeline.py b/sdk/python/kfp/dsl/_pipeline.py index d92736851f00..d736a346191d 100644 --- a/sdk/python/kfp/dsl/_pipeline.py +++ b/sdk/python/kfp/dsl/_pipeline.py @@ -17,6 +17,7 @@ from . import _resource_op from . import _ops_group from ..components._naming import _make_name_unique_by_adding_index +from ..components import _dsl_bridge, _components import sys @@ -189,6 +190,8 @@ def __enter__(self): raise Exception('Nested pipelines are not allowed.') Pipeline._default_pipeline = self + self._old_container_task_constructor = _components._container_task_constructor + _components._container_task_constructor = _dsl_bridge._create_container_op_from_component_and_arguments def register_op_and_generate_id(op): return self.add_op(op, op.is_exit_handler) @@ -200,6 +203,7 @@ def register_op_and_generate_id(op): def __exit__(self, *args): Pipeline._default_pipeline = None _container_op._register_op_handler = self._old__register_op_handler + _components._container_task_constructor = self._old_container_task_constructor def add_op(self, op: _container_op.BaseOp, define_only: bool): """Add a new operator. diff --git a/sdk/python/tests/components/test_components.py b/sdk/python/tests/components/test_components.py index 7a1e213cebda..a6db9e121fd5 100644 --- a/sdk/python/tests/components/test_components.py +++ b/sdk/python/tests/components/test_components.py @@ -27,14 +27,21 @@ @contextmanager def no_task_resolving_context(): - old_handler = kfp.components._components._created_task_transformation_handler + old_handler = kfp.components._components._container_task_constructor try: - kfp.components._components._created_task_transformation_handler = None + kfp.components._components._container_task_constructor = kfp.components._components._create_task_spec_from_component_and_arguments yield None finally: - kfp.components._components._created_task_transformation_handler = old_handler + kfp.components._components._container_task_constructor = old_handler class LoadComponentTestCase(unittest.TestCase): + def setUp(self): + self.old_container_task_constructor = kfp.components._components._container_task_constructor + kfp.components._components._container_task_constructor = kfp.components._dsl_bridge._create_container_op_from_component_and_arguments + + def tearDown(self): + kfp.components._components._container_task_constructor = self.old_container_task_constructor + def _test_load_component_from_file(self, component_path: str): task_factory1 = comp.load_component_from_file(component_path) @@ -644,7 +651,7 @@ def test_check_type_validation_of_task_spec_outputs(self): consumer_op(producer_task.outputs['out1']) consumer_op(producer_task.outputs['out2'].without_type()) consumer_op(producer_task.outputs['out2'].with_type('Integer')) - with self.assertRaises(InconsistentTypeException): + with self.assertRaises(TypeError): consumer_op(producer_task.outputs['out2']) def test_type_compatibility_check_for_simple_types(self): diff --git a/sdk/python/tests/components/test_graph_components.py b/sdk/python/tests/components/test_graph_components.py index cd11eadf57ab..6312554d8290 100644 --- a/sdk/python/tests/components/test_graph_components.py +++ b/sdk/python/tests/components/test_graph_components.py @@ -325,7 +325,11 @@ def test_load_nested_graph_components(self): graph out 4: '42' ''' op = comp.load_component_from_text(component_text) - task = op('graph 1', 'graph 2') + old_value = comp._components._always_expand_graph_components = True + try: + task = op('graph 1', 'graph 2') + finally: + comp._components._always_expand_graph_components = old_value self.assertIn('out3_1', str(task.outputs['graph out 1'])) # Checks that the outputs coming from tasks in nested subgraphs are properly resolved. self.assertIn('out1_2', str(task.outputs['graph out 2'])) self.assertEqual(task.outputs['graph out 3'], 'graph 2') diff --git a/sdk/python/tests/components/test_python_op.py b/sdk/python/tests/components/test_python_op.py index 084580a0dc3c..e3967c14c046 100644 --- a/sdk/python/tests/components/test_python_op.py +++ b/sdk/python/tests/components/test_python_op.py @@ -21,6 +21,7 @@ import kfp import kfp.components as comp +from kfp.components._components import _resolve_command_line_and_paths def add_two_numbers(a: float, b: float) -> float: '''Returns sum of two arguments''' @@ -82,11 +83,15 @@ def helper_test_2_in_1_out_component_using_local_call(self, func, op, arguments= with tempfile.TemporaryDirectory() as temp_dir_name: with components_local_output_dir_context(temp_dir_name): task = op(arguments[0], arguments[1]) + resolved_cmd = _resolve_command_line_and_paths( + task.component_ref.spec, + task.arguments, + ) - full_command = task.command + task.arguments + full_command = resolved_cmd.command + resolved_cmd.args subprocess.run(full_command, check=True) - output_path = list(task.file_outputs.values())[0] + output_path = list(resolved_cmd.output_paths.values())[0] actual_str = Path(output_path).read_text() self.assertEqual(float(actual_str), float(expected_str)) @@ -102,12 +107,16 @@ def helper_test_2_in_2_out_component_using_local_call(self, func, op, output_nam with tempfile.TemporaryDirectory() as temp_dir_name: with components_local_output_dir_context(temp_dir_name): task = op(arg1, arg2) + resolved_cmd = _resolve_command_line_and_paths( + task.component_ref.spec, + task.arguments, + ) - full_command = task.command + task.arguments + full_command = resolved_cmd.command + resolved_cmd.args subprocess.run(full_command, check=True) - (output_path1, output_path2) = (task.file_outputs[output_names[0]], task.file_outputs[output_names[1]]) + (output_path1, output_path2) = (resolved_cmd.output_paths[output_names[0]], resolved_cmd.output_paths[output_names[1]]) actual1_str = Path(output_path1).read_text() actual2_str = Path(output_path2).read_text() @@ -127,11 +136,7 @@ def helper_test_component_against_func_using_local_call(self, func: Callable, op expected_output_values_list = [str(value) for value in expected_output_values_list] output_names = [output.name for output in op.component_spec.outputs] - from kfp.components._naming import generate_unique_name_conversion_table, _sanitize_python_function_name - output_name_to_pythonic = generate_unique_name_conversion_table(output_names, _sanitize_python_function_name) - pythonic_output_names = [output_name_to_pythonic[name] for name in output_names] - from collections import OrderedDict - expected_output_values_dict = OrderedDict(zip(pythonic_output_names, expected_output_values_list)) + expected_output_values_dict = dict(zip(output_names, expected_output_values_list)) self.helper_test_component_using_local_call(op, arguments, expected_output_values_dict) @@ -145,21 +150,25 @@ def helper_test_component_using_local_call(self, component_task_factory: Callabl outputs_path = Path(temp_dir_name) / 'outputs' with components_override_input_output_dirs_context(str(inputs_path), str(outputs_path)): task = component_task_factory(**arguments) + resolved_cmd = _resolve_command_line_and_paths( + task.component_ref.spec, + task.arguments, + ) # Preparing input files - for input_name, input_file_path in (task.input_artifact_paths or {}).items(): + for input_name, input_file_path in (resolved_cmd.input_paths or {}).items(): Path(input_file_path).parent.mkdir(parents=True, exist_ok=True) Path(input_file_path).write_text(str(arguments[input_name])) # Constructing the full command-line from resolved command+args - full_command = task.command + task.arguments + full_command = resolved_cmd.command + resolved_cmd.args # Executing the command-line locally subprocess.run(full_command, check=True) - actual_output_values_dict = {output_name: Path(output_path).read_text() for output_name, output_path in task.file_outputs.items()} + actual_output_values_dict = {output_name: Path(output_path).read_text() for output_name, output_path in resolved_cmd.output_paths.items()} - self.assertEqual(actual_output_values_dict, expected_output_values) + self.assertDictEqual(actual_output_values_dict, expected_output_values) def test_func_to_container_op_local_call(self): func = add_two_numbers @@ -507,7 +516,11 @@ def consume_list(list_param: list) -> int: import kfp task_factory = comp.func_to_container_op(consume_list) task = task_factory([1, 2, 3, kfp.dsl.PipelineParam("aaa"), 4, 5, 6]) - full_command_line = task.command + task.arguments + resolved_cmd = _resolve_command_line_and_paths( + task.component_ref.spec, + task.arguments, + ) + full_command_line = resolved_cmd.command + resolved_cmd.args for arg in full_command_line: self.assertNotIn('PipelineParam', arg) @@ -543,7 +556,7 @@ def produce_list() -> list: import json expected_output = json.dumps(["string", 1, 2.2, True, False, None, [3, 4], {'s': 5}]) - self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'output': expected_output}) + self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'Output': expected_output}) def test_input_path(self): @@ -557,7 +570,7 @@ def consume_file_path(number_file_path: InputPath(int)) -> int: self.assertEqual(task_factory.component_spec.inputs[0].type, 'Integer') - self.helper_test_component_using_local_call(task_factory, arguments={'number': "42"}, expected_output_values={'output': '42'}) + self.helper_test_component_using_local_call(task_factory, arguments={'number': "42"}, expected_output_values={'Output': '42'}) def test_input_text_file(self): @@ -571,7 +584,7 @@ def consume_file_path(number_file: InputTextFile(int)) -> int: self.assertEqual(task_factory.component_spec.inputs[0].type, 'Integer') - self.helper_test_component_using_local_call(task_factory, arguments={'number': "42"}, expected_output_values={'output': '42'}) + self.helper_test_component_using_local_call(task_factory, arguments={'number': "42"}, expected_output_values={'Output': '42'}) def test_input_binary_file(self): @@ -585,7 +598,7 @@ def consume_file_path(number_file: InputBinaryFile(int)) -> int: self.assertEqual(task_factory.component_spec.inputs[0].type, 'Integer') - self.helper_test_component_using_local_call(task_factory, arguments={'number': "42"}, expected_output_values={'output': '42'}) + self.helper_test_component_using_local_call(task_factory, arguments={'number': "42"}, expected_output_values={'Output': '42'}) def test_output_path(self): @@ -645,7 +658,7 @@ def write_to_file_path(number_file_path: OutputPath(int)) -> str: self.assertEqual(task_factory.component_spec.outputs[0].type, 'Integer') self.assertEqual(task_factory.component_spec.outputs[1].type, 'String') - self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'number': '42', 'output': 'Hello'}) + self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'number': '42', 'Output': 'Hello'}) def test_all_data_passing_ways(self):