Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK - Got rid of the global variable collecting all created pipelines #1167

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions sdk/python/kfp/compiler/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def parse_arguments():
return args


def _compile_pipeline_function(function_name, output_path, type_check):

pipeline_funcs = dsl.Pipeline.get_pipeline_functions()
def _compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check):
if len(pipeline_funcs) == 0:
raise ValueError('A function with @dsl.pipeline decorator is required in the py file.')

Expand All @@ -72,13 +70,27 @@ def _compile_pipeline_function(function_name, output_path, type_check):
kfp.compiler.Compiler().compile(pipeline_func, output_path, type_check)


class PipelineCollectorContext():
def __enter__(self):
pipeline_funcs = []
def add_pipeline(func):
pipeline_funcs.append(func)
return func
dsl._pipeline._pipeline_decorator_handlers.append(add_pipeline)
return pipeline_funcs

def __exit__(self, *args):
dsl._pipeline._pipeline_decorator_handlers.pop()


def compile_package(package_path, namespace, function_name, output_path, type_check):
tmpdir = tempfile.mkdtemp()
sys.path.insert(0, tmpdir)
try:
subprocess.check_call(['python3', '-m', 'pip', 'install', package_path, '-t', tmpdir])
__import__(namespace)
_compile_pipeline_function(function_name, output_path, type_check)
with PipelineCollectorContext() as pipeline_funcs:
__import__(namespace)
_compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check)
finally:
del sys.path[0]
shutil.rmtree(tmpdir)
Expand All @@ -88,8 +100,9 @@ def compile_pyfile(pyfile, function_name, output_path, type_check):
sys.path.insert(0, os.path.dirname(pyfile))
try:
filename = os.path.basename(pyfile)
__import__(os.path.splitext(filename)[0])
_compile_pipeline_function(function_name, output_path, type_check)
with PipelineCollectorContext() as pipeline_funcs:
__import__(os.path.splitext(filename)[0])
_compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check)
finally:
del sys.path[0]

Expand Down
27 changes: 10 additions & 17 deletions sdk/python/kfp/dsl/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import sys


# Contains a stack of handler functions.
# The last handler in the list is called whenever the @pipeline decorator is used
_pipeline_decorator_handlers = []


def pipeline(name, description):
"""Decorator of pipeline functions.

Expand All @@ -35,8 +40,11 @@ def my_pipeline(a: PipelineParam, b: PipelineParam):
def _pipeline(func):
func._pipeline_name = name
func._pipeline_description = description
Pipeline._add_pipeline_to_global_list(func)
return func

if _pipeline_decorator_handlers:
return _pipeline_decorator_handlers[-1](func) or func
else:
return func

return _pipeline

Expand Down Expand Up @@ -82,31 +90,16 @@ class Pipeline():
# _default_pipeline is set when it (usually a compiler) runs "with Pipeline()"
_default_pipeline = None

# All pipeline functions with @pipeline decorator that are imported.
# Each key is a pipeline function. Each value is a (name, description).
_pipeline_functions = []

@staticmethod
def get_default_pipeline():
"""Get default pipeline. """
return Pipeline._default_pipeline

@staticmethod
def get_pipeline_functions():
"""Get all imported pipeline functions (decorated with @pipeline)."""
return Pipeline._pipeline_functions

@staticmethod
def _add_pipeline_to_global_list(func):
"""Add a pipeline function (decorated with @pipeline)."""
Pipeline._pipeline_functions.append(func)

@staticmethod
def add_pipeline(name, description, func):
"""Add a pipeline function with the specified name and description."""
# Applying the @pipeline decorator to the pipeline function
func = pipeline(name=name, description=description)(func)
Pipeline._add_pipeline_to_global_list(pipeline_meta, func)

def __init__(self, name: str):
"""Create a new instance of Pipeline.
Expand Down