Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK - Compiler - Fixed handling of PipelineParams in artifact arguments #2042

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion sdk/python/kfp/compiler/_op_to_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _parameters_to_json(params: List[dsl.PipelineParam]):
def _inputs_to_json(
inputs_params: List[dsl.PipelineParam],
input_artifact_paths: Dict[str, str] = None,
artifact_arguments: Dict[str, str] = None,
) -> Dict[str, Dict]:
"""Converts a list of PipelineParam into an argo `inputs` JSON obj."""
parameters = _parameters_to_json(inputs_params)
Expand All @@ -138,6 +139,8 @@ def _inputs_to_json(
artifacts = []
for name, path in (input_artifact_paths or {}).items():
artifact = {'name': name, 'path': path}
if name in artifact_arguments: # The arguments should be compiled as DAG task arguments, not template's default values, but in the current DSL-compiler implementation it's too hard to make that work when passing artifact references.
artifact['raw'] = {'data': str(artifact_arguments[name])}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm not familiar with this. Are all artifacts guaranteed to be serializable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea for the next PR.
Right now we only support two types of artifact arguments - constant string and PipelineParam reference to some task's output, which is also serializable.
In future I can put support for serializing complex types here so that it's on par with the component argument serialization support.

artifacts.append(artifact)
artifacts.sort(key=lambda x: x['name']) #Stabilizing the input artifact ordering

Expand Down Expand Up @@ -229,7 +232,8 @@ def _op_to_template(op: BaseOp):

# inputs
input_artifact_paths = processed_op.input_artifact_paths if isinstance(processed_op, dsl.ContainerOp) else None
inputs = _inputs_to_json(processed_op.inputs, input_artifact_paths)
artifact_arguments = processed_op.artifact_arguments if isinstance(processed_op, dsl.ContainerOp) else None
inputs = _inputs_to_json(processed_op.inputs, input_artifact_paths, artifact_arguments)
if inputs:
template['inputs'] = inputs

Expand Down
12 changes: 0 additions & 12 deletions sdk/python/kfp/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,18 +515,6 @@ def _group_to_dag_template(self, group, inputs, outputs, dependencies):
})
arguments.sort(key=lambda x: x['name'])
task['arguments'] = {'parameters': arguments}

if isinstance(sub_group, dsl.ContainerOp) and sub_group.artifact_arguments:
artifact_argument_structs = []
for input_name, argument in sub_group.artifact_arguments.items():
artifact_argument_dict = {'name': input_name}
if isinstance(argument, str):
artifact_argument_dict['raw'] = {'data': str(argument)}
else:
raise TypeError('Argument "{}" was passed to the artifact input "{}", but only constant strings are supported at this moment.'.format(str(argument), input_name))
artifact_argument_structs.append(artifact_argument_dict)
task.setdefault('arguments', {})['artifacts'] = artifact_argument_structs

tasks.append(task)
tasks.sort(key=lambda x: x['name'])
template['dag'] = {'tasks': tasks}
Expand Down
7 changes: 2 additions & 5 deletions sdk/python/kfp/dsl/_container_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ def __init__(
"""

super().__init__(name=name, init_containers=init_containers, sidecars=sidecars, is_exit_handler=is_exit_handler)
self.attrs_with_pipelineparams = BaseOp.attrs_with_pipelineparams + ['_container', 'artifact_location'] #Copying the BaseOp class variable!
self.attrs_with_pipelineparams = BaseOp.attrs_with_pipelineparams + ['_container', 'artifact_location', 'artifact_arguments'] #Copying the BaseOp class variable!

input_artifact_paths = {}
artifact_arguments = {}
Expand All @@ -1025,10 +1025,7 @@ def resolve_artifact_argument(artarg):
input_name = getattr(artarg.input, 'name', artarg.input) or ('input-' + str(len(artifact_arguments)))
input_path = artarg.path or _generate_input_file_name(input_name)
input_artifact_paths[input_name] = input_path
if not isinstance(artarg.argument, str):
raise TypeError('Argument "{}" was passed to the artifact input "{}", but only constant strings are supported at this moment.'.format(str(artarg.argument), input_name))

artifact_arguments[input_name] = artarg.argument
artifact_arguments[input_name] = str(artarg.argument)
return input_path

for artarg in artifact_argument_paths or []:
Expand Down
36 changes: 12 additions & 24 deletions sdk/python/tests/compiler/testdata/input_artifact_raw_value.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ spec:
artifacts:
- name: text
path: /tmp/inputs/text/data
raw:
data: Constant artifact value
name: component-with-inline-input-artifact
outputs:
artifacts:
Expand All @@ -37,6 +39,8 @@ spec:
artifacts:
- name: text
path: /tmp/inputs/text/data
raw:
data: Constant artifact value
name: component-with-input-artifact
outputs:
artifacts:
Expand All @@ -55,6 +59,8 @@ spec:
artifacts:
- name: text
path: /tmp/inputs/text/data
raw:
data: hard-coded artifact value
name: component-with-input-artifact-2
outputs:
artifacts:
Expand All @@ -73,6 +79,8 @@ spec:
artifacts:
- name: text
path: /tmp/inputs/text/data
raw:
data: Text from a file with hard-coded artifact value
name: component-with-input-artifact-3
outputs:
artifacts:
Expand All @@ -84,32 +92,12 @@ spec:
path: /mlpipeline-metrics.json
- dag:
tasks:
- arguments:
artifacts:
- name: text
raw:
data: Constant artifact value
name: component-with-inline-input-artifact
- name: component-with-inline-input-artifact
template: component-with-inline-input-artifact
- arguments:
artifacts:
- name: text
raw:
data: Constant artifact value
name: component-with-input-artifact
- name: component-with-input-artifact
template: component-with-input-artifact
- arguments:
artifacts:
- name: text
raw:
data: hard-coded artifact value
name: component-with-input-artifact-2
- name: component-with-input-artifact-2
template: component-with-input-artifact-2
- arguments:
artifacts:
- name: text
raw:
data: Text from a file with hard-coded artifact value
name: component-with-input-artifact-3
- name: component-with-input-artifact-3
template: component-with-input-artifact-3
name: pipeline-with-artifact-input-raw-argument-value