Skip to content
This repository was archived by the owner on Sep 4, 2024. It is now read-only.

Commit

Permalink
Support for existing_cluster_id in DatabricksNotebookOperator (#73)
Browse files Browse the repository at this point in the history
When tasks are launched with `DatabricksNotebookOperators` from within a TaskGroup
using the `DatabricksWorkflowTaskGroup`, currently we do not support using `existing_cluster_id`
for those Notebook tasks. The PR addresses this issue by allowing to support 
`existing_cluster_id` in such cases and additionally also keeps supporting the current
`job_cluster_key` approach allowing users to use a combination of both for a workflow.


closes: #70
  • Loading branch information
Hang1225 authored Apr 3, 2024
1 parent f7d4e6a commit d72aeb3
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/astro_databricks/operators/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,17 @@ def convert_to_databricks_workflow_task(
for t in self.upstream_task_ids
if t in relevant_upstreams
],
"job_cluster_key": self.job_cluster_key,
**base_task_json,
}

if self.existing_cluster_id and self.job_cluster_key:
raise ValueError ("Both existing_cluster_id and job_cluster_key are set. Only one cluster can be set per task.")

if self.existing_cluster_id:
result['existing_cluster_id'] = self.existing_cluster_id
elif self.job_cluster_key:
result['job_cluster_key'] = self.job_cluster_key

return result

def _get_databricks_task_id(self, task_id: str):
Expand Down
60 changes: 60 additions & 0 deletions tests/databricks/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from unittest import mock

import pytest
import copy
from airflow.exceptions import AirflowException
from airflow.utils.task_group import TaskGroup
from astro_databricks.operators.notebook import DatabricksNotebookOperator
Expand Down Expand Up @@ -50,6 +51,10 @@
"timeout_seconds": 0,
}

expected_workflow_json_existing_cluster_id = copy.deepcopy(expected_workflow_json)
# remove job_cluster_key and add existing_cluster_id
expected_workflow_json_existing_cluster_id['tasks'][1].pop('job_cluster_key')
expected_workflow_json_existing_cluster_id['tasks'][1]['existing_cluster_id'] = 'foo'

@mock.patch("astro_databricks.operators.workflow.DatabricksHook")
@mock.patch("astro_databricks.operators.workflow.ApiClient")
Expand Down Expand Up @@ -374,3 +379,58 @@ def test_create_workflow_with_nested_task_groups(
== "unit_test_dag__test_workflow__middle_task_group__inner_task_group__inner_notebook"
)
assert outer_notebook_json["libraries"] == [{"pypi": {"package": "mlflow==2.4.0"}}]

@mock.patch("astro_databricks.operators.workflow.DatabricksHook")
@mock.patch("astro_databricks.operators.workflow.ApiClient")
@mock.patch("astro_databricks.operators.workflow.JobsApi")
@mock.patch(
"astro_databricks.operators.workflow.RunsApi.get_run",
return_value={"state": {"life_cycle_state": "RUNNING"}},
)
def test_create_workflow_from_notebooks_with_different_clusters(
mock_run_api, mock_jobs_api, mock_api, mock_hook, dag
):
mock_jobs_api.return_value.create_job.return_value = {"job_id": 1}
with dag:
task_group = DatabricksWorkflowTaskGroup(
group_id="test_workflow",
databricks_conn_id="foo",
job_clusters=[{"job_cluster_key": "foo"}],
notebook_params={"notebook_path": "/foo/bar"},
notebook_packages=[{"tg_index": {"package": "tg_package"}}],
)
with task_group:
notebook_1 = DatabricksNotebookOperator(
task_id="notebook_1",
databricks_conn_id="foo",
notebook_path="/foo/bar",
notebook_packages=[{"nb_index": {"package": "nb_package"}}],
source="WORKSPACE",
job_cluster_key="foo",
)
notebook_2 = DatabricksNotebookOperator(
task_id="notebook_2",
databricks_conn_id="foo",
notebook_path="/foo/bar",
source="WORKSPACE",
existing_cluster_id="foo",
notebook_params={
"foo": "bar",
},
)
notebook_1 >> notebook_2

assert len(task_group.children) == 3
task_group.children["test_workflow.launch"].execute(context={})
mock_jobs_api.return_value.create_job.assert_called_once_with(
json=expected_workflow_json_existing_cluster_id,
version=DATABRICKS_JOBS_API_VERSION,
)
mock_jobs_api.return_value.run_now.assert_called_once_with(
job_id=1,
jar_params=[],
notebook_params={"notebook_path": "/foo/bar"},
python_params=[],
spark_submit_params=[],
version=DATABRICKS_JOBS_API_VERSION,
)

0 comments on commit d72aeb3

Please sign in to comment.