Skip to content

Commit

Permalink
Add cancel_previous_run to DatabricksRunNowOperator (#38702)
Browse files Browse the repository at this point in the history
Co-authored-by: subham611 <subhamsinghal@sharechat.co>
  • Loading branch information
SubhamSinghal and subham611 authored Apr 8, 2024
1 parent ef97ed2 commit 4e6d3fa
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 21 deletions.
7 changes: 7 additions & 0 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ class DatabricksRunNowOperator(BaseOperator):
- ``spark_submit_params``
- ``idempotency_token``
- ``repair_run``
- ``cancel_previous_runs``
:param job_id: the job_id of the existing Databricks job.
This field will be templated.
Expand Down Expand Up @@ -740,6 +741,7 @@ class DatabricksRunNowOperator(BaseOperator):
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param deferrable: Run operator in the deferrable mode.
:param repair_run: Repair the databricks run in case of failure.
:param cancel_previous_runs: Cancel all existing running jobs before submitting new one.
"""

# Used in airflow.models.BaseOperator
Expand Down Expand Up @@ -771,6 +773,7 @@ def __init__(
wait_for_termination: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
repair_run: bool = False,
cancel_previous_runs: bool = False,
**kwargs,
) -> None:
"""Create a new ``DatabricksRunNowOperator``."""
Expand All @@ -784,6 +787,7 @@ def __init__(
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs

if job_id is not None:
self.json["job_id"] = job_id
Expand Down Expand Up @@ -831,6 +835,9 @@ def execute(self, context: Context):
self.json["job_id"] = job_id
del self.json["job_name"]

if self.cancel_previous_runs and self.json["job_id"] is not None:
hook.cancel_all_runs(self.json["job_id"])

self.run_id = hook.run_now(self.json)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ All other parameters are optional and described in documentation for ``Databrick
* ``python_named_parameters``
* ``jar_params``
* ``spark_submit_params``
* ``idempotency_token``
* ``repair_run``
* ``cancel_previous_runs``

DatabricksRunNowDeferrableOperator
==================================
Expand Down
110 changes: 89 additions & 21 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def test_exec_success(self, db_mock_class):
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down Expand Up @@ -767,7 +767,7 @@ def test_exec_pipeline_name(self, db_mock_class):
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.find_pipeline_id_by_name.return_value = PIPELINE_ID_TASK["pipeline_id"]
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down Expand Up @@ -798,7 +798,7 @@ def test_exec_failure(self, db_mock_class):
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException):
Expand Down Expand Up @@ -845,7 +845,7 @@ def test_wait_for_termination(self, db_mock_class):
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

assert op.wait_for_termination
Expand Down Expand Up @@ -875,7 +875,7 @@ def test_no_wait_for_termination(self, db_mock_class):
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, wait_for_termination=False, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID

assert not op.wait_for_termination

Expand Down Expand Up @@ -909,7 +909,7 @@ def test_execute_task_deferred(self, db_mock_class):
}
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")

with pytest.raises(TaskDeferred) as exc:
Expand Down Expand Up @@ -971,7 +971,7 @@ def test_execute_complete_failure(self, db_mock_class):
op.execute_complete(context=None, event=event)

db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException, match=f"Job run failed with terminal state: {run_state_failed}"):
Expand All @@ -993,7 +993,7 @@ def test_databricks_submit_run_deferrable_operator_failed_before_defer(self, moc
}
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
op.execute(None)

Expand Down Expand Up @@ -1023,7 +1023,7 @@ def test_databricks_submit_run_deferrable_operator_success_before_defer(self, mo
}
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
op.execute(None)

Expand Down Expand Up @@ -1147,7 +1147,7 @@ def test_exec_success(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down Expand Up @@ -1181,7 +1181,7 @@ def test_exec_failure(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException):
Expand Down Expand Up @@ -1215,7 +1215,7 @@ def test_exec_failure_with_message(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = mock_dict(
{
"job_id": JOB_ID,
Expand Down Expand Up @@ -1279,7 +1279,7 @@ def test_wait_for_termination(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

assert op.wait_for_termination
Expand Down Expand Up @@ -1311,7 +1311,7 @@ def test_no_wait_for_termination(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, wait_for_termination=False, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID

assert not op.wait_for_termination

Expand Down Expand Up @@ -1357,7 +1357,7 @@ def test_exec_with_job_name(self, db_mock_class):
op = DatabricksRunNowOperator(task_id=TASK_ID, job_name=JOB_NAME, json=run)
db_mock = db_mock_class.return_value
db_mock.find_job_id_by_name.return_value = JOB_ID
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down Expand Up @@ -1397,6 +1397,74 @@ def test_exec_failure_if_job_id_not_found(self, db_mock_class):

db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_cancel_previous_runs(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(
task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, wait_for_termination=False, json=run
)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = RUN_ID

assert op.cancel_previous_runs

op.execute(None)

expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
"jar_params": JAR_PARAMS,
"job_id": JOB_ID,
}
)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksRunNowOperator",
)

db_mock.cancel_all_runs.assert_called_once_with(JOB_ID)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_not_called()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_no_cancel_previous_runs(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(
task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, wait_for_termination=False, json=run
)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = RUN_ID

assert not op.cancel_previous_runs

op.execute(None)

expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
"jar_params": JAR_PARAMS,
"job_id": JOB_ID,
}
)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksRunNowOperator",
)

db_mock.cancel_all_runs.assert_not_called()
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_not_called()


class TestDatabricksRunNowDeferrableOperator:
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
Expand All @@ -1407,7 +1475,7 @@ def test_execute_task_deferred(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")

with pytest.raises(TaskDeferred) as exc:
Expand Down Expand Up @@ -1470,7 +1538,7 @@ def test_execute_complete_failure(self, db_mock_class):
op.execute_complete(context=None, event=event)

db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException, match=f"Job run failed with terminal state: {run_state_failed}"):
Expand Down Expand Up @@ -1499,7 +1567,7 @@ def test_execute_complete_failure_and_repair_run(
op.execute_complete(context=None, event=event)

db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
db_mock.get_latest_repair_id.assert_called_once()
db_mock.repair_run.assert_called_once()
Expand All @@ -1521,7 +1589,7 @@ def test_operator_failed_before_defer(self, mock_defer, db_mock_class):
}
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
op.execute(None)

Expand All @@ -1548,7 +1616,7 @@ def test_databricks_run_now_deferrable_operator_failed_before_defer(self, mock_d
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

op.execute(None)
Expand Down Expand Up @@ -1581,7 +1649,7 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down

0 comments on commit 4e6d3fa

Please sign in to comment.