Skip to content

Commit

Permalink
SDK - Components - Added type to graph input references (#2451)
Browse files Browse the repository at this point in the history
This makes the graph input references consistent with task output references.
This is a breaking change, but the graph components are not exposed in the documentation or samples yet.
  • Loading branch information
Ark-kun authored and k8s-ci-robot committed Oct 24, 2019
1 parent 3386a4b commit 681d873
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 26 deletions.
2 changes: 1 addition & 1 deletion sdk/python/kfp/components/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def resolve_argument(argument):
if isinstance(argument, (str, int, float, bool)):
return argument
elif isinstance(argument, GraphInputArgument):
return graph_input_arguments[argument.input_name]
return graph_input_arguments[argument.graph_input.input_name]
elif isinstance(argument, TaskOutputArgument):
upstream_task_output_ref = argument.task_output
upstream_task_outputs = outputs_of_tasks[upstream_task_output_ref.task_id]
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/kfp/components/_python_to_graph_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from typing import Callable

from . import _components
from ._structures import TaskSpec, ComponentSpec, OutputSpec, GraphInputArgument, TaskOutputArgument, GraphImplementation, GraphSpec
from ._structures import TaskSpec, ComponentSpec, OutputSpec, GraphInputReference, TaskOutputArgument, GraphImplementation, GraphSpec
from ._naming import _make_name_unique_by_adding_index
from ._python_op import _extract_component_interface

Expand Down Expand Up @@ -90,7 +90,7 @@ def task_construction_handler(task: TaskSpec):

# Preparing the pipeline_func arguments
# TODO: The key should be original parameter name if different
pipeline_func_args = {input.name: GraphInputArgument(input_name=input.name) for input in input_specs}
pipeline_func_args = {input.name: GraphInputReference(input_name=input.name).as_argument() for input in input_specs}

try:
#Setting the handler to fix and catch the tasks.
Expand Down
31 changes: 28 additions & 3 deletions sdk/python/kfp/components/_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

'ComponentReference',

'GraphInputReference',
'GraphInputArgument',
'TaskOutputReference',
'TaskOutputArgument',
Expand Down Expand Up @@ -306,7 +307,7 @@ def verify_arg(arg):
for task in graph.tasks.values():
if task.arguments is not None:
for argument in task.arguments.values():
if isinstance(argument, GraphInputArgument) and argument.input_name not in self._inputs_dict:
if isinstance(argument, GraphInputArgument) and argument.graph_input.input_name not in self._inputs_dict:
raise TypeError('Argument "{}" references non-existing input.'.format(argument))

def save(self, file_path: str):
Expand Down Expand Up @@ -334,14 +335,38 @@ def _post_init(self) -> None:
raise TypeError('Need at least one argument.')


class GraphInputReference(ModelBase):
'''References the input of the graph (the scope is a single graph).'''
_serialized_names = {
'input_name': 'inputName',
}

def __init__(self,
input_name: str,
type: Optional[TypeSpecType] = None, # Can be used to override the reference data type
):
super().__init__(locals())

def as_argument(self) -> 'GraphInputArgument':
return GraphInputArgument(graph_input=self)

def with_type(self, type_spec: TypeSpecType) -> 'GraphInputReference':
return GraphInputReference(
input_name=self.input_name,
type=type_spec,
)

def without_type(self) -> 'GraphInputReference':
return self.with_type(None)

class GraphInputArgument(ModelBase):
'''Represents the component argument value that comes from the graph component input.'''
_serialized_names = {
'input_name': 'graphInput',
'graph_input': 'graphInput',
}

def __init__(self,
input_name: str,
graph_input: GraphInputReference,
):
super().__init__(locals())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ implementation:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/create_dataset_for_tables/component.yaml
arguments:
gcp_project_id:
graphInput: gcp_project_id
graphInput:
inputName: gcp_project_id
gcp_region:
graphInput: gcp_region
graphInput:
inputName: gcp_region
display_name:
graphInput: dataset_display_name
graphInput:
inputName: dataset_display_name
Automl import data from bigquery:
componentRef:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/import_data_from_bigquery/component.yaml
Expand All @@ -56,7 +59,8 @@ implementation:
taskId: Automl create dataset for tables
type: String
input_uri:
graphInput: dataset_bq_input_uri
graphInput:
inputName: dataset_bq_input_uri
Automl split dataset table column names:
componentRef:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/split_dataset_table_column_names/component.yaml
Expand All @@ -67,18 +71,22 @@ implementation:
taskId: Automl import data from bigquery
type: String
target_column_name:
graphInput: target_column_name
graphInput:
inputName: target_column_name
table_index: '0'
Automl create model for tables:
componentRef:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/create_model_for_tables/component.yaml
arguments:
gcp_project_id:
graphInput: gcp_project_id
graphInput:
inputName: gcp_project_id
gcp_region:
graphInput: gcp_region
graphInput:
inputName: gcp_region
display_name:
graphInput: model_display_name
graphInput:
inputName: model_display_name
dataset_id:
taskOutput:
outputName: dataset_path
Expand All @@ -96,7 +104,8 @@ implementation:
type: JsonArray
optimization_objective: MAXIMIZE_AU_PRC
train_budget_milli_node_hours:
graphInput: train_budget_milli_node_hours
graphInput:
inputName: train_budget_milli_node_hours
Automl prediction service batch predict:
componentRef:
url: https://raw.githubusercontent.com/kubeflow/pipelines/b3179d86b239a08bf4884b50dbf3a9151da96d66/components/gcp/automl/prediction_service_batch_predict/component.yaml
Expand All @@ -107,9 +116,11 @@ implementation:
taskId: Automl create model for tables
type: String
gcs_output_uri_prefix:
graphInput: batch_predict_gcs_output_uri_prefix
graphInput:
inputName: batch_predict_gcs_output_uri_prefix
bq_input_uri:
graphInput: batch_predict_bq_input_uri
graphInput:
inputName: batch_predict_bq_input_uri
outputValues:
model_path:
taskOutput:
Expand Down
18 changes: 9 additions & 9 deletions sdk/python/tests/components/test_graph_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@


import kfp.components as comp
from kfp.components._structures import ComponentReference, ComponentSpec, ContainerSpec, GraphInputArgument, GraphSpec, InputSpec, InputValuePlaceholder, GraphImplementation, OutputPathPlaceholder, OutputSpec, TaskOutputArgument, TaskSpec
from kfp.components._structures import ComponentReference, ComponentSpec, ContainerSpec, GraphInputReference, GraphSpec, InputSpec, InputValuePlaceholder, GraphImplementation, OutputPathPlaceholder, OutputSpec, TaskOutputArgument, TaskSpec

from kfp.components._yaml_utils import load_yaml

class GraphComponentTestCase(unittest.TestCase):
def test_handle_constructing_graph_component(self):
task1 = TaskSpec(component_ref=ComponentReference(name='comp 1'), arguments={'in1 1': 11})
task2 = TaskSpec(component_ref=ComponentReference(name='comp 2'), arguments={'in2 1': 21, 'in2 2': TaskOutputArgument.construct(task_id='task 1', output_name='out1 1')})
task3 = TaskSpec(component_ref=ComponentReference(name='comp 3'), arguments={'in3 1': TaskOutputArgument.construct(task_id='task 2', output_name='out2 1'), 'in3 2': GraphInputArgument(input_name='graph in 1')})
task3 = TaskSpec(component_ref=ComponentReference(name='comp 3'), arguments={'in3 1': TaskOutputArgument.construct(task_id='task 2', output_name='out2 1'), 'in3 2': GraphInputReference(input_name='graph in 1').as_argument()})

graph_component1 = ComponentSpec(
inputs=[
Expand Down Expand Up @@ -75,7 +75,7 @@ def test_handle_parsing_graph_component(self):
componentRef: {name: Comp 3}
arguments:
in3 1: {taskOutput: {taskId: task 2, outputName: out2 1}}
in3 2: {graphInput: graph in 1}
in3 2: {graphInput: {inputName: graph in 1}}
outputValues:
graph out 1: {taskOutput: {taskId: task 3, outputName: out3 1}}
graph out 2: {taskOutput: {taskId: task 1, outputName: out1 2}}
Expand Down Expand Up @@ -231,11 +231,11 @@ def test_load_graph_component(self):
command: [sh, -c, 'cat "$0" "$1" > $2', {inputValue: in3_1}, {inputValue: in3_2}, {outputPath: out3_1}]
arguments:
in3_1: {taskOutput: {taskId: task 2, outputName: out2_1}}
in3_2: {graphInput: graph in 1}
in3_2: {graphInput: {inputName: graph in 1}}
outputValues:
graph out 1: {taskOutput: {taskId: task 3, outputName: out3_1}}
graph out 2: {taskOutput: {taskId: task 1, outputName: out1_2}}
graph out 3: {graphInput: graph in 2}
graph out 3: {graphInput: {inputName: graph in 2}}
graph out 4: '42'
'''
op = comp.load_component_from_text(component_text)
Expand Down Expand Up @@ -311,17 +311,17 @@ def test_load_nested_graph_components(self):
image: busybox
command: [sh, -c, 'cat "$0" "$1" > $2', {inputValue: in3_1}, {inputValue: in3_2}, {outputPath: out3_1}]
arguments:
in3_1: {graphInput: in3_1}
in3_2: {graphInput: in3_1}
in3_1: {graphInput: {inputName: in3_1}}
in3_2: {graphInput: {inputName: in3_1}}
outputValues:
out3_1: {taskOutput: {taskId: graph subtask, outputName: out3_1}}
arguments:
in3_1: {taskOutput: {taskId: task 2, outputName: out2_1}}
in3_2: {graphInput: graph in 1}
in3_2: {graphInput: {inputName: graph in 1}}
outputValues:
graph out 1: {taskOutput: {taskId: task 3, outputName: out3_1}}
graph out 2: {taskOutput: {taskId: task 1, outputName: out1_2}}
graph out 3: {graphInput: graph in 2}
graph out 3: {graphInput: {inputName: graph in 2}}
graph out 4: '42'
'''
op = comp.load_component_from_text(component_text)
Expand Down

0 comments on commit 681d873

Please sign in to comment.