diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 9dc1df5afe675..458577f95ffda 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -44,6 +44,7 @@ GET_RUN_ENDPOINT = ("GET", "api/2.1/jobs/runs/get") CANCEL_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/cancel") DELETE_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/delete") +REPAIR_RUN_ENDPOINT = ("POST", "api/2.1/jobs/runs/repair") OUTPUT_RUNS_JOB_ENDPOINT = ("GET", "api/2.1/jobs/runs/get-output") INSTALL_LIBS_ENDPOINT = ("POST", "api/2.0/libraries/install") @@ -361,6 +362,14 @@ def delete_run(self, run_id: int) -> None: json = {"run_id": run_id} self._do_api_call(DELETE_RUN_ENDPOINT, json) + def repair_run(self, json: dict) -> None: + """ + Re-run one or more tasks. + + :param json: repair a job run. + """ + self._do_api_call(REPAIR_RUN_ENDPOINT, json) + def restart_cluster(self, json: dict) -> None: """ Restarts the cluster. diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 5cbd2f186def2..1d6d862363e5a 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -150,6 +150,13 @@ def delete_run_endpoint(host): return f"https://{host}/api/2.1/jobs/runs/delete" +def repair_run_endpoint(host): + """ + Utility function to generate delete run endpoint given the host. + """ + return f"https://{host}/api/2.1/jobs/runs/repair" + + def start_cluster_endpoint(host): """ Utility function to generate the get run endpoint given the host. @@ -543,6 +550,37 @@ def test_delete_run(self, mock_requests): timeout=self.hook.timeout_seconds, ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") + def test_repair_run(self, mock_requests): + mock_requests.post.return_value.json.return_value = {"repair_id": 734650698524280} + json = ( + { + "run_id": 455644833, + "rerun_tasks": ["task0", "task1"], + "latest_repair_id": 734650698524280, + "rerun_all_failed_tasks": False, + "jar_params": ["john", "doe", "35"], + "notebook_params": {"name": "john doe", "age": "35"}, + "python_params": ["john doe", "35"], + "spark_submit_params": ["--class", "org.apache.spark.examples.SparkPi"], + "python_named_params": {"name": "task", "data": "dbfs:/path/to/data.json"}, + "pipeline_params": {"full_refresh": True}, + "sql_params": {"name": "john doe", "age": "35"}, + "dbt_commands": ["dbt deps", "dbt seed", "dbt run"], + }, + ) + + self.hook.repair_run(json) + + mock_requests.post.assert_called_once_with( + repair_run_endpoint(HOST), + json=json, + params=None, + auth=HTTPBasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests") def test_start_cluster(self, mock_requests): mock_requests.codes.ok = 200