From 6ae9aeec3f866ecfce83945ad1e831be9f8e5ebb Mon Sep 17 00:00:00 2001 From: Deji Ibrahim <31637316+dejii@users.noreply.github.com> Date: Wed, 26 May 2021 15:52:19 +0100 Subject: [PATCH] pass wait_for_done parameter down to _DataflowJobsController (#15541) --- airflow/providers/google/cloud/hooks/dataflow.py | 3 +++ tests/providers/google/cloud/hooks/test_dataflow.py | 7 +++++++ 2 files changed, 10 insertions(+) diff --git a/airflow/providers/google/cloud/hooks/dataflow.py b/airflow/providers/google/cloud/hooks/dataflow.py index ebdbfa33e30db..e416393d1867e 100644 --- a/airflow/providers/google/cloud/hooks/dataflow.py +++ b/airflow/providers/google/cloud/hooks/dataflow.py @@ -728,6 +728,7 @@ def start_template_dataflow( num_retries=self.num_retries, drain_pipeline=self.drain_pipeline, cancel_timeout=self.cancel_timeout, + wait_until_finished=self.wait_until_finished, ) jobs_controller.wait_for_done() return response["job"] @@ -774,6 +775,7 @@ def start_flex_template( poll_sleep=self.poll_sleep, num_retries=self.num_retries, cancel_timeout=self.cancel_timeout, + wait_until_finished=self.wait_until_finished, ) jobs_controller.wait_for_done() @@ -1030,6 +1032,7 @@ def start_sql_job( poll_sleep=self.poll_sleep, num_retries=self.num_retries, drain_pipeline=self.drain_pipeline, + wait_until_finished=self.wait_until_finished, ) jobs_controller.wait_for_done() diff --git a/tests/providers/google/cloud/hooks/test_dataflow.py b/tests/providers/google/cloud/hooks/test_dataflow.py index 0314e5cbe3504..ee75eed9e3956 100644 --- a/tests/providers/google/cloud/hooks/test_dataflow.py +++ b/tests/providers/google/cloud/hooks/test_dataflow.py @@ -835,6 +835,7 @@ def test_start_template_dataflow(self, mock_conn, mock_controller, mock_uuid): location=DEFAULT_DATAFLOW_LOCATION, drain_pipeline=False, cancel_timeout=DEFAULT_CANCEL_TIMEOUT, + wait_until_finished=None, ) mock_controller.return_value.wait_for_done.assert_called_once() @@ -873,6 +874,7 @@ def test_start_template_dataflow_with_custom_region_as_variable( location=TEST_LOCATION, drain_pipeline=False, cancel_timeout=DEFAULT_CANCEL_TIMEOUT, + wait_until_finished=None, ) mock_controller.return_value.wait_for_done.assert_called_once() @@ -913,6 +915,7 @@ def test_start_template_dataflow_with_custom_region_as_parameter( location=TEST_LOCATION, drain_pipeline=False, cancel_timeout=DEFAULT_CANCEL_TIMEOUT, + wait_until_finished=None, ) mock_controller.return_value.wait_for_done.assert_called_once() @@ -957,6 +960,7 @@ def test_start_template_dataflow_with_runtime_env(self, mock_conn, mock_dataflow project_number=TEST_PROJECT, drain_pipeline=False, cancel_timeout=DEFAULT_CANCEL_TIMEOUT, + wait_until_finished=None, ) mock_uuid.assert_called_once_with() @@ -1005,6 +1009,7 @@ def test_start_template_dataflow_update_runtime_env(self, mock_conn, mock_datafl project_number=TEST_PROJECT, drain_pipeline=False, cancel_timeout=DEFAULT_CANCEL_TIMEOUT, + wait_until_finished=None, ) mock_uuid.assert_called_once_with() @@ -1037,6 +1042,7 @@ def test_start_flex_template(self, mock_conn, mock_controller): poll_sleep=self.dataflow_hook.poll_sleep, num_retries=self.dataflow_hook.num_retries, cancel_timeout=DEFAULT_CANCEL_TIMEOUT, + wait_until_finished=self.dataflow_hook.wait_until_finished, ) mock_controller.return_value.get_jobs.wait_for_done.assrt_called_once_with() mock_controller.return_value.get_jobs.assrt_called_once_with() @@ -1110,6 +1116,7 @@ def test_start_sql_job_failed_to_run( project_number=TEST_PROJECT, num_retries=5, drain_pipeline=False, + wait_until_finished=None, ) mock_controller.return_value.wait_for_done.assert_called_once() assert result == test_job