diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index c38b0683c37b3..0d819e1b709d5 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -67,23 +67,22 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None: log.info("%s completed successfully.", operator.task_id) log.info("View run status, Spark UI, and logs at %s", run_page_url) return - if run_state.result_state == "FAILED": - task_run_id = None + failed_tasks = [] for task in run_info.get("tasks", []): if task.get("state", {}).get("result_state", "") == "FAILED": task_run_id = task["run_id"] - if task_run_id is not None: - run_output = hook.get_run_output(task_run_id) - if "error" in run_output: - notebook_error = run_output["error"] - else: - notebook_error = run_state.state_message - else: - notebook_error = run_state.state_message + task_key = task["task_key"] + run_output = hook.get_run_output(task_run_id) + if "error" in run_output: + error = run_output["error"] + else: + error = run_state.state_message + failed_tasks.append({"task_key": task_key, "run_id": task_run_id, "error": error}) + error_message = ( f"{operator.task_id} failed with terminal state: {run_state} " - f"and with the error {notebook_error}" + f"and with the errors {failed_tasks}" ) else: error_message = ( diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index e6cb240dfc9f3..64b9ba985cb19 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1310,6 +1310,7 @@ def test_exec_failure_with_message(self, db_mock_class): "tasks": [ { "run_id": 2, + "task_key": "first_task", "state": { "life_cycle_state": "TERMINATED", "result_state": "FAILED", @@ -1321,10 +1322,76 @@ def test_exec_failure_with_message(self, db_mock_class): ) db_mock.get_run_output = mock_dict({"error": "Exception: Something went wrong..."}) - with pytest.raises(AirflowException) as exc_info: + with pytest.raises(AirflowException, match="Exception: Something went wrong"): op.execute(None) - assert exc_info.value.args[0].endswith(" Exception: Something went wrong...") + expected = utils.normalise_json_content( + { + "notebook_params": NOTEBOOK_PARAMS, + "notebook_task": NOTEBOOK_TASK, + "jar_params": JAR_PARAMS, + "job_id": JOB_ID, + } + ) + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksRunNowOperator", + ) + db_mock.run_now.assert_called_once_with(expected) + db_mock.get_run_page_url.assert_called_once_with(RUN_ID) + db_mock.get_run.assert_called_once_with(RUN_ID) + assert RUN_ID == op.run_id + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_exec_multiple_failures_with_message(self, db_mock_class): + """ + Test the execute function in case where the run failed. + """ + run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS} + op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run) + db_mock = db_mock_class.return_value + db_mock.run_now.return_value = RUN_ID + db_mock.get_run = mock_dict( + { + "job_id": JOB_ID, + "run_id": 1, + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "failed", + }, + "tasks": [ + { + "run_id": 2, + "task_key": "first_task", + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "failed", + }, + }, + { + "run_id": 3, + "task_key": "second_task", + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "failed", + }, + }, + ], + } + ) + db_mock.get_run_output = mock_dict({"error": "Exception: Something went wrong..."}) + + with pytest.raises( + AirflowException, + match="(?=.*Exception: Something went wrong.*)(?=.*Exception: Something went wrong.*)", + ): + op.execute(None) expected = utils.normalise_json_content( { @@ -1341,6 +1408,8 @@ def test_exec_failure_with_message(self, db_mock_class): retry_args=None, caller="DatabricksRunNowOperator", ) + db_mock.get_run_output.assert_called() + assert db_mock.get_run_output.call_count == 2 db_mock.run_now.assert_called_once_with(expected) db_mock.get_run_page_url.assert_called_once_with(RUN_ID) db_mock.get_run.assert_called_once_with(RUN_ID)