Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK - Got rid of the global variable collecting all created pipelines #1167

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions sdk/python/kfp/compiler/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,7 @@ def parse_arguments():
return args


def _compile_pipeline_function(function_name, output_path, type_check):

pipeline_funcs = dsl.Pipeline.get_pipeline_functions()
def _compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check):
if len(pipeline_funcs) == 0:
raise ValueError('A function with @dsl.pipeline decorator is required in the py file.')

Expand All @@ -72,13 +70,28 @@ def _compile_pipeline_function(function_name, output_path, type_check):
kfp.compiler.Compiler().compile(pipeline_func, output_path, type_check)


class PipelineCollectorContext():
def __enter__(self):
pipeline_funcs = []
def add_pipeline(func):
pipeline_funcs.append(func)
return func
self.old_handler = dsl._pipeline._pipeline_decorator_handler
dsl._pipeline._pipeline_decorator_handler = add_pipeline
return pipeline_funcs

def __exit__(self, *args):
dsl._pipeline._pipeline_decorator_handler = self.old_handler


def compile_package(package_path, namespace, function_name, output_path, type_check):
tmpdir = tempfile.mkdtemp()
sys.path.insert(0, tmpdir)
try:
subprocess.check_call(['python3', '-m', 'pip', 'install', package_path, '-t', tmpdir])
__import__(namespace)
_compile_pipeline_function(function_name, output_path, type_check)
with PipelineCollectorContext() as pipeline_funcs:
__import__(namespace)
_compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check)
finally:
del sys.path[0]
shutil.rmtree(tmpdir)
Expand All @@ -88,8 +101,9 @@ def compile_pyfile(pyfile, function_name, output_path, type_check):
sys.path.insert(0, os.path.dirname(pyfile))
try:
filename = os.path.basename(pyfile)
__import__(os.path.splitext(filename)[0])
_compile_pipeline_function(function_name, output_path, type_check)
with PipelineCollectorContext() as pipeline_funcs:
__import__(os.path.splitext(filename)[0])
_compile_pipeline_function(pipeline_funcs, function_name, output_path, type_check)
finally:
del sys.path[0]

Expand Down
27 changes: 10 additions & 17 deletions sdk/python/kfp/dsl/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
import sys


# This handler is called whenever the @pipeline decorator is applied.
# It can be used by command-line DSL compiler to inject code that runs for every pipeline definition.
_pipeline_decorator_handler = None


def pipeline(name, description):
"""Decorator of pipeline functions.

Expand All @@ -35,8 +40,11 @@ def my_pipeline(a: PipelineParam, b: PipelineParam):
def _pipeline(func):
func._pipeline_name = name
func._pipeline_description = description
Pipeline._add_pipeline_to_global_list(func)
return func

if _pipeline_decorator_handler:
return _pipeline_decorator_handler(func) or func
else:
return func

return _pipeline

Expand Down Expand Up @@ -82,31 +90,16 @@ class Pipeline():
# _default_pipeline is set when it (usually a compiler) runs "with Pipeline()"
_default_pipeline = None

# All pipeline functions with @pipeline decorator that are imported.
# Each key is a pipeline function. Each value is a (name, description).
_pipeline_functions = []

@staticmethod
def get_default_pipeline():
"""Get default pipeline. """
return Pipeline._default_pipeline

@staticmethod
def get_pipeline_functions():
"""Get all imported pipeline functions (decorated with @pipeline)."""
return Pipeline._pipeline_functions

@staticmethod
def _add_pipeline_to_global_list(func):
"""Add a pipeline function (decorated with @pipeline)."""
Pipeline._pipeline_functions.append(func)

@staticmethod
def add_pipeline(name, description, func):
"""Add a pipeline function with the specified name and description."""
# Applying the @pipeline decorator to the pipeline function
func = pipeline(name=name, description=description)(func)
Pipeline._add_pipeline_to_global_list(pipeline_meta, func)

def __init__(self, name: str):
"""Create a new instance of Pipeline.
Expand Down