Skip to content
This repository was archived by the owner on Sep 4, 2024. It is now read-only.

Address follow up comments on PR #66 #68

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/astro_databricks/constants.py

This file was deleted.

22 changes: 11 additions & 11 deletions src/astro_databricks/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from databricks_cli.runs.api import RunsApi
from databricks_cli.sdk.api_client import ApiClient

from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks.operators.workflow import (
DatabricksMetaData,
DatabricksWorkflowTaskGroup,
Expand All @@ -21,6 +20,7 @@
DatabricksJobRepairSingleFailedLink,
DatabricksJobRunLink,
)
from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION


class DatabricksTaskOperator(BaseOperator):
Expand Down Expand Up @@ -190,24 +190,24 @@ def monitor_databricks_job(self):
api_client = self._get_api_client()
runs_api = RunsApi(api_client)
current_task = self._get_current_databricks_task(runs_api)
url = runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
"run_page_url"
]
url = runs_api.get_run(
self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION
)["run_page_url"]
self.log.info(f"Check the job run in Databricks: {url}")
self._wait_for_pending_task(current_task, runs_api)
self._wait_for_running_task(current_task, runs_api)
self._wait_for_terminating_task(current_task, runs_api)
final_state = runs_api.get_run(
current_task["run_id"], version=JOBS_API_VERSION
current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION
)["state"]
self._handle_final_state(final_state)

def _get_current_databricks_task(self, runs_api):
return {
x["task_key"]: x
for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
"tasks"
]
for x in runs_api.get_run(
self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION
)["tasks"]
}[self._get_databricks_task_id(self.task_id)]

def _handle_final_state(self, final_state):
Expand All @@ -223,9 +223,9 @@ def _handle_final_state(self, final_state):
)

def _get_lifestyle_state(self, current_task, runs_api):
return runs_api.get_run(current_task["run_id"], version=JOBS_API_VERSION)[
"state"
]["life_cycle_state"]
return runs_api.get_run(
current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION
)["state"]["life_cycle_state"]

def _wait_on_state(self, current_task, runs_api, state):
while self._get_lifestyle_state(current_task, runs_api) == state:
Expand Down
26 changes: 14 additions & 12 deletions src/astro_databricks/operators/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from databricks_cli.runs.api import RunsApi
from databricks_cli.sdk.api_client import ApiClient

from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks import settings
from astro_databricks.operators.workflow import (
DatabricksMetaData,
DatabricksWorkflowTaskGroup,
Expand Down Expand Up @@ -221,24 +221,24 @@ def monitor_databricks_job(self):
api_client = self._get_api_client()
runs_api = RunsApi(api_client)
current_task = self._get_current_databricks_task(runs_api)
url = runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
"run_page_url"
]
url = runs_api.get_run(
self.databricks_run_id, version=settings.DATABRICKS_JOBS_API_VERSION
)["run_page_url"]
self.log.info(f"Check the job run in Databricks: {url}")
self._wait_for_pending_task(current_task, runs_api)
self._wait_for_running_task(current_task, runs_api)
self._wait_for_terminating_task(current_task, runs_api)
final_state = runs_api.get_run(
current_task["run_id"], version=JOBS_API_VERSION
current_task["run_id"], version=settings.DATABRICKS_JOBS_API_VERSION
)["state"]
self._handle_final_state(final_state)

def _get_current_databricks_task(self, runs_api):
return {
x["task_key"]: x
for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
"tasks"
]
for x in runs_api.get_run(
self.databricks_run_id, version=settings.DATABRICKS_JOBS_API_VERSION
)["tasks"]
}[self._get_databricks_task_id(self.task_id)]

def _handle_final_state(self, final_state):
Expand All @@ -254,9 +254,9 @@ def _handle_final_state(self, final_state):
)

def _get_lifestyle_state(self, current_task, runs_api):
return runs_api.get_run(current_task["run_id"], version=JOBS_API_VERSION)[
"state"
]["life_cycle_state"]
return runs_api.get_run(
current_task["run_id"], version=settings.DATABRICKS_JOBS_API_VERSION
)["state"]["life_cycle_state"]

