Skip to content

Commit

Permalink
SDK - Got rid of the global variable collecting all created pipelines
Browse files Browse the repository at this point in the history
This list was only used by the command-line compiler.
The command-line compiler can still collect the created pipelines by registering a handler function in `_pipeline_decorator_handlers`.
  • Loading branch information
Ark-kun committed Apr 15, 2019
1 parent 71325c3 commit 5313bfd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 24 deletions.
26 changes: 19 additions & 7 deletions sdk/python/kfp/compiler/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def parse_arguments():
return args


def _compile_pipeline_function(function_name, output_path, type_check):

pipeline_funcs = dsl.Pipeline.get_pipeline_functions()
def _compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check):
if len(pipeline_funcs) == 0:
raise ValueError('A function with @dsl.pipeline decorator is required in the py file.')

Expand All @@ -72,13 +70,26 @@ def _compile_pipeline_function(function_name, output_path, type_check):
kfp.compiler.Compiler().compile(pipeline_func, output_path, type_check)


class PipelineCollectorContext():
def __enter__(self):
pipeline_funcs = []
def add_pipeline(func):
pipeline_funcs.append(func)
dsl._pipeline._pipeline_decorator_handlers.append(add_pipeline)
return pipeline_funcs

def __exit__(self, *args):
dsl._pipeline._pipeline_decorator_handlers.pop()


def compile_package(package_path, namespace, function_name, output_path, type_check):
tmpdir = tempfile.mkdtemp()
sys.path.insert(0, tmpdir)
try:
subprocess.check_call(['python3', '-m', 'pip', 'install', package_path, '-t', tmpdir])
__import__(namespace)
_compile_pipeline_function(function_name, output_path, type_check)
with PipelineCollectorContext() as pipeline_funcs:
__import__(namespace)
_compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check)
finally:
del sys.path[0]
shutil.rmtree(tmpdir)
Expand All @@ -88,8 +99,9 @@ def compile_pyfile(pyfile, function_name, output_path, type_check):
sys.path.insert(0, os.path.dirname(pyfile))
try:
filename = os.path.basename(pyfile)
__import__(os.path.splitext(filename)[0])
_compile_pipeline_function(function_name, output_path, type_check)
with PipelineCollectorContext() as pipeline_funcs:
__import__(os.path.splitext(filename)[0])
_compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check)
finally:
del sys.path[0]

Expand Down
27 changes: 10 additions & 17 deletions sdk/python/kfp/dsl/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import sys


# Contains a stack of handler functions.
# The last handler in the list is called whenever the @pipeline decorator is used
_pipeline_decorator_handlers = []


def pipeline(name, description):
"""Decorator of pipeline functions.
Expand All @@ -35,8 +40,11 @@ def my_pipeline(a: PipelineParam, b: PipelineParam):
def _pipeline(func):
func._pipeline_name = name
func._pipeline_description = description
Pipeline._add_pipeline_to_global_list(func)
return func

if _pipeline_decorator_handlers:
return _pipeline_decorator_handlers[-1](func)
else:
return func

return _pipeline

Expand Down Expand Up @@ -82,31 +90,16 @@ class Pipeline():
# _default_pipeline is set when it (usually a compiler) runs "with Pipeline()"
_default_pipeline = None

# All pipeline functions with @pipeline decorator that are imported.
# Each key is a pipeline function. Each value is a (name, description).
_pipeline_functions = []

@staticmethod
def get_default_pipeline():
"""Get default pipeline. """
return Pipeline._default_pipeline

@staticmethod
def get_pipeline_functions():
"""Get all imported pipeline functions (decorated with @pipeline)."""
return Pipeline._pipeline_functions

@staticmethod
def _add_pipeline_to_global_list(func):
"""Add a pipeline function (decorated with @pipeline)."""
Pipeline._pipeline_functions.append(func)

@staticmethod
def add_pipeline(name, description, func):
"""Add a pipeline function with the specified name and description."""
# Applying the @pipeline decorator to the pipeline function
func = pipeline(name=name, description=description)(func)
Pipeline._add_pipeline_to_global_list(pipeline_meta, func)

def __init__(self, name: str):
"""Create a new instance of Pipeline.
Expand Down

0 comments on commit 5313bfd

Please sign in to comment.