Skip to content

Commit

Permalink
done?
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Dec 5, 2024
1 parent 87fb170 commit cc76dd4
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 276 deletions.
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "1.10.0a1"
version = "1.10.0-alpha.1"
272 changes: 96 additions & 176 deletions dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
@@ -1,77 +1,50 @@
import re
import time
from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from datetime import timedelta
from io import BytesIO
from typing import Any, Optional, Protocol, TypeVar

from dbt_common.exceptions import DbtRuntimeError
from requests import Response, Session
from requests.adapters import HTTPAdapter
from typing_extensions import Self
from urllib3.util.retry import Retry

from databricks.sdk import WorkspaceClient
from databricks.sdk.service.compute import CommandStatus, ContextStatus, Language
from databricks.sdk.service.iam import User
from databricks.sdk.service.jobs import (
Continuous,
CronSchedule,
Format,
GitSource,
JobAccessControlRequest,
JobCluster,
JobDeployment,
JobEditMode,
JobEmailNotifications,
JobEnvironment,
JobNotificationSettings,
JobParameterDefinition,
JobRunAs,
JobSettings,
JobsHealthRules,
QueueSettings,
Run,
RunResultState,
SubmitTask,
Task,
TerminationTypeType,
TriggerSettings,
WebhookNotifications,
)
from databricks.sdk.service.workspace import ImportFormat
from databricks.sdk.service.workspace import Language as WorkspaceLanguage
from dbt.adapters.databricks import utils
from dbt.adapters.databricks.__version__ import version
from dbt.adapters.databricks.credentials import (
BearerAuth,
DatabricksCredentialManager,
DatabricksCredentials,
)
from dbt.adapters.databricks.logging import logger

DEFAULT_POLLING_INTERVAL = 10
SUBMISSION_LANGUAGE = "python"
USER_AGENT = f"dbt-databricks/{version}"


class PrefixSession:
def __init__(self, session: Session, host: str, api: str):
self.prefix = f"https://{host}{api}"
self.session = session

def get(
self, suffix: str = "", json: Optional[Any] = None, params: Optional[dict[str, Any]] = None
) -> Response:
return self.session.get(f"{self.prefix}{suffix}", json=json, params=params)

def post(
self, suffix: str = "", json: Optional[Any] = None, params: Optional[dict[str, Any]] = None
) -> Response:
return self.session.post(f"{self.prefix}{suffix}", json=json, params=params)

def put(
self, suffix: str = "", json: Optional[Any] = None, params: Optional[dict[str, Any]] = None
) -> Response:
return self.session.put(f"{self.prefix}{suffix}", json=json, params=params)


class DatabricksApi(ABC):
def __init__(self, session: Session, host: str, api: str):
self.session = PrefixSession(session, host, api)


class CommandContextApi:
def __init__(self, wc: WorkspaceClient):
Expand Down Expand Up @@ -134,7 +107,7 @@ def get_folder(self, catalog: str, schema: str) -> str:
return folder


class WorkspaceApi(DatabricksApi):
class WorkspaceApi:
def __init__(self, wc: WorkspaceClient, folder_api: FolderApi):
self.wc = wc
self.folder_api = folder_api
Expand All @@ -155,42 +128,6 @@ def upload_notebook(self, path: str, compiled_code: str) -> None:
)


class PollableApi(DatabricksApi, ABC):
def __init__(self, session: Session, host: str, api: str, polling_interval: int, timeout: int):
super().__init__(session, host, api)
self.timeout = timeout
self.polling_interval = polling_interval

def _poll_api(
self,
url: str,
params: dict,
get_state_func: Callable[[Response], str],
terminal_states: set[str],
expected_end_state: str,
unexpected_end_state_func: Callable[[Response], None],
) -> Response:
state = None
start = time.time()
exceeded_timeout = False
while state not in terminal_states:
if time.time() - start > self.timeout:
exceeded_timeout = True
break
# should we do exponential backoff?
time.sleep(self.polling_interval)
response = self.session.get(url, params=params)
if response.status_code != 200:
raise DbtRuntimeError(f"Error polling for completion.\n {response.content!r}")
state = get_state_func(response)
if exceeded_timeout:
raise DbtRuntimeError("Python model run timed out")
if state != expected_end_state:
unexpected_end_state_func(response)

return response


@dataclass(frozen=True, eq=True, unsafe_hash=True)
class CommandExecution(object):
command_id: str
Expand Down Expand Up @@ -269,37 +206,13 @@ def __init__(self, wc: WorkspaceClient, timeout: int):
self.wc = wc
self.timeout = timedelta(seconds=timeout)

