From 201cc13d39a61372010e5a9cf081ab267a7bf91b Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 7 Oct 2024 13:46:03 +0200 Subject: [PATCH 01/17] refactor API client --- databricks/sdk/_base_client.py | 324 +++++++++++++++++++++++++++++++++ databricks/sdk/core.py | 318 +++----------------------------- tests/fixture_server.py | 33 ++++ tests/test_base_client.py | 282 ++++++++++++++++++++++++++++ tests/test_core.py | 279 +--------------------------- 5 files changed, 670 insertions(+), 566 deletions(-) create mode 100644 databricks/sdk/_base_client.py create mode 100644 tests/fixture_server.py create mode 100644 tests/test_base_client.py diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py new file mode 100644 index 00000000..47d6340e --- /dev/null +++ b/databricks/sdk/_base_client.py @@ -0,0 +1,324 @@ +import logging +from datetime import timedelta +from types import TracebackType +from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, + Optional, Type, Union) +import urllib.parse + +import requests +import requests.adapters + +from . import useragent +from .casing import Casing +from .clock import Clock, RealClock +from .errors import DatabricksError, _ErrorCustomizer, _Parser +from .logger import RoundTrip +from .retries import retried + +logger = logging.getLogger('databricks.sdk') + + +class _BaseClient: + + def __init__(self, + debug_truncate_bytes: int = None, + retry_timeout_seconds: int = None, + user_agent_base: str = None, + header_factory: Callable[[], dict] = None, + max_connection_pools: int = None, + max_connections_per_pool: int = None, + pool_block: bool = True, + http_timeout_seconds: float = None, + extra_error_customizers: List[_ErrorCustomizer] = None, + debug_headers: bool = False, + clock: Clock = None): + """ + :param debug_truncate_bytes: + :param retry_timeout_seconds: + :param user_agent_base: + :param header_factory: A function that returns a dictionary of headers to include in the request. + :param max_connection_pools: Number of urllib3 connection pools to cache before discarding the least + recently used pool. Python requests default value is 10. + :param max_connections_per_pool: The maximum number of connections to save in the pool. Improves performance + in multithreaded situations. For now, we're setting it to the same value as connection_pool_size. + :param pool_block: If pool_block is False, then more connections will are created, but not saved after the + first use. Blocks when no free connections are available. urllib3 ensures that no more than + pool_maxsize connections are used at a time. Prevents platform from flooding. By default, requests library + doesn't block. + :param http_timeout_seconds: + :param extra_error_customizers: + :param debug_headers: Whether to include debug headers in the request log. + :param clock: Clock object to use for time-related operations. + """ + + self._debug_truncate_bytes = debug_truncate_bytes or 96 + self._debug_headers = debug_headers + self._retry_timeout_seconds = retry_timeout_seconds or 300 + self._user_agent_base = user_agent_base or useragent.to_string() + self._header_factory = header_factory + self._clock = clock or RealClock() + self._session = requests.Session() + self._session.auth = self._authenticate + + # We don't use `max_retries` from HTTPAdapter to align with a more production-ready + # retry strategy established in the Databricks SDK for Go. See _is_retryable and + # @retried for more details. + http_adapter = requests.adapters.HTTPAdapter(pool_connections=max_connections_per_pool or 20, + pool_maxsize=max_connection_pools or 20, + pool_block=pool_block) + self._session.mount("https://", http_adapter) + + # Default to 60 seconds + self._http_timeout_seconds = http_timeout_seconds or 60 + + self._error_parser = _Parser(extra_error_customizers=extra_error_customizers) + + def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + if self._header_factory: + headers = self._header_factory() + for k, v in headers.items(): + r.headers[k] = v + return r + + @staticmethod + def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: + # Convert True -> "true" for Databricks APIs to understand booleans. + # See: https://github.com/databricks/databricks-sdk-py/issues/142 + if query is None: + return None + with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} + + # Query parameters may be nested, e.g. + # {'filter_by': {'user_ids': [123, 456]}} + # The HTTP-compatible representation of this is + # filter_by.user_ids=123&filter_by.user_ids=456 + # To achieve this, we convert the above dictionary to + # {'filter_by.user_ids': [123, 456]} + # See the following for more information: + # https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule + def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: + for k1, v1 in d.items(): + if isinstance(v1, dict): + v1 = dict(flatten_dict(v1)) + for k2, v2 in v1.items(): + yield f"{k1}.{k2}", v2 + else: + yield k1, v1 + + flattened = dict(flatten_dict(with_fixed_bools)) + return flattened + + def do(self, + method: str, + url: str, + query: dict = None, + headers: dict = None, + body: dict = None, + raw: bool = False, + files=None, + data=None, + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, + response_headers: List[str] = None) -> Union[dict, list, BinaryIO]: + if headers is None: + headers = {} + headers['User-Agent'] = self._user_agent_base + retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), + is_retryable=self._is_retryable, + clock=self._clock) + response = retryable(self._perform)(method, + url, + query=query, + headers=headers, + body=body, + raw=raw, + files=files, + data=data, + auth=auth) + + resp = dict() + for header in response_headers if response_headers else []: + resp[header] = response.headers.get(Casing.to_header_case(header)) + if raw: + resp["contents"] = _StreamingResponse(response) + return resp + if not len(response.content): + return resp + + json_response = response.json() + if json_response is None: + return resp + + if isinstance(json_response, list): + return json_response + + return {**resp, **json_response} + + @staticmethod + def _is_retryable(err: BaseException) -> Optional[str]: + # this method is Databricks-specific port of urllib3 retries + # (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py) + # and Databricks SDK for Go retries + # (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go) + from urllib3.exceptions import ProxyError + if isinstance(err, ProxyError): + err = err.original_error + if isinstance(err, requests.ConnectionError): + # corresponds to `connection reset by peer` and `connection refused` errors from Go, + # which are generally related to the temporary glitches in the networking stack, + # also caused by endpoint protection software, like ZScaler, to drop connections while + # not yet authenticated. + # + # return a simple string for debug log readability, as `raise TimeoutError(...) from err` + # will bubble up the original exception in case we reach max retries. + return f'cannot connect' + if isinstance(err, requests.Timeout): + # corresponds to `TLS handshake timeout` and `i/o timeout` in Go. + # + # return a simple string for debug log readability, as `raise TimeoutError(...) from err` + # will bubble up the original exception in case we reach max retries. + return f'timeout' + if isinstance(err, DatabricksError): + message = str(err) + transient_error_string_matches = [ + "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException", + "does not have any associated worker environments", "There is no worker environment with id", + "Unknown worker environment", "ClusterNotReadyException", "Unexpected error", + "Please try again later or try a faster operation.", + "RPC token bucket limit has been exceeded", + ] + for substring in transient_error_string_matches: + if substring not in message: + continue + return f'matched {substring}' + return None + + def _perform(self, + method: str, + url: str, + query: dict = None, + headers: dict = None, + body: dict = None, + raw: bool = False, + files=None, + data=None, + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): + response = self._session.request(method, + url, + params=self._fix_query_string(query), + json=body, + headers=headers, + files=files, + data=data, + auth=auth, + stream=raw, + timeout=self._http_timeout_seconds) + self._record_request_log(response, raw=raw or data is not None or files is not None) + error = self._error_parser.get_api_error(response) + if error is not None: + raise error from None + return response + + def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: + if not logger.isEnabledFor(logging.DEBUG): + return + logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate()) + + +class _StreamingResponse(BinaryIO): + _response: requests.Response + _buffer: bytes + _content: Union[Iterator[bytes], None] + _chunk_size: Union[int, None] + _closed: bool = False + + def fileno(self) -> int: + pass + + def flush(self) -> int: + pass + + def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None): + self._response = response + self._buffer = b'' + self._content = None + self._chunk_size = chunk_size + + def _open(self) -> None: + if self._closed: + raise ValueError("I/O operation on closed file") + if not self._content: + self._content = self._response.iter_content(chunk_size=self._chunk_size) + + def __enter__(self) -> BinaryIO: + self._open() + return self + + def set_chunk_size(self, chunk_size: Union[int, None]) -> None: + self._chunk_size = chunk_size + + def close(self) -> None: + self._response.close() + self._closed = True + + def isatty(self) -> bool: + return False + + def read(self, n: int = -1) -> bytes: + self._open() + read_everything = n < 0 + remaining_bytes = n + res = b'' + while remaining_bytes > 0 or read_everything: + if len(self._buffer) == 0: + try: + self._buffer = next(self._content) + except StopIteration: + break + bytes_available = len(self._buffer) + to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available) + res += self._buffer[:to_read] + self._buffer = self._buffer[to_read:] + remaining_bytes -= to_read + return res + + def readable(self) -> bool: + return self._content is not None + + def readline(self, __limit: int = ...) -> bytes: + raise NotImplementedError() + + def readlines(self, __hint: int = ...) -> List[bytes]: + raise NotImplementedError() + + def seek(self, __offset: int, __whence: int = ...) -> int: + raise NotImplementedError() + + def seekable(self) -> bool: + return False + + def tell(self) -> int: + raise NotImplementedError() + + def truncate(self, __size: Union[int, None] = ...) -> int: + raise NotImplementedError() + + def writable(self) -> bool: + return False + + def write(self, s: Union[bytes, bytearray]) -> int: + raise NotImplementedError() + + def writelines(self, lines: Iterable[bytes]) -> None: + raise NotImplementedError() + + def __next__(self) -> bytes: + return self.read(1) + + def __iter__(self) -> Iterator[bytes]: + return self._content + + def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None], + traceback: Union[TracebackType, None]) -> None: + self._content = None + self._buffer = b'' + self.close() diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 77e8c9aa..c9e49dc8 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -1,19 +1,13 @@ import re -from datetime import timedelta -from types import TracebackType -from typing import Any, BinaryIO, Iterator, Type +from typing import BinaryIO from urllib.parse import urlencode -from requests.adapters import HTTPAdapter - -from .casing import Casing +from ._base_client import _BaseClient from .config import * # To preserve backwards compatibility (as these definitions were previously in this module) from .credentials_provider import * -from .errors import DatabricksError, _ErrorCustomizer, _Parser -from .logger import RoundTrip +from .errors import DatabricksError, _ErrorCustomizer from .oauth import retrieve_token -from .retries import retried __all__ = ['Config', 'DatabricksError'] @@ -24,54 +18,21 @@ OIDC_TOKEN_PATH = "/oidc/v1/token" -class ApiClient: - _cfg: Config - _RETRY_AFTER_DEFAULT: int = 1 - - def __init__(self, cfg: Config = None): - if cfg is None: - cfg = Config() +class ApiClient: + def __init__(self, cfg: Config): self._cfg = cfg - # See https://github.com/databricks/databricks-sdk-go/blob/main/client/client.go#L34-L35 - self._debug_truncate_bytes = cfg.debug_truncate_bytes if cfg.debug_truncate_bytes else 96 - self._retry_timeout_seconds = cfg.retry_timeout_seconds if cfg.retry_timeout_seconds else 300 - self._user_agent_base = cfg.user_agent - self._session = requests.Session() - self._session.auth = self._authenticate - - # Number of urllib3 connection pools to cache before discarding the least - # recently used pool. Python requests default value is 10. - pool_connections = cfg.max_connection_pools - if pool_connections is None: - pool_connections = 20 - - # The maximum number of connections to save in the pool. Improves performance - # in multithreaded situations. For now, we're setting it to the same value - # as connection_pool_size. - pool_maxsize = cfg.max_connections_per_pool - if cfg.max_connections_per_pool is None: - pool_maxsize = pool_connections - - # If pool_block is False, then more connections will are created, - # but not saved after the first use. Blocks when no free connections are available. - # urllib3 ensures that no more than pool_maxsize connections are used at a time. - # Prevents platform from flooding. By default, requests library doesn't block. - pool_block = True - - # We don't use `max_retries` from HTTPAdapter to align with a more production-ready - # retry strategy established in the Databricks SDK for Go. See _is_retryable and - # @retried for more details. - http_adapter = HTTPAdapter(pool_connections=pool_connections, - pool_maxsize=pool_maxsize, - pool_block=pool_block) - self._session.mount("https://", http_adapter) - - # Default to 60 seconds - self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60 - - self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)]) + self._api_client = _BaseClient(debug_truncate_bytes=cfg.debug_truncate_bytes, + retry_timeout_seconds=cfg.retry_timeout_seconds, + user_agent_base=cfg.user_agent, + header_factory=cfg.authenticate, + max_connection_pools=cfg.max_connection_pools, + max_connections_per_pool=cfg.max_connections_per_pool, + pool_block=True, + http_timeout_seconds=cfg.http_timeout_seconds, + extra_error_customizers=[_AddDebugErrorCustomizer(cfg)], + clock=cfg.clock) @property def account_id(self) -> str: @@ -81,40 +42,6 @@ def account_id(self) -> str: def is_account_client(self) -> bool: return self._cfg.is_account_client - def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: - headers = self._cfg.authenticate() - for k, v in headers.items(): - r.headers[k] = v - return r - - @staticmethod - def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: - # Convert True -> "true" for Databricks APIs to understand booleans. - # See: https://github.com/databricks/databricks-sdk-py/issues/142 - if query is None: - return None - with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} - - # Query parameters may be nested, e.g. - # {'filter_by': {'user_ids': [123, 456]}} - # The HTTP-compatible representation of this is - # filter_by.user_ids=123&filter_by.user_ids=456 - # To achieve this, we convert the above dictionary to - # {'filter_by.user_ids': [123, 456]} - # See the following for more information: - # https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule - def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: - for k1, v1 in d.items(): - if isinstance(v1, dict): - v1 = dict(flatten_dict(v1)) - for k2, v2 in v1.items(): - yield f"{k1}.{k2}", v2 - else: - yield k1, v1 - - flattened = dict(flatten_dict(with_fixed_bools)) - return flattened - def get_oauth_token(self, auth_details: str) -> Token: if not self._cfg.auth_type: self._cfg.authenticate() @@ -142,115 +69,22 @@ def do(self, files=None, data=None, auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, - response_headers: List[str] = None) -> Union[dict, BinaryIO]: - if headers is None: - headers = {} + response_headers: List[str] = None) -> Union[dict, list, BinaryIO]: if url is None: # Remove extra `/` from path for Files API # Once we've fixed the OpenAPI spec, we can remove this path = re.sub('^/api/2.0/fs/files//', '/api/2.0/fs/files/', path) url = f"{self._cfg.host}{path}" - headers['User-Agent'] = self._user_agent_base - retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), - is_retryable=self._is_retryable, - clock=self._cfg.clock) - response = retryable(self._perform)(method, - url, - query=query, - headers=headers, - body=body, - raw=raw, - files=files, - data=data, - auth=auth) - - resp = dict() - for header in response_headers if response_headers else []: - resp[header] = response.headers.get(Casing.to_header_case(header)) - if raw: - resp["contents"] = StreamingResponse(response) - return resp - if not len(response.content): - return resp - - jsonResponse = response.json() - if jsonResponse is None: - return resp - - if isinstance(jsonResponse, list): - return jsonResponse - - return {**resp, **jsonResponse} - - @staticmethod - def _is_retryable(err: BaseException) -> Optional[str]: - # this method is Databricks-specific port of urllib3 retries - # (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py) - # and Databricks SDK for Go retries - # (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go) - from urllib3.exceptions import ProxyError - if isinstance(err, ProxyError): - err = err.original_error - if isinstance(err, requests.ConnectionError): - # corresponds to `connection reset by peer` and `connection refused` errors from Go, - # which are generally related to the temporary glitches in the networking stack, - # also caused by endpoint protection software, like ZScaler, to drop connections while - # not yet authenticated. - # - # return a simple string for debug log readability, as `raise TimeoutError(...) from err` - # will bubble up the original exception in case we reach max retries. - return f'cannot connect' - if isinstance(err, requests.Timeout): - # corresponds to `TLS handshake timeout` and `i/o timeout` in Go. - # - # return a simple string for debug log readability, as `raise TimeoutError(...) from err` - # will bubble up the original exception in case we reach max retries. - return f'timeout' - if isinstance(err, DatabricksError): - message = str(err) - transient_error_string_matches = [ - "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException", - "does not have any associated worker environments", "There is no worker environment with id", - "Unknown worker environment", "ClusterNotReadyException", "Unexpected error", - "Please try again later or try a faster operation.", - "RPC token bucket limit has been exceeded", - ] - for substring in transient_error_string_matches: - if substring not in message: - continue - return f'matched {substring}' - return None - - def _perform(self, - method: str, - url: str, - query: dict = None, - headers: dict = None, - body: dict = None, - raw: bool = False, - files=None, - data=None, - auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): - response = self._session.request(method, - url, - params=self._fix_query_string(query), - json=body, - headers=headers, - files=files, - data=data, - auth=auth, - stream=raw, - timeout=self._http_timeout_seconds) - self._record_request_log(response, raw=raw or data is not None or files is not None) - error = self._error_parser.get_api_error(response) - if error is not None: - raise error from None - return response - - def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: - if not logger.isEnabledFor(logging.DEBUG): - return - logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate()) + return self._api_client.do(method=method, + url=url, + query=query, + headers=headers, + body=body, + raw=raw, + files=files, + data=data, + auth=auth, + response_headers=response_headers) class _AddDebugErrorCustomizer(_ErrorCustomizer): @@ -264,103 +98,3 @@ def customize_error(self, response: requests.Response, kwargs: dict): if response.status_code in (401, 403): message = kwargs.get('message', 'request failed') kwargs['message'] = self._cfg.wrap_debug_info(message) - - -class StreamingResponse(BinaryIO): - _response: requests.Response - _buffer: bytes - _content: Union[Iterator[bytes], None] - _chunk_size: Union[int, None] - _closed: bool = False - - def fileno(self) -> int: - pass - - def flush(self) -> int: - pass - - def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None): - self._response = response - self._buffer = b'' - self._content = None - self._chunk_size = chunk_size - - def _open(self) -> None: - if self._closed: - raise ValueError("I/O operation on closed file") - if not self._content: - self._content = self._response.iter_content(chunk_size=self._chunk_size) - - def __enter__(self) -> BinaryIO: - self._open() - return self - - def set_chunk_size(self, chunk_size: Union[int, None]) -> None: - self._chunk_size = chunk_size - - def close(self) -> None: - self._response.close() - self._closed = True - - def isatty(self) -> bool: - return False - - def read(self, n: int = -1) -> bytes: - self._open() - read_everything = n < 0 - remaining_bytes = n - res = b'' - while remaining_bytes > 0 or read_everything: - if len(self._buffer) == 0: - try: - self._buffer = next(self._content) - except StopIteration: - break - bytes_available = len(self._buffer) - to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available) - res += self._buffer[:to_read] - self._buffer = self._buffer[to_read:] - remaining_bytes -= to_read - return res - - def readable(self) -> bool: - return self._content is not None - - def readline(self, __limit: int = ...) -> bytes: - raise NotImplementedError() - - def readlines(self, __hint: int = ...) -> List[bytes]: - raise NotImplementedError() - - def seek(self, __offset: int, __whence: int = ...) -> int: - raise NotImplementedError() - - def seekable(self) -> bool: - return False - - def tell(self) -> int: - raise NotImplementedError() - - def truncate(self, __size: Union[int, None] = ...) -> int: - raise NotImplementedError() - - def writable(self) -> bool: - return False - - def write(self, s: Union[bytes, bytearray]) -> int: - raise NotImplementedError() - - def writelines(self, lines: Iterable[bytes]) -> None: - raise NotImplementedError() - - def __next__(self) -> bytes: - return self.read(1) - - def __iter__(self) -> Iterator[bytes]: - return self._content - - def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None], - traceback: Union[TracebackType, None]) -> None: - self._content = None - self._buffer = b'' - self.close() diff --git a/tests/fixture_server.py b/tests/fixture_server.py new file mode 100644 index 00000000..04190414 --- /dev/null +++ b/tests/fixture_server.py @@ -0,0 +1,33 @@ +import contextlib +import functools +import typing +from http.server import BaseHTTPRequestHandler + + +@contextlib.contextmanager +def http_fixture_server(handler: typing.Callable[[BaseHTTPRequestHandler], None]): + from http.server import HTTPServer + from threading import Thread + + class _handler(BaseHTTPRequestHandler): + + def __init__(self, handler: typing.Callable[[BaseHTTPRequestHandler], None], *args): + self._handler = handler + super().__init__(*args) + + def __getattr__(self, item): + if 'do_' != item[0:3]: + raise AttributeError(f'method {item} not found') + return functools.partial(self._handler, self) + + handler_factory = functools.partial(_handler, handler) + srv = HTTPServer(('localhost', 0), handler_factory) + t = Thread(target=srv.serve_forever) + try: + t.daemon = True + t.start() + yield 'http://{0}:{1}'.format(*srv.server_address) + finally: + srv.shutdown() + + diff --git a/tests/test_base_client.py b/tests/test_base_client.py new file mode 100644 index 00000000..4cba10db --- /dev/null +++ b/tests/test_base_client.py @@ -0,0 +1,282 @@ +from http.server import BaseHTTPRequestHandler +from typing import List, Iterator + +import pytest +import requests + +from databricks.sdk._base_client import _BaseClient, _StreamingResponse +from databricks.sdk import errors, useragent +from databricks.sdk.core import DatabricksError + +from .clock import FakeClock +from .fixture_server import http_fixture_server + + +class DummyResponse(requests.Response): + _content: Iterator[bytes] + _closed: bool = False + + def __init__(self, content: List[bytes]) -> None: + super().__init__() + self._content = iter(content) + + def iter_content(self, chunk_size: int = 1, decode_unicode=False) -> Iterator[bytes]: + return self._content + + def close(self): + self._closed = True + + def isClosed(self): + return self._closed + + +def test_streaming_response_read(config): + content = b"some initial binary data: \x00\x01" + response = _StreamingResponse(DummyResponse([content])) + assert response.read() == content + + +def test_streaming_response_read_partial(config): + content = b"some initial binary data: \x00\x01" + response = _StreamingResponse(DummyResponse([content])) + assert response.read(8) == b"some ini" + + +def test_streaming_response_read_full(config): + content = b"some initial binary data: \x00\x01" + response = _StreamingResponse(DummyResponse([content, content])) + assert response.read() == content + content + + +def test_streaming_response_read_closes(config): + content = b"some initial binary data: \x00\x01" + dummy_response = DummyResponse([content]) + with _StreamingResponse(dummy_response) as response: + assert response.read() == content + assert dummy_response.isClosed() + + +@pytest.mark.parametrize('status_code,headers,body,expected_error', [ + (400, {}, { + "message": + "errorMessage", + "details": [{ + "type": DatabricksError._error_info_type, + "reason": "error reason", + "domain": "error domain", + "metadata": { + "etag": "error etag" + }, + }, { + "type": "wrong type", + "reason": "wrong reason", + "domain": "wrong domain", + "metadata": { + "etag": "wrong etag" + } + }], + }, + errors.BadRequest('errorMessage', + details=[{ + 'type': DatabricksError._error_info_type, + 'reason': 'error reason', + 'domain': 'error domain', + 'metadata': { + 'etag': 'error etag' + }, + }])), + (401, {}, { + 'error_code': 'UNAUTHORIZED', + 'message': 'errorMessage', + }, + errors.Unauthenticated('errorMessage', error_code='UNAUTHORIZED')), + (403, {}, { + 'error_code': 'FORBIDDEN', + 'message': 'errorMessage', + }, + errors.PermissionDenied('errorMessage', error_code='FORBIDDEN')), + (429, {}, { + 'error_code': 'TOO_MANY_REQUESTS', + 'message': 'errorMessage', + }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=1)), + (429, { + 'Retry-After': '100' + }, { + 'error_code': 'TOO_MANY_REQUESTS', + 'message': 'errorMessage', + }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)), + (503, {}, { + 'error_code': 'TEMPORARILY_UNAVAILABLE', + 'message': 'errorMessage', + }, errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', + retry_after_secs=1)), + (503, { + 'Retry-After': '100' + }, { + 'error_code': 'TEMPORARILY_UNAVAILABLE', + 'message': 'errorMessage', + }, + errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', + retry_after_secs=100)), + (404, {}, { + 'scimType': 'scim type', + 'detail': 'detail', + 'status': 'status', + }, errors.NotFound('scim type detail', error_code='SCIM_status')), +]) +def test_error(requests_mock, status_code, headers, body, expected_error): + client = _BaseClient(clock=FakeClock()) + requests_mock.get("/test", json=body, status_code=status_code, headers=headers) + with pytest.raises(DatabricksError) as raised: + client._perform("GET", "https://localhost/test", headers={"test": "test"}) + actual = raised.value + assert isinstance(actual, type(expected_error)) + assert str(actual) == str(expected_error) + assert actual.error_code == expected_error.error_code + assert actual.retry_after_secs == expected_error.retry_after_secs + expected_error_infos, actual_error_infos = expected_error.get_error_info(), actual.get_error_info() + assert len(expected_error_infos) == len(actual_error_infos) + for expected, actual in zip(expected_error_infos, actual_error_infos): + assert expected.type == actual.type + assert expected.reason == actual.reason + assert expected.domain == actual.domain + assert expected.metadata == actual.metadata + + +def test_api_client_do_custom_headers(requests_mock): + client = _BaseClient() + requests_mock.get("/test", + json={"well": "done"}, + request_headers={ + "test": "test", + "User-Agent": useragent.to_string() + }) + res = client.do("GET", "https://localhost/test", headers={"test": "test"}) + assert res == {"well": "done"} + + +@pytest.mark.parametrize('status_code,include_retry_after', + ((429, False), (429, True), (503, False), (503, True))) +def test_http_retry_after(status_code, include_retry_after): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(status_code) + if include_retry_after: + h.send_header('Retry-After', '1') + h.send_header('Content-Type', 'application/json') + h.end_headers() + else: + h.send_response(200) + h.send_header('Content-Type', 'application/json') + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + +def test_http_retry_after_wrong_format(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(429) + h.send_header('Retry-After', '1.58') + h.end_headers() + else: + h.send_response(200) + h.send_header('Content-Type', 'application/json') + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + +def test_http_retried_exceed_limit(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + h.send_response(429) + h.send_header('Retry-After', '1') + h.end_headers() + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(retry_timeout_seconds=1, clock=FakeClock()) + with pytest.raises(TimeoutError): + res = api_client.do('GET', f'{host}/foo') + + assert len(requests) == 1 + + +def test_http_retried_on_match(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(400) + h.end_headers() + h.wfile.write(b'{"error_code": "abc", "message": "... ClusterNotReadyException ..."}') + else: + h.send_response(200) + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + +def test_http_not_retried_on_normal_errors(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(400) + h.end_headers() + h.wfile.write(b'{"error_code": "abc", "message": "something not found"}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + with pytest.raises(DatabricksError): + api_client.do('GET', f'{host}/foo') + + assert len(requests) == 1 + + +def test_http_retried_on_connection_error(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) > 0: + h.send_response(200) + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + diff --git a/tests/test_core.py b/tests/test_core.py index d54563d4..b61cfa01 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -8,14 +8,11 @@ import typing from datetime import datetime from http.server import BaseHTTPRequestHandler -from typing import Iterator, List import pytest -import requests from databricks.sdk import WorkspaceClient, errors -from databricks.sdk.core import (ApiClient, Config, DatabricksError, - StreamingResponse) +from databricks.sdk.core import ApiClient, Config, DatabricksError from databricks.sdk.credentials_provider import (CliTokenSource, CredentialsProvider, CredentialsStrategy, @@ -28,7 +25,7 @@ from databricks.sdk.service.iam import AccessControlRequest from databricks.sdk.version import __version__ -from .clock import FakeClock +from .fixture_server import http_fixture_server from .conftest import noop_credentials @@ -80,32 +77,6 @@ def write_small_dummy_executable(path: pathlib.Path): return cli -def test_streaming_response_read(config): - content = b"some initial binary data: \x00\x01" - response = StreamingResponse(DummyResponse([content])) - assert response.read() == content - - -def test_streaming_response_read_partial(config): - content = b"some initial binary data: \x00\x01" - response = StreamingResponse(DummyResponse([content])) - assert response.read(8) == b"some ini" - - -def test_streaming_response_read_full(config): - content = b"some initial binary data: \x00\x01" - response = StreamingResponse(DummyResponse([content, content])) - assert response.read() == content + content - - -def test_streaming_response_read_closes(config): - content = b"some initial binary data: \x00\x01" - dummy_response = DummyResponse([content]) - with StreamingResponse(dummy_response) as response: - assert response.read() == content - assert dummy_response.isClosed() - - def write_large_dummy_executable(path: pathlib.Path): cli = path.joinpath('databricks') @@ -290,36 +261,6 @@ def test_config_parsing_non_string_env_vars(monkeypatch): assert c.debug_truncate_bytes == 100 -class DummyResponse(requests.Response): - _content: Iterator[bytes] - _closed: bool = False - - def __init__(self, content: List[bytes]) -> None: - super().__init__() - self._content = iter(content) - - def iter_content(self, chunk_size: int = 1, decode_unicode=False) -> Iterator[bytes]: - return self._content - - def close(self): - self._closed = True - - def isClosed(self): - return self._closed - - -def test_api_client_do_custom_headers(config, requests_mock): - client = ApiClient(config) - requests_mock.get("/test", - json={"well": "done"}, - request_headers={ - "test": "test", - "User-Agent": config.user_agent - }) - res = client.do("GET", "/test", headers={"test": "test"}) - assert res == {"well": "done"} - - def test_access_control_list(config, requests_mock): requests_mock.post("http://localhost/api/2.1/jobs/create", request_headers={"User-Agent": config.user_agent}) @@ -360,80 +301,22 @@ def test_deletes(config, requests_mock): @pytest.mark.parametrize('status_code,headers,body,expected_error', [ - (400, {}, { - "message": - "errorMessage", - "details": [{ - "type": DatabricksError._error_info_type, - "reason": "error reason", - "domain": "error domain", - "metadata": { - "etag": "error etag" - }, - }, { - "type": "wrong type", - "reason": "wrong reason", - "domain": "wrong domain", - "metadata": { - "etag": "wrong etag" - } - }], - }, - errors.BadRequest('errorMessage', - details=[{ - 'type': DatabricksError._error_info_type, - 'reason': 'error reason', - 'domain': 'error domain', - 'metadata': { - 'etag': 'error etag' - }, - }])), (401, {}, { 'error_code': 'UNAUTHORIZED', 'message': 'errorMessage', }, - errors.Unauthenticated('errorMessage. Config: host=http://localhost, auth_type=noop', - error_code='UNAUTHORIZED')), + errors.Unauthenticated('errorMessage. Config: host=http://localhost, auth_type=noop', error_code='UNAUTHORIZED')), (403, {}, { 'error_code': 'FORBIDDEN', 'message': 'errorMessage', }, - errors.PermissionDenied('errorMessage. Config: host=http://localhost, auth_type=noop', - error_code='FORBIDDEN')), - (429, {}, { - 'error_code': 'TOO_MANY_REQUESTS', - 'message': 'errorMessage', - }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=1)), - (429, { - 'Retry-After': '100' - }, { - 'error_code': 'TOO_MANY_REQUESTS', - 'message': 'errorMessage', - }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)), - (503, {}, { - 'error_code': 'TEMPORARILY_UNAVAILABLE', - 'message': 'errorMessage', - }, errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', - retry_after_secs=1)), - (503, { - 'Retry-After': '100' - }, { - 'error_code': 'TEMPORARILY_UNAVAILABLE', - 'message': 'errorMessage', - }, - errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', - retry_after_secs=100)), - (404, {}, { - 'scimType': 'scim type', - 'detail': 'detail', - 'status': 'status', - }, errors.NotFound('scim type detail', error_code='SCIM_status')), + errors.PermissionDenied('errorMessage. Config: host=http://localhost, auth_type=noop', error_code='FORBIDDEN')), ]) def test_error(config, requests_mock, status_code, headers, body, expected_error): client = ApiClient(config) requests_mock.get("/test", json=body, status_code=status_code, headers=headers) with pytest.raises(DatabricksError) as raised: - client._perform("GET", "http://localhost/test", headers={"test": "test"}) + client.do("GET", "/test", headers={"test": "test"}) actual = raised.value assert isinstance(actual, type(expected_error)) assert str(actual) == str(expected_error) @@ -448,158 +331,6 @@ def test_error(config, requests_mock, status_code, headers, body, expected_error assert expected.metadata == actual.metadata -@contextlib.contextmanager -def http_fixture_server(handler: typing.Callable[[BaseHTTPRequestHandler], None]): - from http.server import HTTPServer - from threading import Thread - - class _handler(BaseHTTPRequestHandler): - - def __init__(self, handler: typing.Callable[[BaseHTTPRequestHandler], None], *args): - self._handler = handler - super().__init__(*args) - - def __getattr__(self, item): - if 'do_' != item[0:3]: - raise AttributeError(f'method {item} not found') - return functools.partial(self._handler, self) - - handler_factory = functools.partial(_handler, handler) - srv = HTTPServer(('localhost', 0), handler_factory) - t = Thread(target=srv.serve_forever) - try: - t.daemon = True - t.start() - yield 'http://{0}:{1}'.format(*srv.server_address) - finally: - srv.shutdown() - - -@pytest.mark.parametrize('status_code,include_retry_after', - ((429, False), (429, True), (503, False), (503, True))) -def test_http_retry_after(status_code, include_retry_after): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(status_code) - if include_retry_after: - h.send_header('Retry-After', '1') - h.send_header('Content-Type', 'application/json') - h.end_headers() - else: - h.send_response(200) - h.send_header('Content-Type', 'application/json') - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - -def test_http_retry_after_wrong_format(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(429) - h.send_header('Retry-After', '1.58') - h.end_headers() - else: - h.send_response(200) - h.send_header('Content-Type', 'application/json') - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - -def test_http_retried_exceed_limit(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - h.send_response(429) - h.send_header('Retry-After', '1') - h.end_headers() - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', retry_timeout_seconds=1, clock=FakeClock())) - with pytest.raises(TimeoutError): - api_client.do('GET', '/foo') - - assert len(requests) == 1 - - -def test_http_retried_on_match(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(400) - h.end_headers() - h.wfile.write(b'{"error_code": "abc", "message": "... ClusterNotReadyException ..."}') - else: - h.send_response(200) - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - -def test_http_not_retried_on_normal_errors(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(400) - h.end_headers() - h.wfile.write(b'{"error_code": "abc", "message": "something not found"}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - with pytest.raises(DatabricksError): - api_client.do('GET', '/foo') - - assert len(requests) == 1 - - -def test_http_retried_on_connection_error(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) > 0: - h.send_response(200) - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - def test_github_oidc_flow_works_with_azure(monkeypatch): def inner(h: BaseHTTPRequestHandler): From 384834a77e30213663eb800a47eb94125bf4548f Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 7 Oct 2024 13:48:35 +0200 Subject: [PATCH 02/17] rebase --- databricks/sdk/_base_client.py | 19 +++ databricks/sdk/config.py | 49 ++---- databricks/sdk/credentials_provider.py | 15 +- databricks/sdk/oauth.py | 213 ++++++++++++++++++------- examples/flask_app_with_oauth.py | 46 +++--- tests/test_oauth.py | 131 +++++++++++---- 6 files changed, 318 insertions(+), 155 deletions(-) diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 47d6340e..7734824b 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -18,6 +18,25 @@ logger = logging.getLogger('databricks.sdk') +def fix_host_if_needed(host: Optional[str]) -> Optional[str]: + if not host: + return host + + # Add a default scheme if it's missing + if '://' not in host: + host = 'https://' + host + + o = urllib.parse.urlparse(host) + # remove trailing slash + path = o.path.rstrip('/') + # remove port if 443 + netloc = o.netloc + if o.port == 443: + netloc = netloc.split(':')[0] + + return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + + class _BaseClient: def __init__(self, diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 5cae1b2b..65bf3225 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -14,7 +14,10 @@ from .credentials_provider import CredentialsStrategy, DefaultCredentials from .environments import (ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname) -from .oauth import OidcEndpoints, Token +from .oauth import (OidcEndpoints, Token, get_account_endpoints, + get_azure_entra_id_workspace_endpoints, + get_workspace_endpoints) +from ._base_client import fix_host_if_needed logger = logging.getLogger('databricks.sdk') @@ -118,7 +121,9 @@ def __init__(self, self._set_inner_config(kwargs) self._load_from_env() self._known_file_config_loader() - self._fix_host_if_needed() + updated_host = fix_host_if_needed(self.host) + if updated_host: + self.host = updated_host self._validate() self.init_auth() self._init_product(product, product_version) @@ -250,28 +255,14 @@ def with_user_agent_extra(self, key: str, value: str) -> 'Config': @property def oidc_endpoints(self) -> Optional[OidcEndpoints]: - self._fix_host_if_needed() + self.host = fix_host_if_needed(self.host) if not self.host: return None if self.is_azure and self.azure_client_id: - # Retrieve authorize endpoint to retrieve token endpoint after - res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) - real_auth_url = res.headers.get('location') - if not real_auth_url: - return None - return OidcEndpoints(authorization_endpoint=real_auth_url, - token_endpoint=real_auth_url.replace('/authorize', '/token')) + return get_azure_entra_id_workspace_endpoints(self.host) if self.is_account_client and self.account_id: - prefix = f'{self.host}/oidc/accounts/{self.account_id}' - return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize', - token_endpoint=f'{prefix}/v1/token') - oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server' - res = requests.get(oidc) - if res.status_code != 200: - return None - auth_metadata = res.json() - return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'), - token_endpoint=auth_metadata.get('token_endpoint')) + return get_account_endpoints(self.host, self.account_id) + return get_workspace_endpoints(self.host) def debug_string(self) -> str: """ Returns log-friendly representation of configured attributes """ @@ -345,24 +336,6 @@ def attributes(cls) -> Iterable[ConfigAttribute]: cls._attributes = attrs return cls._attributes - def _fix_host_if_needed(self): - if not self.host: - return - - # Add a default scheme if it's missing - if '://' not in self.host: - self.host = 'https://' + self.host - - o = urllib.parse.urlparse(self.host) - # remove trailing slash - path = o.path.rstrip('/') - # remove port if 443 - netloc = o.netloc - if o.port == 443: - netloc = netloc.split(':')[0] - - self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) - def load_azure_tenant_id(self): """[Internal] Load the Azure tenant ID from the Azure Databricks login page. diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 8c1655af..ef4f48ad 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -197,19 +197,24 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: client_id = '6128a518-99a9-425b-8333-4cc94f04cacd' else: raise ValueError(f'local browser SSO is not supported') - oauth_client = OAuthClient(host=cfg.host, - client_id=client_id, - redirect_url='http://localhost:8020', - client_secret=cfg.client_secret) # Load cached credentials from disk if they exist. # Note that these are local to the Python SDK and not reused by other SDKs. - token_cache = TokenCache(oauth_client) + oidc_endpoints = cfg.oidc_endpoints + token_cache = TokenCache(host=cfg.host, + oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=cfg.client_secret, + redirect_url='http://localhost:8020') credentials = token_cache.load() if credentials: # Force a refresh in case the loaded credentials are expired. credentials.token() else: + oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, + client_id=client_id, + redirect_url='http://localhost:8020', + client_secret=cfg.client_secret) consent = oauth_client.initiate_consent() if not consent: return None diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index e9a3afb9..c4279dea 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -17,6 +17,8 @@ import requests import requests.auth +from ._base_client import _BaseClient, fix_host_if_needed + # Error code for PKCE flow in Azure Active Directory, that gets additional retry. # See https://stackoverflow.com/a/75466778/277035 for more info NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327' @@ -46,8 +48,24 @@ def __call__(self, r): @dataclass class OidcEndpoints: + """ + The endpoints used for OAuth-based authentication in Databricks. + """ + authorization_endpoint: str # ../v1/authorize + """The authorization endpoint for the OAuth flow. The user-agent should be directed to this endpoint in order for + the user to login and authorize the client for user-to-machine (U2M) flows.""" + token_endpoint: str # ../v1/token + """The token endpoint for the OAuth flow.""" + + @staticmethod + def from_dict(d: dict) -> 'OidcEndpoints': + return OidcEndpoints(authorization_endpoint=d.get('authorization_endpoint'), + token_endpoint=d.get('token_endpoint')) + + def as_dict(self) -> dict: + return {'authorization_endpoint': self.authorization_endpoint, 'token_endpoint': self.token_endpoint} @dataclass @@ -220,18 +238,76 @@ def do_GET(self): self.wfile.write(b'You can close this tab.') +def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a given account. + :param host: The Databricks account host. + :param account_id: The account ID. + :return: The account's OIDC endpoints. + """ + host = fix_host_if_needed(host) + oidc = f'{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server' + resp = client.do('GET', oidc) + return OidcEndpoints.from_dict(resp) + + +def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a given workspace. + :param host: The Databricks workspace host. + :return: The workspace's OIDC endpoints. + """ + host = fix_host_if_needed(host) + oidc = f'{host}/.well-known/oauth-authorization-server' + resp = client.do('GET', oidc) + return OidcEndpoints.from_dict(resp) + + +def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]: + """ + Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks + using an application registered in Azure Entra ID. + :param host: The Databricks workspace host. + :return: The OIDC endpoints for the workspace's Azure Entra ID tenant. + """ + # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint + host = fix_host_if_needed(host) + res = requests.get(f'{host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) + real_auth_url = res.headers.get('location') + if not real_auth_url: + return None + return OidcEndpoints(authorization_endpoint=real_auth_url, + token_endpoint=real_auth_url.replace('/authorize', '/token')) + + class SessionCredentials(Refreshable): - def __init__(self, client: 'OAuthClient', token: Token): - self._client = client + def __init__(self, + token: Token, + oidc_endpoints: OidcEndpoints, + client_id: str, + client_secret: str = None, + redirect_url: str = None): + self._oidc_endpoints = oidc_endpoints + self._client_id = client_id + self._client_secret = client_secret + self._redirect_url = redirect_url super().__init__(token) def as_dict(self) -> dict: return {'token': self._token.as_dict()} @staticmethod - def from_dict(client: 'OAuthClient', raw: dict) -> 'SessionCredentials': - return SessionCredentials(client=client, token=Token.from_dict(raw['token'])) + def from_dict(raw: dict, + oidc_endpoints: OidcEndpoints, + client_id: str, + client_secret: str = None, + redirect_url: str = None) -> 'SessionCredentials': + return SessionCredentials(token=Token.from_dict(raw['token']), + oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) def auth_type(self): """Implementing CredentialsProvider protocol""" @@ -252,13 +328,13 @@ def refresh(self) -> Token: raise ValueError('oauth2: token expired and refresh token is not set') params = {'grant_type': 'refresh_token', 'refresh_token': refresh_token} headers = {} - if 'microsoft' in self._client.token_url: + if 'microsoft' in self._oidc_endpoints.token_endpoint: # Tokens issued for the 'Single-Page Application' client-type may # only be redeemed via cross-origin requests - headers = {'Origin': self._client.redirect_url} - return retrieve_token(client_id=self._client.client_id, - client_secret=self._client.client_secret, - token_url=self._client.token_url, + headers = {'Origin': self._redirect_url} + return retrieve_token(client_id=self._client_id, + client_secret=self._client_secret, + token_url=self._oidc_endpoints.token_endpoint, params=params, use_params=True, headers=headers) @@ -266,27 +342,45 @@ def refresh(self) -> Token: class Consent: - def __init__(self, client: 'OAuthClient', state: str, verifier: str, auth_url: str = None) -> None: - self.auth_url = auth_url - + def __init__(self, + state: str, + verifier: str, + oidc_endpoints: OidcEndpoints, + redirect_url: str, + client_id: str, + client_secret: str = None) -> None: self._verifier = verifier self._state = state - self._client = client + self._oidc_endpoints = oidc_endpoints + self._redirect_url = redirect_url + self._client_id = client_id + self._client_secret = client_secret def as_dict(self) -> dict: - return {'state': self._state, 'verifier': self._verifier} + return { + 'state': self._state, + 'verifier': self._verifier, + 'redirect_url': self._redirect_url, + 'oidc_endpoints': self._oidc_endpoints.as_dict(), + 'client_id': self._client_id, + } @staticmethod - def from_dict(client: 'OAuthClient', raw: dict) -> 'Consent': - return Consent(client, raw['state'], raw['verifier']) + def from_dict(raw: dict, client_secret: str = None) -> 'Consent': + return Consent(raw['state'], + raw['verifier'], + oidc_endpoints=OidcEndpoints.from_dict(raw['oidc_endpoints']), + redirect_url=raw['redirect_url'], + client_id=raw['client_id'], + client_secret=client_secret) def launch_external_browser(self) -> SessionCredentials: - redirect_url = urllib.parse.urlparse(self._client.redirect_url) + redirect_url = urllib.parse.urlparse(self._redirect_url) if redirect_url.hostname not in ('localhost', '127.0.0.1'): raise ValueError(f'cannot listen on {redirect_url.hostname}') feedback = [] - logger.info(f'Opening {self.auth_url} in a browser') - webbrowser.open_new(self.auth_url) + logger.info(f'Opening {self._oidc_endpoints.authorization_endpoint} in a browser') + webbrowser.open_new(self._oidc_endpoints.authorization_endpoint) port = redirect_url.port handler_factory = functools.partial(_OAuthCallback, feedback) with HTTPServer(("localhost", port), handler_factory) as httpd: @@ -308,7 +402,7 @@ def exchange(self, code: str, state: str) -> SessionCredentials: if self._state != state: raise ValueError('state mismatch') params = { - 'redirect_uri': self._client.redirect_url, + 'redirect_uri': self._redirect_url, 'grant_type': 'authorization_code', 'code_verifier': self._verifier, 'code': code @@ -316,19 +410,20 @@ def exchange(self, code: str, state: str) -> SessionCredentials: headers = {} while True: try: - token = retrieve_token(client_id=self._client.client_id, - client_secret=self._client.client_secret, - token_url=self._client.token_url, + token = retrieve_token(client_id=self._client_id, + client_secret=self._client_secret, + token_url=self._oidc_endpoints.token_endpoint, params=params, headers=headers, use_params=True) - return SessionCredentials(self._client, token) + return SessionCredentials(token, self._oidc_endpoints, self._client_id, self._client_secret, + self._redirect_url) except ValueError as e: if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e): # Retry in cases of 'Single-Page Application' client-type with # 'Origin' header equal to client's redirect URL. - headers['Origin'] = self._client.redirect_url - msg = f'Retrying OAuth token exchange with {self._client.redirect_url} origin' + headers['Origin'] = self._redirect_url + msg = f'Retrying OAuth token exchange with {self._redirect_url} origin' logger.debug(msg) continue raise e @@ -354,37 +449,19 @@ class OAuthClient: """ def __init__(self, - host: str, - client_id: str, + oidc_endpoints: OidcEndpoints, redirect_url: str, - *, + client_id: str, scopes: List[str] = None, client_secret: str = None): - # TODO: is it a circular dependency?.. - from .core import Config - from .credentials_provider import credentials_strategy - @credentials_strategy('noop', []) - def noop_credentials(_: any): - return lambda: {} - - config = Config(host=host, credentials_strategy=noop_credentials) if not scopes: scopes = ['all-apis'] - oidc = config.oidc_endpoints - if not oidc: - raise ValueError(f'{host} does not support OAuth') - self.host = host self.redirect_url = redirect_url - self.client_id = client_id - self.client_secret = client_secret - self.token_url = oidc.token_endpoint - self.is_aws = config.is_aws - self.is_azure = config.is_azure - self.is_gcp = config.is_gcp - - self._auth_url = oidc.authorization_endpoint + self._client_id = client_id + self._client_secret = client_secret + self._oidc_endpoints = oidc_endpoints self._scopes = scopes def initiate_consent(self) -> Consent: @@ -397,18 +474,23 @@ def initiate_consent(self) -> Consent: params = { 'response_type': 'code', - 'client_id': self.client_id, + 'client_id': self._client_id, 'redirect_uri': self.redirect_url, 'scope': ' '.join(self._scopes), 'state': state, 'code_challenge': challenge, 'code_challenge_method': 'S256' } - url = f'{self._auth_url}?{urllib.parse.urlencode(params)}' - return Consent(self, state, verifier, auth_url=url) + f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}' + return Consent(state, + verifier, + oidc_endpoints=self._oidc_endpoints, + redirect_url=self.redirect_url, + client_id=self._client_id, + client_secret=self._client_secret) def __repr__(self) -> str: - return f'' + return f'' @dataclass @@ -448,17 +530,28 @@ def refresh(self) -> Token: use_header=self.use_header) -class TokenCache(): +class TokenCache: BASE_PATH = "~/.config/databricks-sdk-py/oauth" - def __init__(self, client: OAuthClient) -> None: - self.client = client + def __init__(self, + host: str, + oidc_endpoints: OidcEndpoints, + client_id: str, + redirect_url: str = None, + client_secret: str = None, + scopes: list[str] = None) -> None: + self._host = host + self._client_id = client_id + self._oidc_endpoints = oidc_endpoints + self._redirect_url = redirect_url + self._client_secret = client_secret + self._scopes = scopes or [] @property def filename(self) -> str: # Include host, client_id, and scopes in the cache filename to make it unique. hash = hashlib.sha256() - for chunk in [self.client.host, self.client.client_id, ",".join(self.client._scopes), ]: + for chunk in [self._host, self._client_id, ",".join(self._scopes), ]: hash.update(chunk.encode('utf-8')) return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json")) @@ -472,7 +565,11 @@ def load(self) -> Optional[SessionCredentials]: try: with open(self.filename, 'r') as f: raw = json.load(f) - return SessionCredentials.from_dict(self.client, raw) + return SessionCredentials.from_dict(raw, + oidc_endpoints=self._oidc_endpoints, + client_id=self._client_id, + client_secret=self._client_secret, + redirect_url=self._redirect_url) except Exception: return None diff --git a/examples/flask_app_with_oauth.py b/examples/flask_app_with_oauth.py index 4128de5c..bd4d7d7e 100755 --- a/examples/flask_app_with_oauth.py +++ b/examples/flask_app_with_oauth.py @@ -31,7 +31,7 @@ import logging import sys -from databricks.sdk.oauth import OAuthClient +from databricks.sdk.oauth import OAuthClient, OidcEndpoints, get_workspace_endpoints APP_NAME = "flask-demo" all_clusters_template = """
    @@ -44,7 +44,7 @@
