Skip to content
This repository was archived by the owner on Sep 4, 2024. It is now read-only.

Commit

Permalink
Address follow up comments on PR #66
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajkoti committed Mar 20, 2024
1 parent 0fce468 commit edef7b3
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 46 deletions.
3 changes: 0 additions & 3 deletions src/astro_databricks/constants.py

This file was deleted.

22 changes: 11 additions & 11 deletions src/astro_databricks/operators/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from databricks_cli.runs.api import RunsApi
from databricks_cli.sdk.api_client import ApiClient

from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks.operators.workflow import (
DatabricksMetaData,
DatabricksWorkflowTaskGroup,
Expand All @@ -21,6 +20,7 @@
DatabricksJobRepairSingleFailedLink,
DatabricksJobRunLink,
)
from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION


class DatabricksTaskOperator(BaseOperator):
Expand Down Expand Up @@ -190,24 +190,24 @@ def monitor_databricks_job(self):
api_client = self._get_api_client()
runs_api = RunsApi(api_client)
current_task = self._get_current_databricks_task(runs_api)
url = runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
"run_page_url"
]
url = runs_api.get_run(
self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION
)["run_page_url"]
self.log.info(f"Check the job run in Databricks: {url}")
self._wait_for_pending_task(current_task, runs_api)
self._wait_for_running_task(current_task, runs_api)
self._wait_for_terminating_task(current_task, runs_api)
final_state = runs_api.get_run(
current_task["run_id"], version=JOBS_API_VERSION
current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION
)["state"]
self._handle_final_state(final_state)

def _get_current_databricks_task(self, runs_api):
return {
x["task_key"]: x
for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
"tasks"
]
for x in runs_api.get_run(
self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION
)["tasks"]
}[self._get_databricks_task_id(self.task_id)]

def _handle_final_state(self, final_state):
Expand All @@ -223,9 +223,9 @@ def _handle_final_state(self, final_state):
)

def _get_lifestyle_state(self, current_task, runs_api):
return runs_api.get_run(current_task["run_id"], version=JOBS_API_VERSION)[
"state"
]["life_cycle_state"]
return runs_api.get_run(
current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION
)["state"]["life_cycle_state"]

def _wait_on_state(self, current_task, runs_api, state):
while self._get_lifestyle_state(current_task, runs_api) == state:
Expand Down
24 changes: 12 additions & 12 deletions src/astro_databricks/operators/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from databricks_cli.runs.api import RunsApi
from databricks_cli.sdk.api_client import ApiClient

from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks.operators.workflow import (
DatabricksMetaData,
DatabricksWorkflowTaskGroup,
Expand All @@ -22,6 +21,7 @@
DatabricksJobRepairSingleFailedLink,
DatabricksJobRunLink,
)
from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION


class DatabricksNotebookOperator(BaseOperator):
Expand Down Expand Up @@ -221,24 +221,24 @@ def monitor_databricks_job(self):
api_client = self._get_api_client()
runs_api = RunsApi(api_client)
current_task = self._get_current_databricks_task(runs_api)
url = runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
"run_page_url"
]
url = runs_api.get_run(
self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION
)["run_page_url"]
self.log.info(f"Check the job run in Databricks: {url}")
self._wait_for_pending_task(current_task, runs_api)
self._wait_for_running_task(current_task, runs_api)
self._wait_for_terminating_task(current_task, runs_api)
final_state = runs_api.get_run(
current_task["run_id"], version=JOBS_API_VERSION
current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION
)["state"]
self._handle_final_state(final_state)

def _get_current_databricks_task(self, runs_api):
return {
x["task_key"]: x
for x in runs_api.get_run(self.databricks_run_id, version=JOBS_API_VERSION)[
"tasks"
]
for x in runs_api.get_run(
self.databricks_run_id, version=DATABRICKS_JOBS_API_VERSION
)["tasks"]
}[self._get_databricks_task_id(self.task_id)]