def convert_to_sdk_types(self, job_settings: dict[str, Any]) -> dict[str, Any]:
return {
"access_control_list": convert_sdk_list(
JobAccessControlRequest, job_settings.get("access_control_list")
),
"budget_policy_id": job_settings.get("budget_policy_id"),
"email_notifications": convert_sdk_element(
JobEmailNotifications, job_settings.get("email_notifications")
),
"environments": convert_sdk_list(JobEnvironment, job_settings.get("environments")),
"git_source": convert_sdk_element(GitSource, job_settings),
"health": convert_sdk_element(JobsHealthRules, job_settings.get("health")),
"idempotency_token": job_settings.get("idempotency_token"),
"notification_settings": convert_sdk_element(
JobNotificationSettings, job_settings.get("notification_settings")
),
"queue": convert_sdk_element(QueueSettings, job_settings.get("queue")),
"run_as": convert_sdk_element(JobRunAs, job_settings.get("run_as")),
"timeout_seconds": self.timeout.total_seconds(),
"webhook_notifications": convert_sdk_element(
WebhookNotifications, job_settings.get("webhook_notifications")
),
}

def submit(
self, run_name: str, job_spec: dict[str, Any], **additional_job_settings: dict[str, Any]
) -> int:
submit_response = self.wc.jobs.submit_and_wait(
run_name=run_name,
tasks=[SubmitTask.from_dict(job_spec)],
**self.convert_to_sdk_types(additional_job_settings),
**self._convert_to_sdk_types(additional_job_settings),
)

logger.debug(f"Job submission response={submit_response}")
Expand All @@ -321,6 +234,30 @@ def poll_for_completion(self, run_id: int) -> None:
logger.debug(f"Job run {run_id} failed.\n {run}")
self._get_exception(run, run_id)

def _convert_to_sdk_types(self, job_settings: dict[str, Any]) -> dict[str, Any]:
return {
"access_control_list": convert_sdk_list(
JobAccessControlRequest, job_settings.get("access_control_list")
),
"budget_policy_id": job_settings.get("budget_policy_id"),
"email_notifications": convert_sdk_element(
JobEmailNotifications, job_settings.get("email_notifications")
),
"environments": convert_sdk_list(JobEnvironment, job_settings.get("environments")),
"git_source": convert_sdk_element(GitSource, job_settings.get("git_source")),
"health": convert_sdk_element(JobsHealthRules, job_settings.get("health")),
"idempotency_token": job_settings.get("idempotency_token"),
"notification_settings": convert_sdk_element(
JobNotificationSettings, job_settings.get("notification_settings")
),
"queue": convert_sdk_element(QueueSettings, job_settings.get("queue")),
"run_as": convert_sdk_element(JobRunAs, job_settings.get("run_as")),
"timeout_seconds": self.timeout.total_seconds(),
"webhook_notifications": convert_sdk_element(
WebhookNotifications, job_settings.get("webhook_notifications")
),
}

def _get_exception(self, run: Run, run_id: int) -> None:
try:
run_id = utils.if_some(run.tasks, lambda x: x[0].run_id) or run_id
Expand Down Expand Up @@ -362,70 +299,80 @@ def set(self, job_id: str, access_control_list: list[dict[str, Any]]) -> None:
logger.debug(f"Workflow permissions update response={response}")

def get(self, job_id: str) -> dict[str, Any]:
response = self.session.get(f"/{job_id}")

if response.status_code != 200:
raise DbtRuntimeError(
f"Error fetching Databricks workflow permissions.\n {response.content!r}"
)
response = self.wc.jobs.get_permissions(job_id)

return response.json()
return response.as_dict()


class WorkflowJobApi(DatabricksApi):
def __init__(self, session: Session, host: str):
super().__init__(session, host, "/api/2.1/jobs")
class WorkflowJobApi:
def __init__(self, wc: WorkspaceClient, timeout: int):
self.wc = wc
self.timeout = timedelta(seconds=timeout)

def search_by_name(self, job_name: str) -> list[dict[str, Any]]:
response = self.session.get("/list", json={"name": job_name})

if response.status_code != 200:
raise DbtRuntimeError(f"Error fetching job by name.\n {response.content!r}")
return [j.as_dict() for j in self.wc.jobs.list(name=job_name)]

return response.json().get("jobs", [])

def create(self, job_spec: dict[str, Any]) -> str:
def create(self, job_spec: dict[str, Any]) -> int:
"""
:return: the job_id
"""
response = self.session.post("/create", json=job_spec)
response = self.wc.jobs.create(**self._convert_to_sdk_types(job_spec))

if response.status_code != 200:
raise DbtRuntimeError(f"Error creating Workflow.\n {response.content!r}")
job_id = response.job_id

job_id = response.json()["job_id"]
if not job_id:
raise DbtRuntimeError(f"Error creating Workflow.\n {response}")
logger.info(f"New workflow created with job id {job_id}")
return job_id

def update_job_settings(self, job_id: str, job_spec: dict[str, Any]) -> None:
request_body = {
"job_id": job_id,
"new_settings": job_spec,
}
logger.debug(f"Job settings: {request_body}")
response = self.session.post("/reset", json=request_body)

if response.status_code != 200:
raise DbtRuntimeError(f"Error updating Workflow.\n {response.content!r}")
return job_id

logger.debug(f"Workflow update response={response.json()}")
def update_job_settings(self, job_id: int, job_spec: dict[str, Any]) -> None:
self.wc.jobs.reset(job_id, JobSettings.from_dict(self._convert_to_sdk_types(job_spec)))

