Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cancel_previous_run to DatabricksRunNowOperator #38702

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ class DatabricksRunNowOperator(BaseOperator):
- ``spark_submit_params``
- ``idempotency_token``
- ``repair_run``
- ``cancel_previous_runs``

:param job_id: the job_id of the existing Databricks job.
This field will be templated.
Expand Down Expand Up @@ -740,6 +741,7 @@ class DatabricksRunNowOperator(BaseOperator):
:param wait_for_termination: if we should wait for termination of the job run. ``True`` by default.
:param deferrable: Run operator in the deferrable mode.
:param repair_run: Repair the databricks run in case of failure.
:param cancel_previous_runs: Cancel all existing running jobs before submitting new one.
"""

# Used in airflow.models.BaseOperator
Expand Down Expand Up @@ -771,6 +773,7 @@ def __init__(
wait_for_termination: bool = True,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
repair_run: bool = False,
cancel_previous_runs: bool = False,
**kwargs,
) -> None:
"""Create a new ``DatabricksRunNowOperator``."""
Expand All @@ -784,6 +787,7 @@ def __init__(
self.wait_for_termination = wait_for_termination
self.deferrable = deferrable
self.repair_run = repair_run
self.cancel_previous_runs = cancel_previous_runs

if job_id is not None:
self.json["job_id"] = job_id
Expand Down Expand Up @@ -831,6 +835,9 @@ def execute(self, context: Context):
self.json["job_id"] = job_id
del self.json["job_name"]

if self.cancel_previous_runs and self.json["job_id"] is not None:
hook.cancel_all_runs(self.json["job_id"])

self.run_id = hook.run_now(self.json)
if self.deferrable:
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ All other parameters are optional and described in documentation for ``Databrick
* ``python_named_parameters``
* ``jar_params``
* ``spark_submit_params``
* ``idempotency_token``
* ``repair_run``
* ``cancel_previous_runs``

DatabricksRunNowDeferrableOperator
==================================
Expand Down
110 changes: 89 additions & 21 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,7 @@ def test_exec_success(self, db_mock_class):
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down Expand Up @@ -767,7 +767,7 @@ def test_exec_pipeline_name(self, db_mock_class):
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.find_pipeline_id_by_name.return_value = PIPELINE_ID_TASK["pipeline_id"]
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down Expand Up @@ -798,7 +798,7 @@ def test_exec_failure(self, db_mock_class):
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException):
Expand Down Expand Up @@ -845,7 +845,7 @@ def test_wait_for_termination(self, db_mock_class):
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

assert op.wait_for_termination
Expand Down Expand Up @@ -875,7 +875,7 @@ def test_no_wait_for_termination(self, db_mock_class):
}
op = DatabricksSubmitRunOperator(task_id=TASK_ID, wait_for_termination=False, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID

assert not op.wait_for_termination

Expand Down Expand Up @@ -909,7 +909,7 @@ def test_execute_task_deferred(self, db_mock_class):
}
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")

with pytest.raises(TaskDeferred) as exc:
Expand Down Expand Up @@ -971,7 +971,7 @@ def test_execute_complete_failure(self, db_mock_class):
op.execute_complete(context=None, event=event)

db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException, match=f"Job run failed with terminal state: {run_state_failed}"):
Expand All @@ -993,7 +993,7 @@ def test_databricks_submit_run_deferrable_operator_failed_before_defer(self, moc
}
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
op.execute(None)

Expand Down Expand Up @@ -1023,7 +1023,7 @@ def test_databricks_submit_run_deferrable_operator_success_before_defer(self, mo
}
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")
op.execute(None)

Expand Down Expand Up @@ -1147,7 +1147,7 @@ def test_exec_success(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down Expand Up @@ -1181,7 +1181,7 @@ def test_exec_failure(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException):
Expand Down Expand Up @@ -1215,7 +1215,7 @@ def test_exec_failure_with_message(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = mock_dict(
{
"job_id": JOB_ID,
Expand Down Expand Up @@ -1279,7 +1279,7 @@ def test_wait_for_termination(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

assert op.wait_for_termination
Expand Down Expand Up @@ -1311,7 +1311,7 @@ def test_no_wait_for_termination(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, wait_for_termination=False, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID

assert not op.wait_for_termination

Expand Down Expand Up @@ -1357,7 +1357,7 @@ def test_exec_with_job_name(self, db_mock_class):
op = DatabricksRunNowOperator(task_id=TASK_ID, job_name=JOB_NAME, json=run)
db_mock = db_mock_class.return_value
db_mock.find_job_id_by_name.return_value = JOB_ID
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down Expand Up @@ -1397,6 +1397,74 @@ def test_exec_failure_if_job_id_not_found(self, db_mock_class):

db_mock.find_job_id_by_name.assert_called_once_with(JOB_NAME)

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_cancel_previous_runs(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(
task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=True, wait_for_termination=False, json=run
)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = RUN_ID

assert op.cancel_previous_runs

op.execute(None)

expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
"jar_params": JAR_PARAMS,
"job_id": JOB_ID,
}
)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksRunNowOperator",
)

db_mock.cancel_all_runs.assert_called_once_with(JOB_ID)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_not_called()

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_no_cancel_previous_runs(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(
task_id=TASK_ID, job_id=JOB_ID, cancel_previous_runs=False, wait_for_termination=False, json=run
)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = RUN_ID

assert not op.cancel_previous_runs

op.execute(None)

expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
"jar_params": JAR_PARAMS,
"job_id": JOB_ID,
}
)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksRunNowOperator",
)

db_mock.cancel_all_runs.assert_not_called()
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_not_called()


class TestDatabricksRunNowDeferrableOperator:
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
Expand All @@ -1407,7 +1475,7 @@ def test_execute_task_deferred(self, db_mock_class):
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("RUNNING", "RUNNING")

with pytest.raises(TaskDeferred) as exc:
Expand Down Expand Up @@ -1470,7 +1538,7 @@ def test_execute_complete_failure(self, db_mock_class):
op.execute_complete(context=None, event=event)

db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

with pytest.raises(AirflowException, match=f"Job run failed with terminal state: {run_state_failed}"):
Expand Down Expand Up @@ -1499,7 +1567,7 @@ def test_execute_complete_failure_and_repair_run(
op.execute_complete(context=None, event=event)

db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
db_mock.get_latest_repair_id.assert_called_once()
db_mock.repair_run.assert_called_once()
Expand All @@ -1521,7 +1589,7 @@ def test_operator_failed_before_defer(self, mock_defer, db_mock_class):
}
op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.submit_run.return_value = 1
db_mock.submit_run.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")
op.execute(None)

Expand All @@ -1548,7 +1616,7 @@ def test_databricks_run_now_deferrable_operator_failed_before_defer(self, mock_d
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "FAILED")

op.execute(None)
Expand Down Expand Up @@ -1581,7 +1649,7 @@ def test_databricks_run_now_deferrable_operator_success_before_defer(self, mock_
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = 1
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = make_run_with_state_mock("TERMINATED", "SUCCESS")

op.execute(None)
Expand Down