From 1f9c6b87197e21921bdbfa8a5e85cff1c85981d0 Mon Sep 17 00:00:00 2001 From: subham611 Date: Wed, 3 Apr 2024 16:13:06 +0530 Subject: [PATCH 1/5] Adds cancel previous run parameter --- airflow/providers/databricks/operators/databricks.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 0eedc444fb97b..457c37107bd8b 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -651,6 +651,7 @@ class DatabricksRunNowOperator(BaseOperator): - ``spark_submit_params`` - ``idempotency_token`` - ``repair_run`` + - ``cancel_previous_runs`` :param job_id: the job_id of the existing Databricks job. This field will be templated. @@ -740,6 +741,7 @@ class DatabricksRunNowOperator(BaseOperator): :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param deferrable: Run operator in the deferrable mode. :param repair_run: Repair the databricks run in case of failure. + :param cancel_previous_runs: Cancel all existing running jobs before submitting new one. """ # Used in airflow.models.BaseOperator @@ -771,6 +773,7 @@ def __init__( wait_for_termination: bool = True, deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), repair_run: bool = False, + cancel_previous_runs: bool = False, **kwargs, ) -> None: """Create a new ``DatabricksRunNowOperator``.""" @@ -784,6 +787,7 @@ def __init__( self.wait_for_termination = wait_for_termination self.deferrable = deferrable self.repair_run = repair_run + self.cancel_previous_runs = cancel_previous_runs if job_id is not None: self.json["job_id"] = job_id @@ -830,6 +834,8 @@ def execute(self, context: Context): raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found") self.json["job_id"] = job_id del self.json["job_name"] + if self.cancel_previous_runs: + hook.cancel_all_runs(job_id) self.run_id = hook.run_now(self.json) if self.deferrable: From 705f3be156a08900dc3dabebc67f411f2f12033a Mon Sep 17 00:00:00 2001 From: subham611 Date: Wed, 3 Apr 2024 20:50:42 +0530 Subject: [PATCH 2/5] Adds unit tests --- .../databricks/operators/test_databricks.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 6797377161962..0720af8293f64 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1397,6 +1397,70 @@ def test_exec_failure_if_job_id_not_found(self, db_mock_class): db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME) + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_cancel_previous_runs(self, db_mock_class): + run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, json=run) + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = 1 + + assert op.cancel_previous_runs + + op.execute(None) + + expected = utils.normalise_json_content( + { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": JAR_PARAMS, + "job_id": JOB_ID, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksRunNowOperator", + ) + + db_mock.cancel_all_runs.assert_called_once_with(JOB_ID) + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_not_called() + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_no_cancel_previous_runs(self, db_mock_class): + run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, json=run) + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = 1 + + assert not op.cancel_previous_runs + + op.execute(None) + + expected = utils.normalise_json_content( + { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": JAR_PARAMS, + "job_id": JOB_ID, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksRunNowOperator", + ) + + db_mock.cancel_all_runs.assert_not_called() + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_not_called() + class TestDatabricksRunNowDeferrableOperator: @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") From 79ef3c170a05ec1c1d63e86d449ee7eb5e7be026 Mon Sep 17 00:00:00 2001 From: subham611 Date: Thu, 4 Apr 2024 08:04:35 +0530 Subject: [PATCH 3/5] Refactor UT --- .../operators/run_now.rst | 1 + .../databricks/operators/test_databricks.py | 46 +++++++++---------- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/docs/apache-airflow-providers-databricks/operators/run_now.rst b/docs/apache-airflow-providers-databricks/operators/run_now.rst index a4b00d9005c81..a9da7512e26e8 100644 --- a/docs/apache-airflow-providers-databricks/operators/run_now.rst +++ b/docs/apache-airflow-providers-databricks/operators/run_now.rst @@ -45,6 +45,7 @@ All other parameters are optional and described in documentation for ``Databrick * ``python_named_parameters`` * ``jar_params`` * ``spark_submit_params`` +* ``cancel_previous_runs`` DatabricksRunNowDeferrableOperator ================================== diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 0720af8293f64..8cec35638da1d 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -737,7 +737,7 @@ def test_exec_success(self, db_mock_class): } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -767,7 +767,7 @@ def test_exec_pipeline_name(self, db_mock_class): op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value db_mock.find_pipeline_id_by_name.return_value = PIPELINE_ID_TASK["pipeline_id"] - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -798,7 +798,7 @@ def test_exec_failure(self, db_mock_class): } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException): @@ -845,7 +845,7 @@ def test_wait_for_termination(self, db_mock_class): } op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") assert op.wait_for_termination @@ -875,7 +875,7 @@ def test_no_wait_for_termination(self, db_mock_class): } op = DatabricksSubmitRunOperator(task_id=TASK_ID, wait_for_termination=False, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID assert not op.wait_for_termination @@ -909,7 +909,7 @@ def test_execute_task_deferred(self, db_mock_class): } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING") with pytest.raises(TaskDeferred) as exc: @@ -971,7 +971,7 @@ def test_execute_complete_failure(self, db_mock_class): op.execute_complete(context=None, event=event) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException, match=f"Job run failed with terminal state: {run_state_failed}"): @@ -993,7 +993,7 @@ def test_databricks_submit_run_deferrable_operator_failed_before_defer(self, moc } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) @@ -1023,7 +1023,7 @@ def test_databricks_submit_run_deferrable_operator_success_before_defer(self, mo } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -1147,7 +1147,7 @@ def test_exec_success(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -1181,7 +1181,7 @@ def test_exec_failure(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException): @@ -1215,7 +1215,7 @@ def test_exec_failure_with_message(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = mock_dict( { "job_id": JOB_ID, @@ -1279,7 +1279,7 @@ def test_wait_for_termination(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") assert op.wait_for_termination @@ -1311,7 +1311,7 @@ def test_no_wait_for_termination(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, wait_for_termination=False, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID assert not op.wait_for_termination @@ -1357,7 +1357,7 @@ def test_exec_with_job_name(self, db_mock_class): op = DatabricksRunNowOperator(task_id=TASK_ID, job_name=JOB_NAME, json=run) db_mock = db_mock_class.return_value db_mock.find_job_id_by_name.return_value = JOB_ID - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) @@ -1402,7 +1402,7 @@ def test_cancel_previous_runs(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID assert op.cancel_previous_runs @@ -1434,7 +1434,7 @@ def test_no_cancel_previous_runs(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID assert not op.cancel_previous_runs @@ -1471,7 +1471,7 @@ def test_execute_task_deferred(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING") with pytest.raises(TaskDeferred) as exc: @@ -1534,7 +1534,7 @@ def test_execute_complete_failure(self, db_mock_class): op.execute_complete(context=None, event=event) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") with pytest.raises(AirflowException, match=f"Job run failed with terminal state: {run_state_failed}"): @@ -1563,7 +1563,7 @@ def test_execute_complete_failure_and_repair_run( op.execute_complete(context=None, event=event) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") db_mock.get_latest_repair_id.assert_called_once() db_mock.repair_run.assert_called_once() @@ -1585,7 +1585,7 @@ def test_operator_failed_before_defer(self, mock_defer, db_mock_class): } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) db_mock = db_mock_class.return_value - db_mock.submit_run.return_value = 1 + db_mock.submit_run.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) @@ -1612,7 +1612,7 @@ def test_databricks_run_now_deferrable_operator_failed_before_defer(self, mock_d run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED") op.execute(None) @@ -1645,7 +1645,7 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_ run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) db_mock = db_mock_class.return_value - db_mock.run_now.return_value = 1 + db_mock.run_now.return_value = RUN_ID db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS") op.execute(None) From 381fe56791182b41068cda10876310ea412180f2 Mon Sep 17 00:00:00 2001 From: subham611 Date: Thu, 4 Apr 2024 08:49:11 +0530 Subject: [PATCH 4/5] Fix unit test --- airflow/providers/databricks/operators/databricks.py | 5 +++-- tests/providers/databricks/operators/test_databricks.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 457c37107bd8b..247d810a6bf3d 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -834,8 +834,9 @@ def execute(self, context: Context): raise AirflowException(f"Job ID for job name {self.json['job_name']} can not be found") self.json["job_id"] = job_id del self.json["job_name"] - if self.cancel_previous_runs: - hook.cancel_all_runs(job_id) + + if self.cancel_previous_runs and self.json["job_id"] is not None: + hook.cancel_all_runs(self.json["job_id"]) self.run_id = hook.run_now(self.json) if self.deferrable: diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 8cec35638da1d..ca4dbcdcb3ae2 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1400,7 +1400,7 @@ def test_exec_failure_if_job_id_not_found(self, db_mock_class): @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_cancel_previous_runs(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} - op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, json=run) + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, wait_for_termination=False, json=run) db_mock = db_mock_class.return_value db_mock.run_now.return_value = RUN_ID @@ -1432,7 +1432,7 @@ def test_cancel_previous_runs(self, db_mock_class): @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_no_cancel_previous_runs(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} - op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, json=run) + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, wait_for_termination=False, json=run) db_mock = db_mock_class.return_value db_mock.run_now.return_value = RUN_ID From 2501e3a4fa00b23b096fcd66ae64dddf8ab32354 Mon Sep 17 00:00:00 2001 From: subham611 Date: Thu, 4 Apr 2024 09:24:34 +0530 Subject: [PATCH 5/5] Fix formatting --- .../operators/run_now.rst | 2 ++ tests/providers/databricks/operators/test_databricks.py | 8 ++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/apache-airflow-providers-databricks/operators/run_now.rst b/docs/apache-airflow-providers-databricks/operators/run_now.rst index a9da7512e26e8..facf47e7d6c56 100644 --- a/docs/apache-airflow-providers-databricks/operators/run_now.rst +++ b/docs/apache-airflow-providers-databricks/operators/run_now.rst @@ -45,6 +45,8 @@ All other parameters are optional and described in documentation for ``Databrick * ``python_named_parameters`` * ``jar_params`` * ``spark_submit_params`` +* ``idempotency_token`` +* ``repair_run`` * ``cancel_previous_runs`` DatabricksRunNowDeferrableOperator diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index ca4dbcdcb3ae2..f2a3441f435cf 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1400,7 +1400,9 @@ def test_exec_failure_if_job_id_not_found(self, db_mock_class): @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_cancel_previous_runs(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} - op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, wait_for_termination=False, json=run) + op = DatabricksRunNowOperator( + task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, wait_for_termination=False, json=run + ) db_mock = db_mock_class.return_value db_mock.run_now.return_value = RUN_ID @@ -1432,7 +1434,9 @@ def test_cancel_previous_runs(self, db_mock_class): @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_no_cancel_previous_runs(self, db_mock_class): run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} - op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, wait_for_termination=False, json=run) + op = DatabricksRunNowOperator( + task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, wait_for_termination=False, json=run + ) db_mock = db_mock_class.return_value db_mock.run_now.return_value = RUN_ID