From bd6e798c93b6959041c490bd4d6cab3c788197a4 Mon Sep 17 00:00:00 2001 From: Gaurav Miglani Date: Thu, 18 Apr 2024 16:59:04 +0530 Subject: [PATCH 1/7] [FEAT] added notebook error in databricks deferrable handler --- .../databricks/operators/databricks.py | 42 +++++++++++-------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index eab772d233b13..7ace0608cf25d 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -69,19 +69,7 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None: return if run_state.result_state == "FAILED": - task_run_id = None - if "tasks" in run_info: - for task in run_info["tasks"]: - if task.get("state", {}).get("result_state", "") == "FAILED": - task_run_id = task["run_id"] - if task_run_id is not None: - run_output = hook.get_run_output(task_run_id) - if "error" in run_output: - notebook_error = run_output["error"] - else: - notebook_error = run_state.state_message - else: - notebook_error = run_state.state_message + notebook_error = _get_databricks_notebook_error(run_info, hook, run_state) error_message = ( f"{operator.task_id} failed with terminal state: {run_state} " f"and with the error {notebook_error}" @@ -156,17 +144,37 @@ def _handle_deferrable_databricks_operator_execution(operator, hook, log, contex log.info("%s completed successfully.", operator.task_id) -def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) -> None: +def _get_databricks_notebook_error(run_info: dict, hook: DatabricksHook, run_state: RunState): + task_run_id = None + if "tasks" in run_info: + for task in run_info["tasks"]: + if task.get("state", {}).get("result_state", "") == "FAILED": + task_run_id = task["run_id"] + if task_run_id is not None: + run_output = hook.get_run_output(task_run_id) + if "error" in run_output: + notebook_error = run_output["error"] + else: + notebook_error = run_state.state_message + else: + notebook_error = run_state.state_message + return notebook_error + + +def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger, hook: DatabricksHook) -> None: validate_trigger_event(event) run_state = RunState.from_json(event["run_state"]) run_page_url = event["run_page_url"] + run_id = event["run_id"] log.info("View run status, Spark UI, and logs at %s", run_page_url) if run_state.is_successful: log.info("Job run completed successfully.") return + run_info = hook.get_run(run_id) + notebook_error = _get_databricks_notebook_error(run_info, hook, run_state) + error_message = f"Job run failed with terminal state: {run_state} and with the error {notebook_error}" - error_message = f"Job run failed with terminal state: {run_state}" if event["repair_run"]: log.warning( "%s but since repair run is set, repairing the run with all failed tasks", @@ -573,7 +581,7 @@ def on_kill(self): self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id) def execute_complete(self, context: dict | None, event: dict): - _handle_deferrable_databricks_operator_completion(event, self.log) + _handle_deferrable_databricks_operator_completion(event, self.log, self._hook) @deprecated( @@ -850,7 +858,7 @@ def execute(self, context: Context): def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: if event: - _handle_deferrable_databricks_operator_completion(event, self.log) + _handle_deferrable_databricks_operator_completion(event, self.log, self._hook) if event["repair_run"]: self.repair_run = False self.run_id = event["run_id"] From 3b9715f14298cc4ff837aaa5ffc16d1ea6db5951 Mon Sep 17 00:00:00 2001 From: Gaurav Miglani Date: Fri, 19 Apr 2024 03:35:30 +0530 Subject: [PATCH 2/7] [CHORE] added return type --- airflow/providers/databricks/operators/databricks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 7ace0608cf25d..321131b68bf07 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -144,7 +144,7 @@ def _handle_deferrable_databricks_operator_execution(operator, hook, log, contex log.info("%s completed successfully.", operator.task_id) -def _get_databricks_notebook_error(run_info: dict, hook: DatabricksHook, run_state: RunState): +def _get_databricks_notebook_error(run_info: dict, hook: DatabricksHook, run_state: RunState) -> str: task_run_id = None if "tasks" in run_info: for task in run_info["tasks"]: From ab91c889f91c639386ef059a813cc93c1bdedd31 Mon Sep 17 00:00:00 2001 From: Gaurav Miglani Date: Mon, 22 Apr 2024 02:14:54 +0530 Subject: [PATCH 3/7] [CHORE] implement async get_run_output and handling notebook error --- .../providers/databricks/hooks/databricks.py | 11 +++++ .../databricks/operators/databricks.py | 43 +++++++++++-------- .../databricks/hooks/test_databricks.py | 17 ++++++++ 3 files changed, 54 insertions(+), 17 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 1a0ab8e8c6ba8..2cad06aa59e86 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -491,6 +491,17 @@ def get_run_output(self, run_id: int) -> dict: run_output = self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json) return run_output + async def a_get_run_output(self, run_id: int) -> dict: + """ + Async version of `get_run_output()`. + + :param run_id: id of the run + :return: output of the run + """ + json = {"run_id": run_id} + run_output = await self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json) + return run_output + def cancel_run(self, run_id: int) -> None: """ Cancel the run. diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 321131b68bf07..6ba1d3a3c99b7 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -69,7 +69,19 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None: return if run_state.result_state == "FAILED": - notebook_error = _get_databricks_notebook_error(run_info, hook, run_state) + task_run_id = None + if "tasks" in run_info: + for task in run_info["tasks"]: + if task.get("state", {}).get("result_state", "") == "FAILED": + task_run_id = task["run_id"] + if task_run_id is not None: + run_output = hook.get_run_output(task_run_id) + if "error" in run_output: + notebook_error = run_output["error"] + else: + notebook_error = run_state.state_message + else: + notebook_error = run_state.state_message error_message = ( f"{operator.task_id} failed with terminal state: {run_state} " f"and with the error {notebook_error}" @@ -144,35 +156,32 @@ def _handle_deferrable_databricks_operator_execution(operator, hook, log, contex log.info("%s completed successfully.", operator.task_id) -def _get_databricks_notebook_error(run_info: dict, hook: DatabricksHook, run_state: RunState) -> str: +def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger, hook: DatabricksHook) -> None: + validate_trigger_event(event) + run_state = RunState.from_json(event["run_state"]) + run_page_url = event["run_page_url"] + run_id = event["run_id"] + log.info("View run status, Spark UI, and logs at %s", run_page_url) + + if run_state.is_successful: + log.info("Job run completed successfully.") + return + + run_info = await hook.a_get_run(run_id) task_run_id = None if "tasks" in run_info: for task in run_info["tasks"]: if task.get("state", {}).get("result_state", "") == "FAILED": task_run_id = task["run_id"] if task_run_id is not None: - run_output = hook.get_run_output(task_run_id) + run_output = await hook.a_get_run_output(task_run_id) if "error" in run_output: notebook_error = run_output["error"] else: notebook_error = run_state.state_message else: notebook_error = run_state.state_message - return notebook_error - -def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger, hook: DatabricksHook) -> None: - validate_trigger_event(event) - run_state = RunState.from_json(event["run_state"]) - run_page_url = event["run_page_url"] - run_id = event["run_id"] - log.info("View run status, Spark UI, and logs at %s", run_page_url) - - if run_state.is_successful: - log.info("Job run completed successfully.") - return - run_info = hook.get_run(run_id) - notebook_error = _get_databricks_notebook_error(run_info, hook, run_state) error_message = f"Job run failed with terminal state: {run_state} and with the error {notebook_error}" if event["repair_run"]: diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 64d4de1d37766..514af27546bcd 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -1603,6 +1603,23 @@ async def test_get_cluster_state(self, mock_get): timeout=self.hook.timeout_seconds, ) + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks_base.aiohttp.ClientSession.get") + async def test_get_run_output(self, mock_get): + mock_get.return_value.__aenter__.return_value.json = AsyncMock(return_value=GET_RUN_OUTPUT_RESPONSE) + async with self.hook: + run_output = await self.hook.a_get_run_output(RUN_ID) + run_output_error = run_output.get("error") + + assert run_output_error == ERROR_MESSAGE + mock_get.assert_called_once_with( + get_run_output_endpoint(HOST), + json=None, + auth=aiohttp.BasicAuth(LOGIN, PASSWORD), + headers=self.hook.user_agent_header, + timeout=self.hook.timeout_seconds, + ) + @pytest.mark.db_test class TestDatabricksHookAsyncAadToken: From 1b08b98b5584cd4c9a4825e3c024a537197e7793 Mon Sep 17 00:00:00 2001 From: Gaurav Miglani Date: Mon, 22 Apr 2024 16:04:54 +0530 Subject: [PATCH 4/7] [CHORE] added login in run method --- .../databricks/operators/databricks.py | 23 ++++--------------- .../databricks/triggers/databricks.py | 17 ++++++++++++++ .../providers/databricks/utils/databricks.py | 2 +- .../databricks/utils/test_databricks.py | 1 + 4 files changed, 23 insertions(+), 20 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 6ba1d3a3c99b7..967d3f4c3bc5a 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -156,32 +156,17 @@ def _handle_deferrable_databricks_operator_execution(operator, hook, log, contex log.info("%s completed successfully.", operator.task_id) -def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger, hook: DatabricksHook) -> None: +def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) -> None: validate_trigger_event(event) run_state = RunState.from_json(event["run_state"]) run_page_url = event["run_page_url"] - run_id = event["run_id"] + notebook_error = event["notebook_error"] log.info("View run status, Spark UI, and logs at %s", run_page_url) if run_state.is_successful: log.info("Job run completed successfully.") return - run_info = await hook.a_get_run(run_id) - task_run_id = None - if "tasks" in run_info: - for task in run_info["tasks"]: - if task.get("state", {}).get("result_state", "") == "FAILED": - task_run_id = task["run_id"] - if task_run_id is not None: - run_output = await hook.a_get_run_output(task_run_id) - if "error" in run_output: - notebook_error = run_output["error"] - else: - notebook_error = run_state.state_message - else: - notebook_error = run_state.state_message - error_message = f"Job run failed with terminal state: {run_state} and with the error {notebook_error}" if event["repair_run"]: @@ -590,7 +575,7 @@ def on_kill(self): self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id) def execute_complete(self, context: dict | None, event: dict): - _handle_deferrable_databricks_operator_completion(event, self.log, self._hook) + _handle_deferrable_databricks_operator_completion(event, self.log) @deprecated( @@ -867,7 +852,7 @@ def execute(self, context: Context): def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: if event: - _handle_deferrable_databricks_operator_completion(event, self.log, self._hook) + _handle_deferrable_databricks_operator_completion(event, self.log) if event["repair_run"]: self.repair_run = False self.run_id = event["run_id"] diff --git a/airflow/providers/databricks/triggers/databricks.py b/airflow/providers/databricks/triggers/databricks.py index 4c1eecb85f7fd..70a6df9529a3f 100644 --- a/airflow/providers/databricks/triggers/databricks.py +++ b/airflow/providers/databricks/triggers/databricks.py @@ -84,13 +84,30 @@ async def run(self): async with self.hook: while True: run_state = await self.hook.a_get_run_state(self.run_id) + notebook_error = None if run_state.is_terminal: + if run_state.result_state == "FAILED": + run_info = await self.hook.a_get_run(self.run_id) + task_run_id = None + if "tasks" in run_info: + for task in run_info["tasks"]: + if task.get("state", {}).get("result_state", "") == "FAILED": + task_run_id = task["run_id"] + if task_run_id is not None: + run_output = await self.hook.a_get_run_output(task_run_id) + if "error" in run_output: + notebook_error = run_output["error"] + else: + notebook_error = run_state.state_message + else: + notebook_error = run_state.state_message yield TriggerEvent( { "run_id": self.run_id, "run_page_url": self.run_page_url, "run_state": run_state.to_json(), "repair_run": self.repair_run, + "notebook_error": notebook_error } ) return diff --git a/airflow/providers/databricks/utils/databricks.py b/airflow/providers/databricks/utils/databricks.py index 0635017b28f80..d28ecae3637eb 100644 --- a/airflow/providers/databricks/utils/databricks.py +++ b/airflow/providers/databricks/utils/databricks.py @@ -55,7 +55,7 @@ def validate_trigger_event(event: dict): See: :class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger`. """ - keys_to_check = ["run_id", "run_page_url", "run_state"] + keys_to_check = ["run_id", "run_page_url", "run_state", "notebook_error"] for key in keys_to_check: if key not in event: raise AirflowException(f"Could not find `{key}` in the event: {event}") diff --git a/tests/providers/databricks/utils/test_databricks.py b/tests/providers/databricks/utils/test_databricks.py index 7619bcb8ad07f..00bbd43419c00 100644 --- a/tests/providers/databricks/utils/test_databricks.py +++ b/tests/providers/databricks/utils/test_databricks.py @@ -53,6 +53,7 @@ def test_validate_trigger_event_success(self): "run_id": RUN_ID, "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), + "notebook_error": None } assert validate_trigger_event(event) is None From f406db494240154e833435bf4b3dc9cea402711e Mon Sep 17 00:00:00 2001 From: Gaurav Miglani Date: Mon, 22 Apr 2024 19:24:18 +0530 Subject: [PATCH 5/7] [FIX] tests --- .../providers/databricks/hooks/databricks.py | 2 +- .../databricks/triggers/databricks.py | 2 +- .../databricks/hooks/test_databricks.py | 2 +- .../databricks/operators/test_databricks.py | 5 ++ .../databricks/triggers/test_databricks.py | 72 ++++++++++++++++++- .../databricks/utils/test_databricks.py | 2 +- 6 files changed, 79 insertions(+), 6 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks.py b/airflow/providers/databricks/hooks/databricks.py index 2cad06aa59e86..710074d239af0 100644 --- a/airflow/providers/databricks/hooks/databricks.py +++ b/airflow/providers/databricks/hooks/databricks.py @@ -499,7 +499,7 @@ async def a_get_run_output(self, run_id: int) -> dict: :return: output of the run """ json = {"run_id": run_id} - run_output = await self._do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json) + run_output = await self._a_do_api_call(OUTPUT_RUNS_JOB_ENDPOINT, json) return run_output def cancel_run(self, run_id: int) -> None: diff --git a/airflow/providers/databricks/triggers/databricks.py b/airflow/providers/databricks/triggers/databricks.py index 70a6df9529a3f..82d83e1218d15 100644 --- a/airflow/providers/databricks/triggers/databricks.py +++ b/airflow/providers/databricks/triggers/databricks.py @@ -107,7 +107,7 @@ async def run(self): "run_page_url": self.run_page_url, "run_state": run_state.to_json(), "repair_run": self.repair_run, - "notebook_error": notebook_error + "notebook_error": notebook_error, } ) return diff --git a/tests/providers/databricks/hooks/test_databricks.py b/tests/providers/databricks/hooks/test_databricks.py index 514af27546bcd..0f1d2c242e456 100644 --- a/tests/providers/databricks/hooks/test_databricks.py +++ b/tests/providers/databricks/hooks/test_databricks.py @@ -1614,7 +1614,7 @@ async def test_get_run_output(self, mock_get): assert run_output_error == ERROR_MESSAGE mock_get.assert_called_once_with( get_run_output_endpoint(HOST), - json=None, + json={"run_id": RUN_ID}, auth=aiohttp.BasicAuth(LOGIN, PASSWORD), headers=self.hook.user_agent_header, timeout=self.hook.timeout_seconds, diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 46e14a917ab4e..dce0f88876a84 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1013,6 +1013,7 @@ def test_execute_complete_success(self): "run_id": RUN_ID, "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), + "notebook_error": None, } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) @@ -1033,6 +1034,7 @@ def test_execute_complete_failure(self, db_mock_class): "run_page_url": RUN_PAGE_URL, "run_state": run_state_failed.to_json(), "repair_run": False, + "notebook_error": None, } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) @@ -1583,6 +1585,7 @@ def test_execute_complete_success(self): "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), "repair_run": False, + "notebook_error": None, } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) @@ -1600,6 +1603,7 @@ def test_execute_complete_failure(self, db_mock_class): "run_page_url": RUN_PAGE_URL, "run_state": run_state_failed.to_json(), "repair_run": False, + "notebook_error": None, } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) @@ -1630,6 +1634,7 @@ def test_execute_complete_failure_and_repair_run( "run_page_url": RUN_PAGE_URL, "run_state": run_state_failed.to_json(), "repair_run": True, + "notebook_error": None, } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) diff --git a/tests/providers/databricks/triggers/test_databricks.py b/tests/providers/databricks/triggers/test_databricks.py index a0313a31f4fa0..6d541d4c486d6 100644 --- a/tests/providers/databricks/triggers/test_databricks.py +++ b/tests/providers/databricks/triggers/test_databricks.py @@ -38,13 +38,17 @@ RETRY_DELAY = 10 RETRY_LIMIT = 3 RUN_ID = 1 +TASK_RUN_ID = 11 JOB_ID = 42 RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1" +ERROR_MESSAGE = "error message from databricks API" +GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}} RUN_LIFE_CYCLE_STATES = ["PENDING", "RUNNING", "TERMINATING", "TERMINATED", "SKIPPED", "INTERNAL_ERROR"] LIFE_CYCLE_STATE_PENDING = "PENDING" LIFE_CYCLE_STATE_TERMINATED = "TERMINATED" +LIFE_CYCLE_STATE_INTERNAL_ERROR = "INTERNAL_ERROR" STATE_MESSAGE = "Waiting for cluster" @@ -66,6 +70,25 @@ "result_state": "SUCCESS", }, } +GET_RUN_RESPONSE_TERMINATED_WITH_FAILED = { + "job_id": JOB_ID, + "run_page_url": RUN_PAGE_URL, + "state": { + "life_cycle_state": LIFE_CYCLE_STATE_INTERNAL_ERROR, + "state_message": None, + "result_state": "FAILED", + }, + "tasks": [ + { + "run_id": TASK_RUN_ID, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "Workload failed, see run output for details", + }, + } + ], +} class TestDatabricksExecutionTrigger: @@ -101,15 +124,21 @@ def test_serialize(self): ) @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") - async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_url): + async def test_run_return_success( + self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output + ): mock_get_run_page_url.return_value = RUN_PAGE_URL mock_get_run_state.return_value = RunState( life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="SUCCESS", ) + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE trigger_event = self.trigger.run() async for event in trigger_event: @@ -121,13 +150,49 @@ async def test_run_return_success(self, mock_get_run_state, mock_get_run_page_ur ).to_json(), "run_page_url": RUN_PAGE_URL, "repair_run": False, + "notebook_error": None, + } + ) + + @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_page_url") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") + async def test_run_return_failure( + self, mock_get_run_state, mock_get_run_page_url, mock_get_run, mock_get_run_output + ): + mock_get_run_page_url.return_value = RUN_PAGE_URL + mock_get_run_state.return_value = RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, + state_message="", + result_state="FAILED", + ) + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED_WITH_FAILED + + trigger_event = self.trigger.run() + async for event in trigger_event: + assert event == TriggerEvent( + { + "run_id": RUN_ID, + "run_state": RunState( + life_cycle_state=LIFE_CYCLE_STATE_TERMINATED, state_message="", result_state="FAILED" + ).to_json(), + "run_page_url": RUN_PAGE_URL, + "repair_run": False, + "notebook_error": ERROR_MESSAGE, } ) @pytest.mark.asyncio + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") + @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") @mock.patch("airflow.providers.databricks.triggers.databricks.asyncio.sleep") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_state") - async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep): + async def test_sleep_between_retries( + self, mock_get_run_state, mock_sleep, mock_get_run, mock_get_run_output + ): mock_get_run_state.side_effect = [ RunState( life_cycle_state=LIFE_CYCLE_STATE_PENDING, @@ -140,6 +205,8 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep): result_state="SUCCESS", ), ] + mock_get_run.return_value = GET_RUN_RESPONSE_TERMINATED + mock_get_run_output.return_value = GET_RUN_OUTPUT_RESPONSE trigger_event = self.trigger.run() async for event in trigger_event: @@ -151,6 +218,7 @@ async def test_sleep_between_retries(self, mock_get_run_state, mock_sleep): ).to_json(), "run_page_url": RUN_PAGE_URL, "repair_run": False, + "notebook_error": None, } ) mock_sleep.assert_called_once() diff --git a/tests/providers/databricks/utils/test_databricks.py b/tests/providers/databricks/utils/test_databricks.py index 00bbd43419c00..df85fdd9efb13 100644 --- a/tests/providers/databricks/utils/test_databricks.py +++ b/tests/providers/databricks/utils/test_databricks.py @@ -53,7 +53,7 @@ def test_validate_trigger_event_success(self): "run_id": RUN_ID, "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), - "notebook_error": None + "notebook_error": None, } assert validate_trigger_event(event) is None From 805544c2936a49f572bc6cfa32ba9115ad0bd4b1 Mon Sep 17 00:00:00 2001 From: Gaurav Miglani Date: Thu, 25 Apr 2024 22:20:29 +0530 Subject: [PATCH 6/7] [CHORE] refactor review changes --- .../databricks/operators/databricks.py | 7 +-- .../databricks/triggers/databricks.py | 62 +++++++++---------- 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 967d3f4c3bc5a..c174d729d8ffe 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -70,10 +70,9 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None: if run_state.result_state == "FAILED": task_run_id = None - if "tasks" in run_info: - for task in run_info["tasks"]: - if task.get("state", {}).get("result_state", "") == "FAILED": - task_run_id = task["run_id"] + for task in run_info.get("tasks", []): + if task.get("state", {}).get("result_state", "") == "FAILED": + task_run_id = task["run_id"] if task_run_id is not None: run_output = hook.get_run_output(task_run_id) if "error" in run_output: diff --git a/airflow/providers/databricks/triggers/databricks.py b/airflow/providers/databricks/triggers/databricks.py index 82d83e1218d15..a0842349ef2c8 100644 --- a/airflow/providers/databricks/triggers/databricks.py +++ b/airflow/providers/databricks/triggers/databricks.py @@ -85,37 +85,37 @@ async def run(self): while True: run_state = await self.hook.a_get_run_state(self.run_id) notebook_error = None - if run_state.is_terminal: - if run_state.result_state == "FAILED": - run_info = await self.hook.a_get_run(self.run_id) - task_run_id = None - if "tasks" in run_info: - for task in run_info["tasks"]: - if task.get("state", {}).get("result_state", "") == "FAILED": - task_run_id = task["run_id"] - if task_run_id is not None: - run_output = await self.hook.a_get_run_output(task_run_id) - if "error" in run_output: - notebook_error = run_output["error"] - else: - notebook_error = run_state.state_message - else: - notebook_error = run_state.state_message - yield TriggerEvent( - { - "run_id": self.run_id, - "run_page_url": self.run_page_url, - "run_state": run_state.to_json(), - "repair_run": self.repair_run, - "notebook_error": notebook_error, - } + if not run_state.is_terminal: + self.log.info( + "run-id %s in run state %s. sleeping for %s seconds", + self.run_id, + run_state, + self.polling_period_seconds, ) - return + await asyncio.sleep(self.polling_period_seconds) + continue - self.log.info( - "run-id %s in run state %s. sleeping for %s seconds", - self.run_id, - run_state, - self.polling_period_seconds, + if run_state.result_state == "FAILED": + run_info = await self.hook.a_get_run(self.run_id) + task_run_id = None + for task in run_info.get("tasks", []): + if task.get("state", {}).get("result_state", "") == "FAILED": + task_run_id = task["run_id"] + if task_run_id is not None: + run_output = await self.hook.a_get_run_output(task_run_id) + if "error" in run_output: + notebook_error = run_output["error"] + else: + notebook_error = run_state.state_message + else: + notebook_error = run_state.state_message + yield TriggerEvent( + { + "run_id": self.run_id, + "run_page_url": self.run_page_url, + "run_state": run_state.to_json(), + "repair_run": self.repair_run, + "notebook_error": notebook_error, + } ) - await asyncio.sleep(self.polling_period_seconds) + return From 3f6d898df0dc5b485c04865e36fe035ce4e4b138 Mon Sep 17 00:00:00 2001 From: Gaurav Miglani Date: Tue, 30 Apr 2024 15:59:35 +0530 Subject: [PATCH 7/7] [CHORE] review changes, getting all failed task errors --- .../databricks/operators/databricks.py | 4 +- .../databricks/triggers/databricks.py | 20 +++++----- .../providers/databricks/utils/databricks.py | 2 +- .../databricks/operators/test_databricks.py | 10 ++--- .../databricks/triggers/test_databricks.py | 39 ++++++++++++++++--- .../databricks/utils/test_databricks.py | 2 +- 6 files changed, 51 insertions(+), 26 deletions(-) diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 6f2b334938914..c38b0683c37b3 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -159,14 +159,14 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) validate_trigger_event(event) run_state = RunState.from_json(event["run_state"]) run_page_url = event["run_page_url"] - notebook_error = event["notebook_error"] + errors = event["errors"] log.info("View run status, Spark UI, and logs at %s", run_page_url) if run_state.is_successful: log.info("Job run completed successfully.") return - error_message = f"Job run failed with terminal state: {run_state} and with the error {notebook_error}" + error_message = f"Job run failed with terminal state: {run_state} and with the errors {errors}" if event["repair_run"]: log.warning( diff --git a/airflow/providers/databricks/triggers/databricks.py b/airflow/providers/databricks/triggers/databricks.py index a0842349ef2c8..d20202fdca7f8 100644 --- a/airflow/providers/databricks/triggers/databricks.py +++ b/airflow/providers/databricks/triggers/databricks.py @@ -84,7 +84,6 @@ async def run(self): async with self.hook: while True: run_state = await self.hook.a_get_run_state(self.run_id) - notebook_error = None if not run_state.is_terminal: self.log.info( "run-id %s in run state %s. sleeping for %s seconds", @@ -95,27 +94,26 @@ async def run(self): await asyncio.sleep(self.polling_period_seconds) continue + failed_tasks = [] if run_state.result_state == "FAILED": run_info = await self.hook.a_get_run(self.run_id) - task_run_id = None for task in run_info.get("tasks", []): if task.get("state", {}).get("result_state", "") == "FAILED": task_run_id = task["run_id"] - if task_run_id is not None: - run_output = await self.hook.a_get_run_output(task_run_id) - if "error" in run_output: - notebook_error = run_output["error"] - else: - notebook_error = run_state.state_message - else: - notebook_error = run_state.state_message + task_key = task["task_key"] + run_output = await self.hook.a_get_run_output(task_run_id) + if "error" in run_output: + error = run_output["error"] + else: + error = run_state.state_message + failed_tasks.append({"task_key": task_key, "run_id": task_run_id, "error": error}) yield TriggerEvent( { "run_id": self.run_id, "run_page_url": self.run_page_url, "run_state": run_state.to_json(), "repair_run": self.repair_run, - "notebook_error": notebook_error, + "errors": failed_tasks, } ) return diff --git a/airflow/providers/databricks/utils/databricks.py b/airflow/providers/databricks/utils/databricks.py index d28ecae3637eb..88d622c3bc1fb 100644 --- a/airflow/providers/databricks/utils/databricks.py +++ b/airflow/providers/databricks/utils/databricks.py @@ -55,7 +55,7 @@ def validate_trigger_event(event: dict): See: :class:`~airflow.providers.databricks.triggers.databricks.DatabricksExecutionTrigger`. """ - keys_to_check = ["run_id", "run_page_url", "run_state", "notebook_error"] + keys_to_check = ["run_id", "run_page_url", "run_state", "errors"] for key in keys_to_check: if key not in event: raise AirflowException(f"Could not find `{key}` in the event: {event}") diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index f080586060d0e..e6cb240dfc9f3 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1024,7 +1024,7 @@ def test_execute_complete_success(self): "run_id": RUN_ID, "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), - "notebook_error": None, + "errors": [], } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) @@ -1045,7 +1045,7 @@ def test_execute_complete_failure(self, db_mock_class): "run_page_url": RUN_PAGE_URL, "run_state": run_state_failed.to_json(), "repair_run": False, - "notebook_error": None, + "errors": [], } op = DatabricksSubmitRunDeferrableOperator(task_id=TASK_ID, json=run) @@ -1596,7 +1596,7 @@ def test_execute_complete_success(self): "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), "repair_run": False, - "notebook_error": None, + "errors": [], } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) @@ -1614,7 +1614,7 @@ def test_execute_complete_failure(self, db_mock_class): "run_page_url": RUN_PAGE_URL, "run_state": run_state_failed.to_json(), "repair_run": False, - "notebook_error": None, + "errors": [], } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) @@ -1645,7 +1645,7 @@ def test_execute_complete_failure_and_repair_run( "run_page_url": RUN_PAGE_URL, "run_state": run_state_failed.to_json(), "repair_run": True, - "notebook_error": None, + "errors": [], } op = DatabricksRunNowDeferrableOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) diff --git a/tests/providers/databricks/triggers/test_databricks.py b/tests/providers/databricks/triggers/test_databricks.py index 6d541d4c486d6..b4bcbc133c0ce 100644 --- a/tests/providers/databricks/triggers/test_databricks.py +++ b/tests/providers/databricks/triggers/test_databricks.py @@ -38,7 +38,12 @@ RETRY_DELAY = 10 RETRY_LIMIT = 3 RUN_ID = 1 -TASK_RUN_ID = 11 +TASK_RUN_ID1 = 11 +TASK_RUN_ID1_KEY = "first_task" +TASK_RUN_ID2 = 22 +TASK_RUN_ID2_KEY = "second_task" +TASK_RUN_ID3 = 33 +TASK_RUN_ID3_KEY = "third_task" JOB_ID = 42 RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1" ERROR_MESSAGE = "error message from databricks API" @@ -80,13 +85,32 @@ }, "tasks": [ { - "run_id": TASK_RUN_ID, + "run_id": TASK_RUN_ID1, + "task_key": TASK_RUN_ID1_KEY, "state": { "life_cycle_state": "TERMINATED", "result_state": "FAILED", "state_message": "Workload failed, see run output for details", }, - } + }, + { + "run_id": TASK_RUN_ID2, + "task_key": TASK_RUN_ID2_KEY, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "SUCCESS", + "state_message": None, + }, + }, + { + "run_id": TASK_RUN_ID3, + "task_key": TASK_RUN_ID3_KEY, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "Workload failed, see run output for details", + }, + }, ], } @@ -150,7 +174,7 @@ async def test_run_return_success( ).to_json(), "run_page_url": RUN_PAGE_URL, "repair_run": False, - "notebook_error": None, + "errors": [], } ) @@ -181,7 +205,10 @@ async def test_run_return_failure( ).to_json(), "run_page_url": RUN_PAGE_URL, "repair_run": False, - "notebook_error": ERROR_MESSAGE, + "errors": [ + {"task_key": TASK_RUN_ID1_KEY, "run_id": TASK_RUN_ID1, "error": ERROR_MESSAGE}, + {"task_key": TASK_RUN_ID3_KEY, "run_id": TASK_RUN_ID3, "error": ERROR_MESSAGE}, + ], } ) @@ -218,7 +245,7 @@ async def test_sleep_between_retries( ).to_json(), "run_page_url": RUN_PAGE_URL, "repair_run": False, - "notebook_error": None, + "errors": [], } ) mock_sleep.assert_called_once() diff --git a/tests/providers/databricks/utils/test_databricks.py b/tests/providers/databricks/utils/test_databricks.py index df85fdd9efb13..8c6ce8ce4ba59 100644 --- a/tests/providers/databricks/utils/test_databricks.py +++ b/tests/providers/databricks/utils/test_databricks.py @@ -53,7 +53,7 @@ def test_validate_trigger_event_success(self): "run_id": RUN_ID, "run_page_url": RUN_PAGE_URL, "run_state": RunState("TERMINATED", "SUCCESS", "").to_json(), - "notebook_error": None, + "errors": [], } assert validate_trigger_event(event) is None