Skip to content

Commit

Permalink
SDK - Got rid of the global variable collecting all created pipelines (
Browse files Browse the repository at this point in the history
…#1167)

* SDK - Got rid of the global variable collecting all created pipelines
This list was only used by the command-line compiler.
The command-line compiler can still collect the created pipelines by registering a handler function in `_pipeline_decorator_handlers`.

* Replaced handler stack with a single handler.
  • Loading branch information
Ark-kun authored and k8s-ci-robot committed Apr 19, 2019
1 parent 0b40672 commit ee119ec
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 24 deletions.
28 changes: 21 additions & 7 deletions sdk/python/kfp/compiler/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def parse_arguments():
return args


def _compile_pipeline_function(function_name, output_path, type_check):

pipeline_funcs = dsl.Pipeline.get_pipeline_functions()
def _compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check):
if len(pipeline_funcs) == 0:
raise ValueError('A function with @dsl.pipeline decorator is required in the py file.')

Expand All @@ -72,13 +70,28 @@ def _compile_pipeline_function(function_name, output_path, type_check):
kfp.compiler.Compiler().compile(pipeline_func, output_path, type_check)


class PipelineCollectorContext():
def __enter__(self):
pipeline_funcs = []
def add_pipeline(func):
pipeline_funcs.append(func)
return func
self.old_handler = dsl._pipeline._pipeline_decorator_handler
dsl._pipeline._pipeline_decorator_handler = add_pipeline
return pipeline_funcs

def __exit__(self, *args):
dsl._pipeline._pipeline_decorator_handler = self.old_handler


def compile_package(package_path, namespace, function_name, output_path, type_check):
tmpdir = tempfile.mkdtemp()
sys.path.insert(0, tmpdir)
try:
subprocess.check_call(['python3', '-m', 'pip', 'install', package_path, '-t', tmpdir])
__import__(namespace)
_compile_pipeline_function(function_name, output_path, type_check)
with PipelineCollectorContext() as pipeline_funcs:
__import__(namespace)
_compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check)
finally:
del sys.path[0]
shutil.rmtree(tmpdir)
Expand All @@ -88,8 +101,9 @@ def compile_pyfile(pyfile, function_name, output_path, type_check):
sys.path.insert(0, os.path.dirname(pyfile))
try:
filename = os.path.basename(pyfile)
__import__(os.path.splitext(filename)[0])
_compile_pipeline_function(function_name, output_path, type_check)
with PipelineCollectorContext() as pipeline_funcs:
__import__(os.path.splitext(filename)[0])
_compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check)
finally:
del sys.path[0]

Expand Down
27 changes: 10 additions & 17 deletions sdk/python/kfp/dsl/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import sys


# This handler is called whenever the @pipeline decorator is applied.
# It can be used by command-line DSL compiler to inject code that runs for every pipeline definition.
_pipeline_decorator_handler = None


def pipeline(name, description):
"""Decorator of pipeline functions.
Expand All @@ -35,8 +40,11 @@ def my_pipeline(a: PipelineParam, b: PipelineParam):
def _pipeline(func):
func._pipeline_name = name
func._pipeline_description = description
Pipeline._add_pipeline_to_global_list(func)
return func

if _pipeline_decorator_handler:
return _pipeline_decorator_handler(func) or func
else:
return func

return _pipeline

Expand Down Expand Up @@ -82,31 +90,16 @@ class Pipeline():
# _default_pipeline is set when it (usually a compiler) runs "with Pipeline()"
_default_pipeline = None

# All pipeline functions with @pipeline decorator that are imported.
# Each key is a pipeline function. Each value is a (name, description).
_pipeline_functions = []

@staticmethod
def get_default_pipeline():
"""Get default pipeline. """
return Pipeline._default_pipeline

@staticmethod
def get_pipeline_functions():
"""Get all imported pipeline functions (decorated with @pipeline)."""
return Pipeline._pipeline_functions

@staticmethod
def _add_pipeline_to_global_list(func):
"""Add a pipeline function (decorated with @pipeline)."""
Pipeline._pipeline_functions.append(func)

@staticmethod
def add_pipeline(name, description, func):
"""Add a pipeline function with the specified name and description."""
# Applying the @pipeline decorator to the pipeline function
func = pipeline(name=name, description=description)(func)
Pipeline._add_pipeline_to_global_list(pipeline_meta, func)

def __init__(self, name: str):
"""Create a new instance of Pipeline.
Expand Down

0 comments on commit ee119ec

Please sign in to comment.