Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK - Components - Added output references to TaskSpec #1991

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def _create_task_factory_from_component_spec(component_spec:ComponentSpec, compo

if component_ref is None:
component_ref = ComponentReference(name=component_spec.name or component_filename or _default_component_name)
component_ref._component_spec = component_spec
component_ref.spec = component_spec

def create_task_from_component_and_arguments(pythonic_arguments):
arguments = {}
Expand Down Expand Up @@ -238,6 +238,8 @@ def create_task_from_component_and_arguments(pythonic_arguments):
component_ref=component_ref,
arguments=arguments,
)
task._init_outputs()

if _created_task_transformation_handler:
task = _created_task_transformation_handler[-1](task)
return task
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/kfp/components/_dsl_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

def create_container_op_from_task(task_spec: TaskSpec):
argument_values = task_spec.arguments
component_spec = task_spec.component_ref._component_spec
component_spec = task_spec.component_ref.spec

if not isinstance(component_spec.implementation, ContainerImplementation):
raise TypeError('Only container component tasks can be converted to ContainerOp')
Expand Down
23 changes: 21 additions & 2 deletions sdk/python/kfp/components/_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,12 +315,13 @@ def __init__(self,
digest: Optional[str] = None,
tag: Optional[str] = None,
url: Optional[str] = None,
spec: Optional[ComponentSpec] = None,
):
super().__init__(locals())
self._post_init()

def _post_init(self) -> None:
if not any([self.name, self.digest, self.tag, self.url]):
if not any([self.name, self.digest, self.tag, self.url, self.spec]):
raise TypeError('Need at least one argument.')


Expand All @@ -344,10 +345,13 @@ class TaskOutputReference(ModelBase):
}

def __init__(self,
task_id: str,
output_name: str,
task_id: Optional[str] = None, # Used for linking to the upstream task in serialized component file.
task: Optional['TaskSpec'] = None, # Used for linking to the upstream task in runtime since Task does not have an ID until inserted into a graph.
):
super().__init__(locals())
if self.task_id is None and self.task is None:
raise TypeError('task_id and task cannot be None at the same time.')


class TaskOutputArgument(ModelBase): #Has additional constructor for convenience
Expand Down Expand Up @@ -483,6 +487,21 @@ def __init__(self,
super().__init__(locals())
#TODO: If component_ref is resolved to component spec, then check that the arguments correspond to the inputs

def _init_outputs(self):
#Adding output references to the task
if self.component_ref.spec is None:
return
task_outputs = OrderedDict()
for output in self.component_ref.spec.outputs or []:
task_output_ref = TaskOutputReference(
output_name=output.name,
task=self,
)
task_output_arg = TaskOutputArgument(task_output=task_output_ref)
task_outputs[output.name] = task_output_arg

self.outputs = task_outputs


class GraphSpec(ModelBase):
'''Describes the graph component implementation. It represents a graph of component tasks connected to the upstream sources of data using the argument specifications. It also describes the sources of graph output values.'''
Expand Down
27 changes: 27 additions & 0 deletions sdk/python/tests/components/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import sys
import unittest
from contextlib import contextmanager
from pathlib import Path


Expand All @@ -23,6 +24,16 @@
from kfp.components._yaml_utils import load_yaml
from kfp.dsl.types import InconsistentTypeException


@contextmanager
def no_task_resolving_context():
old_handler = kfp.components._components._created_task_transformation_handler
try:
kfp.components._components._created_task_transformation_handler = None
yield None
finally:
kfp.components._components._created_task_transformation_handler = old_handler

class LoadComponentTestCase(unittest.TestCase):
def _test_load_component_from_file(self, component_path: str):
task_factory1 = comp.load_component_from_file(component_path)
Expand Down Expand Up @@ -561,6 +572,22 @@ def test_passing_component_metadata_to_container_op(self):
self.assertEqual(task1.pod_annotations['key1'], 'value1')
self.assertEqual(task1.pod_labels['key1'], 'value1')

def test_check_task_spec_outputs_dictionary(self):
component_text = '''\
outputs:
- {name: out 1}
- {name: out 2}
implementation:
container:
image: busybox
command: [touch, {outputPath: out 1}, {outputPath: out 2}]
'''
op = comp.load_component_from_text(component_text)
with no_task_resolving_context():
task = op()

self.assertEqual(list(task.outputs.keys()), ['out 1', 'out 2'])

def test_type_compatibility_check_for_simple_types(self):
component_a = '''\
outputs:
Expand Down