From 6df6b8c9cce14edd449b7446994f0127fb73e155 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Wed, 25 Mar 2020 10:15:51 +0100 Subject: [PATCH] Fix example DAG for MLEngine in backport package (#7813) --- backport_packages/setup_backport_packages.py | 29 ++++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/backport_packages/setup_backport_packages.py b/backport_packages/setup_backport_packages.py index 55222da1e97c1..a0cc4662a125a 100644 --- a/backport_packages/setup_backport_packages.py +++ b/backport_packages/setup_backport_packages.py @@ -143,6 +143,8 @@ def run(self): def change_import_paths_to_deprecated(): from bowler import LN, TOKEN, Capture, Filename, Query from fissix.pytree import Leaf + from fissix.fixer_util import KeywordArg, Name, Comma + from bowler import BowlerTool def remove_tags_modifier(node: LN, capture: Capture, filename: Filename) -> None: for node in capture['function_arguments'][0].post_order(): @@ -161,6 +163,14 @@ def remove_super_init_call(node: LN, capture: Capture, filename: Filename) -> No if any(c.value for c in ch.parent.post_order() if isinstance(c, Leaf)): ch.parent.remove() + def add_provide_context_to_python_operator(node: LN, capture: Capture, filename: Filename) -> None: + fn_args = capture['function_arguments'][0] + fn_args.append_child(Comma()) + + provide_context_arg = KeywordArg(Name('provide_context'), Name('True')) + provide_context_arg.prefix = fn_args.children[0].prefix + fn_args.append_child(provide_context_arg) + changes = [ ("airflow.operators.bash", "airflow.operators.bash_operator"), ("airflow.operators.python", "airflow.operators.python_operator"), @@ -174,9 +184,9 @@ def remove_super_init_call(node: LN, capture: Capture, filename: Filename) -> No # Move and refactor imports for Dataflow copyfile( os.path.join(dirname(__file__), os.pardir, "airflow", "utils", "python_virtualenv.py"), - os.path.join(dirname(__file__), "airflow", "providers", - "google", "cloud", "utils", "python_virtualenv.py" - ) + os.path.join( + dirname(__file__), "airflow", "providers", "google", "cloud", "utils", "python_virtualenv.py" + ) ) ( qry @@ -185,9 +195,9 @@ def remove_super_init_call(node: LN, capture: Capture, filename: Filename) -> No ) copyfile( os.path.join(dirname(__file__), os.pardir, "airflow", "utils", "process_utils.py"), - os.path.join(dirname(__file__), "airflow", "providers", - "google", "cloud", "utils", "process_utils.py" - ) + os.path.join( + dirname(__file__), "airflow", "providers", "google", "cloud", "utils", "process_utils.py" + ) ) ( qry @@ -212,6 +222,13 @@ def remove_super_init_call(node: LN, capture: Capture, filename: Filename) -> No # Fix super().__init__() call in hooks qry.select_subclass("BaseHook").modify(remove_super_init_call) + ( + qry.select_function("PythonOperator") + .is_call() + .is_filename(include=r"mlengine_operator_utils.py$") + .modify(add_provide_context_to_python_operator) + ) + qry.execute(write=True, silent=False, interactive=False)