diff --git a/sdk/python/kfp/components/_structures.py b/sdk/python/kfp/components/_structures.py index c3315d12820..f1874a6e170 100644 --- a/sdk/python/kfp/components/_structures.py +++ b/sdk/python/kfp/components/_structures.py @@ -489,21 +489,57 @@ def __init__(self, super().__init__(locals()) +class RetryStrategySpec(ModelBase): + _serialized_names = { + 'max_retries': 'maxRetries', + } + + def __init__(self, + max_retries: int, + ): + super().__init__(locals()) + + +class KubernetesExecutionOptionsSpec(ModelBase): + _serialized_names = { + 'main_container': 'mainContainer', + 'pod_spec': 'podSpec', + } + + def __init__(self, + metadata: Optional[v1.ObjectMetaArgoSubset] = None, + main_container: Optional[v1.Container] = None, + pod_spec: Optional[v1.PodSpecArgoSubset] = None, + ): + super().__init__(locals()) + + +class ExecutionOptionsSpec(ModelBase): + _serialized_names = { + 'retry_strategy': 'retryStrategy', + 'kubernetes_options': 'kubernetesOptions', + } + + def __init__(self, + retry_strategy: Optional[RetryStrategySpec] = None, + kubernetes_options: Optional[KubernetesExecutionOptionsSpec] = None, + ): + super().__init__(locals()) + + class TaskSpec(ModelBase): '''Task specification. Task is a "configured" component - a component supplied with arguments and other applied configuration changes.''' _serialized_names = { 'component_ref': 'componentRef', 'is_enabled': 'isEnabled', - 'k8s_container_options': 'k8sContainerOptions', - 'k8s_pod_options': 'k8sPodOptions', + 'execution_options': 'executionOptions' } def __init__(self, component_ref: ComponentReference, arguments: Optional[Mapping[str, ArgumentType]] = None, is_enabled: Optional[PredicateType] = None, - k8s_container_options: Optional[v1.Container] = None, - k8s_pod_options: Optional[v1.PodArgoSubset] = None, + execution_options: Optional[ExecutionOptionsSpec] = None, ): super().__init__(locals()) #TODO: If component_ref is resolved to component spec, then check that the arguments correspond to the inputs diff --git a/sdk/python/tests/components/test_graph_components.py b/sdk/python/tests/components/test_graph_components.py index c2505d5b785..99a65bb2e19 100644 --- a/sdk/python/tests/components/test_graph_components.py +++ b/sdk/python/tests/components/test_graph_components.py @@ -135,16 +135,18 @@ def test_handle_parsing_task_container_spec_options(self): tasks: task 1: componentRef: {name: Comp 1} - k8sContainerOptions: - resources: - requests: - memory: 1024Mi - cpu: 200m + executionOptions: + kubernetesOptions: + mainContainer: + resources: + requests: + memory: 1024Mi + cpu: 200m ''' struct = load_yaml(component_text) component_spec = ComponentSpec.from_dict(struct) - self.assertEqual(component_spec.implementation.graph.tasks['task 1'].k8s_container_options.resources.requests['memory'], '1024Mi') + self.assertEqual(component_spec.implementation.graph.tasks['task 1'].execution_options.kubernetes_options.main_container.resources.requests['memory'], '1024Mi') def test_handle_parsing_task_volumes_and_mounts(self): @@ -154,20 +156,21 @@ def test_handle_parsing_task_volumes_and_mounts(self): tasks: task 1: componentRef: {name: Comp 1} - k8sContainerOptions: - volumeMounts: - - name: workdir - mountPath: /mnt/vol - k8sPodOptions: - spec: - volumes: - - name: workdir - emptyDir: {} + executionOptions: + kubernetesOptions: + mainContainer: + volumeMounts: + - name: workdir + mountPath: /mnt/vol + podSpec: + volumes: + - name: workdir + emptyDir: {} ''' struct = load_yaml(component_text) component_spec = ComponentSpec.from_dict(struct) - self.assertEqual(component_spec.implementation.graph.tasks['task 1'].k8s_pod_options.spec.volumes[0].name, 'workdir') - self.assertTrue(component_spec.implementation.graph.tasks['task 1'].k8s_pod_options.spec.volumes[0].empty_dir is not None) + self.assertEqual(component_spec.implementation.graph.tasks['task 1'].execution_options.kubernetes_options.pod_spec.volumes[0].name, 'workdir') + self.assertIsNotNone(component_spec.implementation.graph.tasks['task 1'].execution_options.kubernetes_options.pod_spec.volumes[0].empty_dir) def test_load_graph_component(self): component_text = '''\