diff --git a/src/astro_databricks/constants.py b/src/astro_databricks/constants.py deleted file mode 100644 index 13d81c9..0000000 --- a/src/astro_databricks/constants.py +++ /dev/null @@ -1,3 +0,0 @@ -import os - -JOBS_API_VERSION = os.getenv("JOBS_API_VERSION", "2.1") diff --git a/src/astro_databricks/operators/common.py b/src/astro_databricks/operators/common.py index ccc77aa..4a7f5a9 100644 --- a/src/astro_databricks/operators/common.py +++ b/src/astro_databricks/operators/common.py @@ -12,7 +12,6 @@ from databricks_cli.runs.api import RunsApi from databricks_cli.sdk.api_client import ApiClient -from astro_databricks.constants import JOBS_API_VERSION from astro_databricks.operators.workflow import ( DatabricksMetaData, DatabricksWorkflowTaskGroup, @@ -21,6 +20,7 @@ DatabricksJobRepairSingleFailedLink, DatabricksJobRunLink, ) +from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION class DatabricksTaskOperator(BaseOperator): @@ -190,24 +190,24 @@ def monitor_databricks_job(self): api_client = self._get_api_client() runs_api = RunsApi(api_client) current_task = self._get_current_databricks_task(runs_api) - url = runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[ - "run_page_url" - ] + url = runs_api.get_run( + self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION + )["run_page_url"] self.log.info(f"Check the job run in Databricks: {url}") self._wait_for_pending_task(current_task, runs_api) self._wait_for_running_task(current_task, runs_api) self._wait_for_terminating_task(current_task, runs_api) final_state = runs_api.get_run( - current_task["run_id"], version=JOBS_API_VERSION + current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION )["state"] self._handle_final_state(final_state) def _get_current_databricks_task(self, runs_api): return { x["task_key"]: x - for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[ - "tasks" - ] + for x in runs_api.get_run( + self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION + )["tasks"] }[self._get_databricks_task_id(self.task_id)] def _handle_final_state(self, final_state): @@ -223,9 +223,9 @@ def _handle_final_state(self, final_state): ) def _get_lifestyle_state(self, current_task, runs_api): - return runs_api.get_run(current_task["run_id"], version=JOBS_API_VERSION)[ - "state" - ]["life_cycle_state"] + return runs_api.get_run( + current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION + )["state"]["life_cycle_state"] def _wait_on_state(self, current_task, runs_api, state): while self._get_lifestyle_state(current_task, runs_api) == state: diff --git a/src/astro_databricks/operators/notebook.py b/src/astro_databricks/operators/notebook.py index 87767bf..efb7341 100644 --- a/src/astro_databricks/operators/notebook.py +++ b/src/astro_databricks/operators/notebook.py @@ -13,7 +13,7 @@ from databricks_cli.runs.api import RunsApi from databricks_cli.sdk.api_client import ApiClient -from astro_databricks.constants import JOBS_API_VERSION +from astro_databricks import settings from astro_databricks.operators.workflow import ( DatabricksMetaData, DatabricksWorkflowTaskGroup, @@ -221,24 +221,24 @@ def monitor_databricks_job(self): api_client = self._get_api_client() runs_api = RunsApi(api_client) current_task = self._get_current_databricks_task(runs_api) - url = runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[ - "run_page_url" - ] + url = runs_api.get_run( + self.databricks_run_id, version=settings.DATABRICKS_JOBS_API_VERSION + )["run_page_url"] self.log.info(f"Check the job run in Databricks: {url}") self._wait_for_pending_task(current_task, runs_api) self._wait_for_running_task(current_task, runs_api) self._wait_for_terminating_task(current_task, runs_api) final_state = runs_api.get_run( - current_task["run_id"], version=JOBS_API_VERSION + current_task["run_id"], version=settings.DATABRICKS_JOBS_API_VERSION )["state"] self._handle_final_state(final_state) def _get_current_databricks_task(self, runs_api): return { x["task_key"]: x - for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[ - "tasks" - ] + for x in runs_api.get_run( + self.databricks_run_id, version=settings.DATABRICKS_JOBS_API_VERSION + )["tasks"] }[self._get_databricks_task_id(self.task_id)] def _handle_final_state(self, final_state): @@ -254,9 +254,9 @@ def _handle_final_state(self, final_state): ) def _get_lifestyle_state(self, current_task, runs_api): - return runs_api.get_run(current_task["run_id"], version=JOBS_API_VERSION)[ - "state" - ]["life_cycle_state"] + return runs_api.get_run( + current_task["run_id"], version=settings.DATABRICKS_JOBS_API_VERSION + )["state"]["life_cycle_state"] def _wait_on_state(self, current_task, runs_api, state): while self._get_lifestyle_state(current_task, runs_api) == state: @@ -300,7 +300,9 @@ def launch_notebook_job(self): else: raise ValueError("Must specify either existing_cluster_id or new_cluster") runs_api = RunsApi(api_client) - run = runs_api.submit_run(run_json, version=JOBS_API_VERSION) + run = runs_api.submit_run( + run_json, version=settings.DATABRICKS_JOBS_API_VERSION + ) self.databricks_run_id = run["run_id"] return run diff --git a/src/astro_databricks/operators/workflow.py b/src/astro_databricks/operators/workflow.py index ede9a14..c659783 100644 --- a/src/astro_databricks/operators/workflow.py +++ b/src/astro_databricks/operators/workflow.py @@ -26,11 +26,11 @@ from databricks_cli.sdk.api_client import ApiClient from mergedeep import merge -from astro_databricks.constants import JOBS_API_VERSION from astro_databricks.plugins.plugin import ( DatabricksJobRepairAllFailedLink, DatabricksJobRunLink, ) +from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION @define @@ -41,7 +41,7 @@ class DatabricksMetaData: def _get_job_by_name(job_name: str, jobs_api: JobsApi) -> dict | None: - jobs = jobs_api.list_jobs(version=JOBS_API_VERSION).get("jobs", []) + jobs = jobs_api.list_jobs(version=DATABRICKS_JOBS_API_VERSION).get("jobs", []) for job in jobs: if job.get("settings", {}).get("name") == job_name: return job @@ -177,14 +177,14 @@ def execute(self, context: Context) -> Any: jobs_api.reset_job( json={"job_id": job_id, "new_settings": current_job_spec}, - version=JOBS_API_VERSION, + version=DATABRICKS_JOBS_API_VERSION, ) else: self.log.info( "Creating new job with spec %s", json.dumps(current_job_spec, indent=4) ) job_id = jobs_api.create_job( - json=current_job_spec, version=JOBS_API_VERSION + json=current_job_spec, version=DATABRICKS_JOBS_API_VERSION )["job_id"] run_id = jobs_api.run_now( @@ -193,14 +193,16 @@ def execute(self, context: Context) -> Any: notebook_params=self.notebook_params, python_params=self.task_group.python_params, spark_submit_params=self.task_group.spark_submit_params, - version=JOBS_API_VERSION, + version=DATABRICKS_JOBS_API_VERSION, )["run_id"] self.databricks_run_id = run_id runs_api = RunsApi(api_client) - url = runs_api.get_run(run_id, version=JOBS_API_VERSION).get("run_page_url") + url = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION).get( + "run_page_url" + ) self.log.info(f"Check the job run in Databricks: {url}") - state = runs_api.get_run(run_id, version=JOBS_API_VERSION)["state"][ + state = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION)["state"][ "life_cycle_state" ] self.log.info(f"Job state: {state}") @@ -213,9 +215,9 @@ def execute(self, context: Context) -> Any: while state in ("PENDING", "BLOCKED"): self.log.info(f"Job {state}") time.sleep(5) - state = runs_api.get_run(run_id, version=JOBS_API_VERSION)["state"][ - "life_cycle_state" - ] + state = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION)[ + "state" + ]["life_cycle_state"] return { "databricks_conn_id": self.databricks_conn_id, diff --git a/src/astro_databricks/settings.py b/src/astro_databricks/settings.py new file mode 100644 index 0000000..0e93c75 --- /dev/null +++ b/src/astro_databricks/settings.py @@ -0,0 +1,3 @@ +import os + +DATABRICKS_JOBS_API_VERSION = os.getenv("DATABRICKS_JOBS_API_VERSION", "2.1") diff --git a/tests/databricks/test_notebook.py b/tests/databricks/test_notebook.py index 02f4f4a..d905afc 100644 --- a/tests/databricks/test_notebook.py +++ b/tests/databricks/test_notebook.py @@ -1,13 +1,14 @@ +import os from unittest import mock from unittest.mock import MagicMock import pytest from airflow.exceptions import AirflowException -from astro_databricks.constants import JOBS_API_VERSION from astro_databricks.operators.notebook import DatabricksNotebookOperator from astro_databricks.operators.workflow import ( DatabricksWorkflowTaskGroup, ) +from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION @pytest.fixture @@ -111,6 +112,7 @@ def test_databricks_notebook_operator_with_taskgroup( mock_monitor.assert_called_once() +@pytest.mark.parametrize("api_version", ["3.2", "2.1"]) @mock.patch( "astro_databricks.operators.notebook.DatabricksNotebookOperator.monitor_databricks_job" ) @@ -122,24 +124,36 @@ def test_databricks_notebook_operator_with_taskgroup( ) @mock.patch("astro_databricks.operators.notebook.RunsApi") def test_databricks_notebook_operator_without_taskgroup_new_cluster( - mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag + mock_runs_api, + mock_api_client, + mock_get_databricks_task_id, + mock_monitor, + dag, + api_version, ): mock_get_databricks_task_id.return_value = "1234" mock_runs_api.return_value = mock.MagicMock() - with dag: - DatabricksNotebookOperator( - task_id="notebook", - databricks_conn_id="foo", - notebook_path="/foo/bar", - source="WORKSPACE", - job_cluster_key="foo", - notebook_params={ - "foo": "bar", - }, - notebook_packages=[{"nb_index": {"package": "nb_package"}}], - new_cluster={"foo": "bar"}, - ) - dag.test() + with mock.patch.dict(os.environ, {"DATABRICKS_JOBS_API_VERSION": api_version}): + import importlib + + import astro_databricks + + importlib.reload(astro_databricks.settings) + + with dag: + DatabricksNotebookOperator( + task_id="notebook", + databricks_conn_id="foo", + notebook_path="/foo/bar", + source="WORKSPACE", + job_cluster_key="foo", + notebook_params={ + "foo": "bar", + }, + notebook_packages=[{"nb_index": {"package": "nb_package"}}], + new_cluster={"foo": "bar"}, + ) + dag.test() mock_runs_api.return_value.submit_run.assert_called_once_with( { "run_name": "1234", @@ -153,7 +167,7 @@ def test_databricks_notebook_operator_without_taskgroup_new_cluster( "timeout_seconds": 0, "email_notifications": {}, }, - version=JOBS_API_VERSION, + version=api_version, ) mock_monitor.assert_called_once() @@ -200,7 +214,7 @@ def test_databricks_notebook_operator_without_taskgroup_existing_cluster( "timeout_seconds": 0, "email_notifications": {}, }, - version=JOBS_API_VERSION, + version=DATABRICKS_JOBS_API_VERSION, ) mock_monitor.assert_called_once() @@ -297,7 +311,7 @@ def test_wait_for_pending_task(mock_sleep, mock_runs_api, databricks_notebook_op {"state": {"life_cycle_state": "RUNNING"}}, ] databricks_notebook_operator._wait_for_pending_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123", version=JOBS_API_VERSION) + mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION) assert mock_runs_api.get_run.call_count == 2 mock_runs_api.reset_mock() @@ -314,7 +328,7 @@ def test_wait_for_terminating_task( {"state": {"life_cycle_state": "TERMINATED"}}, ] databricks_notebook_operator._wait_for_terminating_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123", version=JOBS_API_VERSION) + mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION) assert mock_runs_api.get_run.call_count == 3 mock_runs_api.reset_mock() @@ -329,7 +343,7 @@ def test_wait_for_running_task(mock_sleep, mock_runs_api, databricks_notebook_op {"state": {"life_cycle_state": "TERMINATED"}}, ] databricks_notebook_operator._wait_for_running_task(current_task, mock_runs_api) - mock_runs_api.get_run.assert_called_with("123", version=JOBS_API_VERSION) + mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION) assert mock_runs_api.get_run.call_count == 3 mock_runs_api.reset_mock() @@ -383,7 +397,8 @@ def test_monitor_databricks_job_success( databricks_notebook_operator.databricks_run_id = "1" databricks_notebook_operator.monitor_databricks_job() mock_runs_api.return_value.get_run.assert_called_with( - databricks_notebook_operator.databricks_run_id, version=JOBS_API_VERSION + databricks_notebook_operator.databricks_run_id, + version=DATABRICKS_JOBS_API_VERSION, ) assert ( "Check the job run in Databricks: https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1" diff --git a/tests/databricks/test_workflow.py b/tests/databricks/test_workflow.py index cca566b..053fdfd 100644 --- a/tests/databricks/test_workflow.py +++ b/tests/databricks/test_workflow.py @@ -6,9 +6,9 @@ import pytest from airflow.exceptions import AirflowException from airflow.utils.task_group import TaskGroup -from astro_databricks.constants import JOBS_API_VERSION from astro_databricks.operators.notebook import DatabricksNotebookOperator from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup +from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION expected_workflow_json = { "name": "unit_test_dag.test_workflow", @@ -143,7 +143,7 @@ def test_create_workflow_from_notebooks_with_create( task_group.children["test_workflow.launch"].execute(context={}) mock_jobs_api.return_value.create_job.assert_called_once_with( json=expected_workflow_json, - version=JOBS_API_VERSION, + version=DATABRICKS_JOBS_API_VERSION, ) mock_jobs_api.return_value.run_now.assert_called_once_with( job_id=1, @@ -151,7 +151,7 @@ def test_create_workflow_from_notebooks_with_create( notebook_params={"notebook_path": "/foo/bar"}, python_params=[], spark_submit_params=[], - version=JOBS_API_VERSION, + version=DATABRICKS_JOBS_API_VERSION, )