diff --git a/sdk/python/kfp/compiler/_component_builder.py b/sdk/python/kfp/compiler/_component_builder.py index ece5dd6362a..bc0b081615c 100644 --- a/sdk/python/kfp/compiler/_component_builder.py +++ b/sdk/python/kfp/compiler/_component_builder.py @@ -21,6 +21,7 @@ import shutil from collections import OrderedDict from pathlib import Path +from typing import Callable from ..components._components import _create_task_factory_from_component_spec from ._container_builder import ContainerBuilder @@ -407,7 +408,10 @@ def build_python_component(component_func, target_image, base_image=None, depend if base_image is None: base_image = getattr(component_func, '_component_base_image', None) if base_image is None: - raise ValueError('base_image must not be None') + from ..components._python_op import get_default_base_image + base_image = get_default_base_image() + if isinstance(base_image, Callable): + base_image = base_image() logging.info('Build an image that is based on ' + base_image + diff --git a/sdk/python/kfp/components/_python_op.py b/sdk/python/kfp/components/_python_op.py index 2771501d1bf..755a3414151 100644 --- a/sdk/python/kfp/components/_python_op.py +++ b/sdk/python/kfp/components/_python_op.py @@ -15,6 +15,8 @@ __all__ = [ 'func_to_container_op', 'func_to_component_text', + 'get_default_base_image', + 'set_default_base_image', ] from ._yaml_utils import dump_yaml @@ -24,7 +26,7 @@ import inspect from pathlib import Path import typing -from typing import TypeVar, Generic, List +from typing import Callable, Generic, List, TypeVar, Union T = TypeVar('T') @@ -42,6 +44,17 @@ class OutputFile(Generic[T], str): _default_base_image='tensorflow/tensorflow:1.13.2-py3' +def get_default_base_image() -> Union[str, Callable[[], str]]: + return _default_base_image + + +def set_default_base_image(image_or_factory: Union[str, Callable[[], str]]): + '''set_default_base_image sets the name of the container image that will be used for component creation when base_image is not specified. + Alternatively, the base image can also be set to a factory function that will be returning the image. + ''' + _default_base_image = image_or_factory + + def _python_function_name_to_component_name(name): import re return re.sub(' +', ' ', name.replace('_', ' ')).strip(' ').capitalize() @@ -207,7 +220,7 @@ def annotation_to_type_struct(annotation): return component_spec -def _func_to_component_spec(func, extra_code='', base_image=_default_base_image, modules_to_capture: List[str] = None, use_code_pickling=False) -> ComponentSpec: +def _func_to_component_spec(func, extra_code='', base_image : str = None, modules_to_capture: List[str] = None, use_code_pickling=False) -> ComponentSpec: '''Takes a self-contained python function and converts it to component Args: @@ -220,13 +233,15 @@ def _func_to_component_spec(func, extra_code='', base_image=_default_base_image, ''' decorator_base_image = getattr(func, '_component_base_image', None) if decorator_base_image is not None: - if base_image is not _default_base_image and decorator_base_image != base_image: + if base_image is not None and decorator_base_image != base_image: raise ValueError('base_image ({}) conflicts with the decorator-specified base image metadata ({})'.format(base_image, decorator_base_image)) else: base_image = decorator_base_image else: if base_image is None: - raise ValueError('base_image cannot be None') + base_image = _default_base_image + if isinstance(base_image, Callable): + base_image = base_image() component_spec = _extract_component_interface(func)