Skip to content

Commit

Permalink
Fix example DAG for MLEngine in backport package (#7813)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored Mar 25, 2020
1 parent a15026f commit 6df6b8c
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions backport_packages/setup_backport_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def run(self):
def change_import_paths_to_deprecated():
from bowler import LN, TOKEN, Capture, Filename, Query
from fissix.pytree import Leaf
from fissix.fixer_util import KeywordArg, Name, Comma
from bowler import BowlerTool

def remove_tags_modifier(node: LN, capture: Capture, filename: Filename) -> None:
for node in capture['function_arguments'][0].post_order():
Expand All @@ -161,6 +163,14 @@ def remove_super_init_call(node: LN, capture: Capture, filename: Filename) -> No
if any(c.value for c in ch.parent.post_order() if isinstance(c, Leaf)):
ch.parent.remove()

def add_provide_context_to_python_operator(node: LN, capture: Capture, filename: Filename) -> None:
fn_args = capture['function_arguments'][0]
fn_args.append_child(Comma())

provide_context_arg = KeywordArg(Name('provide_context'), Name('True'))
provide_context_arg.prefix = fn_args.children[0].prefix
fn_args.append_child(provide_context_arg)

changes = [
("airflow.operators.bash", "airflow.operators.bash_operator"),
("airflow.operators.python", "airflow.operators.python_operator"),
Expand All @@ -174,9 +184,9 @@ def remove_super_init_call(node: LN, capture: Capture, filename: Filename) -> No
# Move and refactor imports for Dataflow
copyfile(
os.path.join(dirname(__file__), os.pardir, "airflow", "utils", "python_virtualenv.py"),
os.path.join(dirname(__file__), "airflow", "providers",
"google", "cloud", "utils", "python_virtualenv.py"
)
os.path.join(
dirname(__file__), "airflow", "providers", "google", "cloud", "utils", "python_virtualenv.py"
)
)
(
qry
Expand All @@ -185,9 +195,9 @@ def remove_super_init_call(node: LN, capture: Capture, filename: Filename) -> No
)
copyfile(
os.path.join(dirname(__file__), os.pardir, "airflow", "utils", "process_utils.py"),
os.path.join(dirname(__file__), "airflow", "providers",
"google", "cloud", "utils", "process_utils.py"
)
os.path.join(
dirname(__file__), "airflow", "providers", "google", "cloud", "utils", "process_utils.py"
)
)
(
qry
Expand All @@ -212,6 +222,13 @@ def remove_super_init_call(node: LN, capture: Capture, filename: Filename) -> No
# Fix super().__init__() call in hooks
qry.select_subclass("BaseHook").modify(remove_super_init_call)

(
qry.select_function("PythonOperator")
.is_call()
.is_filename(include=r"mlengine_operator_utils.py$")
.modify(add_provide_context_to_python_operator)
)

qry.execute(write=True, silent=False, interactive=False)


Expand Down

0 comments on commit 6df6b8c

Please sign in to comment.