def _handle_final_state(self, final_state):
Expand All @@ -254,9 +254,9 @@ def _handle_final_state(self, final_state):
)

def _get_lifestyle_state(self, current_task, runs_api):
return runs_api.get_run(current_task["run_id"], version=JOBS_API_VERSION)[
"state"
]["life_cycle_state"]
return runs_api.get_run(
current_task["run_id"], version=DATABRICKS_JOBS_API_VERSION
)["state"]["life_cycle_state"]

def _wait_on_state(self, current_task, runs_api, state):
while self._get_lifestyle_state(current_task, runs_api) == state:
Expand Down Expand Up @@ -300,7 +300,7 @@ def launch_notebook_job(self):
else:
raise ValueError("Must specify either existing_cluster_id or new_cluster")
runs_api = RunsApi(api_client)
run = runs_api.submit_run(run_json, version=JOBS_API_VERSION)
run = runs_api.submit_run(run_json, version=DATABRICKS_JOBS_API_VERSION)
self.databricks_run_id = run["run_id"]
return run

Expand Down
22 changes: 12 additions & 10 deletions src/astro_databricks/operators/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
from databricks_cli.sdk.api_client import ApiClient
from mergedeep import merge

from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks.plugins.plugin import (
DatabricksJobRepairAllFailedLink,
DatabricksJobRunLink,
)
from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION


@define
Expand All @@ -41,7 +41,7 @@ class DatabricksMetaData:


def _get_job_by_name(job_name: str, jobs_api: JobsApi) -> dict | None:
jobs = jobs_api.list_jobs(version=JOBS_API_VERSION).get("jobs", [])
jobs = jobs_api.list_jobs(version=DATABRICKS_JOBS_API_VERSION).get("jobs", [])
for job in jobs:
if job.get("settings", {}).get("name") == job_name:
return job
Expand Down Expand Up @@ -177,14 +177,14 @@ def execute(self, context: Context) -> Any:

jobs_api.reset_job(
json={"job_id": job_id, "new_settings": current_job_spec},
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)
else:
self.log.info(
"Creating new job with spec %s", json.dumps(current_job_spec, indent=4)
)
job_id = jobs_api.create_job(
json=current_job_spec, version=JOBS_API_VERSION
json=current_job_spec, version=DATABRICKS_JOBS_API_VERSION
)["job_id"]

run_id = jobs_api.run_now(
Expand All @@ -193,14 +193,16 @@ def execute(self, context: Context) -> Any:
notebook_params=self.notebook_params,
python_params=self.task_group.python_params,
spark_submit_params=self.task_group.spark_submit_params,
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)["run_id"]
self.databricks_run_id = run_id

runs_api = RunsApi(api_client)
url = runs_api.get_run(run_id, version=JOBS_API_VERSION).get("run_page_url")
url = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION).get(
"run_page_url"
)
self.log.info(f"Check the job run in Databricks: {url}")
state = runs_api.get_run(run_id, version=JOBS_API_VERSION)["state"][
state = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION)["state"][
"life_cycle_state"
]
self.log.info(f"Job state: {state}")
Expand All @@ -213,9 +215,9 @@ def execute(self, context: Context) -> Any:
while state in ("PENDING", "BLOCKED"):
self.log.info(f"Job {state}")
time.sleep(5)
state = runs_api.get_run(run_id, version=JOBS_API_VERSION)["state"][
"life_cycle_state"
]
state = runs_api.get_run(run_id, version=DATABRICKS_JOBS_API_VERSION)[
"state"
]["life_cycle_state"]

return {
"databricks_conn_id": self.databricks_conn_id,
Expand Down
3 changes: 3 additions & 0 deletions src/astro_databricks/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

DATABRICKS_JOBS_API_VERSION = os.getenv("DATABRICKS_JOBS_API_VERSION", "2.1")
15 changes: 8 additions & 7 deletions tests/databricks/test_notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import pytest
from airflow.exceptions import AirflowException
from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks.operators.notebook import DatabricksNotebookOperator
from astro_databricks.operators.workflow import (
DatabricksWorkflowTaskGroup,
)
from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION


@pytest.fixture
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_databricks_notebook_operator_without_taskgroup_new_cluster(
"timeout_seconds": 0,
"email_notifications": {},
},
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)
mock_monitor.assert_called_once()

