Skip to content

Commit

Permalink
get all failed tasks errors in when exception raised in DatabricksCre…
Browse files Browse the repository at this point in the history
…ateJobsOperator (#39354)
  • Loading branch information
gaurav7261 authored May 3, 2024
1 parent 61d1c95 commit 2d103e1
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
21 changes: 10 additions & 11 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,22 @@ def _handle_databricks_operator_execution(operator, hook, log, context) -> None:
log.info("%s completed successfully.", operator.task_id)
log.info("View run status, Spark UI, and logs at %s", run_page_url)
return

if run_state.result_state == "FAILED":
task_run_id = None
failed_tasks = []
for task in run_info.get("tasks", []):
if task.get("state", {}).get("result_state", "") == "FAILED":
task_run_id = task["run_id"]
if task_run_id is not None:
run_output = hook.get_run_output(task_run_id)
if "error" in run_output:
notebook_error = run_output["error"]
else:
notebook_error = run_state.state_message
else:
notebook_error = run_state.state_message
task_key = task["task_key"]
run_output = hook.get_run_output(task_run_id)
if "error" in run_output:
error = run_output["error"]
else:
error = run_state.state_message
failed_tasks.append({"task_key": task_key, "run_id": task_run_id, "error": error})

error_message = (
f"{operator.task_id} failed with terminal state: {run_state} "
f"and with the error {notebook_error}"
f"and with the errors {failed_tasks}"
)
else:
error_message = (
Expand Down
73 changes: 71 additions & 2 deletions tests/providers/databricks/operators/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,7 @@ def test_exec_failure_with_message(self, db_mock_class):
"tasks": [
{
"run_id": 2,
"task_key": "first_task",
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
Expand All @@ -1321,10 +1322,76 @@ def test_exec_failure_with_message(self, db_mock_class):
)
db_mock.get_run_output = mock_dict({"error": "Exception: Something went wrong..."})

with pytest.raises(AirflowException) as exc_info:
with pytest.raises(AirflowException, match="Exception: Something went wrong"):
op.execute(None)

assert exc_info.value.args[0].endswith(" Exception: Something went wrong...")
expected = utils.normalise_json_content(
{
"notebook_params": NOTEBOOK_PARAMS,
"notebook_task": NOTEBOOK_TASK,
"jar_params": JAR_PARAMS,
"job_id": JOB_ID,
}
)
db_mock_class.assert_called_once_with(
DEFAULT_CONN_ID,
retry_limit=op.databricks_retry_limit,
retry_delay=op.databricks_retry_delay,
retry_args=None,
caller="DatabricksRunNowOperator",
)
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
assert RUN_ID == op.run_id

@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
def test_exec_multiple_failures_with_message(self, db_mock_class):
"""
Test the execute function in case where the run failed.
"""
run = {"notebook_params": NOTEBOOK_PARAMS, "notebook_task": NOTEBOOK_TASK, "jar_params": JAR_PARAMS}
op = DatabricksRunNowOperator(task_id=TASK_ID, job_id=JOB_ID, json=run)
db_mock = db_mock_class.return_value
db_mock.run_now.return_value = RUN_ID
db_mock.get_run = mock_dict(
{
"job_id": JOB_ID,
"run_id": 1,
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "failed",
},
"tasks": [
{
"run_id": 2,
"task_key": "first_task",
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "failed",
},
},
{
"run_id": 3,
"task_key": "second_task",
"state": {
"life_cycle_state": "TERMINATED",
"result_state": "FAILED",
"state_message": "failed",
},
},
],
}
)
db_mock.get_run_output = mock_dict({"error": "Exception: Something went wrong..."})

with pytest.raises(
AirflowException,
match="(?=.*Exception: Something went wrong.*)(?=.*Exception: Something went wrong.*)",
):
op.execute(None)

expected = utils.normalise_json_content(
{
Expand All @@ -1341,6 +1408,8 @@ def test_exec_failure_with_message(self, db_mock_class):
retry_args=None,
caller="DatabricksRunNowOperator",
)
db_mock.get_run_output.assert_called()
assert db_mock.get_run_output.call_count == 2
db_mock.run_now.assert_called_once_with(expected)
db_mock.get_run_page_url.assert_called_once_with(RUN_ID)
db_mock.get_run.assert_called_once_with(RUN_ID)
Expand Down

0 comments on commit 2d103e1

Please sign in to comment.