def _wait_on_state(self, current_task, runs_api, state):
while self._get_lifestyle_state(current_task, runs_api) == state:
Expand Down Expand Up @@ -300,7 +300,9 @@ def launch_notebook_job(self):
else:
raise ValueError("Must specify either existing_cluster_id or new_cluster")
runs_api = RunsApi(api_client)
run = runs_api.submit_run(run_json, version=JOBS_API_VERSION)
run = runs_api.submit_run(
run_json, version=settings.DATABRICKS_JOBS_API_VERSION
)
self.databricks_run_id = run["run_id"]
return run

Expand Down
22 changes: 12 additions & 10 deletions src/astro_databricks/operators/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from databricks_cli.sdk.api_client import ApiClient
from mergedeep import merge

from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks.plugins.plugin import (
DatabricksJobRepairAllFailedLink,
DatabricksJobRunLink,
)
from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION


@define
Expand All @@ -41,7 +41,7 @@ class DatabricksMetaData:


def _get_job_by_name(job_name: str, jobs_api: JobsApi) -> dict | None:
jobs = jobs_api.list_jobs(version=JOBS_API_VERSION).get("jobs", [])
jobs = jobs_api.list_jobs(version=DATABRICKS_JOBS_API_VERSION).get("jobs", [])
for job in jobs:
if job.get("settings", {}).get("name") == job_name:
return job
Expand Down Expand Up @@ -177,14 +177,14 @@ def execute(self, context: Context) -> Any:

jobs_api.reset_job(
json={"job_id": job_id, "new_settings": current_job_spec},
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)
else:
self.log.info(
"Creating new job with spec %s", json.dumps(current_job_spec, indent=4)
)
job_id = jobs_api.create_job(
json=current_job_spec, version=JOBS_API_VERSION
json=current_job_spec, version=DATABRICKS_JOBS_API_VERSION
)["job_id"]

run_id = jobs_api.run_now(
Expand All @@ -193,14 +193,16 @@ def execute(self, context: Context) -> Any:
notebook_params=self.notebook_params,
python_params=self.task_group.python_params,
spark_submit_params=self.task_group.spark_submit_params,
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)["run_id"]
self.databricks_run_id = run_id

runs_api = RunsApi(api_client)
url = runs_api.get_run(run_id, version=JOBS_API_VERSION).get("run_page_url")
url = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION).get(
"run_page_url"
)
self.log.info(f"Check the job run in Databricks: {url}")
state = runs_api.get_run(run_id, version=JOBS_API_VERSION)["state"][
state = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION)["state"][
"life_cycle_state"
]
self.log.info(f"Job state: {state}")
Expand All @@ -213,9 +215,9 @@ def execute(self, context: Context) -> Any:
while state in ("PENDING", "BLOCKED"):
self.log.info(f"Job {state}")
time.sleep(5)
state = runs_api.get_run(run_id, version=JOBS_API_VERSION)["state"][
"life_cycle_state"
]
state = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION)[
"state"
]["life_cycle_state"]