""" -def create_flask_app(oauth_client: OAuthClient): +def create_flask_app(host: str, oidc_endpoints: OidcEndpoints, client_id: str, client_secret: str, redirect_url: str): """The create_flask_app function creates a Flask app that is enabled with OAuth. It initializes the app and web session secret keys with a randomly generated token. It defines two routes for @@ -64,7 +64,7 @@ def callback(): the callback parameters, and redirects the user to the index page.""" from databricks.sdk.oauth import Consent - consent = Consent.from_dict(oauth_client, session["consent"]) + consent = Consent.from_dict(session["consent"], client_secret=client_secret) session["creds"] = consent.exchange_callback_parameters(request.args).as_dict() return redirect(url_for("index")) @@ -73,17 +73,25 @@ def index(): """The index page checks if the user has already authenticated and retrieves the user's credentials using the Databricks SDK WorkspaceClient. It then renders the template with the clusters' list.""" if "creds" not in session: + oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) consent = oauth_client.initiate_consent() session["consent"] = consent.as_dict() - return redirect(consent.auth_url) + return redirect(oidc_endpoints.authorization_endpoint) from databricks.sdk import WorkspaceClient from databricks.sdk.oauth import SessionCredentials - credentials_provider = SessionCredentials.from_dict(oauth_client, session["creds"]) - workspace_client = WorkspaceClient(host=oauth_client.host, + credentials_strategy = SessionCredentials.from_dict(session["creds"], + oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) + workspace_client = WorkspaceClient(host=host, product=APP_NAME, - credentials_provider=credentials_provider, + credentials_strategy=credentials_strategy, ) return render_template_string(all_clusters_template, w=workspace_client) @@ -110,22 +118,6 @@ def register_custom_app(args: argparse.Namespace) -> tuple[str, str]: return custom_app.client_id, custom_app.client_secret -def init_oauth_config(args) -> OAuthClient: - """Creates Databricks SDK configuration for OAuth""" - oauth_client = OAuthClient(host=args.host, - client_id=args.client_id, - client_secret=args.client_secret, - redirect_url=f"http://localhost:{args.port}/callback", - scopes=["all-apis"], - ) - if not oauth_client.client_id: - client_id, client_secret = register_custom_app(args) - oauth_client.client_id = client_id - oauth_client.client_secret = client_secret - - return oauth_client - - def parse_arguments() -> argparse.Namespace: """Parses arguments for this demo""" parser = argparse.ArgumentParser(prog=APP_NAME, description=__doc__.strip()) @@ -145,8 +137,12 @@ def parse_arguments() -> argparse.Namespace: logging.getLogger("databricks.sdk").setLevel(logging.DEBUG) args = parse_arguments() - oauth_cfg = init_oauth_config(args) - app = create_flask_app(oauth_cfg) + oidc_endpoints = get_workspace_endpoints(args.host) + client_id, client_secret = args.client_id, args.client_secret + if not client_id: + client_id, client_secret = register_custom_app(args) + redirect_url=f"http://localhost:{args.port}/callback" + app = create_flask_app(args.host, oidc_endpoints, client_id, client_secret, redirect_url) app.run( host="localhost", diff --git a/tests/test_oauth.py b/tests/test_oauth.py index ce2d514f..2f4ba223 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -1,29 +1,102 @@ -from databricks.sdk.core import Config -from databricks.sdk.oauth import OAuthClient, OidcEndpoints, TokenCache - - -def test_token_cache_unique_filename_by_host(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(client_id="abc", redirect_url="http://localhost:8020") - c1 = OAuthClient(host="http://localhost:", **common_args) - c2 = OAuthClient(host="https://bar.cloud.databricks.com", **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename - - -def test_token_cache_unique_filename_by_client_id(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(host="http://localhost:", redirect_url="http://localhost:8020") - c1 = OAuthClient(client_id="abc", **common_args) - c2 = OAuthClient(client_id="def", **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename - - -def test_token_cache_unique_filename_by_scopes(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(host="http://localhost:", client_id="abc", redirect_url="http://localhost:8020") - c1 = OAuthClient(scopes=["foo"], **common_args) - c2 = OAuthClient(scopes=["bar"], **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename +from databricks.sdk.oauth import OidcEndpoints, TokenCache, get_workspace_endpoints, get_azure_entra_id_workspace_endpoints, get_account_endpoints +from databricks.sdk._base_client import _BaseClient +from .clock import FakeClock + + + +def test_token_cache_unique_filename_by_host(): + common_args = dict( + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(host="http://localhost:", **common_args).filename != TokenCache("https://bar.cloud.databricks.com", **common_args).filename + + +def test_token_cache_unique_filename_by_client_id(): + common_args = dict( + host="http://localhost:", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(client_id="abc", **common_args).filename != TokenCache(client_id="def", **common_args).filename + + +def test_token_cache_unique_filename_by_scopes(): + common_args = dict( + host="http://localhost:", + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(scopes=["foo"], **common_args).filename != TokenCache(scopes=["bar"], **common_args).filename + + +def test_account_oidc_endpoints(requests_mock): + requests_mock.get("https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + json={"authorization_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token"}) + client = _BaseClient(clock=FakeClock()) + endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) + assert endpoints == OidcEndpoints( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token") + +def test_account_oidc_endpoints_retry_on_429(requests_mock): + request_count = 0 + + def nth_request(n): + def observe_request(_request): + nonlocal request_count + is_match = request_count == n + if is_match: + request_count += 1 + return is_match + return observe_request + + requests_mock.get("https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(0), + status_code=429) + requests_mock.get("https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(1), + json={"authorization_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token"}) + client = _BaseClient(clock=FakeClock()) + endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) + assert endpoints == OidcEndpoints( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token") + + +def test_workspace_oidc_endpoints(requests_mock): + requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", + json={"authorization_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token"}) + client = _BaseClient(clock=FakeClock()) + endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) + assert endpoints == OidcEndpoints( + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token") + + +def test_workspace_oidc_endpoints_retry_on_429(requests_mock): + request_count = 0 + + def nth_request(n): + def observe_request(_request): + nonlocal request_count + is_match = request_count == n + if is_match: + request_count += 1 + return is_match + return observe_request + + requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", + additional_matcher=nth_request(0), + status_code=429) + requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", + additional_matcher=nth_request(1), + json={"authorization_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token"}) + client = _BaseClient(clock=FakeClock()) + endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) + assert endpoints == OidcEndpoints( + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token") From 33c7df707f2b21fcc1aa9eba6d6b2eff59caddde Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 7 Oct 2024 13:50:33 +0200 Subject: [PATCH 03/17] fmt --- databricks/sdk/_base_client.py | 1 - databricks/sdk/core.py | 1 - tests/fixture_server.py | 2 -- tests/test_base_client.py | 28 ++++++++++++---------------- tests/test_core.py | 25 ++++++++++++------------- 5 files changed, 24 insertions(+), 33 deletions(-) diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 47d6340e..62c2974e 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -3,7 +3,6 @@ from types import TracebackType from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, Optional, Type, Union) -import urllib.parse import requests import requests.adapters diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index c9e49dc8..eab22cd7 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -18,7 +18,6 @@ OIDC_TOKEN_PATH = "/oidc/v1/token" - class ApiClient: def __init__(self, cfg: Config): diff --git a/tests/fixture_server.py b/tests/fixture_server.py index 04190414..e15f9cf2 100644 --- a/tests/fixture_server.py +++ b/tests/fixture_server.py @@ -29,5 +29,3 @@ def __getattr__(self, item): yield 'http://{0}:{1}'.format(*srv.server_address) finally: srv.shutdown() - - diff --git a/tests/test_base_client.py b/tests/test_base_client.py index 4cba10db..e9e7324a 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -1,11 +1,11 @@ from http.server import BaseHTTPRequestHandler -from typing import List, Iterator +from typing import Iterator, List import pytest import requests -from databricks.sdk._base_client import _BaseClient, _StreamingResponse from databricks.sdk import errors, useragent +from databricks.sdk._base_client import _BaseClient, _StreamingResponse from databricks.sdk.core import DatabricksError from .clock import FakeClock @@ -59,7 +59,7 @@ def test_streaming_response_read_closes(config): @pytest.mark.parametrize('status_code,headers,body,expected_error', [ (400, {}, { "message": - "errorMessage", + "errorMessage", "details": [{ "type": DatabricksError._error_info_type, "reason": "error reason", @@ -88,13 +88,11 @@ def test_streaming_response_read_closes(config): (401, {}, { 'error_code': 'UNAUTHORIZED', 'message': 'errorMessage', - }, - errors.Unauthenticated('errorMessage', error_code='UNAUTHORIZED')), + }, errors.Unauthenticated('errorMessage', error_code='UNAUTHORIZED')), (403, {}, { 'error_code': 'FORBIDDEN', 'message': 'errorMessage', - }, - errors.PermissionDenied('errorMessage', error_code='FORBIDDEN')), + }, errors.PermissionDenied('errorMessage', error_code='FORBIDDEN')), (429, {}, { 'error_code': 'TOO_MANY_REQUESTS', 'message': 'errorMessage', @@ -102,9 +100,9 @@ def test_streaming_response_read_closes(config): (429, { 'Retry-After': '100' }, { - 'error_code': 'TOO_MANY_REQUESTS', - 'message': 'errorMessage', - }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)), + 'error_code': 'TOO_MANY_REQUESTS', + 'message': 'errorMessage', + }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)), (503, {}, { 'error_code': 'TEMPORARILY_UNAVAILABLE', 'message': 'errorMessage', @@ -113,9 +111,9 @@ def test_streaming_response_read_closes(config): (503, { 'Retry-After': '100' }, { - 'error_code': 'TEMPORARILY_UNAVAILABLE', - 'message': 'errorMessage', - }, + 'error_code': 'TEMPORARILY_UNAVAILABLE', + 'message': 'errorMessage', + }, errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', retry_after_secs=100)), (404, {}, { @@ -217,7 +215,7 @@ def inner(h: BaseHTTPRequestHandler): with http_fixture_server(inner) as host: api_client = _BaseClient(retry_timeout_seconds=1, clock=FakeClock()) with pytest.raises(TimeoutError): - res = api_client.do('GET', f'{host}/foo') + api_client.do('GET', f'{host}/foo') assert len(requests) == 1 @@ -278,5 +276,3 @@ def inner(h: BaseHTTPRequestHandler): assert 'foo' in res assert len(requests) == 2 - - diff --git a/tests/test_core.py b/tests/test_core.py index b61cfa01..16a4c2ad 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,11 +1,8 @@ -import contextlib -import functools import os import pathlib import platform import random import string -import typing from datetime import datetime from http.server import BaseHTTPRequestHandler @@ -25,8 +22,8 @@ from databricks.sdk.service.iam import AccessControlRequest from databricks.sdk.version import __version__ -from .fixture_server import http_fixture_server from .conftest import noop_credentials +from .fixture_server import http_fixture_server def test_parse_dsn(): @@ -300,18 +297,20 @@ def test_deletes(config, requests_mock): assert res is None -@pytest.mark.parametrize('status_code,headers,body,expected_error', [ - (401, {}, { +@pytest.mark.parametrize( + 'status_code,headers,body,expected_error', + [(401, {}, { 'error_code': 'UNAUTHORIZED', 'message': 'errorMessage', }, - errors.Unauthenticated('errorMessage. Config: host=http://localhost, auth_type=noop', error_code='UNAUTHORIZED')), - (403, {}, { - 'error_code': 'FORBIDDEN', - 'message': 'errorMessage', - }, - errors.PermissionDenied('errorMessage. Config: host=http://localhost, auth_type=noop', error_code='FORBIDDEN')), -]) + errors.Unauthenticated('errorMessage. Config: host=http://localhost, auth_type=noop', + error_code='UNAUTHORIZED')), + (403, {}, { + 'error_code': 'FORBIDDEN', + 'message': 'errorMessage', + }, + errors.PermissionDenied('errorMessage. Config: host=http://localhost, auth_type=noop', + error_code='FORBIDDEN')), ]) def test_error(config, requests_mock, status_code, headers, body, expected_error): client = ApiClient(config) requests_mock.get("/test", json=body, status_code=status_code, headers=headers) From 9b778fa721444e301e3eaa0efb607bc42f1de351 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 7 Oct 2024 13:58:00 +0200 Subject: [PATCH 04/17] fmt and fix impor --- databricks/sdk/_base_client.py | 1 + databricks/sdk/config.py | 2 +- tests/test_oauth.py | 97 +++++++++++++++++++++------------- 3 files changed, 61 insertions(+), 39 deletions(-) diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 8ba1d6b6..631472f6 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -6,6 +6,7 @@ import requests import requests.adapters +import urllib.parse from . import useragent from .casing import Casing diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 65bf3225..24f3c711 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -10,6 +10,7 @@ import requests from . import useragent +from ._base_client import fix_host_if_needed from .clock import Clock, RealClock from .credentials_provider import CredentialsStrategy, DefaultCredentials from .environments import (ALL_ENVS, AzureEnvironment, Cloud, @@ -17,7 +18,6 @@ from .oauth import (OidcEndpoints, Token, get_account_endpoints, get_azure_entra_id_workspace_endpoints, get_workspace_endpoints) -from ._base_client import fix_host_if_needed logger = logging.getLogger('databricks.sdk') diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 2f4ba223..dc5d36e1 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -1,63 +1,78 @@ -from databricks.sdk.oauth import OidcEndpoints, TokenCache, get_workspace_endpoints, get_azure_entra_id_workspace_endpoints, get_account_endpoints from databricks.sdk._base_client import _BaseClient -from .clock import FakeClock +from databricks.sdk.oauth import (OidcEndpoints, TokenCache, + get_account_endpoints, + get_workspace_endpoints) +from .clock import FakeClock def test_token_cache_unique_filename_by_host(): - common_args = dict( - client_id="abc", - redirect_url="http://localhost:8020", - oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - assert TokenCache(host="http://localhost:", **common_args).filename != TokenCache("https://bar.cloud.databricks.com", **common_args).filename + common_args = dict(client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(host="http://localhost:", + **common_args).filename != TokenCache("https://bar.cloud.databricks.com", + **common_args).filename def test_token_cache_unique_filename_by_client_id(): - common_args = dict( - host="http://localhost:", - redirect_url="http://localhost:8020", - oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - assert TokenCache(client_id="abc", **common_args).filename != TokenCache(client_id="def", **common_args).filename + common_args = dict(host="http://localhost:", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(client_id="abc", **common_args).filename != TokenCache(client_id="def", + **common_args).filename def test_token_cache_unique_filename_by_scopes(): - common_args = dict( - host="http://localhost:", - client_id="abc", - redirect_url="http://localhost:8020", - oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - assert TokenCache(scopes=["foo"], **common_args).filename != TokenCache(scopes=["bar"], **common_args).filename + common_args = dict(host="http://localhost:", + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(scopes=["foo"], **common_args).filename != TokenCache(scopes=["bar"], + **common_args).filename def test_account_oidc_endpoints(requests_mock): - requests_mock.get("https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", - json={"authorization_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", - "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token"}) + requests_mock.get( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + json={ + "authorization_endpoint": + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token" + }) client = _BaseClient(clock=FakeClock()) endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) assert endpoints == OidcEndpoints( "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token") + def test_account_oidc_endpoints_retry_on_429(requests_mock): request_count = 0 def nth_request(n): + def observe_request(_request): nonlocal request_count is_match = request_count == n if is_match: request_count += 1 return is_match + return observe_request - requests_mock.get("https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", - additional_matcher=nth_request(0), - status_code=429) - requests_mock.get("https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", - additional_matcher=nth_request(1), - json={"authorization_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", - "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token"}) + requests_mock.get( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(0), + status_code=429) + requests_mock.get( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(1), + json={ + "authorization_endpoint": + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token" + }) client = _BaseClient(clock=FakeClock()) endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) assert endpoints == OidcEndpoints( @@ -67,25 +82,29 @@ def observe_request(_request): def test_workspace_oidc_endpoints(requests_mock): requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", - json={"authorization_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", - "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token"}) + json={ + "authorization_endpoint": + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token" + }) client = _BaseClient(clock=FakeClock()) endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) - assert endpoints == OidcEndpoints( - "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", - "https://my-workspace.cloud.databricks.com/oidc/oauth/token") + assert endpoints == OidcEndpoints("https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token") def test_workspace_oidc_endpoints_retry_on_429(requests_mock): request_count = 0 def nth_request(n): + def observe_request(_request): nonlocal request_count is_match = request_count == n if is_match: request_count += 1 return is_match + return observe_request requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", @@ -93,10 +112,12 @@ def observe_request(_request): status_code=429) requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", additional_matcher=nth_request(1), - json={"authorization_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", - "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token"}) + json={ + "authorization_endpoint": + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token" + }) client = _BaseClient(clock=FakeClock()) endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) - assert endpoints == OidcEndpoints( - "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", - "https://my-workspace.cloud.databricks.com/oidc/oauth/token") + assert endpoints == OidcEndpoints("https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token") From 4970d7078dc1dded89c1dbe348da8d367d4a5581 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 7 Oct 2024 14:00:05 +0200 Subject: [PATCH 05/17] fix python 3.7 --- databricks/sdk/oauth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index c4279dea..f6413068 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -539,7 +539,7 @@ def __init__(self, client_id: str, redirect_url: str = None, client_secret: str = None, - scopes: list[str] = None) -> None: + scopes: List[str] = None) -> None: self._host = host self._client_id = client_id self._oidc_endpoints = oidc_endpoints From 89f4de902e50fb5cdf3cd7116f8df2d55869e026 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 7 Oct 2024 14:04:45 +0200 Subject: [PATCH 06/17] fmt --- databricks/sdk/_base_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 631472f6..16b64309 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -1,4 +1,5 @@ import logging +import urllib.parse from datetime import timedelta from types import TracebackType from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, @@ -6,7 +7,6 @@ import requests import requests.adapters -import urllib.parse from . import useragent from .casing import Casing From 62c4c81e9adcf90869d3617b4904fe2b1212671f Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Oct 2024 10:39:47 +0200 Subject: [PATCH 07/17] some tests and an example --- databricks/sdk/credentials_provider.py | 19 ++++++------- databricks/sdk/oauth.py | 39 +++++++++++++++----------- examples/external_browser_auth.py | 17 +++++++++++ 3 files changed, 48 insertions(+), 27 deletions(-) create mode 100644 examples/external_browser_auth.py diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 410fd130..841bc77f 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -187,17 +187,16 @@ def token() -> Token: def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: if cfg.auth_type != 'external-browser': return None + client_id, client_secret = None, None if cfg.client_id: client_id = cfg.client_id - elif cfg.is_aws: + client_secret = cfg.client_secret + elif cfg.azure_client_id: + client_id = cfg.azure_client + client_secret = cfg.azure_client_secret + + if not client_id: client_id = 'databricks-cli' - elif cfg.is_azure: - # Use Azure AD app for cases when Azure CLI is not available on the machine. - # App has to be registered as Single-page multi-tenant to support PKCE - # TODO: temporary app ID, change it later. - client_id = '6128a518-99a9-425b-8333-4cc94f04cacd' - else: - raise ValueError(f'local browser SSO is not supported') # Load cached credentials from disk if they exist. # Note that these are local to the Python SDK and not reused by other SDKs. @@ -205,7 +204,7 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: token_cache = TokenCache(host=cfg.host, oidc_endpoints=oidc_endpoints, client_id=client_id, - client_secret=cfg.client_secret, + client_secret=client_secret, redirect_url='http://localhost:8020') credentials = token_cache.load() if credentials: @@ -215,7 +214,7 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, client_id=client_id, redirect_url='http://localhost:8020', - client_secret=cfg.client_secret) + client_secret=client_secret) consent = oauth_client.initiate_consent() if not consent: return None diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index f6413068..3a1a1fd2 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -258,7 +258,7 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O :return: The workspace's OIDC endpoints. """ host = fix_host_if_needed(host) - oidc = f'{host}/.well-known/oauth-authorization-server' + oidc = f'{host}/oidc/.well-known/oauth-authorization-server' resp = client.do('GET', oidc) return OidcEndpoints.from_dict(resp) @@ -284,11 +284,11 @@ class SessionCredentials(Refreshable): def __init__(self, token: Token, - oidc_endpoints: OidcEndpoints, + token_endpoint: str, client_id: str, client_secret: str = None, redirect_url: str = None): - self._oidc_endpoints = oidc_endpoints + self._token_endpoint = token_endpoint self._client_id = client_id self._client_secret = client_secret self._redirect_url = redirect_url @@ -299,12 +299,12 @@ def as_dict(self) -> dict: @staticmethod def from_dict(raw: dict, - oidc_endpoints: OidcEndpoints, + token_endpoint: str, client_id: str, client_secret: str = None, redirect_url: str = None) -> 'SessionCredentials': return SessionCredentials(token=Token.from_dict(raw['token']), - oidc_endpoints=oidc_endpoints, + token_endpoint=token_endpoint, client_id=client_id, client_secret=client_secret, redirect_url=redirect_url) @@ -328,13 +328,13 @@ def refresh(self) -> Token: raise ValueError('oauth2: token expired and refresh token is not set') params = {'grant_type': 'refresh_token', 'refresh_token': refresh_token} headers = {} - if 'microsoft' in self._oidc_endpoints.token_endpoint: + if 'microsoft' in self._token_endpoint: # Tokens issued for the 'Single-Page Application' client-type may # only be redeemed via cross-origin requests headers = {'Origin': self._redirect_url} return retrieve_token(client_id=self._client_id, client_secret=self._client_secret, - token_url=self._oidc_endpoints.token_endpoint, + token_url=self._token_endpoint, params=params, use_params=True, headers=headers) @@ -345,14 +345,16 @@ class Consent: def __init__(self, state: str, verifier: str, - oidc_endpoints: OidcEndpoints, + authorization_url: str, redirect_url: str, + token_endpoint: str, client_id: str, client_secret: str = None) -> None: self._verifier = verifier self._state = state - self._oidc_endpoints = oidc_endpoints + self._authorization_url = authorization_url self._redirect_url = redirect_url + self._token_endpoint = token_endpoint self._client_id = client_id self._client_secret = client_secret @@ -360,8 +362,9 @@ def as_dict(self) -> dict: return { 'state': self._state, 'verifier': self._verifier, + 'authorization_url': self._authorization_url, 'redirect_url': self._redirect_url, - 'oidc_endpoints': self._oidc_endpoints.as_dict(), + 'token_endpoint': self._token_endpoint, 'client_id': self._client_id, } @@ -369,8 +372,9 @@ def as_dict(self) -> dict: def from_dict(raw: dict, client_secret: str = None) -> 'Consent': return Consent(raw['state'], raw['verifier'], - oidc_endpoints=OidcEndpoints.from_dict(raw['oidc_endpoints']), + authorization_url=raw['authorization_url'], redirect_url=raw['redirect_url'], + token_endpoint=raw['token_endpoint'], client_id=raw['client_id'], client_secret=client_secret) @@ -379,8 +383,8 @@ def launch_external_browser(self) -> SessionCredentials: if redirect_url.hostname not in ('localhost', '127.0.0.1'): raise ValueError(f'cannot listen on {redirect_url.hostname}') feedback = [] - logger.info(f'Opening {self._oidc_endpoints.authorization_endpoint} in a browser') - webbrowser.open_new(self._oidc_endpoints.authorization_endpoint) + logger.info(f'Opening {self._authorization_url} in a browser') + webbrowser.open_new(self._authorization_url) port = redirect_url.port handler_factory = functools.partial(_OAuthCallback, feedback) with HTTPServer(("localhost", port), handler_factory) as httpd: @@ -412,11 +416,11 @@ def exchange(self, code: str, state: str) -> SessionCredentials: try: token = retrieve_token(client_id=self._client_id, client_secret=self._client_secret, - token_url=self._oidc_endpoints.token_endpoint, + token_url=self._token_endpoint, params=params, headers=headers, use_params=True) - return SessionCredentials(token, self._oidc_endpoints, self._client_id, self._client_secret, + return SessionCredentials(token, self._token_endpoint, self._client_id, self._client_secret, self._redirect_url) except ValueError as e: if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e): @@ -481,11 +485,12 @@ def initiate_consent(self) -> Consent: 'code_challenge': challenge, 'code_challenge_method': 'S256' } - f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}' + auth_url = f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}' return Consent(state, verifier, - oidc_endpoints=self._oidc_endpoints, + authorization_url=auth_url, redirect_url=self.redirect_url, + token_endpoint=self._oidc_endpoints.token_endpoint, client_id=self._client_id, client_secret=self._client_secret) diff --git a/examples/external_browser_auth.py b/examples/external_browser_auth.py new file mode 100644 index 00000000..21395933 --- /dev/null +++ b/examples/external_browser_auth.py @@ -0,0 +1,17 @@ +from databricks.sdk import WorkspaceClient +import logging + +logging.basicConfig(level=logging.DEBUG) + + +def run(): + w = WorkspaceClient( + host=input("Enter Databricks host: "), + auth_type="external-browser", + ) + me = w.current_user.me() + print(me) + + +if __name__ == "__main__": + run() From ff329e71eb3bd844f8a97b523ec64dcb303191d3 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Oct 2024 11:12:43 +0200 Subject: [PATCH 08/17] tweaks and testing --- databricks/sdk/oauth.py | 4 ++++ examples/external_browser_auth.py | 19 ++++++++++++++++--- examples/flask_app_with_oauth.py | 26 +++++++++++++++----------- 3 files changed, 35 insertions(+), 14 deletions(-) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 3a1a1fd2..c78cd339 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -368,6 +368,10 @@ def as_dict(self) -> dict: 'client_id': self._client_id, } + @property + def authorization_url(self) -> str: + return self._authorization_url + @staticmethod def from_dict(raw: dict, client_secret: str = None) -> 'Consent': return Consent(raw['state'], diff --git a/examples/external_browser_auth.py b/examples/external_browser_auth.py index 21395933..a7d6e5f4 100644 --- a/examples/external_browser_auth.py +++ b/examples/external_browser_auth.py @@ -1,12 +1,18 @@ from databricks.sdk import WorkspaceClient +import argparse import logging +from typing import Optional logging.basicConfig(level=logging.DEBUG) -def run(): +def run(host: str, client_id: Optional[str], client_secret: Optional[str], azure_client_id: Optional[str], azure_client_secret: Optional[str]): w = WorkspaceClient( - host=input("Enter Databricks host: "), + host=host, + client_id=client_id, + client_secret=client_secret, + azure_client_id=azure_client_id, + azure_client_secret=azure_client_secret, auth_type="external-browser", ) me = w.current_user.me() @@ -14,4 +20,11 @@ def run(): if __name__ == "__main__": - run() + parser = argparse.ArgumentParser() + parser.add_argument("--host", help="Databricks host", required=True) + parser.add_argument("--client_id", help="Databricks client_id", default=None) + parser.add_argument("--azure_client_id", help="Databricks azure_client_id", default=None) + parser.add_argument("--client_secret", help="Databricks client_secret", default=None) + parser.add_argument("--azure_client_secret", help="Databricks azure_client_secret", default=None) + namespace = parser.parse_args() + run(namespace.host, namespace.client_id, namespace.client_secret, namespace.azure_client_id, namespace.azure_client_secret) diff --git a/examples/flask_app_with_oauth.py b/examples/flask_app_with_oauth.py index bd4d7d7e..f4db23b2 100755 --- a/examples/flask_app_with_oauth.py +++ b/examples/flask_app_with_oauth.py @@ -32,19 +32,20 @@ import sys from databricks.sdk.oauth import OAuthClient, OidcEndpoints, get_workspace_endpoints +from databricks.sdk.service.compute import ListClustersFilterBy, State APP_NAME = "flask-demo" all_clusters_template = """""" -def create_flask_app(host: str, oidc_endpoints: OidcEndpoints, client_id: str, client_secret: str, redirect_url: str): +def create_flask_app(workspace_host: str, client_id: str, client_secret: str): """The create_flask_app function creates a Flask app that is enabled with OAuth. It initializes the app and web session secret keys with a randomly generated token. It defines two routes for @@ -72,6 +73,9 @@ def callback(): def index(): """The index page checks if the user has already authenticated and retrieves the user's credentials using the Databricks SDK WorkspaceClient. It then renders the template with the clusters' list.""" + oidc_endpoints = get_workspace_endpoints(workspace_host) + port = request.environ.get("SERVER_PORT") + redirect_url=f"http://localhost:{port}/callback" if "creds" not in session: oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, client_id=client_id, @@ -79,22 +83,24 @@ def index(): redirect_url=redirect_url) consent = oauth_client.initiate_consent() session["consent"] = consent.as_dict() - return redirect(oidc_endpoints.authorization_endpoint) + return redirect(consent.authorization_url) from databricks.sdk import WorkspaceClient from databricks.sdk.oauth import SessionCredentials credentials_strategy = SessionCredentials.from_dict(session["creds"], - oidc_endpoints=oidc_endpoints, + token_endpoint=oidc_endpoints.token_endpoint, client_id=client_id, client_secret=client_secret, redirect_url=redirect_url) - workspace_client = WorkspaceClient(host=host, + workspace_client = WorkspaceClient(host=workspace_host, product=APP_NAME, credentials_strategy=credentials_strategy, ) - - return render_template_string(all_clusters_template, w=workspace_client) + clusters = workspace_client.clusters.list( + filter_by=ListClustersFilterBy(cluster_states=[State.RUNNING, State.PENDING]) + ) + return render_template_string(all_clusters_template, workspace_host=workspace_host, clusters=clusters) return app @@ -137,12 +143,10 @@ def parse_arguments() -> argparse.Namespace: logging.getLogger("databricks.sdk").setLevel(logging.DEBUG) args = parse_arguments() - oidc_endpoints = get_workspace_endpoints(args.host) client_id, client_secret = args.client_id, args.client_secret if not client_id: client_id, client_secret = register_custom_app(args) - redirect_url=f"http://localhost:{args.port}/callback" - app = create_flask_app(args.host, oidc_endpoints, client_id, client_secret, redirect_url) + app = create_flask_app(args.host, client_id, client_secret) app.run( host="localhost", From d3504340983abb3245e634a38bd077d97802f512 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Oct 2024 11:21:46 +0200 Subject: [PATCH 09/17] Better examples --- examples/external_browser_auth.py | 35 ++++++++++++++++++++++++++++++- examples/flask_app_with_oauth.py | 8 +++++-- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/examples/external_browser_auth.py b/examples/external_browser_auth.py index a7d6e5f4..0609fe85 100644 --- a/examples/external_browser_auth.py +++ b/examples/external_browser_auth.py @@ -19,6 +19,29 @@ def run(host: str, client_id: Optional[str], client_secret: Optional[str], azure print(me) +def register_custom_app() -> tuple[str, str]: + """Creates new Custom OAuth App in Databricks Account""" + logging.info("No OAuth custom app client/secret provided, creating new app") + + from databricks.sdk import AccountClient + + account_client = AccountClient() + + custom_app = account_client.custom_app_integration.create( + name="external-browser-demo", + redirect_urls=[ + f"http://localhost:8020", + ], + confidential=True, + scopes=["all-apis"], + ) + logging.info(f"Created new custom app: " + f"--client_id {custom_app.client_id} " + f"--client_secret {custom_app.client_secret}") + + return custom_app.client_id, custom_app.client_secret + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", help="Databricks host", required=True) @@ -26,5 +49,15 @@ def run(host: str, client_id: Optional[str], client_secret: Optional[str], azure parser.add_argument("--azure_client_id", help="Databricks azure_client_id", default=None) parser.add_argument("--client_secret", help="Databricks client_secret", default=None) parser.add_argument("--azure_client_secret", help="Databricks azure_client_secret", default=None) + parser.add_argument("--register-custom-app", action="store_true", help="Register a new custom app") namespace = parser.parse_args() - run(namespace.host, namespace.client_id, namespace.client_secret, namespace.azure_client_id, namespace.azure_client_secret) + if namespace.register_custom_app and (namespace.client_id is not None or namespace.azure_client_id is not None): + raise ValueError("Cannot register custom app and provide --client_id/--azure_client_id at the same time") + if not namespace.register_custom_app and namespace.client_id is None and namespace.azure_client_secret is None: + raise ValueError("Must provide --client_id/--azure_client_id or register a custom app") + if namespace.register_custom_app: + client_id, client_secret = register_custom_app() + else: + client_id, client_secret = namespace.client_id, namespace.client_secret + + run(namespace.host, client_id, client_secret, namespace.azure_client_id, namespace.azure_client_secret) diff --git a/examples/flask_app_with_oauth.py b/examples/flask_app_with_oauth.py index f4db23b2..7c18eadc 100755 --- a/examples/flask_app_with_oauth.py +++ b/examples/flask_app_with_oauth.py @@ -31,7 +31,7 @@ import logging import sys -from databricks.sdk.oauth import OAuthClient, OidcEndpoints, get_workspace_endpoints +from databricks.sdk.oauth import OAuthClient, get_workspace_endpoints from databricks.sdk.service.compute import ListClustersFilterBy, State APP_NAME = "flask-demo" @@ -114,7 +114,11 @@ def register_custom_app(args: argparse.Namespace) -> tuple[str, str]: account_client = AccountClient(profile=args.profile) custom_app = account_client.custom_app_integration.create( - name=APP_NAME, redirect_urls=[f"http://localhost:{args.port}/callback"], confidential=True, + name=APP_NAME, + redirect_urls=[ + f"http://localhost:{args.port}/callback", + ], + confidential=True, scopes=["all-apis"], ) logging.info(f"Created new custom app: " From 05d69b95afca99d64ab247513176ef283624a042 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Oct 2024 11:27:22 +0200 Subject: [PATCH 10/17] test confidential and non confidential apps --- examples/external_browser_auth.py | 47 ++++++++++++++++++------------- 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/examples/external_browser_auth.py b/examples/external_browser_auth.py index 0609fe85..061ff60c 100644 --- a/examples/external_browser_auth.py +++ b/examples/external_browser_auth.py @@ -1,25 +1,11 @@ from databricks.sdk import WorkspaceClient import argparse import logging -from typing import Optional logging.basicConfig(level=logging.DEBUG) -def run(host: str, client_id: Optional[str], client_secret: Optional[str], azure_client_id: Optional[str], azure_client_secret: Optional[str]): - w = WorkspaceClient( - host=host, - client_id=client_id, - client_secret=client_secret, - azure_client_id=azure_client_id, - azure_client_secret=azure_client_secret, - auth_type="external-browser", - ) - me = w.current_user.me() - print(me) - - -def register_custom_app() -> tuple[str, str]: +def register_custom_app(confidential: bool) -> tuple[str, str]: """Creates new Custom OAuth App in Databricks Account""" logging.info("No OAuth custom app client/secret provided, creating new app") @@ -32,16 +18,24 @@ def register_custom_app() -> tuple[str, str]: redirect_urls=[ f"http://localhost:8020", ], - confidential=True, + confidential=confidential, scopes=["all-apis"], ) logging.info(f"Created new custom app: " f"--client_id {custom_app.client_id} " - f"--client_secret {custom_app.client_secret}") + f"{'--client_secret ' + custom_app.client_secret if confidential else ''}") return custom_app.client_id, custom_app.client_secret +def delete_custom_app(client_id: str): + """Creates new Custom OAuth App in Databricks Account""" + logging.info(f"Deleting custom app {client_id}") + from databricks.sdk import AccountClient + account_client = AccountClient() + account_client.custom_app_integration.delete(client_id) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", help="Databricks host", required=True) @@ -50,14 +44,29 @@ def register_custom_app() -> tuple[str, str]: parser.add_argument("--client_secret", help="Databricks client_secret", default=None) parser.add_argument("--azure_client_secret", help="Databricks azure_client_secret", default=None) parser.add_argument("--register-custom-app", action="store_true", help="Register a new custom app") + parser.add_argument("--register-custom-app-confidential", action="store_true", help="Register a new custom app") namespace = parser.parse_args() if namespace.register_custom_app and (namespace.client_id is not None or namespace.azure_client_id is not None): raise ValueError("Cannot register custom app and provide --client_id/--azure_client_id at the same time") if not namespace.register_custom_app and namespace.client_id is None and namespace.azure_client_secret is None: raise ValueError("Must provide --client_id/--azure_client_id or register a custom app") if namespace.register_custom_app: - client_id, client_secret = register_custom_app() + client_id, client_secret = register_custom_app(namespace.register_custom_app_confidential) else: client_id, client_secret = namespace.client_id, namespace.client_secret - run(namespace.host, client_id, client_secret, namespace.azure_client_id, namespace.azure_client_secret) + w = WorkspaceClient( + host=namespace.host, + client_id=client_id, + client_secret=client_secret, + azure_client_id=namespace.azure_client_id, + azure_client_secret=namespace.azure_client_secret, + auth_type="external-browser", + ) + me = w.current_user.me() + print(me) + + if namespace.register_custom_app: + delete_custom_app(client_id) + + From a3879bf12f714b5cb41854a06cdcf7e2d5586f53 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Oct 2024 11:28:46 +0200 Subject: [PATCH 11/17] private --- databricks/sdk/_base_client.py | 2 +- databricks/sdk/config.py | 4 ++-- databricks/sdk/oauth.py | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py index 16b64309..95ce39cb 100644 --- a/databricks/sdk/_base_client.py +++ b/databricks/sdk/_base_client.py @@ -18,7 +18,7 @@ logger = logging.getLogger('databricks.sdk') -def fix_host_if_needed(host: Optional[str]) -> Optional[str]: +def _fix_host_if_needed(host: Optional[str]) -> Optional[str]: if not host: return host diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 24f3c711..125f81fd 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -10,7 +10,7 @@ import requests from . import useragent -from ._base_client import fix_host_if_needed +from ._base_client import _fix_host_if_needed from .clock import Clock, RealClock from .credentials_provider import CredentialsStrategy, DefaultCredentials from .environments import (ALL_ENVS, AzureEnvironment, Cloud, @@ -121,7 +121,7 @@ def __init__(self, self._set_inner_config(kwargs) self._load_from_env() self._known_file_config_loader() - updated_host = fix_host_if_needed(self.host) + updated_host = _fix_host_if_needed(self.host) if updated_host: self.host = updated_host self._validate() diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index c78cd339..904a7fcf 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -17,7 +17,7 @@ import requests import requests.auth -from ._base_client import _BaseClient, fix_host_if_needed +from ._base_client import _BaseClient, _fix_host_if_needed # Error code for PKCE flow in Azure Active Directory, that gets additional retry. # See https://stackoverflow.com/a/75466778/277035 for more info @@ -245,7 +245,7 @@ def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _Bas :param account_id: The account ID. :return: The account's OIDC endpoints. """ - host = fix_host_if_needed(host) + host = _fix_host_if_needed(host) oidc = f'{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server' resp = client.do('GET', oidc) return OidcEndpoints.from_dict(resp) @@ -257,7 +257,7 @@ def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> O :param host: The Databricks workspace host. :return: The workspace's OIDC endpoints. """ - host = fix_host_if_needed(host) + host = _fix_host_if_needed(host) oidc = f'{host}/oidc/.well-known/oauth-authorization-server' resp = client.do('GET', oidc) return OidcEndpoints.from_dict(resp) @@ -271,7 +271,7 @@ def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints] :return: The OIDC endpoints for the workspace's Azure Entra ID tenant. """ # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint - host = fix_host_if_needed(host) + host = _fix_host_if_needed(host) res = requests.get(f'{host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) real_auth_url = res.headers.get('location') if not real_auth_url: From 9056d7343f7d9e4cf1098753bd980a9e09b70469 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Oct 2024 11:30:21 +0200 Subject: [PATCH 12/17] fix --- databricks/sdk/config.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 125f81fd..b4efdf60 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -121,9 +121,7 @@ def __init__(self, self._set_inner_config(kwargs) self._load_from_env() self._known_file_config_loader() - updated_host = _fix_host_if_needed(self.host) - if updated_host: - self.host = updated_host + self._fix_host_if_needed() self._validate() self.init_auth() self._init_product(product, product_version) @@ -255,7 +253,7 @@ def with_user_agent_extra(self, key: str, value: str) -> 'Config': @property def oidc_endpoints(self) -> Optional[OidcEndpoints]: - self.host = fix_host_if_needed(self.host) + self._fix_host_if_needed() if not self.host: return None if self.is_azure and self.azure_client_id: @@ -336,6 +334,11 @@ def attributes(cls) -> Iterable[ConfigAttribute]: cls._attributes = attrs return cls._attributes + def _fix_host_if_needed(self): + updated_host = _fix_host_if_needed(self.host) + if updated_host: + self.host = updated_host + def load_azure_tenant_id(self): """[Internal] Load the Azure tenant ID from the Azure Databricks login page. From a12726a9bca0ae3c96088b386b3a988110044e9b Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Oct 2024 11:33:44 +0200 Subject: [PATCH 13/17] fix --- databricks/sdk/oauth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 904a7fcf..ebdd55df 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -575,7 +575,7 @@ def load(self) -> Optional[SessionCredentials]: with open(self.filename, 'r') as f: raw = json.load(f) return SessionCredentials.from_dict(raw, - oidc_endpoints=self._oidc_endpoints, + token_endpoint=self._oidc_endpoints.token_endpoint, client_id=self._client_id, client_secret=self._client_secret, redirect_url=self._redirect_url) From 1430e8ab79b3c42be8cbaf21a95096b0f169c9ef Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Oct 2024 11:35:00 +0200 Subject: [PATCH 14/17] tweak --- databricks/sdk/credentials_provider.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 841bc77f..a79151b5 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -201,11 +201,12 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: # Load cached credentials from disk if they exist. # Note that these are local to the Python SDK and not reused by other SDKs. oidc_endpoints = cfg.oidc_endpoints + redirect_url = 'http://localhost:8020' token_cache = TokenCache(host=cfg.host, oidc_endpoints=oidc_endpoints, client_id=client_id, client_secret=client_secret, - redirect_url='http://localhost:8020') + redirect_url=redirect_url) credentials = token_cache.load() if credentials: # Force a refresh in case the loaded credentials are expired. @@ -213,7 +214,7 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: else: oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, client_id=client_id, - redirect_url='http://localhost:8020', + redirect_url=redirect_url, client_secret=client_secret) consent = oauth_client.initiate_consent() if not consent: From c8d63e140648e436c27c3dba9e6a06c2effc2b54 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 11 Oct 2024 13:33:33 +0200 Subject: [PATCH 15/17] fix test --- tests/test_oauth.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_oauth.py b/tests/test_oauth.py index dc5d36e1..a637a550 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -48,6 +48,9 @@ def test_account_oidc_endpoints(requests_mock): def test_account_oidc_endpoints_retry_on_429(requests_mock): + # It doesn't seem possible to use requests_mock to return different responses for the same request, e.g. when + # simulating a transient failure. Instead, the nth_request matcher increments a test-wide counter and only matches + # the nth request. request_count = 0 def nth_request(n): @@ -81,7 +84,7 @@ def observe_request(_request): def test_workspace_oidc_endpoints(requests_mock): - requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", + requests_mock.get("https://my-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server", json={ "authorization_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", @@ -107,10 +110,10 @@ def observe_request(_request): return observe_request - requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", + requests_mock.get("https://my-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server", additional_matcher=nth_request(0), status_code=429) - requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", + requests_mock.get("https://my-workspace.cloud.databricks.com/oidc/.well-known/oauth-authorization-server", additional_matcher=nth_request(1), json={ "authorization_endpoint": From 7dee280c040d496e52609d8f7c5049ebd73c6c0a Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 18 Oct 2024 15:20:38 +0200 Subject: [PATCH 16/17] add from_host --- databricks/sdk/oauth.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index ebdd55df..30b8d636 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -472,6 +472,29 @@ def __init__(self, self._oidc_endpoints = oidc_endpoints self._scopes = scopes + @staticmethod + def from_host(self, + host: str, + client_id: str, + redirect_url: str, + *, + scopes: List[str] = None, + client_secret: str = None) -> 'OAuthClient': + from .core import Config + from .credentials_provider import credentials_strategy + + @credentials_strategy('noop', []) + def noop_credentials(_: any): + return lambda: {} + + config = Config(host=host, credentials_strategy=noop_credentials) + if not scopes: + scopes = ['all-apis'] + oidc = config.oidc_endpoints + if not oidc: + raise ValueError(f'{host} does not support OAuth') + return OAuthClient(oidc, redirect_url, client_id, scopes, client_secret) + def initiate_consent(self) -> Consent: state = secrets.token_urlsafe(16) From e9333efbcc37ff5b951b2765166887872d4ae850 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Fri, 18 Oct 2024 15:45:36 +0200 Subject: [PATCH 17/17] fix --- databricks/sdk/oauth.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 30b8d636..6cac45af 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -473,8 +473,7 @@ def __init__(self, self._scopes = scopes @staticmethod - def from_host(self, - host: str, + def from_host(host: str, client_id: str, redirect_url: str, *,