Skip to content

Commit

Permalink
Update connection object to cached_property in DatabricksHook (
Browse files Browse the repository at this point in the history
  • Loading branch information
josh-fell authored Dec 27, 2021
1 parent f99f2c3 commit e7659d0
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
operators talk to the ``api/2.0/jobs/runs/submit``
`endpoint <https://docs.databricks.com/api/latest/jobs.html#runs-submit>`_.
"""
import sys
import time
from time import sleep
from typing import Dict
Expand All @@ -34,6 +35,12 @@
from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection

if sys.version_info >= (3, 8):
from functools import cached_property
else:
from cached_property import cached_property

RESTART_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/restart")
START_CLUSTER_ENDPOINT = ("POST", "api/2.0/clusters/start")
Expand Down Expand Up @@ -143,11 +150,10 @@ def __init__(
self.retry_delay = retry_delay
self.aad_tokens: Dict[str, dict] = {}
self.aad_timeout_seconds = 10
self.databricks_conn = self.get_connection(self.databricks_conn_id)
if 'host' in self.databricks_conn.extra_dejson:
self.host = self._parse_host(self.databricks_conn.extra_dejson['host'])
else:
self.host = self._parse_host(self.databricks_conn.host)

@cached_property
def databricks_conn(self) -> Connection:
return self.get_connection(self.databricks_conn_id)

@staticmethod
def _parse_host(host: str) -> str:
Expand Down Expand Up @@ -305,6 +311,11 @@ def _do_api_call(self, endpoint_info, json):
"""
method, endpoint = endpoint_info

if 'host' in self.databricks_conn.extra_dejson:
self.host = self._parse_host(self.databricks_conn.extra_dejson['host'])
else:
self.host = self._parse_host(self.databricks_conn.host)

url = f'https://{self.host}/{endpoint}'

aad_headers = self._get_aad_headers()
Expand Down

0 comments on commit e7659d0

Please sign in to comment.