return {
"databricks_conn_id": self.databricks_conn_id,
Expand Down
3 changes: 3 additions & 0 deletions src/astro_databricks/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

DATABRICKS_JOBS_API_VERSION = os.getenv("DATABRICKS_JOBS_API_VERSION", "2.1")
59 changes: 37 additions & 22 deletions tests/databricks/test_notebook.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import os
from unittest import mock
from unittest.mock import MagicMock

import pytest
from airflow.exceptions import AirflowException
from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks.operators.notebook import DatabricksNotebookOperator
from astro_databricks.operators.workflow import (
DatabricksWorkflowTaskGroup,
)
from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION


@pytest.fixture
Expand Down Expand Up @@ -111,6 +112,7 @@ def test_databricks_notebook_operator_with_taskgroup(
mock_monitor.assert_called_once()


@pytest.mark.parametrize("api_version", ["3.2", "2.1"])
@mock.patch(
"astro_databricks.operators.notebook.DatabricksNotebookOperator.monitor_databricks_job"
)
Expand All @@ -122,24 +124,36 @@ def test_databricks_notebook_operator_with_taskgroup(
)
@mock.patch("astro_databricks.operators.notebook.RunsApi")
def test_databricks_notebook_operator_without_taskgroup_new_cluster(
mock_runs_api, mock_api_client, mock_get_databricks_task_id, mock_monitor, dag
mock_runs_api,
mock_api_client,
mock_get_databricks_task_id,
mock_monitor,
dag,
api_version,
):
mock_get_databricks_task_id.return_value = "1234"
mock_runs_api.return_value = mock.MagicMock()
with dag:
DatabricksNotebookOperator(
task_id="notebook",
databricks_conn_id="foo",
notebook_path="/foo/bar",
source="WORKSPACE",
job_cluster_key="foo",
notebook_params={
"foo": "bar",
},
notebook_packages=[{"nb_index": {"package": "nb_package"}}],
new_cluster={"foo": "bar"},
)
dag.test()
with mock.patch.dict(os.environ, {"DATABRICKS_JOBS_API_VERSION": api_version}):
import importlib

import astro_databricks

importlib.reload(astro_databricks.settings)

with dag:
DatabricksNotebookOperator(
task_id="notebook",
databricks_conn_id="foo",
notebook_path="/foo/bar",
source="WORKSPACE",
job_cluster_key="foo",
notebook_params={
"foo": "bar",
},
notebook_packages=[{"nb_index": {"package": "nb_package"}}],
new_cluster={"foo": "bar"},
)
dag.test()
mock_runs_api.return_value.submit_run.assert_called_once_with(
{
"run_name": "1234",
Expand All @@ -153,7 +167,7 @@ def test_databricks_notebook_operator_without_taskgroup_new_cluster(
"timeout_seconds": 0,
"email_notifications": {},
},
version=JOBS_API_VERSION,
version=api_version,
)
mock_monitor.assert_called_once()

Expand Down Expand Up @@ -200,7 +214,7 @@ def test_databricks_notebook_operator_without_taskgroup_existing_cluster(
"timeout_seconds": 0,
"email_notifications": {},
},
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)
mock_monitor.assert_called_once()

Expand Down Expand Up @@ -297,7 +311,7 @@ def test_wait_for_pending_task(mock_sleep, mock_runs_api, databricks_notebook_op
{"state": {"life_cycle_state": "RUNNING"}},
]
databricks_notebook_operator._wait_for_pending_task(current_task, mock_runs_api)
mock_runs_api.get_run.assert_called_with("123", version=JOBS_API_VERSION)
mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION)
assert mock_runs_api.get_run.call_count == 2
mock_runs_api.reset_mock()

Expand All @@ -314,7 +328,7 @@ def test_wait_for_terminating_task(
{"state": {"life_cycle_state": "TERMINATED"}},
]
databricks_notebook_operator._wait_for_terminating_task(current_task, mock_runs_api)
mock_runs_api.get_run.assert_called_with("123", version=JOBS_API_VERSION)
mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION)
assert mock_runs_api.get_run.call_count == 3
mock_runs_api.reset_mock()

Expand All @@ -329,7 +343,7 @@ def test_wait_for_running_task(mock_sleep, mock_runs_api, databricks_notebook_op
{"state": {"life_cycle_state": "TERMINATED"}},
]
databricks_notebook_operator._wait_for_running_task(current_task, mock_runs_api)
mock_runs_api.get_run.assert_called_with("123", version=JOBS_API_VERSION)
mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION)
assert mock_runs_api.get_run.call_count == 3
mock_runs_api.reset_mock()

Expand Down Expand Up @@ -383,7 +397,8 @@ def test_monitor_databricks_job_success(
databricks_notebook_operator.databricks_run_id = "1"
databricks_notebook_operator.monitor_databricks_job()
mock_runs_api.return_value.get_run.assert_called_with(
databricks_notebook_operator.databricks_run_id, version=JOBS_API_VERSION
databricks_notebook_operator.databricks_run_id,
version=DATABRICKS_JOBS_API_VERSION,
)
assert (
"Check the job run in Databricks: https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1"
Expand Down
6 changes: 3 additions & 3 deletions tests/databricks/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import pytest
from airflow.exceptions import AirflowException
from airflow.utils.task_group import TaskGroup
from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks.operators.notebook import DatabricksNotebookOperator
from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup
from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION

expected_workflow_json = {
"name": "unit_test_dag.test_workflow",
Expand Down Expand Up @@ -143,15 +143,15 @@ def test_create_workflow_from_notebooks_with_create(
task_group.children["test_workflow.launch"].execute(context={})
mock_jobs_api.return_value.create_job.assert_called_once_with(
json=expected_workflow_json,
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)
mock_jobs_api.return_value.run_now.assert_called_once_with(
job_id=1,
jar_params=[],
notebook_params={"notebook_path": "/foo/bar"},
python_params=[],
spark_submit_params=[],
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)


Expand Down
Loading