def run(self, job_id: str, enable_queueing: bool = True) -> str:
request_body = {
"job_id": job_id,
"queue": {
"enabled": enable_queueing,
},
}
response = self.session.post("/run-now", json=request_body)
def run(self, job_id: int) -> int:
response = self.wc.jobs.run_now_and_wait(job_id)

if response.status_code != 200:
raise DbtRuntimeError(f"Error triggering run for workflow.\n {response.content!r}")
if not response.run_id:
raise DbtRuntimeError(f"Error running Workflow.\n {response}")

response_json = response.json()
logger.info(f"Workflow trigger response={response_json}")
return response.run_id

return response_json["run_id"]
def _convert_to_sdk_types(self, job_settings: dict[str, Any]) -> dict[str, Any]:
return {
"access_control_list": convert_sdk_list(
JobAccessControlRequest, job_settings.get("access_control_list")
),
"budget_policy_id": job_settings.get("budget_policy_id"),
"continuous": convert_sdk_element(Continuous, job_settings.get("continuous")),
"deployment": convert_sdk_element(JobDeployment, job_settings.get("deployment")),
"description": job_settings.get("description"),
"edit_mode": convert_sdk_element(JobEditMode, job_settings.get("edit_mode")),
"email_notifications": convert_sdk_element(
JobEmailNotifications, job_settings.get("email_notifications")
),
"environments": convert_sdk_list(JobEnvironment, job_settings.get("environments")),
"format": convert_sdk_element(Format, job_settings.get("format")),
"git_source": convert_sdk_element(GitSource, job_settings.get("git_source")),
"health": convert_sdk_element(JobsHealthRules, job_settings.get("health")),
"idempotency_token": job_settings.get("idempotency_token"),
"job_clusters": convert_sdk_list(JobCluster, job_settings.get("job_clusters")),
"max_concurrent_runs": job_settings.get("max_concurrent_runs"),
"name": job_settings.get("name"),
"notification_settings": convert_sdk_element(
JobNotificationSettings, job_settings.get("notification_settings")
),
"parameters": convert_sdk_list(JobParameterDefinition, job_settings.get("parameters")),
"queue": convert_sdk_element(QueueSettings, job_settings.get("queue")),
"run_as": convert_sdk_element(JobRunAs, job_settings.get("run_as")),
"schedule": convert_sdk_element(CronSchedule, job_settings.get("schedule")),
"tags": job_settings.get("tags"),
"tasks": convert_sdk_list(Task, job_settings.get("tasks")),
"timeout_seconds": self.timeout.total_seconds(),
"trigger_settings": convert_sdk_element(TriggerSettings, job_settings.get("trigger")),
"webhook_notifications": convert_sdk_element(
WebhookNotifications, job_settings.get("webhook_notifications")
),
}


class DatabricksApiClient:
Expand All @@ -434,9 +381,6 @@ class DatabricksApiClient:
def __init__(
self,
wc: WorkspaceClient,
session: Session,
host: str,
polling_interval: int,
timeout: int,
use_user_folder: bool,
):
Expand All @@ -448,9 +392,9 @@ def __init__(
self.folders = SharedFolderApi()
self.workspace = WorkspaceApi(wc, self.folders)
self.commands = CommandApi(wc, timeout)
self.job_runs = JobRunsApi(session, host, polling_interval, timeout)
self.workflows = WorkflowJobApi(session, host)
self.workflow_permissions = JobPermissionsApi(session, host)
self.job_runs = JobRunsApi(wc, timeout)
self.workflows = WorkflowJobApi(wc, timeout)
self.workflow_permissions = JobPermissionsApi(wc)

@classmethod
def create(
Expand All @@ -461,31 +405,7 @@ def create(
) -> "DatabricksApiClient":
if cls.instance:
return cls.instance
polling_interval = DEFAULT_POLLING_INTERVAL
retry_strategy = Retry(total=4, backoff_factor=0.5)
adapter = HTTPAdapter(max_retries=retry_strategy)
session = Session()
session.mount("https://", adapter)

invocation_env = credentials.get_invocation_env()
user_agent = USER_AGENT
if invocation_env:
user_agent = f"{user_agent} ({invocation_env})"

connection_parameters = credentials.connection_parameters.copy() # type: ignore[union-attr]

http_headers = credentials.get_all_http_headers(
connection_parameters.pop("http_headers", {})
)
header_factory = credentials.authenticate().credentials_provider() # type: ignore
session.auth = BearerAuth(header_factory)

session.headers.update({"User-Agent": user_agent, **http_headers})
host = credentials.host

assert host is not None, "Host must be set in the credentials"
wc = DatabricksCredentialManager.create_from(credentials).api_client
cls.instance = DatabricksApiClient(
wc, session, host, polling_interval, timeout, use_user_folder
)
cls.instance = DatabricksApiClient(wc, timeout, use_user_folder)
return cls.instance
Loading

0 comments on commit cc76dd4

Please sign in to comment.