diff --git a/sdk/python/kfp/compiler/_component_builder.py b/sdk/python/kfp/compiler/_component_builder.py index b911596b1f2..5c04b420042 100644 --- a/sdk/python/kfp/compiler/_component_builder.py +++ b/sdk/python/kfp/compiler/_component_builder.py @@ -411,7 +411,7 @@ def _generate_pythonop(component_func, target_image, target_component_file=None) The returned value is in fact a function, which should generates a container_op instance. """ from ..components._python_op import _python_function_name_to_component_name - from ..components._structures import InputSpec, OutputSpec, ImplementationSpec, ContainerSpec, ComponentSpec + from ..components._structures import InputSpec, InputValuePlaceholder, OutputPathPlaceholder, OutputSpec, ContainerImplementation, ContainerSpec, ComponentSpec #Component name and description are derived from the function's name and docstribng, but can be overridden by @python_component function decorator @@ -428,11 +428,11 @@ def _generate_pythonop(component_func, target_image, target_component_file=None) description=component_description, inputs=[InputSpec(name=input_name, type='str') for input_name in input_names], #TODO: Chnage type to actual type outputs=[OutputSpec(name=output_name)], - implementation=ImplementationSpec( + implementation=ContainerImplementation( container=ContainerSpec( image=target_image, #command=['python3', program_file], #TODO: Include the command line - args=[{'value': input_name} for input_name in input_names] + [{'output': output_name}], + args=[InputValuePlaceholder(input_name) for input_name in input_names] + [OutputPathPlaceholder(output_name)], ) ) ) diff --git a/sdk/python/kfp/components/_components.py b/sdk/python/kfp/components/_components.py index cefeb418eeb..85e9e773224 100644 --- a/sdk/python/kfp/components/_components.py +++ b/sdk/python/kfp/components/_components.py @@ -23,6 +23,7 @@ from collections import OrderedDict from ._yaml_utils import load_yaml from ._structures import ComponentSpec +from ._structures import * _default_component_name = 'Component' @@ -238,86 +239,71 @@ def expand_command_part(arg): #input values with original names return None if isinstance(arg, (str, int, float, bool)): return str(arg) - elif isinstance(arg, dict): - if len(arg) != 1: - raise ValueError('Failed to parse argument dict: "{}"'.format(arg)) - (func_name, func_argument) = list(arg.items())[0] - func_name=func_name.lower() - - if func_name == 'value': - assert isinstance(func_argument, str) - port_name = func_argument - input_value = pythonic_input_argument_values[input_name_to_pythonic[port_name]] - if input_value is not None: - return str(input_value) - else: - input_spec = inputs_dict[port_name] - if input_spec.optional: - #Even when we support default values there is no need to check for a default here. - #In current execution flow (called by python task factory), the missing argument would be replaced with the default value by python itself. - return None - else: - raise ValueError('No value provided for input {}'.format(port_name)) - - elif func_name == 'file': - assert isinstance(func_argument, str) - port_name = func_argument - input_filename = _generate_input_file_name(port_name) - input_key = input_name_to_kubernetes[port_name] - input_value = pythonic_input_argument_values[input_name_to_pythonic[port_name]] - if input_value is not None: - return input_filename - else: - input_spec = inputs_dict[port_name] - if input_spec.optional: - #Even when we support default values there is no need to check for a default here. - #In current execution flow (called by python task factory), the missing argument would be replaced with the default value by python itself. - return None - else: - raise ValueError('No value provided for input {}'.format(port_name)) - - elif func_name == 'output': - assert isinstance(func_argument, str) - port_name = func_argument - output_filename = _generate_output_file_name(port_name) - output_key = output_name_to_kubernetes[port_name] - if output_key in file_outputs: - if file_outputs[output_key] != output_filename: - raise ValueError('Conflicting output files specified for port {}: {} and {}'.format(port_name, file_outputs[output_key], output_filename)) + + if isinstance(arg, InputValuePlaceholder): + port_name = arg.input_name + input_value = pythonic_input_argument_values[input_name_to_pythonic[port_name]] + if input_value is not None: + return str(input_value) + else: + input_spec = inputs_dict[port_name] + if input_spec.optional: + #Even when we support default values there is no need to check for a default here. + #In current execution flow (called by python task factory), the missing argument would be replaced with the default value by python itself. + return None else: - file_outputs[output_key] = output_filename - - return output_filename - - elif func_name == 'concat': - assert isinstance(func_argument, list) - items_to_concatenate = func_argument - expanded_argument_strings = expand_argument_list(items_to_concatenate) - return ''.join(expanded_argument_strings) - - elif func_name == 'if': - assert isinstance(func_argument, dict) - condition_node = func_argument['cond'] - then_node = func_argument['then'] - else_node = func_argument.get('else', None) - condition_result = expand_command_part(condition_node) - from distutils.util import strtobool - condition_result_bool = condition_result and strtobool(condition_result) #Python gotcha: bool('False') == True; Need to use strtobool; Also need to handle None and [] - result_node = then_node if condition_result_bool else else_node - if result_node is None: - return [] - if isinstance(result_node, list): - expanded_result = expand_argument_list(result_node) + raise ValueError('No value provided for input {}'.format(port_name)) + + if isinstance(arg, InputPathPlaceholder): + port_name = arg.input_name + input_filename = _generate_input_file_name(port_name) + input_key = input_name_to_kubernetes[port_name] + input_value = pythonic_input_argument_values[input_name_to_pythonic[port_name]] + if input_value is not None: + return input_filename + else: + input_spec = inputs_dict[port_name] + if input_spec.optional: + #Even when we support default values there is no need to check for a default here. + #In current execution flow (called by python task factory), the missing argument would be replaced with the default value by python itself. + return None else: - expanded_result = expand_command_part(result_node) - return expanded_result - - elif func_name == 'ispresent': - assert isinstance(func_argument, str) - input_name = func_argument - pythonic_input_name = input_name_to_pythonic[input_name] - argument_is_present = pythonic_input_argument_values[pythonic_input_name] is not None - return str(argument_is_present) + raise ValueError('No value provided for input {}'.format(port_name)) + + elif isinstance(arg, OutputPathPlaceholder): + port_name = arg.output_name + output_filename = _generate_output_file_name(port_name) + output_key = output_name_to_kubernetes[port_name] + if output_key in file_outputs: + if file_outputs[output_key] != output_filename: + raise ValueError('Conflicting output files specified for port {}: {} and {}'.format(port_name, file_outputs[output_key], output_filename)) + else: + file_outputs[output_key] = output_filename + + return output_filename + + elif isinstance(arg, ConcatPlaceholder): + expanded_argument_strings = expand_argument_list(arg.items) + return ''.join(expanded_argument_strings) + + elif isinstance(arg, IfPlaceholder): + arg = arg.if_structure + condition_result = expand_command_part(arg.condition) + from distutils.util import strtobool + condition_result_bool = condition_result and strtobool(condition_result) #Python gotcha: bool('False') == True; Need to use strtobool; Also need to handle None and [] + result_node = arg.then_value if condition_result_bool else arg.else_value + if result_node is None: + return [] + if isinstance(result_node, list): + expanded_result = expand_argument_list(result_node) + else: + expanded_result = expand_command_part(result_node) + return expanded_result + + elif isinstance(arg, IsPresentPlaceholder): + pythonic_input_name = input_name_to_pythonic[arg.input_name] + argument_is_present = pythonic_input_argument_values[pythonic_input_name] is not None + return str(argument_is_present) else: raise TypeError('Unrecognized argument type: {}'.format(arg)) diff --git a/sdk/python/kfp/components/_python_op.py b/sdk/python/kfp/components/_python_op.py index 44fed3d46f4..3035d94ab18 100644 --- a/sdk/python/kfp/components/_python_op.py +++ b/sdk/python/kfp/components/_python_op.py @@ -19,7 +19,7 @@ from ._yaml_utils import dump_yaml from ._components import _create_task_factory_from_component_spec -from ._structures import InputSpec, OutputSpec, ImplementationSpec, ContainerSpec, ComponentSpec +from ._structures import * from pathlib import Path from typing import TypeVar, Generic @@ -79,73 +79,50 @@ def _func_to_component_spec(func, extra_code='', base_image=_default_base_image) extra_output_names = [] arguments = [] - def annotation_to_argument_kind_and_type_name(annotation): + def annotation_to_type_struct(annotation): if not annotation or annotation == inspect.Parameter.empty: - return ('value', None) - if hasattr(annotation, '__origin__'): #Generic type - type_name = annotation.__origin__.__name__ - type_args = annotation.__args__ - #if len(type_args) != 1: - # raise TypeError('Unsupported generic type {}'.format(type_name)) - inner_type = type_args[0] - if type_name == InputFile.__name__: - return ('file', inner_type.__name__) - elif type_name == OutputFile.__name__: - return ('output', inner_type.__name__) + return None if isinstance(annotation, type): - return ('value', annotation.__name__) + return str(annotation.__name__) else: - #!!! It's important to preserve string anotations as strings. Annotations that are neither types nor strings are converted to strings. - #Materializer adds double quotes to the types it does not recognize. - fix it to not quote strings. - #We need two kind of strings: we can use any type name for component YAML, but for generated Python code we must use valid python type annotations. - return ('value', "'" + str(annotation) + "'") + return str(annotation) for parameter in parameters: - annotation = parameter.annotation - - (argument_kind, parameter_type_name) = annotation_to_argument_kind_and_type_name(annotation) - - parameter_to_type_name[parameter.name] = parameter_type_name - + type_struct = annotation_to_type_struct(parameter.annotation) + parameter_to_type_name[parameter.name] = str(type_struct) #TODO: Humanize the input/output names - arguments.append({argument_kind: parameter.name}) - - parameter_spec = OrderedDict([('name', parameter.name)]) - if parameter_type_name: - parameter_spec['type'] = parameter_type_name - if argument_kind == 'value' or argument_kind == 'file': - inputs.append(parameter_spec) - elif argument_kind == 'output': - outputs.append(parameter_spec) - else: - #Cannot happen - raise ValueError('Unrecognized argument kind {}.'.format(argument_kind)) + arguments.append(InputValuePlaceholder(parameter.name)) + + input_spec = InputSpec( + name=parameter.name, + type=type_struct, + ) + inputs.append(input_spec) #Analyzing the return type annotations. return_ann = signature.return_annotation if hasattr(return_ann, '_fields'): #NamedTuple for field_name in return_ann._fields: - output_spec = OrderedDict([('name', field_name)]) + type_struct = None if hasattr(return_ann, '_field_types'): - output_type = return_ann._field_types.get(field_name, None) - if isinstance(output_type, type): - output_type_name = output_type.__name__ - else: - output_type_name = str(output_type) - - if output_type: - output_spec['type'] = output_type_name + type_struct = annotation_to_type_struct(return_ann._field_types.get(field_name, None)) + + output_spec = OutputSpec( + name=field_name, + type=type_struct, + ) outputs.append(output_spec) extra_output_names.append(field_name) - arguments.append({'output': field_name}) - else: - output_spec = OrderedDict([('name', single_output_name_const)]) - (_, output_type_name) = annotation_to_argument_kind_and_type_name(signature.return_annotation) - if output_type_name: - output_spec['type'] = output_type_name + arguments.append(OutputPathPlaceholder(field_name)) + elif signature.return_annotation is not None and signature.return_annotation != inspect.Parameter.empty: + type_struct = annotation_to_type_struct(signature.return_annotation) + output_spec = OutputSpec( + name=single_output_name_const, + type=type_struct, + ) outputs.append(output_spec) extra_output_names.append(single_output_pythonic_name_const) - arguments.append({'output': single_output_name_const}) + arguments.append(OutputPathPlaceholder(single_output_name_const)) func_name=func.__name__ @@ -226,9 +203,9 @@ def annotation_to_argument_kind_and_type_name(annotation): component_spec = ComponentSpec( name=component_name, description=description, - inputs=[InputSpec.from_struct(input) for input in inputs], - outputs=[OutputSpec.from_struct(output) for output in outputs], - implementation=ImplementationSpec( + inputs=inputs, + outputs=outputs, + implementation=ContainerImplementation( container=ContainerSpec( image=base_image, command=['python3', '-c', full_source], diff --git a/sdk/python/kfp/components/_structures.py b/sdk/python/kfp/components/_structures.py index c23f2efdf6b..a70874f628b 100644 --- a/sdk/python/kfp/components/_structures.py +++ b/sdk/python/kfp/components/_structures.py @@ -13,473 +13,543 @@ # limitations under the License. __all__ = [ - 'InputOrOutputSpec', 'InputSpec', 'OutputSpec', + + 'InputValuePlaceholder', + 'InputPathPlaceholder', + 'OutputPathPlaceholder', + 'ConcatPlaceholder', + 'IsPresentPlaceholder', + 'IfPlaceholder', + 'ContainerSpec', - 'GraphInputReferenceSpec', - 'TaskOutputReferenceSpec', - 'DataValueOrReferenceSpec', - 'TaskSpec', - 'GraphSpec', - 'ImplementationSpec', + 'ContainerImplementation', + 'SourceSpec', + 'ComponentSpec', -] + 'ComponentReference', + + 'GraphInputArgument', + 'TaskOutputReference', + 'TaskOutputArgument', + + 'EqualsPredicate', + 'NotEqualsPredicate', + 'GreaterThanPredicate', + 'GreaterThanOrEqualPredicate', + 'LessThenPredicate', + 'LessThenOrEqualPredicate', + 'NotPredicate', + 'AndPredicate', + 'OrPredicate', + + 'TaskSpec', + + 'GraphSpec', + 'GraphImplementation', + + 'PipelineRunSpec', +] -import copy from collections import OrderedDict -from typing import Union, List, Mapping, Tuple - - -class InputOrOutputSpec: - def __init__(self, name:str, type:str=None, description:str=None, optional:bool=False, pattern:str=None): - if not isinstance(name, str): - raise ValueError('name must be a string') - self.name = name - self.type = type - self.description = description - self.optional = optional - self.pattern = pattern - - @classmethod - def from_struct(cls, struct:Union[Tuple[str, Mapping],Mapping[str,Mapping],str]): - #if not isinstance(struct, tuple) and not isinstance(struct, dict) and not isinstance(struct, str): - # raise ValueError('InputOrOutputSpec.from_struct only supports tuples, dicts and strings') - - #We support two different serialization variants: - #1: {name: {'type': type}, ...} - #2: [{'name': name, 'type': type}, ...] - #1st one looks nicer, but we must take care to preserve the port ordering (prior to version 3.6 Python's dict does not preserve the order of elements). - if isinstance(struct, tuple): #(name: {'type': type}) - assert(len(struct) == 2) - (name, spec_dict) = struct - elif isinstance(struct, dict): #{'name': name, 'type': type} - spec_dict = copy.deepcopy(struct) - name = spec_dict.pop('name') - elif isinstance(struct, str): - name = struct - spec_dict = {} - else: - raise ValueError('InputOrOutputSpec.from_struct only supports tuples, dicts and strings') - #port_spec = InputOrOutputSpec(name) - port_spec = cls(name) - - if 'type' in spec_dict: - port_spec.type = spec_dict.pop('type') - check_instance_type(port_spec.type, [str, list, dict]) #TODO: Check format further - - if 'description' in spec_dict: - port_spec.description = str(spec_dict.pop('description')) - if 'optional' in spec_dict: - port_spec.optional = bool(spec_dict.pop('optional')) - - if 'pattern' in spec_dict: - port_spec.pattern = str(spec_dict.pop('pattern')) +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union - if spec_dict: - raise ValueError('Found unrecognized properties: {}'.format(spec_dict)) - - return port_spec - - def to_struct(self): - struct = OrderedDict() - struct['name'] = self.name - if self.type: - struct['type'] = self.type - if self.description: - struct['description'] = self.description - if self.optional: - struct['optional'] = self.optional - if self.pattern: - struct['pattern'] = self.pattern - - return struct - - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' +from .modelbase import ModelBase +from .structures.kubernetes import v1 -class InputSpec(InputOrOutputSpec): - pass +PrimitiveTypes = Union[str, int, float, bool] +PrimitiveTypesIncludingNone = Optional[PrimitiveTypes] -class OutputSpec(InputOrOutputSpec): - pass +class InputSpec(ModelBase): + '''Describes the component input specification''' + def __init__(self, + name: str, + type: Optional[Union[str, Dict, List]] = None, + description: Optional[str] = None, + default: Optional[PrimitiveTypes] = None, + optional: Optional[bool] = False, + ): + super().__init__(locals()) -class ContainerSpec: - def __init__(self, image:str, command:List=None, args:List=None, file_outputs:Mapping[str,str]=None): - if not isinstance(image, str): - raise ValueError('image must be a string') - self.image = image - self.command = command - self.args = args - self.file_outputs = file_outputs - - @staticmethod - def from_struct(spec_dict:Mapping): - spec_dict = copy.deepcopy(spec_dict) - - image = spec_dict.pop('image') - - container_spec = ContainerSpec(image) - - if 'command' in spec_dict: - container_spec.command = list(spec_dict.pop('command')) - if 'args' in spec_dict: - container_spec.args = list(spec_dict.pop('args')) - if 'fileOutputs' in spec_dict: - container_spec.file_outputs = dict(spec_dict.pop('fileOutputs')) - - if spec_dict: - raise ValueError('Found unrecognized properties: {}'.format(spec_dict)) - - return container_spec - - def to_struct(self): - struct = OrderedDict() - if self.image: - struct['image'] = self.image - if self.command: - struct['command'] = self.command - if self.args: - struct['args'] = self.args - if self.file_outputs: - struct['fileOutputs'] = self.file_outputs - - return struct - - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' +class OutputSpec(ModelBase): + '''Describes the component output specification''' + def __init__(self, + name: str, + type: Optional[Union[str, Dict, List]] = None, + description: Optional[str] = None, + ): + super().__init__(locals()) -def check_instance_type(obj, accepted_types:List[type]): - if not [accepted_type for accepted_type in accepted_types if isinstance(obj, accepted_type)]: - raise ValueError('Encountered object of wrong type. Accepted types: {}. Actual type: {}'.format(accepted_types, obj.__class__.__name__)) +class InputValuePlaceholder(ModelBase): #Non-standard attr names + '''Represents the command-line argument placeholder that will be replaced at run-time by the input argument value.''' + _serialized_names = { + #'input_name': 'inputValue', + 'input_name': 'value', #TODO: Rename to inputValue + } -class GraphInputReferenceSpec: - def __init__(self, input_name:str): - self.input_name = input_name + def __init__(self, + input_name: str, + ): + super().__init__(locals()) - @staticmethod - def from_struct(struct:Union[Tuple[str, str],List[str]]): - if len(struct) != 2 or struct[0].lower() != 'GraphInput'.lower(): - raise ValueError('Error parsing GraphInputReferenceSpec: "{}". The correct format is [GraphInput, Input name].'.format(struct)) - input_name = struct[1] - check_instance_type(input_name, [str]) - - return GraphInputReferenceSpec(input_name) - def to_struct(self): - return ['GraphInput', self.input_name] - - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' +class InputPathPlaceholder(ModelBase): #Non-standard attr names + '''Represents the command-line argument placeholder that will be replaced at run-time by a local file path pointing to a file containing the input argument value.''' + _serialized_names = { + #'input_name': 'inputPath', + 'input_name': 'file', #TODO: Rename to inputPath + } + def __init__(self, + input_name: str, + ): + super().__init__(locals()) -class TaskOutputReferenceSpec: - def __init__(self, task_id:str, output_name:str): - self.task_id = task_id - self.output_name = output_name - @staticmethod - def from_struct(struct:Union[Tuple[str, str, str],List[str]]): - if len(struct) != 3 or struct[0].lower() != 'TaskOutput'.lower(): - raise ValueError('Error parsing TaskOutputReferenceSpec: "{}". The correct format is [TaskOutput, Task ID, Output name].'.format(struct)) - task_id = struct[1] - output_name = struct[2] - - check_instance_type(task_id, [str]) - check_instance_type(output_name, [str]) - - return TaskOutputReferenceSpec(task_id, output_name) +class OutputPathPlaceholder(ModelBase): #Non-standard attr names + '''Represents the command-line argument placeholder that will be replaced at run-time by a local file path pointing to a file where the program should write its output data.''' + _serialized_names = { + #'output_name': 'outputPath', + 'output_name': 'output', #TODO: Rename to outputPath + } - def to_struct(self): - return ['TaskOutput', self.task_id, self.output_name] - - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' - - -class DataValueOrReferenceSpec: - def __init__(self, constant_value:str=None, graph_input:str=None, task_output:Tuple[str, str]=None): - if constant_value != None: - if graph_input != None or task_output != None: - raise ValueError('Specify only one argument.') - self.constant_value = constant_value - if graph_input != None: - if constant_value != None or task_output != None: - raise ValueError('Specify only one argument.') - self.graph_input = GraphInputReferenceSpec(graph_input) - if task_output != None: - if constant_value != None or graph_input != None: - raise ValueError('Specify only one argument.') - (task_id, output_name) = task_output - self.task_output = TaskOutputReferenceSpec(task_id, output_name) + def __init__(self, + output_name: str, + ): + super().__init__(locals()) - @staticmethod - def from_struct(struct:Union[str,Tuple[str, str],Tuple[str, str],List[str]]): - if isinstance(struct, tuple) or isinstance(struct, list): - kind = struct[0] - if kind.lower() == 'GraphInput'.lower(): - return DataValueOrReferenceSpec(graph_input=GraphInputReferenceSpec.from_struct(struct)) - elif kind.lower() == 'TaskOutput'.lower(): - return DataValueOrReferenceSpec(task_output=TaskOutputReferenceSpec.from_struct(struct)) - else: - raise ValueError('Found unknown input value spec: {}'.format(struct)) - return DataValueOrReferenceSpec(constant_value=str(struct)) - - def to_struct(self): - if self.constant_value != None: - return self.constant_value - if self.graph_input != None: - return self.graph_input.to_struct() - if self.task_output != None: - return self.task_output.to_struct() - raise AssertionError('Invalid internal state') - - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' +CommandlineArgumentType = Optional[Union[ + PrimitiveTypes, InputValuePlaceholder, InputPathPlaceholder, OutputPathPlaceholder, + 'ConcatPlaceholder', + 'IfPlaceholder', +]] -class TaskSpec: - def __init__(self, componet_id:str, inputValues:Mapping[str,DataValueOrReferenceSpec]=None, enabled:DataValueOrReferenceSpec=None): - if componet_id == None: - raise ValueError('componetId is required') - self.componet_id = componet_id - self.input_values = inputValues - self.enabled = enabled - @staticmethod - def from_struct(struct:Mapping): - struct = copy.deepcopy(struct) - - componet_id = str(struct.pop('componetId')) - spec = TaskSpec(componet_id) - - if 'inputValues' in struct: - spec.input_values = OrderedDict([(name, DataValueOrReferenceSpec.from_struct(value)) for name, value in struct.pop('inputValues').items()]) - if 'enabled' in struct: - spec.enabled = DataValueOrReferenceSpec.from_struct(struct.pop('enabled')) +class ConcatPlaceholder(ModelBase): #Non-standard attr names + '''Represents the command-line argument placeholder that will be replaced at run-time by the concatenated values of its items.''' + _serialized_names = { + 'items': 'concat', + } - if struct: - raise ValueError('Found unrecognized properties: {}'.format(struct)) - - return spec + def __init__(self, + items: List[CommandlineArgumentType], + ): + super().__init__(locals()) - def to_struct(self): - struct = OrderedDict() - - struct['componetId'] = self.componet_id - if self.input_values: - struct['inputValues'] = self.input_values.to_struct() - if self.enabled: - struct['enabled'] = self.enabled.to_struct() - - return struct - - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' +class IsPresentPlaceholder(ModelBase): #Non-standard attr names + '''Represents the command-line argument placeholder that will be replaced at run-time by a boolean value specifying whether the caller has passed an argument for the specified optional input.''' + _serialized_names = { + 'input_name': 'isPresent', + } -class GraphSpec: - def __init__(self, tasks:Mapping[str,TaskSpec], outputValues:Mapping[str,DataValueOrReferenceSpec]=None): - self.tasks = tasks - self.outputValues = outputValues - - @staticmethod - def from_struct(struct:Mapping): - struct = copy.deepcopy(struct) - - tasks_dict = struct.pop('tasks') - tasks = OrderedDict([(task_id, TaskSpec.from_struct(task_struct)) for task_id, task_struct in tasks_dict.items()]) + def __init__(self, + input_name: str, + ): + super().__init__(locals()) - obj = GraphSpec(tasks) - if 'outputValues' in struct: - outputValues_dict = struct.pop('outputValues') - obj.outputValues = OrderedDict([(name, DataValueOrReferenceSpec.from_struct(value_struct)) for name, value_struct in outputValues_dict.items()]) +IfConditionArgumentType = Union[bool, str, IsPresentPlaceholder, InputValuePlaceholder] - if struct: - raise ValueError('Found unrecognized properties: {}'.format(struct)) - - return obj - - def to_struct(self): - struct = OrderedDict() - if self.tasks: - struct['tasks'] = OrderedDict([(name, task_spec.to_struct()) for name, task_spec in self.tasks.items()]) - if self.outputValues: - struct['outputValues'] = OrderedDict([(name, value_spec.to_struct()) for name, value_spec in self.outputValues.items()]) - - return struct - - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' +class IfPlaceholderStructure(ModelBase): #Non-standard attr names + '''Used in by the IfPlaceholder - the command-line argument placeholder that will be replaced at run-time by the expanded value of either "then_value" or "else_value" depending on the submissio-time resolved value of the "cond" predicate.''' + _serialized_names = { + 'condition': 'cond', + 'then_value': 'then', + 'else_value': 'else', + } -class ImplementationSpec: - def __init__(self, container=None, graph=None): - if not container and not graph: - raise ValueError('Implementation is required') - if container and graph: - raise ValueError('Only one implementation can be specified') + def __init__(self, + condition: IfConditionArgumentType, + then_value: Union[CommandlineArgumentType, List[CommandlineArgumentType]], + else_value: Optional[Union[CommandlineArgumentType, List[CommandlineArgumentType]]] = None, + ): + super().__init__(locals()) - self.container = container - self.graph = graph - - @staticmethod - def from_struct(spec_dict:Mapping): - if len(spec_dict) != 1: - raise ValueError('There must be exactly one implementation') - - for name, value in spec_dict.items(): - if name == 'container': - return ImplementationSpec(ContainerSpec.from_struct(value)) - elif name == 'graph': - return ImplementationSpec(GraphSpec.from_struct(value)) - else: - raise ValueError('Unknown implementation type {}'.format(name)) - - def to_struct(self): - struct = {} - if self.container: - struct['container'] = self.container.to_struct() - if self.graph: - struct['graph'] = self.graph.to_struct() - - return struct - - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' +class IfPlaceholder(ModelBase): #Non-standard attr names + '''Represents the command-line argument placeholder that will be replaced at run-time by the expanded value of either "then_value" or "else_value" depending on the submissio-time resolved value of the "cond" predicate.''' + _serialized_names = { + 'if_structure': 'if', + } -class SourceSpec: - def __init__(self, url:str=None): - self.url = url - - @staticmethod - def from_struct(struct:Mapping): - struct = copy.deepcopy(struct) - spec = SourceSpec() - - if 'url' in struct: - spec.url = struct.pop('url') - check_instance_type(spec.url, [str]) + def __init__(self, + if_structure: IfPlaceholderStructure, + ): + super().__init__(locals()) - if struct: - raise ValueError('Found unrecognized properties: {}'.format(struct)) - - return spec - def to_struct(self): - struct = OrderedDict() - if self.url: - struct['url'] = self.url - return struct - - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' +class ContainerSpec(ModelBase): + '''Describes the container component implementation.''' + _serialized_names = { + 'file_outputs': 'fileOutputs', #TODO: rename to something like legacy_unconfigurable_output_paths + } + + def __init__(self, + image: str, + command: Optional[List[CommandlineArgumentType]] = None, + args: Optional[List[CommandlineArgumentType]] = None, + env: Optional[Mapping[str, str]] = None, + file_outputs: Optional[Mapping[str, str]] = None, #TODO: rename to something like legacy_unconfigurable_output_paths + ): + super().__init__(locals()) + + +class ContainerImplementation(ModelBase): + '''Represents the container component implementation.''' + def __init__(self, + container: ContainerSpec, + ): + super().__init__(locals()) -class ComponentSpec: +ImplementationType = Union[ContainerImplementation, 'GraphImplementation'] + + +class SourceSpec(ModelBase): + '''Specifies the location of the component source code.''' + def __init__(self, + url: str = None + ): + super().__init__(locals()) + + +class ComponentSpec(ModelBase): + '''Component specification. Describes the metadata (name, description, source), the interface (inputs and outputs) and the implementation of the component.''' def __init__( self, - implementation:ImplementationSpec, - name:str=None, - description:str=None, - source:Mapping=None, - inputs:List[InputSpec]=None, - outputs:List[OutputSpec]=None, - version:str='google.com/cloud/pipelines/component/v1', + implementation: ImplementationType, + name: Optional[str] = None, #? Move to metadata? + description: Optional[str] = None, #? Move to metadata? + source: Optional[SourceSpec] = None, #? Move to metadata? + inputs: Optional[List[InputSpec]] = None, + outputs: Optional[List[OutputSpec]] = None, + version: Optional[str] = 'google.com/cloud/pipelines/component/v1', + #tags: Optional[Set[str]] = None, ): - if not implementation: - raise ValueError('Implementation is required') - self.version = version - self.name = name - self.description = description - self.source = source - self.inputs = inputs - self.outputs = outputs - self.implementation = implementation - - @staticmethod - def _ports_collection_from_struct(struct_collection:Union[Mapping[str,Mapping],List[Mapping]]): - if isinstance(struct_collection, dict): - port_struct_iterator = struct_collection.items() - elif isinstance(struct_collection, list): - port_struct_iterator = struct_collection - else: - check_instance_type(struct_collection, [dict, list]) - #raise ValueError('Unknown inputs collection type: {}'.format(struct.__class__.__name__)) - return list([InputSpec.from_struct(port_struct) for port_struct in port_struct_iterator]) - - @staticmethod - def from_struct(struct:Mapping): - struct = copy.deepcopy(struct) - - implementation_struct = struct.pop('implementation') - implementation_spec = ImplementationSpec.from_struct(implementation_struct) - - spec = ComponentSpec(implementation_spec) - - if 'version' in struct: - spec.version = struct.pop('version') - if 'name' in struct: - spec.name = struct.pop('name') - if 'description' in struct: - spec.description = struct.pop('description') - if 'source' in struct: - spec.source = SourceSpec.from_struct(struct.pop('source')) - if 'inputs' in struct: - inputs_struct = struct.pop('inputs') - if isinstance(inputs_struct, dict): - input_iterator = inputs_struct.items() - elif isinstance(inputs_struct, list): - input_iterator = inputs_struct - else: - check_instance_type(inputs_struct, [dict, list]) - #raise ValueError('Unknown inputs collection type: {}'.format(inputs_struct.__class__.__name__)) - spec.inputs = [InputSpec.from_struct(input_struct) for input_struct in input_iterator] - if 'outputs' in struct: - outputs_struct = struct.pop('outputs') - if isinstance(outputs_struct, dict): - output_iterator = outputs_struct.items() - elif isinstance(outputs_struct, list): - output_iterator = outputs_struct - else: - check_instance_type(outputs_struct, [dict, list]) - #raise ValueError('Unknown outputs collection type: {}'.format(outputs_struct.__class__.__name__)) - spec.outputs = [OutputSpec.from_struct(output_struct) for output_struct in output_iterator] - - if struct: - raise ValueError('Found unrecognized properties: {}'.format(struct)) - - return spec + super().__init__(locals()) + self._post_init() - def to_struct(self): - struct = OrderedDict() - - if self.version: - struct['version'] = self.version - if self.name: - struct['name'] = self.name - if self.description: - struct['description'] = self.description - if self.source: - struct['source'] = self.source.to_struct() + def _post_init(self): + #Checking input names for uniqueness + self._inputs_dict = {} if self.inputs: - struct['inputs'] = [input.to_struct() for input in self.inputs] + for input in self.inputs: + if input.name in self._inputs_dict: + raise ValueError('Non-unique input name "{}"'.format(input.name)) + self._inputs_dict[input.name] = input + + #Checking output names for uniqueness + self._outputs_dict = {} if self.outputs: - struct['outputs'] = [output.to_struct() for output in self.outputs] - if self.implementation: - struct['implementation'] = self.implementation.to_struct() - - return struct + for output in self.outputs: + if output.name in self._outputs_dict: + raise ValueError('Non-unique output name "{}"'.format(output.name)) + self._outputs_dict[output.name] = output + + if isinstance(self.implementation, ContainerImplementation): + container = self.implementation.container + + if container.file_outputs: + for output_name, path in container.file_outputs.items(): + if output_name not in self._outputs_dict: + raise TypeError('Unconfigurable output entry "{}" references non-existing output.'.format({output_name: path})) + + def verify_arg(arg): + if arg is None: + pass + elif isinstance(arg, (str, int, float, bool)): + pass + elif isinstance(arg, list): + for arg2 in arg: + verify_arg(arg2) + elif isinstance(arg, (InputValuePlaceholder, InputPathPlaceholder, IsPresentPlaceholder)): + if arg.input_name not in self._inputs_dict: + raise TypeError('Argument "{}" references non-existing input.'.format(arg)) + elif isinstance(arg, OutputPathPlaceholder): + if arg.output_name not in self._outputs_dict: + raise TypeError('Argument "{}" references non-existing output.'.format(arg)) + elif isinstance(arg, ConcatPlaceholder): + for arg2 in arg.items: + verify_arg(arg2) + elif isinstance(arg, IfPlaceholder): + verify_arg(arg.if_structure.condition) + verify_arg(arg.if_structure.then_value) + verify_arg(arg.if_structure.else_value) + else: + raise TypeError('Unexpected argument "{}"'.format(arg)) + + verify_arg(container.command) + verify_arg(container.args) + + if isinstance(self.implementation, GraphImplementation): + graph = self.implementation.graph + + if graph.output_values is not None: + for output_name, argument in graph.output_values.items(): + if output_name not in self._outputs_dict: + raise TypeError('Graph output argument entry "{}" references non-existing output.'.format({output_name: argument})) + + if graph.tasks is not None: + for task in graph.tasks.values(): + if task.arguments is not None: + for argument in task.arguments.values(): + if isinstance(argument, GraphInputArgument) and argument.input_name not in self._inputs_dict: + raise TypeError('Argument "{}" references non-existing input.'.format(argument)) + + +class ComponentReference(ModelBase): + '''Component reference. Contains information that can be used to locate and load a component by name, digest or URL''' + def __init__(self, + name: Optional[str] = None, + digest: Optional[str] = None, + tag: Optional[str] = None, + url: Optional[str] = None, + ): + super().__init__(locals()) + self._post_init() + + def _post_init(self) -> None: + if not any([self.name, self.digest, self.tag, self.url]): + raise TypeError('Need at least one argument.') + + +class GraphInputArgument(ModelBase): + '''Represents the component argument value that comes from the graph component input.''' + _serialized_names = { + 'input_name': 'graphInput', + } + + def __init__(self, + input_name: str, + ): + super().__init__(locals()) + + +class TaskOutputReference(ModelBase): + '''References the output of some task (the scope is a single graph).''' + _serialized_names = { + 'task_id': 'taskId', + 'output_name': 'outputName', + } + + def __init__(self, + task_id: str, + output_name: str, + ): + super().__init__(locals()) + + +class TaskOutputArgument(ModelBase): #Has additional constructor for convenience + '''Represents the component argument value that comes from the output of another task.''' + _serialized_names = { + 'task_output': 'taskOutput', + } + + def __init__(self, + task_output: TaskOutputReference, + ): + super().__init__(locals()) + + @staticmethod + def construct( + task_id: str, + output_name: str, + ) -> 'TaskOutputArgument': + return TaskOutputArgument(TaskOutputReference( + task_id=task_id, + output_name=output_name, + )) + + +ArgumentType = Union[PrimitiveTypes, GraphInputArgument, TaskOutputArgument] + + +class TwoOperands(ModelBase): + def __init__(self, + op1: ArgumentType, + op2: ArgumentType, + ): + super().__init__(locals()) + + +class BinaryPredicate(ModelBase): #abstract base type + def __init__(self, + operands: TwoOperands + ): + super().__init__(locals()) + + +class EqualsPredicate(BinaryPredicate): + '''Represents the "equals" comparison predicate.''' + _serialized_names = {'operands': '=='} + + +class NotEqualsPredicate(BinaryPredicate): + '''Represents the "not equals" comparison predicate.''' + _serialized_names = {'operands': '!='} + + +class GreaterThanPredicate(BinaryPredicate): + '''Represents the "greater than" comparison predicate.''' + _serialized_names = {'operands': '>'} + + +class GreaterThanOrEqualPredicate(BinaryPredicate): + '''Represents the "greater than or equal" comparison predicate.''' + _serialized_names = {'operands': '>='} + + +class LessThenPredicate(BinaryPredicate): + '''Represents the "less than" comparison predicate.''' + _serialized_names = {'operands': '<'} + + +class LessThenOrEqualPredicate(BinaryPredicate): + '''Represents the "less than or equal" comparison predicate.''' + _serialized_names = { 'operands': '<='} + + +PredicateType = Union[ + ArgumentType, + EqualsPredicate, NotEqualsPredicate, GreaterThanPredicate, GreaterThanOrEqualPredicate, LessThenPredicate, LessThenOrEqualPredicate, + 'NotPredicate', 'AndPredicate', 'OrPredicate', +] + + +class TwoBooleanOperands(ModelBase): + def __init__(self, + op1: PredicateType, + op2: PredicateType, + ): + super().__init__(locals()) + + +class NotPredicate(ModelBase): + '''Represents the "not" logical operation.''' + _serialized_names = {'operand': 'not'} + + def __init__(self, + operand: PredicateType + ): + super().__init__(locals()) + + +class AndPredicate(ModelBase): + '''Represents the "and" logical operation.''' + _serialized_names = {'operands': 'and'} + + def __init__(self, + operands: TwoBooleanOperands + ) : + super().__init__(locals()) + +class OrPredicate(ModelBase): + '''Represents the "or" logical operation.''' + _serialized_names = {'operands': 'or'} + + def __init__(self, + operands: TwoBooleanOperands + ): + super().__init__(locals()) + + +class TaskSpec(ModelBase): + '''Task specification. Task is a "configured" component - a component supplied with arguments and other applied configuration changes.''' + _serialized_names = { + 'component_ref': 'componentRef', + 'is_enabled': 'isEnabled', + 'k8s_container_options': 'k8sContainerOptions', + 'k8s_pod_options': 'k8sPodOptions', + } + + def __init__(self, + component_ref: ComponentReference, + arguments: Optional[Mapping[str, ArgumentType]] = None, + is_enabled: Optional[PredicateType] = None, + k8s_container_options: Optional[v1.Container] = None, + k8s_pod_options: Optional[v1.PodArgoSubset] = None, + ): + super().__init__(locals()) + + +class GraphSpec(ModelBase): + '''Describes the graph component implementation. It represents a graph of component tasks connected to the upstream sources of data using the argument specifications. It also describes the sources of graph output values.''' + _serialized_names = { + 'output_values': 'outputValues', + } + + def __init__(self, + tasks: Mapping[str, TaskSpec], + output_values: Mapping[str, ArgumentType] = None, + ): + super().__init__(locals()) + self._post_init() - def __repr__(self): - return self.__class__.__name__ + '.from_struct(' + str(self.to_struct()) + ')' + def _post_init(self): + #Checking task output references and preparing the dependency table + task_dependencies = {} + for task_id, task in self.tasks.items(): + dependencies = set() + task_dependencies[task_id] = dependencies + if task.arguments is not None: + for argument in task.arguments.values(): + if isinstance(argument, TaskOutputArgument): + dependencies.add(argument.task_output.task_id) + if argument.task_output.task_id not in self.tasks: + raise TypeError('Argument "{}" references non-existing task.'.format(argument)) + + #Topologically sorting tasks to detect cycles + task_dependents = {k: set() for k in task_dependencies.keys()} + for task_id, dependencies in task_dependencies.items(): + for dependency in dependencies: + task_dependents[dependency].add(task_id) + task_number_of_remaining_dependencies = {k: len(v) for k, v in task_dependencies.items()} + sorted_tasks = OrderedDict() + def process_task(task_id): + if task_number_of_remaining_dependencies[task_id] == 0 and task_id not in sorted_tasks: + sorted_tasks[task_id] = self.tasks[task_id] + for dependent_task in task_dependents[task_id]: + task_number_of_remaining_dependencies[dependent_task] = task_number_of_remaining_dependencies[dependent_task] - 1 + process_task(dependent_task) + for task_id in task_dependencies.keys(): + process_task(task_id) + if len(sorted_tasks) != len(task_dependencies): + tasks_with_unsatisfied_dependencies = {k: v for k, v in task_number_of_remaining_dependencies.items() if v > 0} + task_wth_minimal_number_of_unsatisfied_dependencies = min(tasks_with_unsatisfied_dependencies.keys(), key=lambda task_id: tasks_with_unsatisfied_dependencies[task_id]) + raise ValueError('Task "{}" has cyclical dependency.'.format(task_wth_minimal_number_of_unsatisfied_dependencies)) + + self._toposorted_tasks = sorted_tasks + + +class GraphImplementation(ModelBase): + '''Represents the graph component implementation.''' + def __init__(self, + graph: GraphSpec, + ): + super().__init__(locals()) + + +class PipelineRunSpec(ModelBase): + '''The object that can be sent to the backend to start a new Run.''' + _serialized_names = { + 'root_task': 'rootTask', + #'on_exit_task': 'onExitTask', + } + + def __init__(self, + root_task: TaskSpec, + #on_exit_task: Optional[TaskSpec] = None, + ): + super().__init__(locals()) diff --git a/sdk/python/kfp/components/modelbase.py b/sdk/python/kfp/components/modelbase.py new file mode 100644 index 00000000000..dae2ce685b9 --- /dev/null +++ b/sdk/python/kfp/components/modelbase.py @@ -0,0 +1,284 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [ + 'ModelBase', +] + +import inspect +from collections import abc, OrderedDict +from typing import Any, Callable, Dict, List, Mapping, MutableMapping, MutableSequence, Optional, Sequence, Tuple, Type, TypeVar, Union, cast, get_type_hints + + +T = TypeVar('T') + + +def verify_object_against_type(x: Any, typ: Type[T]) -> T: + '''Verifies that the object is compatible to the specified type (types from the typing package can be used).''' + #TODO: Merge with parse_object_from_struct_based_on_type which has almost the same code + if typ is type(None): + if x is None: + return x + else: + raise TypeError('Error: Object "{}" is not None.'.format(x)) + + if typ is Any or type(typ) is TypeVar: + return x + + try: #isinstance can fail for generics + if isinstance(x, typ): + return cast(typ, x) + except: + pass + + if hasattr(typ, '__origin__'): #Handling generic types + if typ.__origin__ is Union: #Optional == Union + exception_map = {} + possible_types = typ.__args__ + if type(None) in possible_types and x is None: #Shortcut for Optional[] tests. Can be removed, but the exceptions will be more noisy. + return x + for possible_type in possible_types: + try: + verify_object_against_type(x, possible_type) + return x + except Exception as ex: + exception_map[possible_type] = ex + pass + #exception_lines = ['Exception for type {}: {}.'.format(t, e) for t, e in exception_map.items()] + exception_lines = [str(e) for t, e in exception_map.items()] + exception_lines.append('Error: Object "{}" is incompatible with type "{}".'.format(x, typ)) + raise TypeError('\n'.join(exception_lines)) + + #not Union => not None + if x is None: + raise TypeError('Error: None object is incompatible with type {}'.format(typ)) + + #assert isinstance(x, typ.__origin__) + generic_type = typ.__origin__ or getattr(typ, '__extra__', None) #In python <3.7 typing.List.__origin__ == None; Python 3.7 has working __origin__, but no __extra__ TODO: Remove the __extra__ once we move to Python 3.7 + if generic_type in [list, List, abc.Sequence, abc.MutableSequence, Sequence, MutableSequence] and type(x) is not str: #! str is also Sequence + if not isinstance(x, generic_type): + raise TypeError('Error: Object "{}" is incompatible with type "{}"'.format(x, typ)) + type_args = typ.__args__ if typ.__args__ is not None else (Any, Any) #Workaround for Python <3.7 (where Mapping.__args__ is None) + inner_type = type_args[0] + for item in x: + verify_object_against_type(item, inner_type) + return x + + elif generic_type in [dict, Dict, abc.Mapping, abc.MutableMapping, Mapping, MutableMapping, OrderedDict]: + if not isinstance(x, generic_type): + raise TypeError('Error: Object "{}" is incompatible with type "{}"'.format(x, typ)) + type_args = typ.__args__ if typ.__args__ is not None else (Any, Any) #Workaround for Python <3.7 (where Mapping.__args__ is None) + inner_key_type = type_args[0] + inner_value_type = type_args[1] + for k, v in x.items(): + verify_object_against_type(k, inner_key_type) + verify_object_against_type(v, inner_value_type) + return x + + else: + raise TypeError('Error: Unsupported generic type "{}". type.__origin__ or type.__extra__ == "{}"'.format(typ, generic_type)) + + raise TypeError('Error: Object "{}" is incompatible with type "{}"'.format(x, typ)) + + +def parse_object_from_struct_based_on_type(struct: Any, typ: Type[T]) -> T: + '''Constructs an object from structure (usually dict) based on type. Supports list and dict types from the typing package plus Optional[] and Union[] types. + If some type is a class that has .from_struct class method, that method is used for object construction. + ''' + if typ is type(None): + if struct is None: + return None + else: + raise TypeError('Error: Structure "{}" is not None.'.format(struct)) + + if typ is Any or type(typ) is TypeVar: + return struct + + try: #isinstance can fail for generics + #if (isinstance(struct, typ) + # and not (typ is Sequence and type(struct) is str) #! str is also Sequence + # and not (typ is int and type(struct) is bool) #! bool is int + #): + if type(struct) is typ: + return struct + except: + pass + + if hasattr(typ, 'from_struct'): + try: #More informative errors + return typ.from_struct(struct) + except Exception as ex: + raise TypeError('Error: {}.from_struct(struct={}) failed with exception:\n{}'.format(typ.__name__, struct, str(ex))) + if hasattr(typ, '__origin__'): #Handling generic types + if typ.__origin__ is Union: #Optional == Union + results = {} + exception_map = {} + possible_types = list(typ.__args__) + #if type(None) in possible_types and struct is None: #Shortcut for Optional[] tests. Can be removed, but the exceptions will be more noisy. + # return None + #Hack for Python <3.7 which for some reason "simplifies" Union[bool, int, ...] to just Union[int, ...] + if int in possible_types: + possible_types = possible_types + [bool] + for possible_type in possible_types: + try: + obj = parse_object_from_struct_based_on_type(struct, possible_type) + results[possible_type] = obj + except Exception as ex: + exception_map[possible_type] = ex + pass + + #Single successful parsing. + if len(results) == 1: + return list(results.values())[0] + + if len(results) > 1: + raise TypeError('Error: Structure "{}" is ambiguous. It can be parsed to multiple types: {}.'.format(struct, list(results.keys()))) + + exception_lines = [str(e) for t, e in exception_map.items()] + exception_lines.append('Error: Structure "{}" is incompatible with type "{}" - none of the types in Union are compatible.'.format(struct, typ)) + raise TypeError('\n'.join(exception_lines)) + #not Union => not None + if struct is None: + raise TypeError('Error: None structure is incompatible with type {}'.format(typ)) + + #assert isinstance(x, typ.__origin__) + generic_type = typ.__origin__ or getattr(typ, '__extra__', None) #In python <3.7 typing.List.__origin__ == None; Python 3.7 has working __origin__, but no __extra__ TODO: Remove the __extra__ once we move to Python 3.7 + if generic_type in [list, List, abc.Sequence, abc.MutableSequence, Sequence, MutableSequence] and type(struct) is not str: #! str is also Sequence + if not isinstance(struct, generic_type): + raise TypeError('Error: Structure "{}" is incompatible with type "{}" - it does not have list type.'.format(struct, typ)) + type_args = typ.__args__ if typ.__args__ is not None else (Any, Any) #Workaround for Python <3.7 (where Mapping.__args__ is None) + inner_type = type_args[0] + return [parse_object_from_struct_based_on_type(item, inner_type) for item in struct] + + elif generic_type in [dict, Dict, abc.Mapping, abc.MutableMapping, Mapping, MutableMapping, OrderedDict]: #in Python <3.7 there is a difference between abc.Mapping and typing.Mapping + if not isinstance(struct, generic_type): + raise TypeError('Error: Structure "{}" is incompatible with type "{}" - it does not have dict type.'.format(struct, typ)) + type_args = typ.__args__ if typ.__args__ is not None else (Any, Any) #Workaround for Python <3.7 (where Mapping.__args__ is None) + inner_key_type = type_args[0] + inner_value_type = type_args[1] + return {parse_object_from_struct_based_on_type(k, inner_key_type): parse_object_from_struct_based_on_type(v, inner_value_type) for k, v in struct.items()} + + else: + raise TypeError('Error: Unsupported generic type "{}". type.__origin__ or type.__extra__ == "{}"'.format(typ, generic_type)) + + raise TypeError('Error: Structure "{}" is incompatible with type "{}". Structure is not the instance of the type, the type does not have .from_struct method and is not generic.'.format(struct, typ)) + + +def convert_object_to_struct(obj, serialized_names: Mapping[str, str] = {}): + '''Converts an object to structure (usually a dict). + Serializes all properties that do not start with underscores. + If the type of some property is a class that has .to_struct class method, that method is used for conversion. + Used by the ModelBase class. + ''' + signature = inspect.signature(obj.__init__) #Needed for default values + result = {} + for python_name, value in obj.__dict__.items(): #TODO: Should we take the fields from the constructor parameters instead? #TODO: Make it possible to specify the field ordering + if python_name.startswith('_'): + continue + attr_name = serialized_names.get(python_name, python_name) + if hasattr(value, "to_struct"): + result[attr_name] = value.to_struct() + elif isinstance(value, list): + result[attr_name] = [(x.to_struct() if hasattr(x, 'to_struct') else x) for x in value] + elif isinstance(value, dict): + result[attr_name] = {k: (v.to_struct() if hasattr(v, 'to_struct') else v) for k, v in value.items()} + else: + param = signature.parameters.get(python_name, None) + if param is None or param.default == inspect.Parameter.empty or value != param.default: + result[attr_name] = value + + return result + + +def parse_object_from_struct_based_on_class_init(cls : Type[T], struct: Mapping, serialized_names: Mapping[str, str] = {}) -> T: + '''Constructs an object of specified class from structure (usually dict) using the class.__init__ method. + Converts all constructor arguments to appropriate types based on the __init__ type hints. + Used by the ModelBase class. + + Arguments: + + serialized_names: specifies the mapping between __init__ parameter names and the structure key names for cases where these names are different (due to language syntax clashes or style differences). + ''' + parameter_types = get_type_hints(cls.__init__) #Properlty resolves forward references + + serialized_names_to_pythonic = {v: k for k, v in serialized_names.items()} + #If a pythonic name has a different original name, we forbid the pythonic name in the structure. Otherwise, this function would accept "python-styled" structures that should be invalid + forbidden_struct_keys = set(serialized_names_to_pythonic.values()).difference(serialized_names_to_pythonic.keys()) + args = {} + for original_name, value in struct.items(): + if original_name in forbidden_struct_keys: + raise ValueError('Use "{}" key instead of pythonic key "{}" in the structure: {}.'.format(serialized_names[original_name], original_name, struct)) + python_name = serialized_names_to_pythonic.get(original_name, original_name) + param_type = parameter_types.get(python_name, None) + if param_type is not None: + args[python_name] = parse_object_from_struct_based_on_type(value, param_type) + else: + args[python_name] = value + + return cls(**args) + + +class ModelBase: + '''Base class for types that can be converted to JSON-like dict structures or constructed from such structures. + The object fields, their types and default values are taken from the __init__ method arguments. + Override the _serialized_names mapping to control the key names of the serialized structures. + + The derived class objects will have the .from_struct and .to_struct methods for conversion to or from structure. The base class constructor accepts the arguments map, checks the argument types and sets the object field values. + + Example derived class: + + class TaskSpec(ModelBase): + _serialized_names = { + 'component_ref': 'componentRef', + 'is_enabled': 'isEnabled', + } + + def __init__(self, + component_ref: ComponentReference, + arguments: Optional[Mapping[str, ArgumentType]] = None, + is_enabled: Optional[Union[ArgumentType, EqualsPredicate, NotEqualsPredicate]] = None, #Optional property with default value + ): + super().__init__(locals()) #Calling the ModelBase constructor to check the argument types and set the object field values. + + task_spec = TaskSpec.from_struct("{'componentRef': {...}, 'isEnabled: {'and': {...}}}") # = instance of TaskSpec + task_struct = task_spec.to_struct() #= "{'componentRef': {...}, 'isEnabled: {'and': {...}}}" + ''' + _serialized_names = {} + def __init__(self, args): + parameter_types = get_type_hints(self.__class__.__init__) + field_values = {k: v for k, v in args.items() if k != 'self' and not k.startswith('_')} + for k, v in field_values.items(): + parameter_type = parameter_types.get(k, None) + if parameter_type is not None: + verify_object_against_type(v, parameter_type) + self.__dict__.update(field_values) + + @classmethod + def from_struct(cls: Type[T], struct: Mapping) -> T: + return parse_object_from_struct_based_on_class_init(cls, struct, serialized_names=cls._serialized_names) + + def to_struct(self) -> Mapping: + return convert_object_to_struct(self, serialized_names=self._serialized_names) + + def _get_field_names(self): + return list(inspect.signature(self.__init__).parameters) + + def __repr__(self): + return self.__class__.__name__ + '(' + ', '.join(param + '=' + repr(getattr(self, param)) for param in self._get_field_names()) + ')' + + def __eq__(self, other): + return self.__class__ == other.__class__ and {k: getattr(self, k) for k in self._get_field_names()} == {k: getattr(self, k) for k in other._get_field_names()} + + def __ne__(self, other): + return not self == other \ No newline at end of file diff --git a/sdk/python/kfp/components/structures/__init__.py b/sdk/python/kfp/components/structures/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/kfp/components/structures/kubernetes/__init__.py b/sdk/python/kfp/components/structures/kubernetes/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/sdk/python/kfp/components/structures/kubernetes/v1.py b/sdk/python/kfp/components/structures/kubernetes/v1.py new file mode 100644 index 00000000000..9f61ec64140 --- /dev/null +++ b/sdk/python/kfp/components/structures/kubernetes/v1.py @@ -0,0 +1,455 @@ +__all__ = [ + 'Container', + 'PodArgoSubset', +] + + +from collections import OrderedDict + +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union + +from ...modelbase import ModelBase + + +class EnvVar(ModelBase): + _serialized_names = { + 'value_from': 'valueFrom', + } + def __init__(self, + name: str, + value: Optional[str] = None, + #value_from: Optional[EnvVarSource] = None, #TODO: Add if needed + ): + super().__init__(locals()) + + +class ExecAction(ModelBase): + def __init__(self, + command: List[str], + ): + super().__init__(locals()) + + +class Handler(ModelBase): + _serialized_names = { + 'http_get': 'httpGet', + 'tcp_socket': 'tcpSocket', + } + def __init__(self, + exec: Optional[ExecAction] = None, + #http_get: Optional[HTTPGetAction] = None, #TODO: Add if needed + #tcp_socket: Optional[TCPSocketAction] = None, #TODO: Add if needed + ): + super().__init__(locals()) + + +class Lifecycle(ModelBase): + _serialized_names = { + 'post_start': 'postStart', + 'pre_stop': 'preStop', + } + def __init__(self, + post_start: Optional[Handler] = None, + pre_stop: Optional[Handler] = None, + ): + super().__init__(locals()) + + +class VolumeMount(ModelBase): + _serialized_names = { + 'mount_path': 'mountPath', + 'mount_propagation': 'mountPropagation', + 'read_only': 'readOnly', + 'sub_path': 'subPath', + } + def __init__(self, + name: str, + mount_path: str, + mount_propagation: Optional[str] = None, + read_only: Optional[bool] = None, + sub_path: Optional[str] = None, + ): + super().__init__(locals()) + + +class ResourceRequirements(ModelBase): + def __init__(self, + limits: Optional[Dict[str, str]] = None, + requests: Optional[Dict[str, str]] = None, + ): + super().__init__(locals()) + + +class ContainerPort(ModelBase): + _serialized_names = { + 'container_port': 'containerPort', + 'host_ip': 'hostIP', + 'host_port': 'hostPort', + } + def __init__(self, + container_port: int, + host_ip: Optional[str] = None, + host_port: Optional[int] = None, + name: Optional[str] = None, + protocol: Optional[str] = None, + ): + super().__init__(locals()) + + +class VolumeDevice(ModelBase): + _serialized_names = { + 'device_path': 'devicePath', + } + def __init__(self, + device_path: str, + name: str, + ): + super().__init__(locals()) + + +class Probe(ModelBase): + _serialized_names = { + 'failure_threshold': 'failureThreshold', + 'http_get': 'httpGet', + 'initial_delay_seconds': 'initialDelaySeconds', + 'period_seconds': 'periodSeconds', + 'success_threshold': 'successThreshold', + 'tcp_socket': 'tcpSocket', + 'timeout_seconds': 'timeoutSeconds' + } + def __init__(self, + exec: Optional[ExecAction] = None, + failure_threshold: Optional[int] = None, + #http_get: Optional[HTTPGetAction] = None, #TODO: Add if needed + initial_delay_seconds: Optional[int] = None, + period_seconds: Optional[int] = None, + success_threshold: Optional[int] = None, + #tcp_socket: Optional[TCPSocketAction] = None, #TODO: Add if needed + timeout_seconds: Optional[int] = None, + ): + super().__init__(locals()) + + +class SecurityContext(ModelBase): + _serialized_names = { + 'allow_privilege_escalation': 'allowPrivilegeEscalation', + 'capabilities': 'capabilities', + 'privileged': 'privileged', + 'read_only_root_filesystem': 'readOnlyRootFilesystem', + 'run_as_group': 'runAsGroup', + 'run_as_non_root': 'runAsNonRoot', + 'run_as_user': 'runAsUser', + 'se_linux_options': 'seLinuxOptions' + } + def __init__(self, + allow_privilege_escalation: Optional[bool] = None, + #capabilities: Optional[Capabilities] = None, #TODO: Add if needed + privileged: Optional[bool] = None, + read_only_root_filesystem: Optional[bool] = None, + run_as_group: Optional[int] = None, + run_as_non_root: Optional[bool] = None, + run_as_user: Optional[int] = None, + #se_linux_options: Optional[SELinuxOptions] = None, #TODO: Add if needed + ): + super().__init__(locals()) + + +class Container(ModelBase): + _serialized_names = { + 'env_from': 'envFrom', + 'image_pull_policy': 'imagePullPolicy', + 'liveness_probe': 'livenessProbe', + 'readiness_probe': 'readinessProbe', + 'security_context': 'securityContext', + 'stdin_once': 'stdinOnce', + 'termination_message_path': 'terminationMessagePath', + 'termination_message_policy': 'terminationMessagePolicy', + 'volume_devices': 'volumeDevices', + 'volume_mounts': 'volumeMounts', + 'working_dir': 'workingDir', + } + + def __init__(self, + #Better to set at Component level + image: Optional[str] = None, + command: Optional[List[str]] = None, + args: Optional[List[str]] = None, + env: Optional[List[EnvVar]] = None, + + working_dir: Optional[str] = None, #Not really needed: container should have proper working dir set up + + lifecycle: Optional[Lifecycle] = None, #Can be used to specify pre-exit commands to run TODO: Probably support at Component level. + + #Better to set at Task level + volume_mounts: Optional[List[VolumeMount]] = None, + resources: Optional[ResourceRequirements] = None, + + #Might not be used a lot + ports: Optional[List[ContainerPort]] = None, + #env_from: Optional[List[EnvFromSource]] = None, #TODO: Add if needed + volume_devices: Optional[List[VolumeDevice]] = None, + + #Probably not needed + name: Optional[str] = None, #Required by k8s schema, but not Argo. + image_pull_policy: Optional[str] = None, + liveness_probe: Optional[Probe] = None, + readiness_probe: Optional[Probe] = None, + security_context: Optional[SecurityContext] = None, + stdin: Optional[bool] = None, + stdin_once: Optional[bool] = None, + termination_message_path: Optional[str] = None, + termination_message_policy: Optional[str] = None, + tty: Optional[bool] = None, + ): + super().__init__(locals()) + + +#class NodeAffinity(ModelBase): +# _serialized_names = { +# 'preferred_during_scheduling_ignored_during_execution': 'preferredDuringSchedulingIgnoredDuringExecution', +# 'required_during_scheduling_ignored_during_execution': 'requiredDuringSchedulingIgnoredDuringExecution', +# } +# def __init__(self, +# preferred_during_scheduling_ignored_during_execution: Optional[List[PreferredSchedulingTerm]] = None, +# required_during_scheduling_ignored_during_execution: Optional[NodeSelector] = None, +# ): +# super().__init__(locals()) + + +#class Affinity(ModelBase): +# _serialized_names = { +# 'node_affinity': 'nodeAffinity', +# 'pod_affinity': 'podAffinity', +# 'pod_anti_affinity': 'podAntiAffinity', +# } +# def __init__(self, +# node_affinity: Optional[NodeAffinity] = None, +# #pod_affinity: Optional[PodAffinity] = None, #TODO: Add if needed +# #pod_anti_affinity: Optional[PodAntiAffinity] = None, #TODO: Add if needed +# ): +# super().__init__(locals()) + + +class Toleration(ModelBase): + _serialized_names = { + 'toleration_seconds': 'tolerationSeconds', + } + def __init__(self, + effect: Optional[str] = None, + key: Optional[str] = None, + operator: Optional[str] = None, + toleration_seconds: Optional[int] = None, + value: Optional[str] = None, + ): + super().__init__(locals()) + + +class KeyToPath(ModelBase): + def __init__(self, + key: str, + path: str, + mode: Optional[int] = None, + ): + super().__init__(locals()) + + +class SecretVolumeSource(ModelBase): + _serialized_names = { + 'default_mode': 'defaultMode', + 'secret_name': 'secretName' + } + def __init__(self, + default_mode: Optional[int] = None, + items: Optional[List[KeyToPath]] = None, + optional: Optional[bool] = None, + secret_name: Optional[str] = None, + ): + super().__init__(locals()) + + +class NFSVolumeSource(ModelBase): + _serialized_names = { + 'read_only': 'readOnly', + } + def __init__(self, + path: str, + server: str, + read_only: Optional[bool] = None, + ): + super().__init__(locals()) + + +class PersistentVolumeClaimVolumeSource(ModelBase): + _serialized_names = { + 'claim_name': 'claimName', + 'read_only': 'readOnly' + } + def __init__(self, + claim_name: str, + read_only: Optional[bool] = None, + ): + super().__init__(locals()) + + +class Volume(ModelBase): + _serialized_names = { + 'aws_elastic_block_store': 'awsElasticBlockStore', + 'azure_disk': 'azureDisk', + 'azure_file': 'azureFile', + 'cephfs': 'cephfs', + 'cinder': 'cinder', + 'config_map': 'configMap', + 'downward_api': 'downwardAPI', + 'empty_dir': 'emptyDir', + 'fc': 'fc', + 'flex_volume': 'flexVolume', + 'flocker': 'flocker', + 'gce_persistent_disk': 'gcePersistentDisk', + 'git_repo': 'gitRepo', + 'glusterfs': 'glusterfs', + 'host_path': 'hostPath', + 'iscsi': 'iscsi', + 'name': 'name', + 'nfs': 'nfs', + 'persistent_volume_claim': 'persistentVolumeClaim', + 'photon_persistent_disk': 'photonPersistentDisk', + 'portworx_volume': 'portworxVolume', + 'projected': 'projected', + 'quobyte': 'quobyte', + 'rbd': 'rbd', + 'scale_io': 'scaleIO', + 'secret': 'secret', + 'storageos': 'storageos', + 'vsphere_volume': 'vsphereVolume' + } + + def __init__(self, + name: str, + secret: Optional[SecretVolumeSource] = None, + nfs: Optional[NFSVolumeSource] = None, + persistent_volume_claim: Optional[PersistentVolumeClaimVolumeSource] = None, + + #No validation for these volume types + aws_elastic_block_store: Optional[Mapping] = None, #AWSElasticBlockStoreVolumeSource, + azure_disk: Optional[Mapping] = None, #AzureDiskVolumeSource, + azure_file: Optional[Mapping] = None, #AzureFileVolumeSource, + cephfs: Optional[Mapping] = None, #CephFSVolumeSource, + cinder: Optional[Mapping] = None, #CinderVolumeSource, + config_map: Optional[Mapping] = None, #ConfigMapVolumeSource, + downward_api: Optional[Mapping] = None, #DownwardAPIVolumeSource, + empty_dir: Optional[Mapping] = None, #EmptyDirVolumeSource, + fc: Optional[Mapping] = None, #FCVolumeSource, + flex_volume: Optional[Mapping] = None, #FlexVolumeSource, + flocker: Optional[Mapping] = None, #FlockerVolumeSource, + gce_persistent_disk: Optional[Mapping] = None, #GCEPersistentDiskVolumeSource, + git_repo: Optional[Mapping] = None, #GitRepoVolumeSource, + glusterfs: Optional[Mapping] = None, #GlusterfsVolumeSource, + host_path: Optional[Mapping] = None, #HostPathVolumeSource, + iscsi: Optional[Mapping] = None, #ISCSIVolumeSource, + photon_persistent_disk: Optional[Mapping] = None, #PhotonPersistentDiskVolumeSource, + portworx_volume: Optional[Mapping] = None, #PortworxVolumeSource, + projected: Optional[Mapping] = None, #ProjectedVolumeSource, + quobyte: Optional[Mapping] = None, #QuobyteVolumeSource, + rbd: Optional[Mapping] = None, #RBDVolumeSource, + scale_io: Optional[Mapping] = None, #ScaleIOVolumeSource, + storageos: Optional[Mapping] = None, #StorageOSVolumeSource, + vsphere_volume: Optional[Mapping] = None, #VsphereVirtualDiskVolumeSource, + ): + super().__init__(locals()) + + +class PodSpecArgoSubset(ModelBase): + _serialized_names = { + 'active_deadline_seconds': 'activeDeadlineSeconds', + 'affinity': 'affinity', + #'automount_service_account_token': 'automountServiceAccountToken', + #'containers': 'containers', + #'dns_config': 'dnsConfig', + #'dns_policy': 'dnsPolicy', + #'host_aliases': 'hostAliases', + #'host_ipc': 'hostIPC', + #'host_network': 'hostNetwork', + #'host_pid': 'hostPID', + #'hostname': 'hostname', + #'image_pull_secrets': 'imagePullSecrets', + #'init_containers': 'initContainers', + #'node_name': 'nodeName', + 'node_selector': 'nodeSelector', + #'priority': 'priority', + #'priority_class_name': 'priorityClassName', + #'readiness_gates': 'readinessGates', + #'restart_policy': 'restartPolicy', + #'scheduler_name': 'schedulerName', + #'security_context': 'securityContext', + #'service_account': 'serviceAccount', + #'service_account_name': 'serviceAccountName', + #'share_process_namespace': 'shareProcessNamespace', + #'subdomain': 'subdomain', + #'termination_grace_period_seconds': 'terminationGracePeriodSeconds', + 'tolerations': 'tolerations', + 'volumes': 'volumes', + } + def __init__(self, + active_deadline_seconds: Optional[int] = None, + affinity: Optional[Mapping] = None, #Affinity, #No validation + #automount_service_account_token: Optional[bool] = None, #Not supported by Argo + #containers: Optional[List[Container]] = None, #Not supported by Argo + #dns_config: Optional[PodDNSConfig] = None, #Not supported by Argo + #dns_policy: Optional[str] = None, #Not supported by Argo + #host_aliases: Optional[List[HostAlias]] = None, #Not supported by Argo + #host_ipc: Optional[bool] = None, #Not supported by Argo + #host_network: Optional[bool] = None, #Not supported by Argo + #host_pid: Optional[bool] = None, #Not supported by Argo + #hostname: Optional[str] = None, #Not supported by Argo + #image_pull_secrets: Optional[List[LocalObjectReference]] = None, #Not supported by Argo + #init_containers: Optional[List[Container]] = None, #Not supported by Argo + #node_name: Optional[str] = None, #Not supported by Argo + node_selector: Optional[Dict[str, str]] = None, + #priority: Optional[int] = None, #Not supported by Argo + #priority_class_name: Optional[str] = None, #Not supported by Argo + #readiness_gates: Optional[List[PodReadinessGate]] = None, #Not supported by Argo + #restart_policy: Optional[str] = None, #Not supported by Argo + #scheduler_name: Optional[str] = None, #Not supported by Argo + #security_context: Optional[PodSecurityContext] = None, #Not supported by Argo + #service_account: Optional[str] = None, #Not supported by Argo + #service_account_name: Optional[str] = None, #Not supported by Argo + #share_process_namespace: Optional[bool] = None, #Not supported by Argo + #subdomain: Optional[str] = None, #Not supported by Argo + #termination_grace_period_seconds: Optional[int] = None, #Not supported by Argo + tolerations: Optional[List[Toleration]] = None, + volumes: Optional[List[Volume]] = None, #Argo only supports volumes at the Workflow level + + #+Argo features: + #+Metadata: ArgoMetadata? (Argo version) + #+RetryStrategy: ArgoRetryStrategy ~= k8s.JobSpec.backoffLimit + #+Parallelism: int + ): + super().__init__(locals()) + + +class ObjectMetaArgoSubset(ModelBase): + def __init__(self, + annotations: Optional[Dict[str, str]] = None, + labels: Optional[Dict[str, str]] = None, + ): + super().__init__(locals()) + + +class PodArgoSubset(ModelBase): + _serialized_names = { + 'api_version': 'apiVersion', + 'kind': 'kind', + 'metadata': 'metadata', + 'spec': 'spec', + 'status': 'status', + } + def __init__(self, + #api_version: Optional[str] = None, + #kind: Optional[str] = None, + #metadata: Optional[ObjectMeta] = None, + metadata: Optional[ObjectMetaArgoSubset] = None, + #spec: Optional[PodSpec] = None, + spec: Optional[PodSpecArgoSubset] = None, + #status: Optional[PodStatus] = None, + ): + super().__init__(locals()) diff --git a/sdk/python/setup.py b/sdk/python/setup.py index a6ae7f9b651..4cabf8574d7 100644 --- a/sdk/python/setup.py +++ b/sdk/python/setup.py @@ -32,6 +32,8 @@ 'kfp', 'kfp.compiler', 'kfp.components', + 'kfp.components.structures', + 'kfp.components.structures.kubernetes', 'kfp.dsl', 'kfp.notebook', 'kfp_experiment', diff --git a/sdk/python/tests/components/test_components.py b/sdk/python/tests/components/test_components.py index 29adc49a155..51967028b0f 100644 --- a/sdk/python/tests/components/test_components.py +++ b/sdk/python/tests/components/test_components.py @@ -17,6 +17,7 @@ import unittest from pathlib import Path +sys.path.insert(0, __file__ + '/../../../') import kfp.components as comp from kfp.components._yaml_utils import load_yaml @@ -41,6 +42,7 @@ def test_load_component_from_file(self): assert task1.arguments[0] == str(arg1) assert task1.arguments[1] == str(arg2) + @unittest.skip @unittest.expectedFailure #The repo is non-public and will change soon. TODO: Update the URL and enable the test once we move to a public repo def test_load_component_from_url(self): url = 'https://raw.githubusercontent.com/kubeflow/pipelines/638045974d688b473cda9f4516a2cf1d7d1e02dd/sdk/python/tests/components/test_data/python_add.component.yaml' @@ -73,7 +75,7 @@ def test_loading_minimal_component(self): task1 = task_factory1() assert task1.image == component_dict['implementation']['container']['image'] - @unittest.expectedFailure #TODO: Check this in the ComponentSpec class, not during materialization. + @unittest.expectedFailure def test_fail_on_duplicate_input_names(self): component_text = '''\ inputs: @@ -85,7 +87,6 @@ def test_fail_on_duplicate_input_names(self): ''' task_factory1 = comp.load_component_from_text(component_text) - @unittest.skip #TODO: Fix in the ComponentSpec class @unittest.expectedFailure def test_fail_on_duplicate_output_names(self): component_text = '''\ @@ -176,7 +177,6 @@ def test_handle_duplicate_input_output_names(self): ''' task_factory1 = comp.load_component_from_text(component_text) - @unittest.skip #TODO: FIX: @unittest.expectedFailure def test_fail_on_unknown_value_argument(self): component_text = '''\ @@ -317,6 +317,8 @@ def test_automatic_output_resolving(self): task1 = task_factory1() self.assertEqual(len(task1.arguments), 2) + self.assertEqual(task1.arguments[0], '--output-data') + self.assertTrue(task1.arguments[1].startswith('/')) def test_optional_inputs_reordering(self): '''Tests optional input reordering. diff --git a/sdk/python/tests/components/test_graph_components.py b/sdk/python/tests/components/test_graph_components.py new file mode 100644 index 00000000000..848d9dad7b9 --- /dev/null +++ b/sdk/python/tests/components/test_graph_components.py @@ -0,0 +1,176 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest +from pathlib import Path + +sys.path.insert(0, __file__ + '/../../../') + +import kfp.components as comp +from kfp.components._structures import ComponentReference, ComponentSpec, ContainerSpec, GraphInputArgument, GraphSpec, InputSpec, InputValuePlaceholder, GraphImplementation, OutputPathPlaceholder, OutputSpec, TaskOutputArgument, TaskSpec + +from kfp.components._yaml_utils import load_yaml + +class GraphComponentTestCase(unittest.TestCase): + def test_handle_constructing_graph_component(self): + task1 = TaskSpec(component_ref=ComponentReference(name='comp 1'), arguments={'in1 1': 11}) + task2 = TaskSpec(component_ref=ComponentReference(name='comp 2'), arguments={'in2 1': 21, 'in2 2': TaskOutputArgument.construct(task_id='task 1', output_name='out1 1')}) + task3 = TaskSpec(component_ref=ComponentReference(name='comp 3'), arguments={'in3 1': TaskOutputArgument.construct(task_id='task 2', output_name='out2 1'), 'in3 2': GraphInputArgument(input_name='graph in 1')}) + + graph_component1 = ComponentSpec( + inputs=[ + InputSpec(name='graph in 1'), + InputSpec(name='graph in 2'), + ], + outputs=[ + OutputSpec(name='graph out 1'), + OutputSpec(name='graph out 2'), + ], + implementation=GraphImplementation(graph=GraphSpec( + tasks={ + 'task 1': task1, + 'task 2': task2, + 'task 3': task3, + }, + output_values={ + 'graph out 1': TaskOutputArgument.construct(task_id='task 3', output_name='out3 1'), + 'graph out 2': TaskOutputArgument.construct(task_id='task 1', output_name='out1 2'), + } + )) + ) + + def test_handle_parsing_graph_component(self): + component_text = '''\ +inputs: +- {name: graph in 1} +- {name: graph in 2} +outputs: +- {name: graph out 1} +- {name: graph out 2} +implementation: + graph: + tasks: + task 1: + componentRef: {name: Comp 1} + arguments: + in1 1: 11 + task 2: + componentRef: {name: Comp 2} + arguments: + in2 1: 21 + in2 2: {taskOutput: {taskId: task 1, outputName: out1 1}} + task 3: + componentRef: {name: Comp 3} + arguments: + in3 1: {taskOutput: {taskId: task 2, outputName: out2 1}} + in3 2: {graphInput: graph in 1} + outputValues: + graph out 1: {taskOutput: {taskId: task 3, outputName: out3 1}} + graph out 2: {taskOutput: {taskId: task 1, outputName: out1 2}} +''' + struct = load_yaml(component_text) + ComponentSpec.from_struct(struct) + + @unittest.expectedFailure + def test_fail_on_cyclic_references(self): + component_text = '''\ +implementation: + graph: + tasks: + task 1: + componentRef: {name: Comp 1} + arguments: + in1 1: {taskOutput: {taskId: task 2, outputName: out2 1}} + task 2: + componentRef: {name: Comp 2} + arguments: + in2 1: {taskOutput: {taskId: task 1, outputName: out1 1}} +''' + struct = load_yaml(component_text) + ComponentSpec.from_struct(struct) + + def test_handle_parsing_predicates(self): + component_text = '''\ +implementation: + graph: + tasks: + task 1: + componentRef: {name: Comp 1} + task 2: + componentRef: {name: Comp 2} + arguments: + in2 1: 21 + in2 2: {taskOutput: {taskId: task 1, outputName: out1 1}} + isEnabled: + not: + and: + op1: + '>': + op1: {taskOutput: {taskId: task 1, outputName: out1 1}} + op2: 0 + op2: + '==': + op1: {taskOutput: {taskId: task 1, outputName: out1 2}} + op2: 'head' +''' + struct = load_yaml(component_text) + ComponentSpec.from_struct(struct) + + def test_handle_parsing_task_container_spec_options(self): + component_text = '''\ +implementation: + graph: + tasks: + task 1: + componentRef: {name: Comp 1} + k8sContainerOptions: + resources: + requests: + memory: 1024Mi + cpu: 200m + +''' + struct = load_yaml(component_text) + component_spec = ComponentSpec.from_struct(struct) + self.assertEqual(component_spec.implementation.graph.tasks['task 1'].k8s_container_options.resources.requests['memory'], '1024Mi') + + + def test_handle_parsing_task_volumes_and_mounts(self): + component_text = '''\ +implementation: + graph: + tasks: + task 1: + componentRef: {name: Comp 1} + k8sContainerOptions: + volumeMounts: + - name: workdir + mountPath: /mnt/vol + k8sPodOptions: + spec: + volumes: + - name: workdir + emptyDir: {} +''' + struct = load_yaml(component_text) + component_spec = ComponentSpec.from_struct(struct) + self.assertEqual(component_spec.implementation.graph.tasks['task 1'].k8s_pod_options.spec.volumes[0].name, 'workdir') + self.assertTrue(component_spec.implementation.graph.tasks['task 1'].k8s_pod_options.spec.volumes[0].empty_dir is not None) + +#TODO: Test task name conversion to Argo-compatible names + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/python/tests/components/test_structure_model_base.py b/sdk/python/tests/components/test_structure_model_base.py new file mode 100644 index 00000000000..1ece9dc6d1f --- /dev/null +++ b/sdk/python/tests/components/test_structure_model_base.py @@ -0,0 +1,234 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest +from pathlib import Path + +from typing import List, Dict, Union, Optional +from kfp.components.modelbase import ModelBase + +class TestModel1(ModelBase): + _serialized_names = { + 'prop_1': 'prop1', + 'prop_2': 'prop 2', + 'prop_3': '@@', + } + + def __init__(self, + prop_0: str, + prop_1: Optional[str] = None, + prop_2: Union[int, str, bool] = '', + prop_3: 'TestModel1' = None, + prop_4: Optional[Dict[str, 'TestModel1']] = None, + prop_5: Optional[Union['TestModel1', List['TestModel1']]] = None, + ): + #print(locals()) + super().__init__(locals()) + + +class StructureModelBaseTestCase(unittest.TestCase): + def test_handle_type_check_for_simple_builtin(self): + self.assertEqual(TestModel1(prop_0='value 0').prop_0, 'value 0') + + with self.assertRaises(TypeError): + TestModel1(prop_0=1) + + with self.assertRaises(TypeError): + TestModel1(prop_0=None) + + with self.assertRaises(TypeError): + TestModel1(prop_0=TestModel1(prop_0='value 0')) + + def test_handle_type_check_for_optional_builtin(self): + self.assertEqual(TestModel1(prop_0='', prop_1='value 1').prop_1, 'value 1') + self.assertEqual(TestModel1(prop_0='', prop_1=None).prop_1, None) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_1=1) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_1=TestModel1(prop_0='', prop_1='value 1')) + + def test_handle_type_check_for_union_builtin(self): + self.assertEqual(TestModel1(prop_0='', prop_2='value 2').prop_2, 'value 2') + self.assertEqual(TestModel1(prop_0='', prop_2=22).prop_2, 22) + self.assertEqual(TestModel1(prop_0='', prop_2=True).prop_2, True) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_2=None) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_2=22.22) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_2=TestModel1(prop_0='', prop_2='value 2')) + + def test_handle_type_check_for_class(self): + val3 = TestModel1(prop_0='value 0') + self.assertEqual(TestModel1(prop_0='', prop_3=val3).prop_3, val3) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_3=1) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_3='value 3') + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_3=[val3]) + + def test_handle_type_check_for_dict_class(self): + val4 = TestModel1(prop_0='value 0') + self.assertEqual(TestModel1(prop_0='', prop_4={'key 4': val4}).prop_4['key 4'], val4) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_4=1) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_4='value 4') + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_4=[val4]) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_4={42: val4}) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_4={'key 4': [val4]}) + + def test_handle_type_check_for_union_dict_class(self): + val5 = TestModel1(prop_0='value 0') + self.assertEqual(TestModel1(prop_0='', prop_5=val5).prop_5, val5) + self.assertEqual(TestModel1(prop_0='', prop_5=[val5]).prop_5[0], val5) + self.assertEqual(TestModel1(prop_0='', prop_5=None).prop_5, None) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_5=1) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_5='value 5') + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_5={'key 5': 'value 5'}) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_5={42: val5}) + + with self.assertRaises(TypeError): + TestModel1(prop_0='', prop_5={'key 5': [val5]}) + + def test_handle_from_to_struct_for_simple_builtin(self): + struct0 = {'prop_0': 'value 0'} + obj0 = TestModel1.from_struct(struct0) + self.assertEqual(obj0.prop_0, 'value 0') + self.assertDictEqual(obj0.to_struct(), struct0) + + with self.assertRaises(AttributeError): #TypeError: + TestModel1.from_struct(None) + + with self.assertRaises(AttributeError): #TypeError: + TestModel1.from_struct('') + + with self.assertRaises(TypeError): + TestModel1.from_struct({}) + + with self.assertRaises(TypeError): + TestModel1.from_struct({'prop0': 'value 0'}) + + def test_handle_from_to_struct_for_optional_builtin(self): + struct11 = {'prop_0': '', 'prop1': 'value 1'} + obj11 = TestModel1.from_struct(struct11) + self.assertEqual(obj11.prop_1, struct11['prop1']) + self.assertDictEqual(obj11.to_struct(), struct11) + + struct12 = {'prop_0': '', 'prop1': None} + obj12 = TestModel1.from_struct(struct12) + self.assertEqual(obj12.prop_1, None) + self.assertDictEqual(obj12.to_struct(), {'prop_0': ''}) + + with self.assertRaises(TypeError): + TestModel1.from_struct({'prop_0': '', 'prop 1': ''}) + + with self.assertRaises(TypeError): + TestModel1.from_struct({'prop_0': '', 'prop1': 1}) + + def test_handle_from_to_struct_for_union_builtin(self): + struct21 = {'prop_0': '', 'prop 2': 'value 2'} + obj21 = TestModel1.from_struct(struct21) + self.assertEqual(obj21.prop_2, struct21['prop 2']) + self.assertDictEqual(obj21.to_struct(), struct21) + + struct22 = {'prop_0': '', 'prop 2': 22} + obj22 = TestModel1.from_struct(struct22) + self.assertEqual(obj22.prop_2, struct22['prop 2']) + self.assertDictEqual(obj22.to_struct(), struct22) + + struct23 = {'prop_0': '', 'prop 2': True} + obj23 = TestModel1.from_struct(struct23) + self.assertEqual(obj23.prop_2, struct23['prop 2']) + self.assertDictEqual(obj23.to_struct(), struct23) + + with self.assertRaises(TypeError): + TestModel1.from_struct({'prop_0': 'ZZZ', 'prop 2': None}) + + with self.assertRaises(TypeError): + TestModel1.from_struct({'prop_0': '', 'prop 2': 22.22}) + + def test_handle_from_to_struct_for_class(self): + val3 = TestModel1(prop_0='value 0') + + struct31 = {'prop_0': '', '@@': val3.to_struct()} #{'prop_0': '', '@@': TestModel1(prop_0='value 0')} is also valid for from_struct, but this cannot happen when parsing for real + obj31 = TestModel1.from_struct(struct31) + self.assertEqual(obj31.prop_3, val3) + self.assertDictEqual(obj31.to_struct(), struct31) + + with self.assertRaises(TypeError): + TestModel1.from_struct({'prop_0': '', '@@': 'value 3'}) + + def test_handle_from_to_struct_for_dict_class(self): + val4 = TestModel1(prop_0='value 0') + + struct41 = {'prop_0': '', 'prop_4': {'val 4': val4.to_struct()}} + obj41 = TestModel1.from_struct(struct41) + self.assertEqual(obj41.prop_4['val 4'], val4) + self.assertDictEqual(obj41.to_struct(), struct41) + + + with self.assertRaises(TypeError): + TestModel1.from_struct({'prop_0': '', 'prop_4': {44: val4.to_struct()}}) + + + def test_handle_from_to_struct_for_union_dict_class(self): + val5 = TestModel1(prop_0='value 0') + + struct51 = {'prop_0': '', 'prop_5': val5.to_struct()} + obj51 = TestModel1.from_struct(struct51) + self.assertEqual(obj51.prop_5, val5) + self.assertDictEqual(obj51.to_struct(), struct51) + + struct52 = {'prop_0': '', 'prop_5': [val5.to_struct()]} + obj52 = TestModel1.from_struct(struct52) + self.assertListEqual(obj52.prop_5, [val5]) + self.assertDictEqual(obj52.to_struct(), struct52) + + with self.assertRaises(TypeError): + TestModel1.from_struct({'prop_0': '', 'prop_5': {44: val5.to_struct()}}) + + with self.assertRaises(TypeError): + TestModel1.from_struct({'prop_0': '', 'prop_5': [val5.to_struct(), None]}) + + +if __name__ == '__main__': + unittest.main()