diff --git a/sdk/python/kfp/components/_python_op.py b/sdk/python/kfp/components/_python_op.py index a53f6c1cd83..f10700dff08 100644 --- a/sdk/python/kfp/components/_python_op.py +++ b/sdk/python/kfp/components/_python_op.py @@ -575,6 +575,17 @@ def get_serializer_and_register_definitions(type_name) -> str: '_output_files = _parsed_args.pop("_output_paths", [])', ]) + # Putting singular return values in a list to be "zipped" with the serializers and output paths + outputs_to_list_code = '' + return_ann = inspect.signature(func).return_annotation + if ( # The return type is singular, not sequence + return_ann is not None + and return_ann != inspect.Parameter.empty + and not isinstance(return_ann, dict) + and not hasattr(return_ann, '_fields') # namedtuple + ): + outputs_to_list_code = '_outputs = [_outputs]' + output_serialization_code = ''.join(' {},\n'.format(s) for s in output_serialization_expression_strings) full_source = \ @@ -589,8 +600,7 @@ def get_serializer_and_register_definitions(type_name) -> str: _outputs = {func_name}(**_parsed_args) -if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] +{outputs_to_list_code} _output_serializers = [ {output_serialization_code} @@ -611,6 +621,7 @@ def get_serializer_and_register_definitions(type_name) -> str: extra_code=extra_code, arg_parse_code='\n'.join(arg_parse_code_lines), output_serialization_code=output_serialization_code, + outputs_to_list_code=outputs_to_list_code, ) #Removing consecutive blank lines diff --git a/sdk/python/tests/compiler/testdata/parallelfor_item_argument_resolving.yaml b/sdk/python/tests/compiler/testdata/parallelfor_item_argument_resolving.yaml index 847b54249a8..47e49a54654 100644 --- a/sdk/python/tests/compiler/testdata/parallelfor_item_argument_resolving.yaml +++ b/sdk/python/tests/compiler/testdata/parallelfor_item_argument_resolving.yaml @@ -31,9 +31,6 @@ spec: _outputs = consume(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] - _output_serializers = [ ] @@ -76,9 +73,6 @@ spec: _outputs = consume(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] - _output_serializers = [ ] @@ -121,9 +115,6 @@ spec: _outputs = consume(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] - _output_serializers = [ ] @@ -166,9 +157,6 @@ spec: _outputs = consume(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] - _output_serializers = [ ] @@ -211,9 +199,6 @@ spec: _outputs = consume(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] - _output_serializers = [ ] @@ -256,9 +241,6 @@ spec: _outputs = consume(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] - _output_serializers = [ ] @@ -301,9 +283,6 @@ spec: _outputs = consume(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] - _output_serializers = [ ] @@ -507,8 +486,7 @@ spec: _outputs = produce_list_of_dicts(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] + _outputs = [_outputs] _output_serializers = [ _serialize_json, @@ -570,8 +548,7 @@ spec: _outputs = produce_list_of_ints(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] + _outputs = [_outputs] _output_serializers = [ _serialize_json, @@ -633,8 +610,7 @@ spec: _outputs = produce_list_of_strings(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] + _outputs = [_outputs] _output_serializers = [ _serialize_json, @@ -690,8 +666,7 @@ spec: _outputs = produce_str(**_parsed_args) - if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str): - _outputs = [_outputs] + _outputs = [_outputs] _output_serializers = [ _serialize_str, diff --git a/sdk/python/tests/components/test_python_op.py b/sdk/python/tests/components/test_python_op.py index 32acdf1e5aa..224dd0bdb1b 100644 --- a/sdk/python/tests/components/test_python_op.py +++ b/sdk/python/tests/components/test_python_op.py @@ -554,7 +554,7 @@ def assert_values_are_same( def test_handling_list_dict_output_values(self): def produce_list() -> list: - return (["string", 1, 2.2, True, False, None, [3, 4], {'s': 5}], ) + return ["string", 1, 2.2, True, False, None, [3, 4], {'s': 5}] # ! JSON map keys are always strings. Python converts all keys to strings without warnings task_factory = comp.func_to_container_op(produce_list)