diff --git a/sdk/python/kfp/dsl/__init__.py b/sdk/python/kfp/dsl/__init__.py index 1bf2721e28f..06fae50887b 100644 --- a/sdk/python/kfp/dsl/__init__.py +++ b/sdk/python/kfp/dsl/__init__.py @@ -17,4 +17,5 @@ from ._pipeline import Pipeline, pipeline, get_pipeline_conf from ._container_op import ContainerOp from ._ops_group import OpsGroup, ExitHandler, Condition -from ._python_component import python_component \ No newline at end of file +from ._component import python_component +#TODO: expose the component decorator when ready \ No newline at end of file diff --git a/sdk/python/kfp/dsl/_python_component.py b/sdk/python/kfp/dsl/_component.py similarity index 50% rename from sdk/python/kfp/dsl/_python_component.py rename to sdk/python/kfp/dsl/_component.py index bb14f37feae..001ed0a6d48 100644 --- a/sdk/python/kfp/dsl/_python_component.py +++ b/sdk/python/kfp/dsl/_component.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ._metadata import ComponentMeta, ParameterMeta, TypeMeta, _annotation_to_typemeta + def python_component(name, description=None, base_image=None, target_component_file: str = None): """Decorator for Python component functions. This decorator adds the metadata to the function object itself. @@ -47,3 +49,50 @@ def _python_component(func): return func return _python_component + +def component(func): + """Decorator for component functions that use ContainerOp. + This is useful to enable type checking in the DSL compiler + + Usage: + ```python + @dsl.component + def foobar(model: TFModel(), step: MLStep()): + return dsl.ContainerOp() + """ + def _component(*args, **kargs): + import inspect + fullargspec = inspect.getfullargspec(func) + annotations = fullargspec.annotations + + # defaults + arg_defaults = {} + if fullargspec.defaults: + for arg, default in zip(reversed(fullargspec.args), reversed(fullargspec.defaults)): + arg_defaults[arg] = default + + # Construct the ComponentMeta + component_meta = ComponentMeta(name=func.__name__, description='') + # Inputs + for arg in fullargspec.args: + arg_type = TypeMeta() + arg_default = arg_defaults[arg] if arg in arg_defaults else '' + if arg in annotations: + arg_type = _annotation_to_typemeta(annotations[arg]) + component_meta.inputs.append(ParameterMeta(name=arg, description='', param_type=arg_type, default=arg_default)) + # Outputs + for output in annotations['return']: + arg_type = _annotation_to_typemeta(annotations['return'][output]) + component_meta.outputs.append(ParameterMeta(name=output, description='', param_type=arg_type)) + + #TODO: add descriptions to the metadata + #docstring parser: + # https://github.com/rr-/docstring_parser + # https://github.com/terrencepreilly/darglint/blob/master/darglint/parse.py + + print(component_meta.serialize()) + #TODO: parse the metadata to the ContainerOp. + container_op = func(*args, **kargs) + return container_op + + return _component \ No newline at end of file diff --git a/sdk/python/kfp/dsl/_metadata.py b/sdk/python/kfp/dsl/_metadata.py index c47dda86f3c..411c26cfe39 100644 --- a/sdk/python/kfp/dsl/_metadata.py +++ b/sdk/python/kfp/dsl/_metadata.py @@ -14,7 +14,7 @@ from typing import Dict, List from abc import ABCMeta, abstractmethod -from ._types import _check_valid_type_dict +from ._types import BaseType, _check_valid_type_dict, _str_to_dict, _instance_to_dict class BaseMeta(object): __metaclass__ = ABCMeta @@ -104,4 +104,23 @@ def to_dict(self): return {'name': self.name, 'description': self.description, 'inputs': [ input.to_dict() for input in self.inputs ] - } \ No newline at end of file + } + +def _annotation_to_typemeta(annotation): + '''_annotation_to_type_meta converts an annotation to an instance of TypeMeta + Args: + annotation(BaseType/str/dict): input/output annotations + Returns: + TypeMeta + ''' + if isinstance(annotation, BaseType): + arg_type = TypeMeta.from_dict(_instance_to_dict(annotation)) + elif isinstance(annotation, str): + arg_type = TypeMeta.from_dict(_str_to_dict(annotation)) + elif isinstance(annotation, dict): + if not _check_valid_type_dict(annotation): + raise ValueError('Annotation ' + str(annotation) + ' is not a valid type dictionary.') + arg_type = TypeMeta.from_dict(annotation) + else: + return TypeMeta() + return arg_type diff --git a/sdk/python/kfp/dsl/_pipeline.py b/sdk/python/kfp/dsl/_pipeline.py index e3ad49a13f6..d3e757ec704 100644 --- a/sdk/python/kfp/dsl/_pipeline.py +++ b/sdk/python/kfp/dsl/_pipeline.py @@ -14,6 +14,7 @@ from . import _container_op +from ._metadata import PipelineMeta, ParameterMeta, TypeMeta, _annotation_to_typemeta from . import _ops_group import sys @@ -32,6 +33,26 @@ def my_pipeline(a: PipelineParam, b: PipelineParam): ``` """ def _pipeline(func): + import inspect + fullargspec = inspect.getfullargspec(func) + args = fullargspec.args + annotations = fullargspec.annotations + + # Construct the PipelineMeta + pipeline_meta = PipelineMeta(name=func.__name__, description='') + # Inputs + for arg in args: + arg_type = TypeMeta() + if arg in annotations: + arg_type = _annotation_to_typemeta(annotations[arg]) + pipeline_meta.inputs.append(ParameterMeta(name=arg, description='', param_type=arg_type)) + + #TODO: add descriptions to the metadata + #docstring parser: + # https://github.com/rr-/docstring_parser + # https://github.com/terrencepreilly/darglint/blob/master/darglint/parse.py + #TODO: parse the metadata to the Pipeline. + Pipeline.add_pipeline(name, description, func) return func diff --git a/sdk/python/tests/dsl/component_tests.py b/sdk/python/tests/dsl/component_tests.py new file mode 100644 index 00000000000..47516b1a20e --- /dev/null +++ b/sdk/python/tests/dsl/component_tests.py @@ -0,0 +1,28 @@ +# Copyright 2018 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from kfp.dsl._component import component +from kfp.dsl._types import GCSPath, Integer +import unittest + +@component +def componentA(a: {'Schema': {'file_type': 'csv'}}, b: '{"number": {"step": "large"}}' = 12, c: GCSPath(path_type='file', file_type='tsv') = 'gs://hello/world') -> {'model': Integer()}: + return 7 + +class TestPythonComponent(unittest.TestCase): + + def test_component(self): + """Test component decorator.""" + componentA(1,2,3) \ No newline at end of file diff --git a/sdk/python/tests/dsl/main.py b/sdk/python/tests/dsl/main.py index c04ba186004..e994f21d83e 100644 --- a/sdk/python/tests/dsl/main.py +++ b/sdk/python/tests/dsl/main.py @@ -21,6 +21,7 @@ import container_op_tests import ops_group_tests import type_tests +import component_tests import metadata_tests if __name__ == '__main__': @@ -30,6 +31,7 @@ suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(container_op_tests)) suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(ops_group_tests)) suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(type_tests)) + suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(component_tests)) suite.addTests(unittest.defaultTestLoader.loadTestsFromModule(metadata_tests)) runner = unittest.TextTestRunner() if not runner.run(suite).wasSuccessful(): diff --git a/sdk/python/tests/dsl/pipeline_tests.py b/sdk/python/tests/dsl/pipeline_tests.py index 7c57e4c01b4..5570ccb6f94 100644 --- a/sdk/python/tests/dsl/pipeline_tests.py +++ b/sdk/python/tests/dsl/pipeline_tests.py @@ -14,6 +14,7 @@ from kfp.dsl import Pipeline, PipelineParam, ContainerOp, pipeline +from kfp.dsl._types import GCSPath, Integer import unittest @@ -55,3 +56,14 @@ def my_pipeline2(): self.assertEqual(('p1', 'description1'), Pipeline.get_pipeline_functions()[my_pipeline1]) self.assertEqual(('p2', 'description2'), Pipeline.get_pipeline_functions()[my_pipeline2]) + + def test_decorator_metadata(self): + """Test @pipeline decorator with metadata.""" + @pipeline( + name='p1', + description='description1' + ) + def my_pipeline1(a: {'Schema': {'file_type': 'csv'}}='good', b: Integer()=12): + pass + + self.assertEqual(('p1', 'description1'), Pipeline.get_pipeline_functions()[my_pipeline1]) \ No newline at end of file