Expand Down Expand Up @@ -200,7 +200,7 @@ def test_databricks_notebook_operator_without_taskgroup_existing_cluster(
"timeout_seconds": 0,
"email_notifications": {},
},
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)
mock_monitor.assert_called_once()

Expand Down Expand Up @@ -297,7 +297,7 @@ def test_wait_for_pending_task(mock_sleep, mock_runs_api, databricks_notebook_op
{"state": {"life_cycle_state": "RUNNING"}},
]
databricks_notebook_operator._wait_for_pending_task(current_task, mock_runs_api)
mock_runs_api.get_run.assert_called_with("123", version=JOBS_API_VERSION)
mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION)
assert mock_runs_api.get_run.call_count == 2
mock_runs_api.reset_mock()

Expand All @@ -314,7 +314,7 @@ def test_wait_for_terminating_task(
{"state": {"life_cycle_state": "TERMINATED"}},
]
databricks_notebook_operator._wait_for_terminating_task(current_task, mock_runs_api)
mock_runs_api.get_run.assert_called_with("123", version=JOBS_API_VERSION)
mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION)
assert mock_runs_api.get_run.call_count == 3
mock_runs_api.reset_mock()

Expand All @@ -329,7 +329,7 @@ def test_wait_for_running_task(mock_sleep, mock_runs_api, databricks_notebook_op
{"state": {"life_cycle_state": "TERMINATED"}},
]
databricks_notebook_operator._wait_for_running_task(current_task, mock_runs_api)
mock_runs_api.get_run.assert_called_with("123", version=JOBS_API_VERSION)
mock_runs_api.get_run.assert_called_with("123", version=DATABRICKS_JOBS_API_VERSION)
assert mock_runs_api.get_run.call_count == 3
mock_runs_api.reset_mock()

Expand Down Expand Up @@ -383,7 +383,8 @@ def test_monitor_databricks_job_success(
databricks_notebook_operator.databricks_run_id = "1"
databricks_notebook_operator.monitor_databricks_job()
mock_runs_api.return_value.get_run.assert_called_with(
databricks_notebook_operator.databricks_run_id, version=JOBS_API_VERSION
databricks_notebook_operator.databricks_run_id,
version=DATABRICKS_JOBS_API_VERSION,
)
assert (
"Check the job run in Databricks: https://databricks-instance-xyz.cloud.databricks.com/#job/1234/run/1"
Expand Down
6 changes: 3 additions & 3 deletions tests/databricks/test_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import pytest
from airflow.exceptions import AirflowException
from airflow.utils.task_group import TaskGroup
from astro_databricks.constants import JOBS_API_VERSION
from astro_databricks.operators.notebook import DatabricksNotebookOperator
from astro_databricks.operators.workflow import DatabricksWorkflowTaskGroup
from astro_databricks.settings import DATABRICKS_JOBS_API_VERSION

expected_workflow_json = {
"name": "unit_test_dag.test_workflow",
Expand Down Expand Up @@ -143,15 +143,15 @@ def test_create_workflow_from_notebooks_with_create(
task_group.children["test_workflow.launch"].execute(context={})
mock_jobs_api.return_value.create_job.assert_called_once_with(
json=expected_workflow_json,
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)
mock_jobs_api.return_value.run_now.assert_called_once_with(
job_id=1,
jar_params=[],
notebook_params={"notebook_path": "/foo/bar"},
python_params=[],
spark_submit_params=[],
version=JOBS_API_VERSION,
version=DATABRICKS_JOBS_API_VERSION,
)


Expand Down

0 comments on commit edef7b3

Please sign in to comment.