Skip to content

Commit

Permalink
Improved telemetry for Databricks provider
Browse files Browse the repository at this point in the history
  • Loading branch information
alexott authored and potiuk committed Jul 25, 2022
1 parent d7f4ee1 commit 83d2881
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 66 deletions.
3 changes: 2 additions & 1 deletion airflow/providers/databricks/hooks/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def __init__(
retry_limit: int = 3,
retry_delay: float = 1.0,
retry_args: Optional[Dict[Any, Any]] = None,
caller: str = "DatabricksHook",
) -> None:
super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args)
super().__init__(databricks_conn_id, timeout_seconds, retry_limit, retry_delay, retry_args, caller)

def run_now(self, json: dict) -> int:
"""
Expand Down
34 changes: 25 additions & 9 deletions airflow/providers/databricks/hooks/databricks_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@
from airflow.exceptions import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection

USER_AGENT_HEADER = {'user-agent': f'airflow-{__version__}'}
from airflow.providers_manager import ProvidersManager

# https://docs.microsoft.com/en-us/azure/databricks/dev-tools/api/latest/aad/service-prin-aad-token#--get-an-azure-active-directory-access-token
# https://docs.microsoft.com/en-us/graph/deployments#app-registration-and-token-service-root-endpoints
Expand Down Expand Up @@ -96,6 +95,7 @@ def __init__(
retry_limit: int = 3,
retry_delay: float = 1.0,
retry_args: Optional[Dict[Any, Any]] = None,
caller: str = "Unknown",
) -> None:
super().__init__()
self.databricks_conn_id = databricks_conn_id
Expand All @@ -106,6 +106,7 @@ def __init__(
self.retry_delay = retry_delay
self.aad_tokens: Dict[str, dict] = {}
self.aad_timeout_seconds = 10
self.caller = caller

def my_after_func(retry_state):
self._log_request_error(retry_state.attempt_number, retry_state.outcome)
Expand All @@ -129,6 +130,21 @@ def databricks_conn(self) -> Connection:
def get_conn(self) -> Connection:
return self.databricks_conn

@cached_property
def user_agent_header(self) -> Dict[str, str]:
return {'user-agent': self.user_agent_value}

@cached_property
def user_agent_value(self) -> str:
manager = ProvidersManager()
package_name = manager.hooks[BaseDatabricksHook.conn_type].package_name # type: ignore[union-attr]
provider = manager.providers[package_name]
version = provider.version
if provider.is_source:
version += "-source"

return f'Airflow/{__version__} Databricks/{version} ({self.caller})'

@cached_property
def host(self) -> str:
if 'host' in self.databricks_conn.extra_dejson:
Expand Down Expand Up @@ -209,7 +225,7 @@ def _get_aad_token(self, resource: str) -> str:
resp = requests.get(
AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**USER_AGENT_HEADER, "Metadata": "true"},
headers={**self.user_agent_header, "Metadata": "true"},
timeout=self.aad_timeout_seconds,
)
else:
Expand All @@ -227,7 +243,7 @@ def _get_aad_token(self, resource: str) -> str:
AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
**USER_AGENT_HEADER,
**self.user_agent_header,
'Content-Type': 'application/x-www-form-urlencoded',
},
timeout=self.aad_timeout_seconds,
Expand Down Expand Up @@ -274,7 +290,7 @@ async def _a_get_aad_token(self, resource: str) -> str:
async with self._session.get(
url=AZURE_METADATA_SERVICE_TOKEN_URL,
params=params,
headers={**USER_AGENT_HEADER, "Metadata": "true"},
headers={**self.user_agent_header, "Metadata": "true"},
timeout=self.aad_timeout_seconds,
) as resp:
resp.raise_for_status()
Expand All @@ -294,7 +310,7 @@ async def _a_get_aad_token(self, resource: str) -> str:
url=AZURE_TOKEN_SERVICE_URL.format(azure_ad_endpoint, tenant_id),
data=data,
headers={
**USER_AGENT_HEADER,
**self.user_agent_header,
'Content-Type': 'application/x-www-form-urlencoded',
},
timeout=self.aad_timeout_seconds,
Expand Down Expand Up @@ -467,7 +483,7 @@ def _do_api_call(
url = f'https://{self.host}/{endpoint}'

aad_headers = self._get_aad_headers()
headers = {**USER_AGENT_HEADER.copy(), **aad_headers}
headers = {**self.user_agent_header, **aad_headers}

auth: AuthBase
token = self._get_token()
Expand Down Expand Up @@ -525,7 +541,7 @@ async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Di
url = f'https://{self.host}/{endpoint}'

aad_headers = await self._a_get_aad_headers()
headers = {**USER_AGENT_HEADER.copy(), **aad_headers}
headers = {**self.user_agent_header, **aad_headers}

auth: aiohttp.BasicAuth
token = await self._a_get_token()
Expand All @@ -551,7 +567,7 @@ async def _a_do_api_call(self, endpoint_info: Tuple[str, str], json: Optional[Di
url,
json=json,
auth=auth,
headers={**headers, **USER_AGENT_HEADER},
headers={**headers, **self.user_agent_header},
timeout=self.timeout_seconds,
) as response:
response.raise_for_status()
Expand Down
7 changes: 3 additions & 4 deletions airflow/providers/databricks/hooks/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
from databricks import sql # type: ignore[attr-defined]
from databricks.sql.client import Connection # type: ignore[attr-defined]

from airflow import __version__
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook

LIST_SQL_ENDPOINTS_ENDPOINT = ('GET', 'api/2.0/sql/endpoints')
USER_AGENT_STRING = f'airflow-{__version__}'


class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
Expand Down Expand Up @@ -62,9 +60,10 @@ def __init__(
http_headers: Optional[List[Tuple[str, str]]] = None,
catalog: Optional[str] = None,
schema: Optional[str] = None,
caller: str = "DatabricksSqlHook",
**kwargs,
) -> None:
super().__init__(databricks_conn_id)
super().__init__(databricks_conn_id, caller=caller)
self._sql_conn = None
self._token: Optional[str] = None
self._http_path = http_path
Expand Down Expand Up @@ -132,7 +131,7 @@ def get_conn(self) -> Connection:
catalog=self.catalog,
session_configuration=self.session_config,
http_headers=self.http_headers,
_user_agent_entry=USER_AGENT_STRING,
_user_agent_entry=self.user_agent_value,
**self._get_extra_config(),
**self.additional_params,
)
Expand Down
44 changes: 26 additions & 18 deletions airflow/providers/databricks/operators/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,14 +382,18 @@ def __init__(
# This variable will be used in case our task gets killed.
self.run_id: Optional[int] = None
self.do_xcom_push = do_xcom_push

def _get_hook(self) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
)
self.hook = None

def _get_hook(self, caller="DatabricksSubmitRunOperator") -> DatabricksHook:
if not self.hook:
self.hook = DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
)
return self.hook

def execute(self, context: 'Context'):
hook = self._get_hook()
Expand All @@ -411,7 +415,7 @@ class DatabricksSubmitRunDeferrableOperator(DatabricksSubmitRunOperator):
"""Deferrable version of ``DatabricksSubmitRunOperator``"""

def execute(self, context):
hook = self._get_hook()
hook = self._get_hook(caller="DatabricksSubmitRunDeferrableOperator")
self.run_id = hook.submit_run(self.json)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)

Expand Down Expand Up @@ -634,14 +638,18 @@ def __init__(
# This variable will be used in case our task gets killed.
self.run_id: Optional[int] = None
self.do_xcom_push = do_xcom_push

def _get_hook(self) -> DatabricksHook:
return DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
)
self.hook = None

def _get_hook(self, caller="DatabricksRunNowOperator") -> DatabricksHook:
if not self.hook:
self.hook = DatabricksHook(
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
retry_args=self.databricks_retry_args,
caller=caller,
)
return self.hook

def execute(self, context: 'Context'):
hook = self._get_hook()
Expand Down Expand Up @@ -669,7 +677,7 @@ class DatabricksRunNowDeferrableOperator(DatabricksRunNowOperator):
"""Deferrable version of ``DatabricksRunNowOperator``"""

def execute(self, context):
hook = self._get_hook()
hook = self._get_hook(caller="DatabricksRunNowDeferrableOperator")
self.run_id = hook.run_now(self.json)
_handle_deferrable_databricks_operator_execution(self, hook, self.log, context)

Expand Down
3 changes: 3 additions & 0 deletions airflow/providers/databricks/operators/databricks_repos.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def _get_hook(self) -> DatabricksHook:
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
caller="DatabricksReposCreateOperator",
)

def execute(self, context: 'Context'):
Expand Down Expand Up @@ -218,6 +219,7 @@ def _get_hook(self) -> DatabricksHook:
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
caller="DatabricksReposUpdateOperator",
)

def execute(self, context: 'Context'):
Expand Down Expand Up @@ -283,6 +285,7 @@ def _get_hook(self) -> DatabricksHook:
self.databricks_conn_id,
retry_limit=self.databricks_retry_limit,
retry_delay=self.databricks_retry_delay,
caller="DatabricksReposDeleteOperator",
)

def execute(self, context: 'Context'):
Expand Down
2 changes: 2 additions & 0 deletions airflow/providers/databricks/operators/databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def _get_hook(self) -> DatabricksSqlHook:
http_headers=self.http_headers,
catalog=self.catalog,
schema=self.schema,
caller="DatabricksSqlOperator",
**self.client_parameters,
)

Expand Down Expand Up @@ -279,6 +280,7 @@ def _get_hook(self) -> DatabricksSqlHook:
http_headers=self._http_headers,
catalog=self._catalog,
schema=self._schema,
caller="DatabricksCopyIntoOperator",
**self._client_parameters,
)

Expand Down
Loading

0 comments on commit 83d2881

Please sign in to comment.